Project import generated by Copybara.

GitOrigin-RevId: b137378673f7d66d41bcd46e4fc3a0d9ef254894
This commit is contained in:
MediaPipe Team 2019-10-25 14:12:58 -07:00 committed by jqtang
parent a2a63e3876
commit 259b48e082
94 changed files with 8567 additions and 401 deletions

View File

@ -37,10 +37,15 @@ A web-based visualizer is hosted on [viz.mediapipe.dev](https://viz.mediapipe.de
* [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General community discussion around MediaPipe * [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General community discussion around MediaPipe
## Publications ## Publications
* [On-Device, Real-Time Hand Tracking with MediaPipe](https://ai.googleblog.com/2019/08/on-device-real-time-hand-tracking-with.html)
* [MediaPipe: A Framework for Building Perception Pipelines](https://arxiv.org/abs/1906.08172) * [MediaPipe: A Framework for Building Perception Pipelines](https://arxiv.org/abs/1906.08172)
## Events ## Events
[Open sourced at CVPR 2019](https://sites.google.com/corp/view/perception-cv4arvr/mediapipe) on June 17~20 in Long Beach, CA * [ML Conference, Berlin 9-11 Dec 2019](https://mlconference.ai/machine-learning-advanced-development/mediapipe-building-real-time-cross-platform-mobile-web-edge-desktop-video-audio-ml-pipelines/)
* [The 3rd Workshop on YouTube-8M Large Scale Video Understanding Workshop](https://research.google.com/youtube8m/workshop2019/index.html) Seoul, Korea ICCV 2019
* [AI DevWorld 2019](https://aidevworld.com) on Oct 10 in San Jose, California
* [Google Industry Workshop at ICIP 2019](http://2019.ieeeicip.org/?action=page4&id=14#Google) [Presentation](https://docs.google.com/presentation/d/e/2PACX-1vRIBBbO_LO9v2YmvbHHEt1cwyqH6EjDxiILjuT0foXy1E7g6uyh4CesB2DkkEwlRDO9_lWfuKMZx98T/pub?start=false&loop=false&delayms=3000&slide=id.g556cc1a659_0_5) on Sept 24 in Taipei, Taiwan
* [Open sourced at CVPR 2019](https://sites.google.com/corp/view/perception-cv4arvr/mediapipe) on June 17~20 in Long Beach, CA
## Alpha Disclaimer ## Alpha Disclaimer
MediaPipe is currently in alpha for v0.6. We are still making breaking API changes and expect to get to stable API by v1.0. MediaPipe is currently in alpha for v0.6. We are still making breaking API changes and expect to get to stable API by v1.0.

View File

@ -10,7 +10,8 @@ http_archive(
sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e", sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e",
) )
load("@bazel_skylib//lib:versions.bzl", "versions") load("@bazel_skylib//lib:versions.bzl", "versions")
versions.check(minimum_bazel_version = "0.24.1") versions.check(minimum_bazel_version = "0.24.1",
maximum_bazel_version = "0.29.1")
# ABSL cpp library. # ABSL cpp library.
http_archive( http_archive(

View File

@ -26,6 +26,13 @@ proto_library(
deps = ["//mediapipe/framework:calculator_proto"], deps = ["//mediapipe/framework:calculator_proto"],
) )
proto_library(
name = "dequantize_byte_array_calculator_proto",
srcs = ["dequantize_byte_array_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library( proto_library(
name = "packet_cloner_calculator_proto", name = "packet_cloner_calculator_proto",
srcs = ["packet_cloner_calculator.proto"], srcs = ["packet_cloner_calculator.proto"],
@ -104,6 +111,14 @@ mediapipe_cc_proto_library(
deps = [":concatenate_vector_calculator_proto"], deps = [":concatenate_vector_calculator_proto"],
) )
mediapipe_cc_proto_library(
name = "dequantize_byte_array_calculator_cc_proto",
srcs = ["dequantize_byte_array_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":dequantize_byte_array_calculator_proto"],
)
mediapipe_cc_proto_library( mediapipe_cc_proto_library(
name = "quantize_float_vector_calculator_cc_proto", name = "quantize_float_vector_calculator_cc_proto",
srcs = ["quantize_float_vector_calculator.proto"], srcs = ["quantize_float_vector_calculator.proto"],
@ -387,6 +402,32 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "string_to_int_calculator",
srcs = ["string_to_int_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_library(
name = "side_packet_to_stream_calculator",
srcs = ["side_packet_to_stream_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_test( cc_test(
name = "immediate_mux_calculator_test", name = "immediate_mux_calculator_test",
srcs = ["immediate_mux_calculator_test.cc"], srcs = ["immediate_mux_calculator_test.cc"],
@ -558,6 +599,32 @@ cc_test(
], ],
) )
cc_library(
name = "dequantize_byte_array_calculator",
srcs = ["dequantize_byte_array_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":dequantize_byte_array_calculator_cc_proto",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_test(
name = "dequantize_byte_array_calculator_test",
srcs = ["dequantize_byte_array_calculator_test.cc"],
deps = [
":dequantize_byte_array_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
],
)
cc_library( cc_library(
name = "quantize_float_vector_calculator", name = "quantize_float_vector_calculator",
srcs = ["quantize_float_vector_calculator.cc"], srcs = ["quantize_float_vector_calculator.cc"],

View File

@ -0,0 +1,90 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cfloat>
#include "mediapipe/calculators/core/dequantize_byte_array_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/status.h"
// Dequantizes a byte array to a vector of floats.
//
// Example config:
// node {
// calculator: "DequantizeByteArrayCalculator"
// input_stream: "ENCODED:encoded"
// output_stream: "FLOAT_VECTOR:float_vector"
// options {
// [mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
// max_quantized_value: 2
// min_quantized_value: -2
// }
// }
// }
namespace mediapipe {
class DequantizeByteArrayCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("ENCODED").Set<std::string>();
cc->Outputs().Tag("FLOAT_VECTOR").Set<std::vector<float>>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
const auto options =
cc->Options<::mediapipe::DequantizeByteArrayCalculatorOptions>();
if (!options.has_max_quantized_value() ||
!options.has_min_quantized_value()) {
return ::mediapipe::InvalidArgumentError(
"Both max_quantized_value and min_quantized_value must be provided "
"in DequantizeByteArrayCalculatorOptions.");
}
float max_quantized_value = options.max_quantized_value();
float min_quantized_value = options.min_quantized_value();
if (max_quantized_value < min_quantized_value + FLT_EPSILON) {
return ::mediapipe::InvalidArgumentError(
"max_quantized_value must be greater than min_quantized_value.");
}
float range = max_quantized_value - min_quantized_value;
scalar_ = range / 255.0;
bias_ = (range / 512.0) + min_quantized_value;
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
const std::string& encoded =
cc->Inputs().Tag("ENCODED").Value().Get<std::string>();
std::vector<float> float_vector;
float_vector.reserve(encoded.length());
for (int i = 0; i < encoded.length(); ++i) {
float_vector.push_back(
static_cast<unsigned char>(encoded.at(i)) * scalar_ + bias_);
}
cc->Outputs()
.Tag("FLOAT_VECTOR")
.AddPacket(MakePacket<std::vector<float>>(float_vector)
.At(cc->InputTimestamp()));
return ::mediapipe::OkStatus();
}
private:
float scalar_;
float bias_;
};
REGISTER_CALCULATOR(DequantizeByteArrayCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,28 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message DequantizeByteArrayCalculatorOptions {
extend CalculatorOptions {
optional DequantizeByteArrayCalculatorOptions ext = 272316343;
}
optional float max_quantized_value = 1;
optional float min_quantized_value = 2;
}

View File

@ -0,0 +1,137 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
namespace mediapipe {
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "DequantizeByteArrayCalculator"
input_stream: "ENCODED:encoded"
output_stream: "FLOAT_VECTOR:float_vector"
options {
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
max_quantized_value: 2
}
}
)");
CalculatorRunner runner(node_config);
std::string empty_string;
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
MakePacket<std::string>(empty_string).At(Timestamp(0)));
auto status = runner.Run();
EXPECT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
testing::HasSubstr(
"Both max_quantized_value and min_quantized_value must be provided"));
}
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "DequantizeByteArrayCalculator"
input_stream: "ENCODED:encoded"
output_stream: "FLOAT_VECTOR:float_vector"
options {
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
max_quantized_value: -2
min_quantized_value: 2
}
}
)");
CalculatorRunner runner(node_config);
std::string empty_string;
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
MakePacket<std::string>(empty_string).At(Timestamp(0)));
auto status = runner.Run();
EXPECT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
testing::HasSubstr(
"max_quantized_value must be greater than min_quantized_value"));
}
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "DequantizeByteArrayCalculator"
input_stream: "ENCODED:encoded"
output_stream: "FLOAT_VECTOR:float_vector"
options {
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
max_quantized_value: 1
min_quantized_value: 1
}
}
)");
CalculatorRunner runner(node_config);
std::string empty_string;
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
MakePacket<std::string>(empty_string).At(Timestamp(0)));
auto status = runner.Run();
EXPECT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
testing::HasSubstr(
"max_quantized_value must be greater than min_quantized_value"));
}
TEST(DequantizeByteArrayCalculatorTest, TestDequantization) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "DequantizeByteArrayCalculator"
input_stream: "ENCODED:encoded"
output_stream: "FLOAT_VECTOR:float_vector"
options {
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
max_quantized_value: 2
min_quantized_value: -2
}
}
)");
CalculatorRunner runner(node_config);
unsigned char input[4] = {0x7F, 0xFF, 0x00, 0x01};
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
MakePacket<std::string>(
std::string(reinterpret_cast<char const*>(input), 4))
.At(Timestamp(0)));
auto status = runner.Run();
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs =
runner.Outputs().Tag("FLOAT_VECTOR").packets;
EXPECT_EQ(1, outputs.size());
const std::vector<float>& result = outputs[0].Get<std::vector<float>>();
ASSERT_FALSE(result.empty());
EXPECT_EQ(4, result.size());
EXPECT_NEAR(0, result[0], 0.01);
EXPECT_NEAR(2, result[1], 0.01);
EXPECT_NEAR(-2, result[2], 0.01);
EXPECT_NEAR(-1.976, result[3], 0.01);
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
}
} // namespace mediapipe

View File

@ -102,6 +102,12 @@ class PreviousLoopbackCalculator : public CalculatorBase {
cc->Outputs().Get(loop_out_id_).AddPacket(std::move(previous_loopback)); cc->Outputs().Get(loop_out_id_).AddPacket(std::move(previous_loopback));
} }
} }
if (!main_ts_.empty()) {
cc->Outputs().Get(loop_out_id_).SetNextTimestampBound(main_ts_.front());
}
if (cc->Inputs().Get(main_id_).IsDone() && main_ts_.empty()) {
cc->Outputs().Get(loop_out_id_).Close();
}
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }

View File

@ -107,5 +107,96 @@ TEST(PreviousLoopbackCalculator, CorrectTimestamps) {
MP_EXPECT_OK(graph_.WaitUntilDone()); MP_EXPECT_OK(graph_.WaitUntilDone());
} }
// A Calculator that outputs a summary packet in CalculatorBase::Close().
class PacketOnCloseCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).Set<int>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
sum_ += cc->Inputs().Index(0).Value().Get<int>();
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
return ::mediapipe::OkStatus();
}
::mediapipe::Status Close(CalculatorContext* cc) final {
cc->Outputs().Index(0).AddPacket(
MakePacket<int>(sum_).At(Timestamp::Max()));
return ::mediapipe::OkStatus();
}
private:
int sum_ = 0;
};
REGISTER_CALCULATOR(PacketOnCloseCalculator);
// Demonstrates that all ouput and input streams in PreviousLoopbackCalculator
// will close as expected when all graph input streams are closed.
TEST(PreviousLoopbackCalculator, ClosesCorrectly) {
std::vector<Packet> outputs;
CalculatorGraphConfig graph_config_ =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: 'in'
node {
calculator: 'PreviousLoopbackCalculator'
input_stream: 'MAIN:in'
input_stream: 'LOOP:out'
input_stream_info: { tag_index: 'LOOP' back_edge: true }
output_stream: 'PREV_LOOP:previous'
}
# This calculator synchronizes its inputs as normal, so it is used
# to check that both "in" and "previous" are ready.
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
input_stream: 'previous'
output_stream: 'out'
output_stream: 'previous2'
}
node {
calculator: 'PacketOnCloseCalculator'
input_stream: 'out'
output_stream: 'close_out'
}
)");
tool::AddVectorSink("close_out", &graph_config_, &outputs);
CalculatorGraph graph_;
MP_ASSERT_OK(graph_.Initialize(graph_config_, {}));
MP_ASSERT_OK(graph_.StartRun({}));
auto send_packet = [&graph_](const std::string& input_name, int n) {
MP_EXPECT_OK(graph_.AddPacketToInputStream(
input_name, MakePacket<int>(n).At(Timestamp(n))));
};
send_packet("in", 1);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{1}));
send_packet("in", 5);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{1, 5}));
send_packet("in", 15);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{1, 5, 15}));
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(outputs),
(std::vector<int64>{1, 5, 15, Timestamp::Max().Value()}));
MP_EXPECT_OK(graph_.WaitUntilDone());
}
} // anonymous namespace } // anonymous namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -0,0 +1,83 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <map>
#include <memory>
#include <set>
#include <string>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
using mediapipe::PacketTypeSet;
using mediapipe::Timestamp;
namespace {
static std::map<std::string, Timestamp>* kTimestampMap = []() {
auto* res = new std::map<std::string, Timestamp>();
res->emplace("AT_PRESTREAM", Timestamp::PreStream());
res->emplace("AT_POSTSTREAM", Timestamp::PostStream());
res->emplace("AT_ZERO", Timestamp(0));
return res;
}();
} // namespace
// Outputs the single input_side_packet at the timestamp specified in the
// output_stream tag. Valid tags are AT_PRESTREAM, AT_POSTSTREAM and AT_ZERO.
class SidePacketToStreamCalculator : public CalculatorBase {
public:
SidePacketToStreamCalculator() = default;
~SidePacketToStreamCalculator() override = default;
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Process(CalculatorContext* cc) override;
::mediapipe::Status Close(CalculatorContext* cc) override;
};
REGISTER_CALCULATOR(SidePacketToStreamCalculator);
::mediapipe::Status SidePacketToStreamCalculator::GetContract(
CalculatorContract* cc) {
cc->InputSidePackets().Index(0).SetAny();
std::set<std::string> tags = cc->Outputs().GetTags();
RET_CHECK_EQ(tags.size(), 1);
RET_CHECK_EQ(kTimestampMap->count(*tags.begin()), 1);
cc->Outputs().Tag(*tags.begin()).SetAny();
return ::mediapipe::OkStatus();
}
::mediapipe::Status SidePacketToStreamCalculator::Process(
CalculatorContext* cc) {
return mediapipe::tool::StatusStop();
}
::mediapipe::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) {
std::set<std::string> tags = cc->Outputs().GetTags();
RET_CHECK_EQ(tags.size(), 1);
const std::string& tag = *tags.begin();
RET_CHECK_EQ(kTimestampMap->count(tag), 1);
cc->Outputs().Tag(tag).AddPacket(
cc->InputSidePackets().Index(0).At(kTimestampMap->at(tag)));
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -34,7 +34,9 @@ namespace mediapipe {
// SplitVectorCalculatorOptions. If the option "element_only" is set to true, // SplitVectorCalculatorOptions. If the option "element_only" is set to true,
// all ranges should be of size 1 and all outputs will be elements of type T. If // all ranges should be of size 1 and all outputs will be elements of type T. If
// "element_only" is false, ranges can be non-zero in size and all outputs will // "element_only" is false, ranges can be non-zero in size and all outputs will
// be of type std::vector<T>. // be of type std::vector<T>. If the option "combine_outputs" is set to true,
// only one output stream can be specified and all ranges of elements will be
// combined into one vector.
// To use this class for a particular type T, register a calculator using // To use this class for a particular type T, register a calculator using
// SplitVectorCalculator<T>. // SplitVectorCalculator<T>.
template <typename T> template <typename T>
@ -49,6 +51,24 @@ class SplitVectorCalculator : public CalculatorBase {
const auto& options = const auto& options =
cc->Options<::mediapipe::SplitVectorCalculatorOptions>(); cc->Options<::mediapipe::SplitVectorCalculatorOptions>();
if (options.combine_outputs()) {
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1);
cc->Outputs().Index(0).Set<std::vector<T>>();
for (int i = 0; i < options.ranges_size() - 1; ++i) {
for (int j = i + 1; j < options.ranges_size(); ++j) {
const auto& range_0 = options.ranges(i);
const auto& range_1 = options.ranges(j);
if ((range_0.begin() >= range_1.begin() &&
range_0.begin() < range_1.end()) ||
(range_1.begin() >= range_0.begin() &&
range_1.begin() < range_0.end())) {
return ::mediapipe::InvalidArgumentError(
"Ranges must be non-overlapping when using combine_outputs "
"option.");
}
}
}
} else {
if (cc->Outputs().NumEntries() != options.ranges_size()) { if (cc->Outputs().NumEntries() != options.ranges_size()) {
return ::mediapipe::InvalidArgumentError( return ::mediapipe::InvalidArgumentError(
"The number of output streams should match the number of ranges " "The number of output streams should match the number of ranges "
@ -73,6 +93,7 @@ class SplitVectorCalculator : public CalculatorBase {
cc->Outputs().Index(i).Set<std::vector<T>>(); cc->Outputs().Index(i).Set<std::vector<T>>();
} }
} }
}
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -83,13 +104,15 @@ class SplitVectorCalculator : public CalculatorBase {
const auto& options = const auto& options =
cc->Options<::mediapipe::SplitVectorCalculatorOptions>(); cc->Options<::mediapipe::SplitVectorCalculatorOptions>();
element_only_ = options.element_only();
combine_outputs_ = options.combine_outputs();
for (const auto& range : options.ranges()) { for (const auto& range : options.ranges()) {
ranges_.push_back({range.begin(), range.end()}); ranges_.push_back({range.begin(), range.end()});
max_range_end_ = std::max(max_range_end_, range.end()); max_range_end_ = std::max(max_range_end_, range.end());
total_elements_ += range.end() - range.begin();
} }
element_only_ = options.element_only();
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -97,6 +120,17 @@ class SplitVectorCalculator : public CalculatorBase {
const auto& input = cc->Inputs().Index(0).Get<std::vector<T>>(); const auto& input = cc->Inputs().Index(0).Get<std::vector<T>>();
RET_CHECK_GE(input.size(), max_range_end_); RET_CHECK_GE(input.size(), max_range_end_);
if (combine_outputs_) {
auto output = absl::make_unique<std::vector<T>>();
output->reserve(total_elements_);
for (int i = 0; i < ranges_.size(); ++i) {
auto elements = absl::make_unique<std::vector<T>>(
input.begin() + ranges_[i].first,
input.begin() + ranges_[i].second);
output->insert(output->end(), elements->begin(), elements->end());
}
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
} else {
if (element_only_) { if (element_only_) {
for (int i = 0; i < ranges_.size(); ++i) { for (int i = 0; i < ranges_.size(); ++i) {
cc->Outputs().Index(i).AddPacket( cc->Outputs().Index(i).AddPacket(
@ -110,6 +144,7 @@ class SplitVectorCalculator : public CalculatorBase {
cc->Outputs().Index(i).Add(output.release(), cc->InputTimestamp()); cc->Outputs().Index(i).Add(output.release(), cc->InputTimestamp());
} }
} }
}
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -117,7 +152,9 @@ class SplitVectorCalculator : public CalculatorBase {
private: private:
std::vector<std::pair<int32, int32>> ranges_; std::vector<std::pair<int32, int32>> ranges_;
int32 max_range_end_ = -1; int32 max_range_end_ = -1;
int32 total_elements_ = 0;
bool element_only_ = false; bool element_only_ = false;
bool combine_outputs_ = false;
}; };
} // namespace mediapipe } // namespace mediapipe

View File

@ -37,4 +37,7 @@ message SplitVectorCalculatorOptions {
// just element of type T. By default, if a range specifies only one element, // just element of type T. By default, if a range specifies only one element,
// it is outputted as an std::vector<T>. // it is outputted as an std::vector<T>.
optional bool element_only = 2 [default = false]; optional bool element_only = 2 [default = false];
// Combines output elements to one vector.
optional bool combine_outputs = 3 [default = false];
} }

View File

@ -105,6 +105,34 @@ class SplitTfLiteTensorVectorCalculatorTest : public ::testing::Test {
} }
} }
void ValidateCombinedVectorOutput(std::vector<Packet>& output_packets,
int expected_elements,
std::vector<int>& input_begin_indices,
std::vector<int>& input_end_indices) {
ASSERT_EQ(1, output_packets.size());
ASSERT_EQ(input_begin_indices.size(), input_end_indices.size());
const std::vector<TfLiteTensor>& output_vec =
output_packets[0].Get<std::vector<TfLiteTensor>>();
ASSERT_EQ(expected_elements, output_vec.size());
const int num_ranges = input_begin_indices.size();
int element_id = 0;
for (int range_id = 0; range_id < num_ranges; ++range_id) {
for (int i = input_begin_indices[range_id];
i < input_end_indices[range_id]; ++i) {
const int expected_value = i;
const TfLiteTensor* result = &output_vec[element_id];
float* result_buffer = result->data.f;
ASSERT_NE(result_buffer, nullptr);
ASSERT_EQ(result_buffer, input_buffers_[i]);
for (int j = 0; j < width * height * channels; ++j) {
ASSERT_EQ(expected_value, result_buffer[j]);
}
element_id++;
}
}
}
void ValidateElementOutput(std::vector<Packet>& output_packets, void ValidateElementOutput(std::vector<Packet>& output_packets,
int input_begin_index) { int input_begin_index) {
ASSERT_EQ(1, output_packets.size()); ASSERT_EQ(1, output_packets.size());
@ -234,6 +262,65 @@ TEST_F(SplitTfLiteTensorVectorCalculatorTest, InvalidOutputStreamCountTest) {
ASSERT_FALSE(graph.Initialize(graph_config).ok()); ASSERT_FALSE(graph.Initialize(graph_config).ok());
} }
TEST_F(SplitTfLiteTensorVectorCalculatorTest,
InvalidCombineOutputsMultipleOutputsTest) {
ASSERT_NE(interpreter_, nullptr);
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
CalculatorGraphConfig graph_config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
R"(
input_stream: "tensor_in"
node {
calculator: "SplitTfLiteTensorVectorCalculator"
input_stream: "tensor_in"
output_stream: "range_0"
output_stream: "range_1"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges: { begin: 0 end: 1 }
ranges: { begin: 2 end: 3 }
combine_outputs: true
}
}
}
)");
// Run the graph.
CalculatorGraph graph;
// The graph should fail running because the number of output streams does not
// match the number of range elements in the options.
ASSERT_FALSE(graph.Initialize(graph_config).ok());
}
TEST_F(SplitTfLiteTensorVectorCalculatorTest, InvalidOverlappingRangesTest) {
ASSERT_NE(interpreter_, nullptr);
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
CalculatorGraphConfig graph_config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
R"(
input_stream: "tensor_in"
node {
calculator: "SplitTfLiteTensorVectorCalculator"
input_stream: "tensor_in"
output_stream: "range_0"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges: { begin: 0 end: 3 }
ranges: { begin: 1 end: 4 }
combine_outputs: true
}
}
}
)");
// Run the graph.
CalculatorGraph graph;
// The graph should fail running because there are overlapping ranges.
ASSERT_FALSE(graph.Initialize(graph_config).ok());
}
TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTestElementOnly) { TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTestElementOnly) {
ASSERT_NE(interpreter_, nullptr); ASSERT_NE(interpreter_, nullptr);
@ -289,6 +376,53 @@ TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTestElementOnly) {
MP_ASSERT_OK(graph.WaitUntilDone()); MP_ASSERT_OK(graph.WaitUntilDone());
} }
TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTestCombiningOutputs) {
ASSERT_NE(interpreter_, nullptr);
PrepareTfLiteTensorVector(/*vector_size=*/5);
ASSERT_NE(input_vec_, nullptr);
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
CalculatorGraphConfig graph_config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
R"(
input_stream: "tensor_in"
node {
calculator: "SplitTfLiteTensorVectorCalculator"
input_stream: "tensor_in"
output_stream: "range_0"
options {
[mediapipe.SplitVectorCalculatorOptions.ext] {
ranges: { begin: 0 end: 1 }
ranges: { begin: 2 end: 3 }
ranges: { begin: 4 end: 5 }
combine_outputs: true
}
}
}
)");
std::vector<Packet> range_0_packets;
tool::AddVectorSink("range_0", &graph_config, &range_0_packets);
// Run the graph.
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"tensor_in", Adopt(input_vec_.release()).At(Timestamp(0))));
// Wait until the calculator finishes processing.
MP_ASSERT_OK(graph.WaitUntilIdle());
std::vector<int> input_begin_indices = {0, 2, 4};
std::vector<int> input_end_indices = {1, 3, 5};
ValidateCombinedVectorOutput(range_0_packets, /*expected_elements=*/3,
input_begin_indices, input_end_indices);
// Fully close the graph at the end.
MP_ASSERT_OK(graph.CloseInputStream("tensor_in"));
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST_F(SplitTfLiteTensorVectorCalculatorTest, TEST_F(SplitTfLiteTensorVectorCalculatorTest,
ElementOnlyDisablesVectorOutputs) { ElementOnlyDisablesVectorOutputs) {
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator. // Prepare a graph to use the SplitTfLiteTensorVectorCalculator.

View File

@ -0,0 +1,79 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sys/types.h>
#include <memory>
#include <string>
#include "absl/strings/numbers.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// Calculator that converts a std::string into an integer type, or fails if the
// conversion is not possible.
//
// Example config:
// node {
// calculator: "StringToIntCalculator"
// input_side_packet: "string"
// output_side_packet: "index"
// }
template <typename IntType>
class StringToIntCalculatorTemplate : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Index(0).Set<std::string>();
cc->OutputSidePackets().Index(0).Set<IntType>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
IntType number;
if (!absl::SimpleAtoi(cc->InputSidePackets().Index(0).Get<std::string>(),
&number)) {
return ::mediapipe::InvalidArgumentError(
"The std::string could not be parsed as an integer.");
}
cc->OutputSidePackets().Index(0).Set(MakePacket<IntType>(number));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
return ::mediapipe::OkStatus();
}
};
using StringToIntCalculator = StringToIntCalculatorTemplate<int>;
REGISTER_CALCULATOR(StringToIntCalculator);
using StringToUintCalculator = StringToIntCalculatorTemplate<uint>;
REGISTER_CALCULATOR(StringToUintCalculator);
using StringToInt32Calculator = StringToIntCalculatorTemplate<int32>;
REGISTER_CALCULATOR(StringToInt32Calculator);
using StringToUint32Calculator = StringToIntCalculatorTemplate<uint32>;
REGISTER_CALCULATOR(StringToUint32Calculator);
using StringToInt64Calculator = StringToIntCalculatorTemplate<int64>;
REGISTER_CALCULATOR(StringToInt64Calculator);
using StringToUint64Calculator = StringToIntCalculatorTemplate<uint64>;
REGISTER_CALCULATOR(StringToUint64Calculator);
} // namespace mediapipe

View File

@ -104,6 +104,17 @@ proto_library(
deps = ["//mediapipe/framework:calculator_proto"], deps = ["//mediapipe/framework:calculator_proto"],
) )
proto_library(
name = "unpack_media_sequence_calculator_proto",
srcs = ["unpack_media_sequence_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/core:packet_resampler_calculator_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/util:audio_decoder_proto",
],
)
proto_library( proto_library(
name = "vector_float_to_tensor_calculator_options_proto", name = "vector_float_to_tensor_calculator_options_proto",
srcs = ["vector_float_to_tensor_calculator_options.proto"], srcs = ["vector_float_to_tensor_calculator_options.proto"],
@ -261,6 +272,17 @@ mediapipe_cc_proto_library(
deps = [":unpack_media_sequence_calculator_proto"], deps = [":unpack_media_sequence_calculator_proto"],
) )
mediapipe_cc_proto_library(
name = "vector_int_to_tensor_calculator_options_cc_proto",
srcs = ["vector_int_to_tensor_calculator_options.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
visibility = ["//visibility:public"],
deps = [":vector_int_to_tensor_calculator_options_proto"],
)
mediapipe_cc_proto_library( mediapipe_cc_proto_library(
name = "vector_float_to_tensor_calculator_options_cc_proto", name = "vector_float_to_tensor_calculator_options_cc_proto",
srcs = ["vector_float_to_tensor_calculator_options.proto"], srcs = ["vector_float_to_tensor_calculator_options.proto"],
@ -621,6 +643,22 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "tfrecord_reader_calculator",
srcs = ["tfrecord_reader_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
alwayslink = 1,
)
cc_library( cc_library(
name = "tensor_to_vector_float_calculator", name = "tensor_to_vector_float_calculator",
srcs = ["tensor_to_vector_float_calculator.cc"], srcs = ["tensor_to_vector_float_calculator.cc"],
@ -662,6 +700,20 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "vector_int_to_tensor_calculator",
srcs = ["vector_int_to_tensor_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":vector_int_to_tensor_calculator_options_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@org_tensorflow//tensorflow/core:framework",
],
alwayslink = 1,
)
cc_library( cc_library(
name = "vector_float_to_tensor_calculator", name = "vector_float_to_tensor_calculator",
srcs = ["vector_float_to_tensor_calculator.cc"], srcs = ["vector_float_to_tensor_calculator.cc"],
@ -676,6 +728,20 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "unpack_yt8m_sequence_example_calculator",
srcs = ["unpack_yt8m_sequence_example_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/tensorflow:lapped_tensor_buffer_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework/port:status",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
alwayslink = 1,
)
cc_test( cc_test(
name = "graph_tensors_packet_generator_test", name = "graph_tensors_packet_generator_test",
srcs = ["graph_tensors_packet_generator_test.cc"], srcs = ["graph_tensors_packet_generator_test.cc"],
@ -980,6 +1046,20 @@ cc_test(
], ],
) )
cc_test(
name = "vector_int_to_tensor_calculator_test",
srcs = ["vector_int_to_tensor_calculator_test.cc"],
deps = [
":vector_int_to_tensor_calculator",
":vector_int_to_tensor_calculator_options_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
cc_test( cc_test(
name = "vector_float_to_tensor_calculator_test", name = "vector_float_to_tensor_calculator_test",
srcs = ["vector_float_to_tensor_calculator_test.cc"], srcs = ["vector_float_to_tensor_calculator_test.cc"],

View File

@ -29,6 +29,11 @@
namespace mediapipe { namespace mediapipe {
const char kBufferSize[] = "BUFFER_SIZE";
const char kOverlap[] = "OVERLAP";
const char kTimestampOffset[] = "TIMESTAMP_OFFSET";
const char kCalculatorOptions[] = "CALCULATOR_OPTIONS";
namespace tf = tensorflow; namespace tf = tensorflow;
// Given an input stream of tensors, concatenates the tensors over timesteps. // Given an input stream of tensors, concatenates the tensors over timesteps.
@ -72,6 +77,9 @@ class LappedTensorBufferCalculator : public CalculatorBase {
::mediapipe::Status AddBatchDimension(tf::Tensor* input_tensor); ::mediapipe::Status AddBatchDimension(tf::Tensor* input_tensor);
int steps_until_output_; int steps_until_output_;
int buffer_size_;
int overlap_;
int timestamp_offset_;
std::unique_ptr<CircularBuffer<Timestamp>> timestamp_buffer_; std::unique_ptr<CircularBuffer<Timestamp>> timestamp_buffer_;
std::unique_ptr<CircularBuffer<tf::Tensor>> buffer_; std::unique_ptr<CircularBuffer<tf::Tensor>> buffer_;
LappedTensorBufferCalculatorOptions options_; LappedTensorBufferCalculatorOptions options_;
@ -87,6 +95,21 @@ REGISTER_CALCULATOR(LappedTensorBufferCalculator);
); );
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
<< "Only one output stream is supported."; << "Only one output stream is supported.";
if (cc->InputSidePackets().HasTag(kBufferSize)) {
cc->InputSidePackets().Tag(kBufferSize).Set<int>();
}
if (cc->InputSidePackets().HasTag(kOverlap)) {
cc->InputSidePackets().Tag(kOverlap).Set<int>();
}
if (cc->InputSidePackets().HasTag(kTimestampOffset)) {
cc->InputSidePackets().Tag(kTimestampOffset).Set<int>();
}
if (cc->InputSidePackets().HasTag(kCalculatorOptions)) {
cc->InputSidePackets()
.Tag(kCalculatorOptions)
.Set<LappedTensorBufferCalculatorOptions>();
}
cc->Outputs().Index(0).Set<tf::Tensor>( cc->Outputs().Index(0).Set<tf::Tensor>(
// Output tensorflow::Tensor stream with possibly overlapping steps. // Output tensorflow::Tensor stream with possibly overlapping steps.
); );
@ -95,16 +118,33 @@ REGISTER_CALCULATOR(LappedTensorBufferCalculator);
::mediapipe::Status LappedTensorBufferCalculator::Open(CalculatorContext* cc) { ::mediapipe::Status LappedTensorBufferCalculator::Open(CalculatorContext* cc) {
options_ = cc->Options<LappedTensorBufferCalculatorOptions>(); options_ = cc->Options<LappedTensorBufferCalculatorOptions>();
RET_CHECK_LT(options_.overlap(), options_.buffer_size()); if (cc->InputSidePackets().HasTag(kCalculatorOptions)) {
RET_CHECK_GE(options_.timestamp_offset(), 0) options_ = cc->InputSidePackets()
.Tag(kCalculatorOptions)
.Get<LappedTensorBufferCalculatorOptions>();
}
buffer_size_ = options_.buffer_size();
if (cc->InputSidePackets().HasTag(kBufferSize)) {
buffer_size_ = cc->InputSidePackets().Tag(kBufferSize).Get<int>();
}
overlap_ = options_.overlap();
if (cc->InputSidePackets().HasTag(kOverlap)) {
overlap_ = cc->InputSidePackets().Tag(kOverlap).Get<int>();
}
timestamp_offset_ = options_.timestamp_offset();
if (cc->InputSidePackets().HasTag(kTimestampOffset)) {
timestamp_offset_ = cc->InputSidePackets().Tag(kTimestampOffset).Get<int>();
}
RET_CHECK_LT(overlap_, buffer_size_);
RET_CHECK_GE(timestamp_offset_, 0)
<< "Negative timestamp_offset is not allowed."; << "Negative timestamp_offset is not allowed.";
RET_CHECK_LT(options_.timestamp_offset(), options_.buffer_size()) RET_CHECK_LT(timestamp_offset_, buffer_size_)
<< "output_frame_num_offset has to be less than buffer_size."; << "output_frame_num_offset has to be less than buffer_size.";
timestamp_buffer_ = timestamp_buffer_ =
absl::make_unique<CircularBuffer<Timestamp>>(options_.buffer_size()); absl::make_unique<CircularBuffer<Timestamp>>(buffer_size_);
buffer_ = buffer_ = absl::make_unique<CircularBuffer<tf::Tensor>>(buffer_size_);
absl::make_unique<CircularBuffer<tf::Tensor>>(options_.buffer_size()); steps_until_output_ = buffer_size_;
steps_until_output_ = options_.buffer_size();
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -128,11 +168,10 @@ REGISTER_CALCULATOR(LappedTensorBufferCalculator);
concatenated.get()); concatenated.get());
RET_CHECK(concat_status.ok()) << concat_status.ToString(); RET_CHECK(concat_status.ok()) << concat_status.ToString();
cc->Outputs().Index(0).Add( cc->Outputs().Index(0).Add(concatenated.release(),
concatenated.release(), timestamp_buffer_->Get(timestamp_offset_));
timestamp_buffer_->Get(options_.timestamp_offset()));
steps_until_output_ = options_.buffer_size() - options_.overlap(); steps_until_output_ = buffer_size_ - overlap_;
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }

View File

@ -0,0 +1,126 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <utility>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/record_reader.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system.h"
namespace mediapipe {
const char kTFRecordPath[] = "TFRECORD_PATH";
const char kRecordIndex[] = "RECORD_INDEX";
const char kExampleTag[] = "EXAMPLE";
const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
// Reads a tensorflow example/sequence example from a tfrecord file.
// If the "RECORD_INDEX" input side packet is provided, the calculator is going
// to fetch the example/sequence example of the tfrecord file at the target
// record index. Otherwise, the reader always reads the first example/sequence
// example of the tfrecord file.
//
// Example config:
// node {
// calculator: "TFRecordReaderCalculator"
// input_side_packet: "TFRECORD_PATH:tfrecord_path"
// input_side_packet: "RECORD_INDEX:record_index"
// output_side_packet: "SEQUENCE_EXAMPLE:sequence_example"
// }
class TFRecordReaderCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
};
::mediapipe::Status TFRecordReaderCalculator::GetContract(
CalculatorContract* cc) {
cc->InputSidePackets().Tag(kTFRecordPath).Set<std::string>();
if (cc->InputSidePackets().HasTag(kRecordIndex)) {
cc->InputSidePackets().Tag(kRecordIndex).Set<int>();
}
RET_CHECK(cc->OutputSidePackets().HasTag(kExampleTag) ||
cc->OutputSidePackets().HasTag(kSequenceExampleTag))
<< "TFRecordReaderCalculator must output either Tensorflow example or "
"sequence example.";
if (cc->OutputSidePackets().HasTag(kExampleTag)) {
cc->OutputSidePackets().Tag(kExampleTag).Set<tensorflow::Example>();
} else {
cc->OutputSidePackets()
.Tag(kSequenceExampleTag)
.Set<tensorflow::SequenceExample>();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status TFRecordReaderCalculator::Open(CalculatorContext* cc) {
std::unique_ptr<tensorflow::RandomAccessFile> file;
auto tf_status = tensorflow::Env::Default()->NewRandomAccessFile(
cc->InputSidePackets().Tag(kTFRecordPath).Get<std::string>(), &file);
RET_CHECK(tf_status.ok())
<< "Failed to open tfrecord file: " << tf_status.error_message();
tensorflow::io::RecordReader reader(file.get(),
tensorflow::io::RecordReaderOptions());
tensorflow::uint64 offset = 0;
std::string example_str;
const int target_idx =
cc->InputSidePackets().HasTag(kRecordIndex)
? cc->InputSidePackets().Tag(kRecordIndex).Get<int>()
: 0;
int current_idx = 0;
while (current_idx <= target_idx) {
tf_status = reader.ReadRecord(&offset, &example_str);
RET_CHECK(tf_status.ok())
<< "Failed to read tfrecord: " << tf_status.error_message();
if (current_idx == target_idx) {
if (cc->OutputSidePackets().HasTag(kExampleTag)) {
tensorflow::Example tf_example;
tf_example.ParseFromString(example_str);
cc->OutputSidePackets()
.Tag(kExampleTag)
.Set(MakePacket<tensorflow::Example>(std::move(tf_example)));
} else {
tensorflow::SequenceExample tf_sequence_example;
tf_sequence_example.ParseFromString(example_str);
cc->OutputSidePackets()
.Tag(kSequenceExampleTag)
.Set(MakePacket<tensorflow::SequenceExample>(
std::move(tf_sequence_example)));
}
}
++current_idx;
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status TFRecordReaderCalculator::Process(CalculatorContext* cc) {
return ::mediapipe::OkStatus();
}
REGISTER_CALCULATOR(TFRecordReaderCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,192 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iterator>
#include "mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h"
namespace mediapipe {
namespace {
const char kId[] = "id";
const char kRgb[] = "rgb";
const char kAudio[] = "audio";
const char kDesiredSegmentSize[] = "DESIRED_SEGMENT_SIZE";
const char kYt8mId[] = "YT8M_ID";
const char kYt8mSequenceExample[] = "YT8M_SEQUENCE_EXAMPLE";
const char kQuantizedRgbFeature[] = "QUANTIZED_RGB_FEATURE";
const char kQuantizedAudioFeature[] = "QUANTIZED_AUDIO_FEATURE";
const char kSegmentSize[] = "SEGMENT_SIZE";
const char kLappedTensorBufferCalculatorOptions[] =
"LAPPED_TENSOR_BUFFER_CALCULATOR_OPTIONS";
std::string GetQuantizedFeature(
const tensorflow::SequenceExample& sequence_example, const std::string& key,
int index) {
const auto& bytes_list = sequence_example.feature_lists()
.feature_list()
.at(key)
.feature()
.Get(index)
.bytes_list()
.value();
CHECK_EQ(1, bytes_list.size());
return bytes_list.Get(0);
}
} // namespace
// Unpacks YT8M Sequence Example. Note that the audio feature and rgb feature
// output are quantized. DequantizeByteArrayCalculator can do the dequantization
// for you.
//
// Example config:
// node {
// calculator: "UnpackYt8mSequenceExampleCalculator"
// input_side_packet: "YT8M_SEQUENCE_EXAMPLE:yt8m_sequence_example"
// output_stream: "QUANTIZED_RGB_FEATURE:quantized_rgb_feature"
// output_stream: "QUANTIZED_AUDIO_FEATURE:quantized_audio_feature"
// }
class UnpackYt8mSequenceExampleCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets()
.Tag(kYt8mSequenceExample)
.Set<tensorflow::SequenceExample>();
if (cc->InputSidePackets().HasTag(kDesiredSegmentSize)) {
cc->InputSidePackets().Tag(kDesiredSegmentSize).Set<int>();
}
cc->Outputs().Tag(kQuantizedRgbFeature).Set<std::string>();
cc->Outputs().Tag(kQuantizedAudioFeature).Set<std::string>();
if (cc->OutputSidePackets().HasTag(kYt8mId)) {
cc->OutputSidePackets().Tag(kYt8mId).Set<std::string>();
}
if (cc->OutputSidePackets().HasTag(kLappedTensorBufferCalculatorOptions)) {
cc->OutputSidePackets()
.Tag(kLappedTensorBufferCalculatorOptions)
.Set<::mediapipe::LappedTensorBufferCalculatorOptions>();
}
if (cc->OutputSidePackets().HasTag(kSegmentSize)) {
cc->OutputSidePackets().Tag(kSegmentSize).Set<int>();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
const tensorflow::SequenceExample& sequence_example =
cc->InputSidePackets()
.Tag(kYt8mSequenceExample)
.Get<tensorflow::SequenceExample>();
const std::string& yt8m_id =
sequence_example.context().feature().at(kId).bytes_list().value().Get(
0);
if (cc->OutputSidePackets().HasTag(kYt8mId)) {
cc->OutputSidePackets().Tag(kYt8mId).Set(
MakePacket<std::string>(yt8m_id));
}
int rgb_feature_list_length =
sequence_example.feature_lists().feature_list().at(kRgb).feature_size();
int audio_feature_list_length = sequence_example.feature_lists()
.feature_list()
.at(kAudio)
.feature_size();
if (rgb_feature_list_length != audio_feature_list_length) {
return ::mediapipe::FailedPreconditionError(absl::StrCat(
"Data corruption: the length of audio features and rgb features are "
"not equal. Please check the sequence example that contains yt8m "
"id: ",
yt8m_id));
}
feature_list_length_ = rgb_feature_list_length;
if (cc->OutputSidePackets().HasTag(kLappedTensorBufferCalculatorOptions) ||
cc->OutputSidePackets().HasTag(kSegmentSize)) {
// If the desired segment size is specified, take the min of the length of
// the feature list and the desired size to be the output segment size.
int segment_size = feature_list_length_;
if (cc->InputSidePackets().HasTag(kDesiredSegmentSize)) {
int desired_segment_size =
cc->InputSidePackets().Tag(kDesiredSegmentSize).Get<int>();
RET_CHECK(desired_segment_size > 0)
<< "The desired segment size must be greater than zero.";
segment_size = std::min(
feature_list_length_,
cc->InputSidePackets().Tag(kDesiredSegmentSize).Get<int>());
}
if (cc->OutputSidePackets().HasTag(
kLappedTensorBufferCalculatorOptions)) {
auto lapped_tensor_buffer_calculator_options = absl::make_unique<
::mediapipe::LappedTensorBufferCalculatorOptions>();
lapped_tensor_buffer_calculator_options->set_add_batch_dim_to_tensors(
true);
lapped_tensor_buffer_calculator_options->set_buffer_size(segment_size);
lapped_tensor_buffer_calculator_options->set_overlap(segment_size - 1);
lapped_tensor_buffer_calculator_options->set_timestamp_offset(
segment_size - 1);
cc->OutputSidePackets()
.Tag(kLappedTensorBufferCalculatorOptions)
.Set(Adopt(lapped_tensor_buffer_calculator_options.release()));
}
if (cc->OutputSidePackets().HasTag(kSegmentSize)) {
cc->OutputSidePackets()
.Tag(kSegmentSize)
.Set(MakePacket<int>(segment_size));
}
}
LOG(INFO) << "Reading the sequence example that contains yt8m id: "
<< yt8m_id << ". Feature list length: " << feature_list_length_;
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
if (current_index_ >= feature_list_length_) {
return ::mediapipe::tool::StatusStop();
}
const tensorflow::SequenceExample& sequence_example =
cc->InputSidePackets()
.Tag(kYt8mSequenceExample)
.Get<tensorflow::SequenceExample>();
// Uses microsecond as the unit of time. In the YT8M dataset, each feature
// represents a second.
const Timestamp timestamp = Timestamp(current_index_ * 1000000);
cc->Outputs()
.Tag(kQuantizedRgbFeature)
.AddPacket(
MakePacket<std::string>(
GetQuantizedFeature(sequence_example, kRgb, current_index_))
.At(timestamp));
cc->Outputs()
.Tag(kQuantizedAudioFeature)
.AddPacket(
MakePacket<std::string>(
GetQuantizedFeature(sequence_example, kAudio, current_index_))
.At(timestamp));
++current_index_;
return ::mediapipe::OkStatus();
}
private:
int current_index_ = 0;
int feature_list_length_ = 0;
};
REGISTER_CALCULATOR(UnpackYt8mSequenceExampleCalculator);
} // namespace mediapipe

View File

@ -23,10 +23,12 @@
namespace mediapipe { namespace mediapipe {
namespace tf = ::tensorflow; namespace {
auto& INPUT_1D = VectorFloatToTensorCalculatorOptions::INPUT_1D; auto& INPUT_1D = VectorFloatToTensorCalculatorOptions::INPUT_1D;
auto& INPUT_2D = VectorFloatToTensorCalculatorOptions::INPUT_2D; auto& INPUT_2D = VectorFloatToTensorCalculatorOptions::INPUT_2D;
} // namespace
namespace tf = ::tensorflow;
// The calculator expects one input (a packet containing a vector<float> or // The calculator expects one input (a packet containing a vector<float> or
// vector<vector<float>>) and generates one output (a packet containing a // vector<vector<float>>) and generates one output (a packet containing a

View File

@ -0,0 +1,203 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Converts a single int or vector<int> or vector<vector<int>> to 1D (or 2D)
// tf::Tensor.
#include "mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator_options.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
namespace mediapipe {
const char kVectorInt[] = "VECTOR_INT";
const char kSingleInt[] = "SINGLE_INT";
const char kTensorOut[] = "TENSOR_OUT";
namespace {
auto& INPUT_1D = VectorIntToTensorCalculatorOptions::INPUT_1D;
auto& INPUT_2D = VectorIntToTensorCalculatorOptions::INPUT_2D;
} // namespace
namespace tf = ::tensorflow;
template <typename TensorType>
void AssignMatrixValue(int r, int c, int value, tf::Tensor* output_tensor) {
output_tensor->tensor<TensorType, 2>()(r, c) = value;
}
// The calculator expects one input (a packet containing a single int or
// vector<int> or vector<vector<int>>) and generates one output (a packet
// containing a tf::Tensor containing the same data). The output tensor will be
// either 1D or 2D with dimensions corresponding to the input vector int. It
// will hold DT_INT32 or DT_UINT8 or DT_INT64 values.
//
// Example config:
// node {
// calculator: "VectorIntToTensorCalculator"
// input_stream: "SINGLE_INT:segment_size_int_stream"
// output_stream: "TENSOR_OUT:segment_size_tensor"
// }
//
// or
//
// node {
// calculator: "VectorIntToTensorCalculator"
// input_stream: "VECTOR_INT:vector_int_features"
// output_stream: "TENSOR_OUT:tensor_features"
// }
class VectorIntToTensorCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
VectorIntToTensorCalculatorOptions options_;
};
REGISTER_CALCULATOR(VectorIntToTensorCalculator);
::mediapipe::Status VectorIntToTensorCalculator::GetContract(
CalculatorContract* cc) {
const auto& options = cc->Options<VectorIntToTensorCalculatorOptions>();
// Start with only one input packet.
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
<< "Only one input stream is supported.";
if (options.input_size() == INPUT_2D) {
cc->Inputs().Tag(kVectorInt).Set<std::vector<std::vector<int>>>();
} else if (options.input_size() == INPUT_1D) {
if (cc->Inputs().HasTag(kSingleInt)) {
cc->Inputs().Tag(kSingleInt).Set<int>();
} else {
cc->Inputs().Tag(kVectorInt).Set<std::vector<int>>();
}
} else {
LOG(FATAL) << "input size not supported";
}
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
<< "Only one output stream is supported.";
cc->Outputs().Tag(kTensorOut).Set<tf::Tensor>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status VectorIntToTensorCalculator::Open(CalculatorContext* cc) {
options_ = cc->Options<VectorIntToTensorCalculatorOptions>();
RET_CHECK(options_.tensor_data_type() == tf::DT_UINT8 ||
options_.tensor_data_type() == tf::DT_INT32 ||
options_.tensor_data_type() == tf::DT_INT64)
<< "Output tensor data type is not supported.";
return ::mediapipe::OkStatus();
}
::mediapipe::Status VectorIntToTensorCalculator::Process(
CalculatorContext* cc) {
tf::TensorShape tensor_shape;
if (options_.input_size() == INPUT_2D) {
const std::vector<std::vector<int>>& input =
cc->Inputs()
.Tag(kVectorInt)
.Value()
.Get<std::vector<std::vector<int>>>();
const int32 rows = input.size();
CHECK_GE(rows, 1);
const int32 cols = input[0].size();
CHECK_GE(cols, 1);
for (int i = 1; i < rows; ++i) {
CHECK_EQ(input[i].size(), cols);
}
if (options_.transpose()) {
tensor_shape = tf::TensorShape({cols, rows});
} else {
tensor_shape = tf::TensorShape({rows, cols});
}
auto output = ::absl::make_unique<tf::Tensor>(options_.tensor_data_type(),
tensor_shape);
if (options_.transpose()) {
for (int r = 0; r < rows; ++r) {
for (int c = 0; c < cols; ++c) {
switch (options_.tensor_data_type()) {
case tf::DT_INT64:
AssignMatrixValue<tf::int64>(c, r, input[r][c], output.get());
break;
case tf::DT_UINT8:
AssignMatrixValue<uint8>(c, r, input[r][c], output.get());
break;
case tf::DT_INT32:
AssignMatrixValue<int>(c, r, input[r][c], output.get());
break;
default:
LOG(FATAL) << "tensor data type is not supported.";
}
}
}
} else {
for (int r = 0; r < rows; ++r) {
for (int c = 0; c < cols; ++c) {
switch (options_.tensor_data_type()) {
case tf::DT_INT64:
AssignMatrixValue<tf::int64>(r, c, input[r][c], output.get());
break;
case tf::DT_UINT8:
AssignMatrixValue<uint8>(r, c, input[r][c], output.get());
break;
case tf::DT_INT32:
AssignMatrixValue<int>(r, c, input[r][c], output.get());
break;
default:
LOG(FATAL) << "tensor data type is not supported.";
}
}
}
}
cc->Outputs().Tag(kTensorOut).Add(output.release(), cc->InputTimestamp());
} else if (options_.input_size() == INPUT_1D) {
std::vector<int> input;
if (cc->Inputs().HasTag(kSingleInt)) {
input.push_back(cc->Inputs().Tag(kSingleInt).Get<int>());
} else {
input = cc->Inputs().Tag(kVectorInt).Value().Get<std::vector<int>>();
}
CHECK_GE(input.size(), 1);
const int32 length = input.size();
tensor_shape = tf::TensorShape({length});
auto output = ::absl::make_unique<tf::Tensor>(options_.tensor_data_type(),
tensor_shape);
for (int i = 0; i < length; ++i) {
switch (options_.tensor_data_type()) {
case tf::DT_INT64:
output->tensor<tf::int64, 1>()(i) = input.at(i);
break;
case tf::DT_UINT8:
output->tensor<uint8, 1>()(i) = input.at(i);
break;
case tf::DT_INT32:
output->tensor<int, 1>()(i) = input.at(i);
break;
default:
LOG(FATAL) << "tensor data type is not supported.";
}
}
cc->Outputs().Tag(kTensorOut).Add(output.release(), cc->InputTimestamp());
} else {
LOG(FATAL) << "input size not supported";
}
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,43 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
import "tensorflow/core/framework/types.proto";
message VectorIntToTensorCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional VectorIntToTensorCalculatorOptions ext = 275364184;
}
enum InputSize {
UNKNOWN = 0;
INPUT_1D = 1;
INPUT_2D = 2;
}
// If input_size is INPUT_2D, unpack a vector<vector<int>> to a
// 2d tensor (matrix). If INPUT_1D, convert a single int or vector<int>
// into a 1d tensor (vector).
optional InputSize input_size = 1 [default = INPUT_1D];
// If true, the output tensor is transposed.
// Otherwise, the output tensor is not transposed.
// It will be ignored if tensor_is_2d is INPUT_1D.
optional bool transpose = 2 [default = false];
optional tensorflow.DataType tensor_data_type = 3 [default = DT_INT32];
}

View File

@ -0,0 +1,202 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator_options.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
namespace mediapipe {
namespace {
namespace tf = ::tensorflow;
class VectorIntToTensorCalculatorTest : public ::testing::Test {
protected:
void SetUpRunner(
const VectorIntToTensorCalculatorOptions::InputSize input_size,
const tensorflow::DataType tensor_data_type, const bool transpose,
const bool single_value) {
CalculatorGraphConfig::Node config;
config.set_calculator("VectorIntToTensorCalculator");
if (single_value) {
config.add_input_stream("SINGLE_INT:input_int");
} else {
config.add_input_stream("VECTOR_INT:input_int");
}
config.add_output_stream("TENSOR_OUT:output_tensor");
auto options = config.mutable_options()->MutableExtension(
VectorIntToTensorCalculatorOptions::ext);
options->set_input_size(input_size);
options->set_transpose(transpose);
options->set_tensor_data_type(tensor_data_type);
runner_ = ::absl::make_unique<CalculatorRunner>(config);
}
void TestConvertFromVectoVectorInt(const bool transpose) {
SetUpRunner(VectorIntToTensorCalculatorOptions::INPUT_2D,
tensorflow::DT_INT32, transpose, false);
auto input = ::absl::make_unique<std::vector<std::vector<int>>>(
2, std::vector<int>(2));
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 2; ++j) {
input->at(i).at(j) = i * 2 + j;
}
}
const int64 time = 1234;
runner_->MutableInputs()
->Tag("VECTOR_INT")
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("TENSOR_OUT").packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
EXPECT_EQ(2, output_tensor.dims());
EXPECT_EQ(tf::DT_INT32, output_tensor.dtype());
const auto matrix = output_tensor.matrix<int>();
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 2; ++j) {
if (!transpose) {
EXPECT_EQ(i * 2 + j, matrix(i, j));
} else {
EXPECT_EQ(j * 2 + i, matrix(i, j));
}
}
}
}
std::unique_ptr<CalculatorRunner> runner_;
};
TEST_F(VectorIntToTensorCalculatorTest, TestSingleValue) {
SetUpRunner(VectorIntToTensorCalculatorOptions::INPUT_1D,
tensorflow::DT_INT32, false, true);
const int64 time = 1234;
runner_->MutableInputs()
->Tag("SINGLE_INT")
.packets.push_back(MakePacket<int>(1).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("TENSOR_OUT").packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
EXPECT_EQ(1, output_tensor.dims());
EXPECT_EQ(tf::DT_INT32, output_tensor.dtype());
const auto vec = output_tensor.vec<int32>();
EXPECT_EQ(1, vec(0));
}
TEST_F(VectorIntToTensorCalculatorTest, TesOneDim) {
SetUpRunner(VectorIntToTensorCalculatorOptions::INPUT_1D,
tensorflow::DT_INT32, false, false);
auto input = ::absl::make_unique<std::vector<int>>(5);
for (int i = 0; i < 5; ++i) {
input->at(i) = i;
}
const int64 time = 1234;
runner_->MutableInputs()
->Tag("VECTOR_INT")
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("TENSOR_OUT").packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
EXPECT_EQ(1, output_tensor.dims());
EXPECT_EQ(tf::DT_INT32, output_tensor.dtype());
const auto vec = output_tensor.vec<int32>();
for (int i = 0; i < 5; ++i) {
EXPECT_EQ(i, vec(i));
}
}
TEST_F(VectorIntToTensorCalculatorTest, TestTwoDims) {
for (bool transpose : {false, true}) {
TestConvertFromVectoVectorInt(transpose);
}
}
TEST_F(VectorIntToTensorCalculatorTest, TestInt64) {
SetUpRunner(VectorIntToTensorCalculatorOptions::INPUT_1D,
tensorflow::DT_INT64, false, true);
const int64 time = 1234;
runner_->MutableInputs()
->Tag("SINGLE_INT")
.packets.push_back(MakePacket<int>(2 ^ 31).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("TENSOR_OUT").packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
EXPECT_EQ(1, output_tensor.dims());
EXPECT_EQ(tf::DT_INT64, output_tensor.dtype());
const auto vec = output_tensor.vec<tf::int64>();
EXPECT_EQ(2 ^ 31, vec(0));
}
TEST_F(VectorIntToTensorCalculatorTest, TestUint8) {
SetUpRunner(VectorIntToTensorCalculatorOptions::INPUT_1D,
tensorflow::DT_UINT8, false, false);
auto input = ::absl::make_unique<std::vector<int>>(5);
for (int i = 0; i < 5; ++i) {
input->at(i) = i;
}
const int64 time = 1234;
runner_->MutableInputs()
->Tag("VECTOR_INT")
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("TENSOR_OUT").packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
EXPECT_EQ(1, output_tensor.dims());
EXPECT_EQ(tf::DT_UINT8, output_tensor.dtype());
const auto vec = output_tensor.vec<uint8>();
for (int i = 0; i < 5; ++i) {
EXPECT_EQ(i, vec(i));
}
}
} // namespace
} // namespace mediapipe

View File

@ -25,7 +25,8 @@
#include "tensorflow/lite/error_reporter.h" #include "tensorflow/lite/error_reporter.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_buffer.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" #include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
@ -45,7 +46,8 @@
#include "tensorflow/lite/delegates/gpu/metal_delegate.h" #include "tensorflow/lite/delegates/gpu/metal_delegate.h"
#endif // iOS #endif // iOS
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
typedef ::tflite::gpu::gl::GlBuffer GpuTensor; typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
typedef id<MTLBuffer> GpuTensor; typedef id<MTLBuffer> GpuTensor;
@ -67,7 +69,8 @@ typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>
namespace mediapipe { namespace mediapipe {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
using ::tflite::gpu::gl::GlProgram; using ::tflite::gpu::gl::GlProgram;
using ::tflite::gpu::gl::GlShader; using ::tflite::gpu::gl::GlShader;
@ -146,7 +149,8 @@ class TfLiteConverterCalculator : public CalculatorBase {
std::unique_ptr<tflite::Interpreter> interpreter_ = nullptr; std::unique_ptr<tflite::Interpreter> interpreter_ = nullptr;
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
mediapipe::GlCalculatorHelper gpu_helper_; mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<GPUData> gpu_data_out_; std::unique_ptr<GPUData> gpu_data_out_;
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
@ -181,7 +185,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
if (cc->Inputs().HasTag("IMAGE")) cc->Inputs().Tag("IMAGE").Set<ImageFrame>(); if (cc->Inputs().HasTag("IMAGE")) cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
if (cc->Inputs().HasTag("MATRIX")) cc->Inputs().Tag("MATRIX").Set<Matrix>(); if (cc->Inputs().HasTag("MATRIX")) cc->Inputs().Tag("MATRIX").Set<Matrix>();
#if !defined(MEDIAPIPE_DISABLE_GPU) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
if (cc->Inputs().HasTag("IMAGE_GPU")) { if (cc->Inputs().HasTag("IMAGE_GPU")) {
cc->Inputs().Tag("IMAGE_GPU").Set<mediapipe::GpuBuffer>(); cc->Inputs().Tag("IMAGE_GPU").Set<mediapipe::GpuBuffer>();
use_gpu |= true; use_gpu |= true;
@ -190,7 +194,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
if (cc->Outputs().HasTag("TENSORS")) if (cc->Outputs().HasTag("TENSORS"))
cc->Outputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>(); cc->Outputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
#if !defined(MEDIAPIPE_DISABLE_GPU) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
if (cc->Outputs().HasTag("TENSORS_GPU")) { if (cc->Outputs().HasTag("TENSORS_GPU")) {
cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>(); cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
use_gpu |= true; use_gpu |= true;
@ -198,7 +202,8 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
if (use_gpu) { if (use_gpu) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
@ -218,7 +223,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
if (cc->Inputs().HasTag("IMAGE_GPU") || if (cc->Inputs().HasTag("IMAGE_GPU") ||
cc->Outputs().HasTag("IMAGE_OUT_GPU")) { cc->Outputs().HasTag("IMAGE_OUT_GPU")) {
#if !defined(MEDIAPIPE_DISABLE_GPU) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
use_gpu_ = true; use_gpu_ = true;
#else #else
RET_CHECK_FAIL() << "GPU processing not enabled."; RET_CHECK_FAIL() << "GPU processing not enabled.";
@ -231,7 +236,8 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
cc->Outputs().HasTag("TENSORS_GPU")); cc->Outputs().HasTag("TENSORS_GPU"));
// Cannot use quantization. // Cannot use quantization.
use_quantized_tensors_ = false; use_quantized_tensors_ = false;
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
@ -264,7 +270,8 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
} }
::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) { ::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); }); gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); });
#endif #endif
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS #if defined(__APPLE__) && !TARGET_OS_OSX // iOS
@ -383,7 +390,8 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
::mediapipe::Status TfLiteConverterCalculator::ProcessGPU( ::mediapipe::Status TfLiteConverterCalculator::ProcessGPU(
CalculatorContext* cc) { CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
// GpuBuffer to tflite::gpu::GlBuffer conversion. // GpuBuffer to tflite::gpu::GlBuffer conversion.
const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>(); const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>();
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
@ -468,7 +476,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
} }
::mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) { ::mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GPU) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
// Get input image sizes. // Get input image sizes.
const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>(); const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>();
mediapipe::ImageFormat::Format format = mediapipe::ImageFormat::Format format =
@ -485,7 +493,8 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
RET_CHECK_FAIL() << "Num input channels is less than desired output."; RET_CHECK_FAIL() << "Num input channels is less than desired output.";
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &include_alpha, &input, &single_channel]() -> ::mediapipe::Status { [this, &include_alpha, &input, &single_channel]() -> ::mediapipe::Status {
// Device memory. // Device memory.

View File

@ -27,7 +27,8 @@
#include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h" #include "tensorflow/lite/model.h"
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_buffer.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/shape.h"
@ -52,7 +53,8 @@
namespace { namespace {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
typedef ::tflite::gpu::gl::GlBuffer GpuTensor; typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
typedef id<MTLBuffer> GpuTensor; typedef id<MTLBuffer> GpuTensor;
@ -68,13 +70,14 @@ size_t RoundUp(size_t n, size_t m) { return ((n + m - 1) / m) * m; } // NOLINT
// * Aux // * Aux
namespace mediapipe { namespace mediapipe {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
using ::tflite::gpu::gl::CopyBuffer; using ::tflite::gpu::gl::CopyBuffer;
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
using ::tflite::gpu::gl::GlBuffer; using ::tflite::gpu::gl::GlBuffer;
#endif #endif
#if !defined(MEDIAPIPE_DISABLE_GPU) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
struct GPUData { struct GPUData {
int elements = 1; int elements = 1;
GpuTensor buffer; GpuTensor buffer;
@ -147,7 +150,8 @@ class TfLiteInferenceCalculator : public CalculatorBase {
std::unique_ptr<tflite::FlatBufferModel> model_; std::unique_ptr<tflite::FlatBufferModel> model_;
TfLiteDelegate* delegate_ = nullptr; TfLiteDelegate* delegate_ = nullptr;
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
mediapipe::GlCalculatorHelper gpu_helper_; mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<GPUData> gpu_data_in_; std::unique_ptr<GPUData> gpu_data_in_;
std::vector<std::unique_ptr<GPUData>> gpu_data_out_; std::vector<std::unique_ptr<GPUData>> gpu_data_out_;
@ -179,7 +183,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
if (cc->Inputs().HasTag("TENSORS")) if (cc->Inputs().HasTag("TENSORS"))
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>(); cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
#if !defined(MEDIAPIPE_DISABLE_GPU) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
if (cc->Inputs().HasTag("TENSORS_GPU")) { if (cc->Inputs().HasTag("TENSORS_GPU")) {
cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>(); cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
use_gpu |= true; use_gpu |= true;
@ -188,7 +192,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
if (cc->Outputs().HasTag("TENSORS")) if (cc->Outputs().HasTag("TENSORS"))
cc->Outputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>(); cc->Outputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
#if !defined(MEDIAPIPE_DISABLE_GPU) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
if (cc->Outputs().HasTag("TENSORS_GPU")) { if (cc->Outputs().HasTag("TENSORS_GPU")) {
cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>(); cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
use_gpu |= true; use_gpu |= true;
@ -206,7 +210,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
use_gpu |= options.use_gpu(); use_gpu |= options.use_gpu();
if (use_gpu) { if (use_gpu) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
@ -225,7 +230,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
MP_RETURN_IF_ERROR(LoadOptions(cc)); MP_RETURN_IF_ERROR(LoadOptions(cc));
if (cc->Inputs().HasTag("TENSORS_GPU")) { if (cc->Inputs().HasTag("TENSORS_GPU")) {
#if !defined(MEDIAPIPE_DISABLE_GPU) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
gpu_input_ = true; gpu_input_ = true;
gpu_inference_ = true; // Inference must be on GPU also. gpu_inference_ = true; // Inference must be on GPU also.
#else #else
@ -235,7 +240,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
} }
if (cc->Outputs().HasTag("TENSORS_GPU")) { if (cc->Outputs().HasTag("TENSORS_GPU")) {
#if !defined(MEDIAPIPE_DISABLE_GPU) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
gpu_output_ = true; gpu_output_ = true;
RET_CHECK(cc->Inputs().HasTag("TENSORS_GPU")) RET_CHECK(cc->Inputs().HasTag("TENSORS_GPU"))
<< "GPU output must also have GPU Input."; << "GPU output must also have GPU Input.";
@ -248,13 +253,15 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
MP_RETURN_IF_ERROR(LoadModel(cc)); MP_RETURN_IF_ERROR(LoadModel(cc));
if (gpu_inference_) { if (gpu_inference_) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
RET_CHECK(gpu_helper_); RET_CHECK(gpu_helper_);
#endif #endif
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &cc]() -> ::mediapipe::Status { return LoadDelegate(cc); })); [this, &cc]() -> ::mediapipe::Status { return LoadDelegate(cc); }));
#else #else
@ -262,6 +269,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
#endif #endif
} }
#if defined(__EMSCRIPTEN__)
MP_RETURN_IF_ERROR(LoadDelegate(cc));
#endif // __EMSCRIPTEN__
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -269,7 +280,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
// 1. Receive pre-processed tensor inputs. // 1. Receive pre-processed tensor inputs.
if (gpu_input_) { if (gpu_input_) {
// Read GPU input into SSBO. // Read GPU input into SSBO.
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
const auto& input_tensors = const auto& input_tensors =
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>(); cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>();
RET_CHECK_EQ(input_tensors.size(), 1); RET_CHECK_EQ(input_tensors.size(), 1);
@ -315,7 +327,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
// 2. Run inference. // 2. Run inference.
if (gpu_inference_) { if (gpu_inference_) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status { gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status {
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
@ -330,7 +343,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
// 3. Output processed tensors. // 3. Output processed tensors.
if (gpu_output_) { if (gpu_output_) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
// Output result tensors (GPU). // Output result tensors (GPU).
auto output_tensors = absl::make_unique<std::vector<GpuTensor>>(); auto output_tensors = absl::make_unique<std::vector<GpuTensor>>();
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
@ -392,7 +406,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) { ::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) {
if (delegate_) { if (delegate_) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status { MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status {
TfLiteGpuDelegateDelete(delegate_); TfLiteGpuDelegateDelete(delegate_);
gpu_data_in_.reset(); gpu_data_in_.reset();
@ -456,6 +471,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
RET_CHECK(interpreter_); RET_CHECK(interpreter_);
#if defined(__EMSCRIPTEN__)
interpreter_->SetNumThreads(1);
#endif // __EMSCRIPTEN__
if (gpu_output_) { if (gpu_output_) {
use_quantized_tensors_ = false; use_quantized_tensors_ = false;
} else { } else {
@ -471,7 +490,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
::mediapipe::Status TfLiteInferenceCalculator::LoadDelegate( ::mediapipe::Status TfLiteInferenceCalculator::LoadDelegate(
CalculatorContext* cc) { CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
// Configure and create the delegate. // Configure and create the delegate.
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault(); TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
options.compile_options.precision_loss_allowed = 1; options.compile_options.precision_loss_allowed = 1;

View File

@ -24,7 +24,8 @@
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/resource_util.h" #include "mediapipe/util/resource_util.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if defined(__EMSCRIPTEN__) || defined(__ANDROID__) || \
(defined(__APPLE__) && !TARGET_OS_OSX)
#include "mediapipe/util/android/file/base/file.h" #include "mediapipe/util/android/file/base/file.h"
#include "mediapipe/util/android/file/base/helpers.h" #include "mediapipe/util/android/file/base/helpers.h"
#else #else
@ -66,8 +67,8 @@ class TfLiteTensorsToClassificationCalculator : public CalculatorBase {
::mediapipe::Status Close(CalculatorContext* cc) override; ::mediapipe::Status Close(CalculatorContext* cc) override;
private: private:
::mediapipe::TfLiteTensorsToClassificationCalculatorOptions options_;
int top_k_ = 0; int top_k_ = 0;
double min_score_threshold_ = 0;
std::unordered_map<int, std::string> label_map_; std::unordered_map<int, std::string> label_map_;
bool label_map_loaded_ = false; bool label_map_loaded_ = false;
}; };
@ -93,15 +94,14 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator);
CalculatorContext* cc) { CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0)); cc->SetOffset(TimestampDiff(0));
auto options = cc->Options< options_ = cc->Options<
::mediapipe::TfLiteTensorsToClassificationCalculatorOptions>(); ::mediapipe::TfLiteTensorsToClassificationCalculatorOptions>();
top_k_ = options.top_k(); top_k_ = options_.top_k();
min_score_threshold_ = options.min_score_threshold(); if (options_.has_label_map_path()) {
if (options.has_label_map_path()) {
std::string string_path; std::string string_path;
ASSIGN_OR_RETURN(string_path, ASSIGN_OR_RETURN(string_path,
PathToResourceAsFile(options.label_map_path())); PathToResourceAsFile(options_.label_map_path()));
std::string label_map_string; std::string label_map_string;
MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string)); MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string));
@ -125,9 +125,11 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator);
RET_CHECK_EQ(input_tensors.size(), 1); RET_CHECK_EQ(input_tensors.size(), 1);
const TfLiteTensor* raw_score_tensor = &input_tensors[0]; const TfLiteTensor* raw_score_tensor = &input_tensors[0];
RET_CHECK_EQ(raw_score_tensor->dims->size, 2); int num_classes = 1;
RET_CHECK_EQ(raw_score_tensor->dims->data[0], 1); for (int i = 0; i < raw_score_tensor->dims->size; ++i) {
int num_classes = raw_score_tensor->dims->data[1]; num_classes *= raw_score_tensor->dims->data[i];
}
if (label_map_loaded_) { if (label_map_loaded_) {
RET_CHECK_EQ(num_classes, label_map_.size()); RET_CHECK_EQ(num_classes, label_map_.size());
} }
@ -135,7 +137,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator);
auto classification_list = absl::make_unique<ClassificationList>(); auto classification_list = absl::make_unique<ClassificationList>();
for (int i = 0; i < num_classes; ++i) { for (int i = 0; i < num_classes; ++i) {
if (raw_scores[i] < min_score_threshold_) { if (options_.has_min_score_threshold() &&
raw_scores[i] < options_.min_score_threshold()) {
continue; continue;
} }
Classification* classification = classification_list->add_classification(); Classification* classification = classification_list->add_classification();
@ -148,6 +151,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator);
// Note that partial_sort will raise error when top_k_ > // Note that partial_sort will raise error when top_k_ >
// classification_list->classification_size(). // classification_list->classification_size().
CHECK_GE(classification_list->classification_size(), top_k_);
auto raw_classification_list = classification_list->mutable_classification(); auto raw_classification_list = classification_list->mutable_classification();
if (top_k_ > 0 && classification_list->classification_size() >= top_k_) { if (top_k_ > 0 && classification_list->classification_size() >= top_k_) {
std::partial_sort(raw_classification_list->begin(), std::partial_sort(raw_classification_list->begin(),

View File

@ -27,7 +27,8 @@
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" #include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" #include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
@ -55,12 +56,14 @@ constexpr int kNumCoordsPerBox = 4;
namespace mediapipe { namespace mediapipe {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
using ::tflite::gpu::gl::GlShader; using ::tflite::gpu::gl::GlShader;
#endif #endif
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
typedef ::tflite::gpu::gl::GlBuffer GpuTensor; typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
typedef ::tflite::gpu::gl::GlProgram GpuProgram; typedef ::tflite::gpu::gl::GlProgram GpuProgram;
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
@ -70,7 +73,7 @@ typedef id<MTLComputePipelineState> GpuProgram;
namespace { namespace {
#if !defined(MEDIAPIPE_DISABLE_GPU) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
struct GPUData { struct GPUData {
GpuProgram decode_program; GpuProgram decode_program;
GpuProgram score_program; GpuProgram score_program;
@ -169,18 +172,21 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase {
const int* detection_classes, std::vector<Detection>* output_detections); const int* detection_classes, std::vector<Detection>* output_detections);
Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax, Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax,
float box_xmax, float score, int class_id, float box_xmax, float score, int class_id,
bool flip_vertically); int detection_id, bool flip_vertically);
int num_classes_ = 0; int num_classes_ = 0;
int num_boxes_ = 0; int num_boxes_ = 0;
int num_coords_ = 0; int num_coords_ = 0;
// Unique detection ID per new detection.
static int next_detection_id_;
std::set<int> ignore_classes_; std::set<int> ignore_classes_;
::mediapipe::TfLiteTensorsToDetectionsCalculatorOptions options_; ::mediapipe::TfLiteTensorsToDetectionsCalculatorOptions options_;
std::vector<Anchor> anchors_; std::vector<Anchor> anchors_;
bool side_packet_anchors_{}; bool side_packet_anchors_{};
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
mediapipe::GlCalculatorHelper gpu_helper_; mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<GPUData> gpu_data_; std::unique_ptr<GPUData> gpu_data_;
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
@ -193,6 +199,10 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase {
}; };
REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
// Initialization of non-const static member should happen outside class
// definition.
int TfLiteTensorsToDetectionsCalculator::next_detection_id_ = 0;
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GetContract( ::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GetContract(
CalculatorContract* cc) { CalculatorContract* cc) {
RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Inputs().GetTags().empty());
@ -204,7 +214,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>(); cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
} }
#if !defined(MEDIAPIPE_DISABLE_GPU) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
if (cc->Inputs().HasTag("TENSORS_GPU")) { if (cc->Inputs().HasTag("TENSORS_GPU")) {
cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>(); cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
use_gpu |= true; use_gpu |= true;
@ -222,7 +232,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
} }
if (use_gpu) { if (use_gpu) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
@ -238,7 +249,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
if (cc->Inputs().HasTag("TENSORS_GPU")) { if (cc->Inputs().HasTag("TENSORS_GPU")) {
gpu_input_ = true; gpu_input_ = true;
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
@ -400,7 +412,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
} }
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU( ::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU(
CalculatorContext* cc, std::vector<Detection>* output_detections) { CalculatorContext* cc, std::vector<Detection>* output_detections) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
const auto& input_tensors = const auto& input_tensors =
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>(); cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>();
RET_CHECK_GE(input_tensors.size(), 2); RET_CHECK_GE(input_tensors.size(), 2);
@ -562,7 +575,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Close( ::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Close(
CalculatorContext* cc) { CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
gpu_helper_.RunInGlContext([this] { gpu_data_.reset(); }); gpu_helper_.RunInGlContext([this] { gpu_data_.reset(); });
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
gpu_data_.reset(); gpu_data_.reset();
@ -672,7 +686,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
Detection detection = ConvertToDetection( Detection detection = ConvertToDetection(
detection_boxes[box_offset + 0], detection_boxes[box_offset + 1], detection_boxes[box_offset + 0], detection_boxes[box_offset + 1],
detection_boxes[box_offset + 2], detection_boxes[box_offset + 3], detection_boxes[box_offset + 2], detection_boxes[box_offset + 3],
detection_scores[i], detection_classes[i], options_.flip_vertically()); detection_scores[i], detection_classes[i], next_detection_id_,
options_.flip_vertically());
// Increment to get next unique detection ID.
++next_detection_id_;
// Add keypoints. // Add keypoints.
if (options_.num_keypoints() > 0) { if (options_.num_keypoints() > 0) {
auto* location_data = detection.mutable_location_data(); auto* location_data = detection.mutable_location_data();
@ -695,10 +712,11 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection( Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection(
float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score, float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score,
int class_id, bool flip_vertically) { int class_id, int detection_id, bool flip_vertically) {
Detection detection; Detection detection;
detection.add_score(score); detection.add_score(score);
detection.add_label_id(class_id); detection.add_label_id(class_id);
detection.set_detection_id(detection_id);
LocationData* location_data = detection.mutable_location_data(); LocationData* location_data = detection.mutable_location_data();
location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX); location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX);
@ -715,7 +733,8 @@ Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection(
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GpuInit( ::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GpuInit(
CalculatorContext* cc) { CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]()
-> ::mediapipe::Status { -> ::mediapipe::Status {
gpu_data_ = absl::make_unique<GPUData>(); gpu_data_ = absl::make_unique<GPUData>();

View File

@ -21,7 +21,8 @@
namespace mediapipe { namespace mediapipe {
// A calculator for converting TFLite tensors from regression models into // A calculator for converting TFLite tensors from regression models into
// landmarks. // landmarks. Note that if the landmarks in the tensor has more than 3
// dimensions, only the first 3 dimensions will be converted to x,y,z.
// //
// Input: // Input:
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32. Only the first // TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32. Only the first
@ -122,9 +123,6 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator);
num_values *= raw_tensor->dims->data[i]; num_values *= raw_tensor->dims->data[i];
} }
const int num_dimensions = num_values / num_landmarks_; const int num_dimensions = num_values / num_landmarks_;
// Landmarks must have less than 3 dimensions. Otherwise please consider
// using matrix.
CHECK_LE(num_dimensions, 3);
CHECK_GT(num_dimensions, 0); CHECK_GT(num_dimensions, 0);
const float* raw_landmarks = raw_tensor->data.f; const float* raw_landmarks = raw_tensor->data.f;

View File

@ -28,7 +28,8 @@
#include "mediapipe/util/resource_util.h" #include "mediapipe/util/resource_util.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/shader_util.h" #include "mediapipe/gpu/shader_util.h"
@ -53,7 +54,8 @@ float Clamp(float val, float min, float max) {
namespace mediapipe { namespace mediapipe {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
using ::tflite::gpu::gl::CopyBuffer; using ::tflite::gpu::gl::CopyBuffer;
using ::tflite::gpu::gl::CreateReadWriteRgbaImageTexture; using ::tflite::gpu::gl::CreateReadWriteRgbaImageTexture;
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
@ -129,7 +131,8 @@ class TfLiteTensorsToSegmentationCalculator : public CalculatorBase {
int tensor_channels_ = 0; int tensor_channels_ = 0;
bool use_gpu_ = false; bool use_gpu_ = false;
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
mediapipe::GlCalculatorHelper gpu_helper_; mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<GlProgram> mask_program_with_prev_; std::unique_ptr<GlProgram> mask_program_with_prev_;
std::unique_ptr<GlProgram> mask_program_no_prev_; std::unique_ptr<GlProgram> mask_program_no_prev_;
@ -159,7 +162,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
} }
// Inputs GPU. // Inputs GPU.
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
if (cc->Inputs().HasTag("TENSORS_GPU")) { if (cc->Inputs().HasTag("TENSORS_GPU")) {
cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GlBuffer>>(); cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GlBuffer>>();
use_gpu |= true; use_gpu |= true;
@ -178,7 +182,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
if (cc->Outputs().HasTag("MASK")) { if (cc->Outputs().HasTag("MASK")) {
cc->Outputs().Tag("MASK").Set<ImageFrame>(); cc->Outputs().Tag("MASK").Set<ImageFrame>();
} }
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
if (cc->Outputs().HasTag("MASK_GPU")) { if (cc->Outputs().HasTag("MASK_GPU")) {
cc->Outputs().Tag("MASK_GPU").Set<mediapipe::GpuBuffer>(); cc->Outputs().Tag("MASK_GPU").Set<mediapipe::GpuBuffer>();
use_gpu |= true; use_gpu |= true;
@ -186,7 +191,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
if (use_gpu) { if (use_gpu) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
} }
@ -199,7 +205,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
if (cc->Inputs().HasTag("TENSORS_GPU")) { if (cc->Inputs().HasTag("TENSORS_GPU")) {
use_gpu_ = true; use_gpu_ = true;
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
} }
@ -207,7 +214,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
MP_RETURN_IF_ERROR(LoadOptions(cc)); MP_RETURN_IF_ERROR(LoadOptions(cc));
if (use_gpu_) { if (use_gpu_) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
MP_RETURN_IF_ERROR(InitGpu(cc)); MP_RETURN_IF_ERROR(InitGpu(cc));
@ -224,7 +232,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Process( ::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Process(
CalculatorContext* cc) { CalculatorContext* cc) {
if (use_gpu_) { if (use_gpu_) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
MP_RETURN_IF_ERROR(ProcessGpu(cc)); MP_RETURN_IF_ERROR(ProcessGpu(cc));
@ -240,7 +249,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Close( ::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Close(
CalculatorContext* cc) { CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
gpu_helper_.RunInGlContext([this] { gpu_helper_.RunInGlContext([this] {
if (upsample_program_) glDeleteProgram(upsample_program_); if (upsample_program_) glDeleteProgram(upsample_program_);
upsample_program_ = 0; upsample_program_ = 0;
@ -367,7 +377,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
if (cc->Inputs().Tag("TENSORS_GPU").IsEmpty()) { if (cc->Inputs().Tag("TENSORS_GPU").IsEmpty()) {
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
// Get input streams. // Get input streams.
const auto& input_tensors = const auto& input_tensors =
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GlBuffer>>(); cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GlBuffer>>();
@ -453,7 +464,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
} }
void TfLiteTensorsToSegmentationCalculator::GlRender() { void TfLiteTensorsToSegmentationCalculator::GlRender() {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
static const GLfloat square_vertices[] = { static const GLfloat square_vertices[] = {
-1.0f, -1.0f, // bottom left -1.0f, -1.0f, // bottom left
1.0f, -1.0f, // bottom right 1.0f, -1.0f, // bottom right
@ -525,7 +537,8 @@ void TfLiteTensorsToSegmentationCalculator::GlRender() {
::mediapipe::Status TfLiteTensorsToSegmentationCalculator::InitGpu( ::mediapipe::Status TfLiteTensorsToSegmentationCalculator::InitGpu(
CalculatorContext* cc) { CalculatorContext* cc) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
!defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]()
-> ::mediapipe::Status { -> ::mediapipe::Status {
// A shader to process a segmentation tensor into an output mask, // A shader to process a segmentation tensor into an output mask,

View File

@ -14,7 +14,7 @@
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:public"])
exports_files(["LICENSE"]) exports_files(["LICENSE"])
@ -234,6 +234,7 @@ cc_library(
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:vector", "//mediapipe/framework/port:vector",
"//mediapipe/util:annotation_renderer", "//mediapipe/util:annotation_renderer",
"//mediapipe/util:render_data_cc_proto",
] + select({ ] + select({
"//mediapipe/gpu:disable_gpu": [], "//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [ "//conditions:default": [
@ -360,6 +361,16 @@ mediapipe_cc_proto_library(
deps = [":landmark_projection_calculator_proto"], deps = [":landmark_projection_calculator_proto"],
) )
mediapipe_cc_proto_library(
name = "landmarks_to_floats_calculator_cc_proto",
srcs = ["landmarks_to_floats_calculator.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
],
visibility = ["//visibility:public"],
deps = [":landmarks_to_floats_calculator_proto"],
)
mediapipe_cc_proto_library( mediapipe_cc_proto_library(
name = "rect_transformation_calculator_cc_proto", name = "rect_transformation_calculator_cc_proto",
srcs = ["rect_transformation_calculator.proto"], srcs = ["rect_transformation_calculator.proto"],
@ -372,7 +383,12 @@ mediapipe_cc_proto_library(
cc_library( cc_library(
name = "detections_to_rects_calculator", name = "detections_to_rects_calculator",
srcs = ["detections_to_rects_calculator.cc"], srcs = [
"detections_to_rects_calculator.cc",
],
hdrs = [
"detections_to_rects_calculator.h",
],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":detections_to_rects_calculator_cc_proto", ":detections_to_rects_calculator_cc_proto",
@ -454,6 +470,17 @@ proto_library(
], ],
) )
proto_library(
name = "labels_to_render_data_calculator_proto",
srcs = ["labels_to_render_data_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
"//mediapipe/util:color_proto",
"//mediapipe/util:render_data_proto",
],
)
proto_library( proto_library(
name = "thresholding_calculator_proto", name = "thresholding_calculator_proto",
srcs = ["thresholding_calculator.proto"], srcs = ["thresholding_calculator.proto"],
@ -483,6 +510,15 @@ proto_library(
], ],
) )
proto_library(
name = "landmarks_to_floats_calculator_proto",
srcs = ["landmarks_to_floats_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
],
)
proto_library( proto_library(
name = "rect_transformation_calculator_proto", name = "rect_transformation_calculator_proto",
srcs = ["rect_transformation_calculator.proto"], srcs = ["rect_transformation_calculator.proto"],
@ -577,6 +613,26 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "labels_to_render_data_calculator",
srcs = ["labels_to_render_data_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":labels_to_render_data_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_options_cc_proto",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:video_stream_header",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"//mediapipe/util:color_cc_proto",
"//mediapipe/util:render_data_cc_proto",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_library( cc_library(
name = "rect_to_render_data_calculator", name = "rect_to_render_data_calculator",
srcs = ["rect_to_render_data_calculator.cc"], srcs = ["rect_to_render_data_calculator.cc"],
@ -658,6 +714,22 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "landmarks_to_floats_calculator",
srcs = ["landmarks_to_floats_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":landmarks_to_floats_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@eigen_archive//:eigen",
],
alwayslink = 1,
)
cc_test( cc_test(
name = "detection_letterbox_removal_calculator_test", name = "detection_letterbox_removal_calculator_test",
srcs = ["detection_letterbox_removal_calculator_test.cc"], srcs = ["detection_letterbox_removal_calculator_test.cc"],
@ -714,6 +786,7 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":top_k_scores_calculator_cc_proto", ":top_k_scores_calculator_cc_proto",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor", "//mediapipe/framework/port:statusor",
@ -750,3 +823,27 @@ cc_test(
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
], ],
) )
mediapipe_cc_proto_library(
name = "labels_to_render_data_calculator_cc_proto",
srcs = ["labels_to_render_data_calculator.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/util:color_cc_proto",
],
visibility = ["//visibility:public"],
deps = [":labels_to_render_data_calculator_proto"],
)
cc_library(
name = "local_file_contents_calculator",
srcs = ["local_file_contents_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)

View File

@ -26,6 +26,7 @@
#include "mediapipe/framework/port/vector.h" #include "mediapipe/framework/port/vector.h"
#include "mediapipe/util/annotation_renderer.h" #include "mediapipe/util/annotation_renderer.h"
#include "mediapipe/util/color.pb.h" #include "mediapipe/util/color.pb.h"
#include "mediapipe/util/render_data.pb.h"
#if !defined(MEDIAPIPE_DISABLE_GPU) #if !defined(MEDIAPIPE_DISABLE_GPU)
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
@ -41,6 +42,8 @@ namespace {
constexpr char kInputFrameTag[] = "INPUT_FRAME"; constexpr char kInputFrameTag[] = "INPUT_FRAME";
constexpr char kOutputFrameTag[] = "OUTPUT_FRAME"; constexpr char kOutputFrameTag[] = "OUTPUT_FRAME";
constexpr char kInputVectorTag[] = "VECTOR";
constexpr char kInputFrameTagGpu[] = "INPUT_FRAME_GPU"; constexpr char kInputFrameTagGpu[] = "INPUT_FRAME_GPU";
constexpr char kOutputFrameTagGpu[] = "OUTPUT_FRAME_GPU"; constexpr char kOutputFrameTagGpu[] = "OUTPUT_FRAME_GPU";
@ -65,6 +68,9 @@ constexpr int kAnnotationBackgroundColor[] = {100, 101, 102};
// 2. RenderData proto on variable number of input streams. All the RenderData // 2. RenderData proto on variable number of input streams. All the RenderData
// at a particular timestamp is drawn on the image in the order of their // at a particular timestamp is drawn on the image in the order of their
// input streams. No tags required. // input streams. No tags required.
// 3. std::vector<RenderData> on variable number of input streams. RenderData
// objects at a particular timestamp are drawn on the image in order of the
// input vector items. These input streams are tagged with "VECTOR".
// //
// Output: // Output:
// 1. OUTPUT_FRAME or OUTPUT_FRAME_GPU: A rendered ImageFrame (or GpuBuffer). // 1. OUTPUT_FRAME or OUTPUT_FRAME_GPU: A rendered ImageFrame (or GpuBuffer).
@ -85,6 +91,8 @@ constexpr int kAnnotationBackgroundColor[] = {100, 101, 102};
// input_stream: "render_data_1" // input_stream: "render_data_1"
// input_stream: "render_data_2" // input_stream: "render_data_2"
// input_stream: "render_data_3" // input_stream: "render_data_3"
// input_stream: "VECTOR:0:render_data_vec_0"
// input_stream: "VECTOR:1:render_data_vec_1"
// output_stream: "OUTPUT_FRAME:decorated_frames" // output_stream: "OUTPUT_FRAME:decorated_frames"
// options { // options {
// [mediapipe.AnnotationOverlayCalculatorOptions.ext] { // [mediapipe.AnnotationOverlayCalculatorOptions.ext] {
@ -99,6 +107,8 @@ constexpr int kAnnotationBackgroundColor[] = {100, 101, 102};
// input_stream: "render_data_1" // input_stream: "render_data_1"
// input_stream: "render_data_2" // input_stream: "render_data_2"
// input_stream: "render_data_3" // input_stream: "render_data_3"
// input_stream: "VECTOR:0:render_data_vec_0"
// input_stream: "VECTOR:1:render_data_vec_1"
// output_stream: "OUTPUT_FRAME_GPU:decorated_frames" // output_stream: "OUTPUT_FRAME_GPU:decorated_frames"
// options { // options {
// [mediapipe.AnnotationOverlayCalculatorOptions.ext] { // [mediapipe.AnnotationOverlayCalculatorOptions.ext] {
@ -188,8 +198,16 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
} }
// Data streams to render. // Data streams to render.
for (int i = 0; i < num_render_streams; ++i) { for (CollectionItemId id = cc->Inputs().BeginId(); id < cc->Inputs().EndId();
cc->Inputs().Index(i).Set<RenderData>(); ++id) {
auto tag_and_index = cc->Inputs().TagAndIndexFromId(id);
std::string tag = tag_and_index.first;
if (tag == kInputVectorTag) {
cc->Inputs().Get(id).Set<std::vector<RenderData>>();
} else if (tag.empty()) {
// Empty tag defaults to accepting a single object of RenderData type.
cc->Inputs().Get(id).Set<RenderData>();
}
} }
// Rendered image. // Rendered image.
@ -285,12 +303,28 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
renderer_->AdoptImage(image_mat.get()); renderer_->AdoptImage(image_mat.get());
// Render streams onto render target. // Render streams onto render target.
for (int i = 0; i < num_render_streams_; ++i) { for (CollectionItemId id = cc->Inputs().BeginId(); id < cc->Inputs().EndId();
if (cc->Inputs().Index(i).IsEmpty()) { ++id) {
auto tag_and_index = cc->Inputs().TagAndIndexFromId(id);
std::string tag = tag_and_index.first;
if (!tag.empty() && tag != kInputVectorTag) {
continue; continue;
} }
const RenderData& render_data = cc->Inputs().Index(i).Get<RenderData>(); if (cc->Inputs().Get(id).IsEmpty()) {
continue;
}
if (tag.empty()) {
// Empty tag defaults to accepting a single object of RenderData type.
const RenderData& render_data = cc->Inputs().Get(id).Get<RenderData>();
renderer_->RenderDataOnImage(render_data); renderer_->RenderDataOnImage(render_data);
} else {
RET_CHECK_EQ(kInputVectorTag, tag);
const std::vector<RenderData>& render_data_vec =
cc->Inputs().Get(id).Get<std::vector<RenderData>>();
for (const RenderData& render_data : render_data_vec) {
renderer_->RenderDataOnImage(render_data);
}
}
} }
if (use_gpu_) { if (use_gpu_) {

View File

@ -19,8 +19,8 @@
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/util/resource_util.h" #include "mediapipe/util/resource_util.h"
#if defined(MEDIAPIPE_LITE) || defined(__ANDROID__) || \ #if defined(MEDIAPIPE_LITE) || defined(__EMSCRIPTEN__) || \
(defined(__APPLE__) && !TARGET_OS_OSX) defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
#include "mediapipe/util/android/file/base/file.h" #include "mediapipe/util/android/file/base/file.h"
#include "mediapipe/util/android/file/base/helpers.h" #include "mediapipe/util/android/file/base/helpers.h"
#else #else

View File

@ -11,6 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mediapipe/calculators/util/detections_to_rects_calculator.h"
#include <cmath> #include <cmath>
#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h" #include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h"
@ -24,8 +26,6 @@
namespace mediapipe { namespace mediapipe {
using mediapipe::DetectionsToRectsCalculatorOptions;
namespace { namespace {
constexpr char kDetectionTag[] = "DETECTION"; constexpr char kDetectionTag[] = "DETECTION";
@ -36,7 +36,10 @@ constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kRectsTag[] = "RECTS"; constexpr char kRectsTag[] = "RECTS";
constexpr char kNormRectsTag[] = "NORM_RECTS"; constexpr char kNormRectsTag[] = "NORM_RECTS";
::mediapipe::Status DetectionToRect(const Detection& detection, Rect* rect) { } // namespace
::mediapipe::Status DetectionsToRectsCalculator::DetectionToRect(
const Detection& detection, Rect* rect) {
const LocationData location_data = detection.location_data(); const LocationData location_data = detection.location_data();
RET_CHECK(location_data.format() == LocationData::BOUNDING_BOX) RET_CHECK(location_data.format() == LocationData::BOUNDING_BOX)
<< "Only Detection with formats of BOUNDING_BOX can be converted to Rect"; << "Only Detection with formats of BOUNDING_BOX can be converted to Rect";
@ -48,8 +51,8 @@ constexpr char kNormRectsTag[] = "NORM_RECTS";
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status DetectionToNormalizedRect(const Detection& detection, ::mediapipe::Status DetectionsToRectsCalculator::DetectionToNormalizedRect(
NormalizedRect* rect) { const Detection& detection, NormalizedRect* rect) {
const LocationData location_data = detection.location_data(); const LocationData location_data = detection.location_data();
RET_CHECK(location_data.format() == LocationData::RELATIVE_BOUNDING_BOX) RET_CHECK(location_data.format() == LocationData::RELATIVE_BOUNDING_BOX)
<< "Only Detection with formats of RELATIVE_BOUNDING_BOX can be " << "Only Detection with formats of RELATIVE_BOUNDING_BOX can be "
@ -63,79 +66,6 @@ constexpr char kNormRectsTag[] = "NORM_RECTS";
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
// Wraps around an angle in radians to within -M_PI and M_PI.
inline float NormalizeRadians(float angle) {
return angle - 2 * M_PI * std::floor((angle - (-M_PI)) / (2 * M_PI));
}
} // namespace
// A calculator that converts Detection proto to Rect proto.
//
// Detection is the format for encoding one or more detections in an image.
// The input can be a single Detection or std::vector<Detection>. The output can
// be either a single Rect or NormalizedRect, or std::vector<Rect> or
// std::vector<NormalizedRect>. If Rect is used, the LocationData format is
// expected to be BOUNDING_BOX, and if NormalizedRect is used it is expected to
// be RELATIVE_BOUNDING_BOX.
//
// When the input is std::vector<Detection> and the output is a Rect or
// NormalizedRect, only the first detection is converted. When the input is a
// single Detection and the output is a std::vector<Rect> or
// std::vector<NormalizedRect>, the output is a vector of size 1.
//
// Inputs:
//
// One of the following:
// DETECTION: A Detection proto.
// DETECTIONS: An std::vector<Detection>.
//
// IMAGE_SIZE (optional): A std::pair<int, int> represention image width and
// height. This is required only when rotation needs to be computed (see
// calculator options).
//
// Output:
// One of the following:
// RECT: A Rect proto.
// NORM_RECT: A NormalizedRect proto.
// RECTS: An std::vector<Rect>.
// NORM_RECTS: An std::vector<NormalizedRect>.
//
// Example config:
// node {
// calculator: "DetectionsToRectsCalculator"
// input_stream: "DETECTIONS:detections"
// input_stream: "IMAGE_SIZE:image_size"
// output_stream: "NORM_RECT:rect"
// options: {
// [mediapipe.DetectionsToRectCalculatorOptions.ext] {
// rotation_vector_start_keypoint_index: 0
// rotation_vector_end_keypoint_index: 2
// rotation_vector_target_angle_degrees: 90
// output_zero_rect_for_empty_detections: true
// }
// }
// }
class DetectionsToRectsCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
float ComputeRotation(const Detection& detection,
const std::pair<int, int> image_size);
DetectionsToRectsCalculatorOptions options_;
int start_keypoint_index_;
int end_keypoint_index_;
float target_angle_; // In radians.
bool rotate_;
bool output_zero_rect_for_empty_detections_;
};
REGISTER_CALCULATOR(DetectionsToRectsCalculator);
::mediapipe::Status DetectionsToRectsCalculator::GetContract( ::mediapipe::Status DetectionsToRectsCalculator::GetContract(
CalculatorContract* cc) { CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag(kDetectionTag) ^ RET_CHECK(cc->Inputs().HasTag(kDetectionTag) ^
@ -232,6 +162,13 @@ REGISTER_CALCULATOR(DetectionsToRectsCalculator);
.Tag(kNormRectTag) .Tag(kNormRectTag)
.AddPacket(MakePacket<NormalizedRect>().At(cc->InputTimestamp())); .AddPacket(MakePacket<NormalizedRect>().At(cc->InputTimestamp()));
} }
if (cc->Outputs().HasTag(kNormRectsTag)) {
auto rect_vector = absl::make_unique<std::vector<NormalizedRect>>();
rect_vector->emplace_back(NormalizedRect());
cc->Outputs()
.Tag(kNormRectsTag)
.Add(rect_vector.release(), cc->InputTimestamp());
}
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -312,4 +249,6 @@ float DetectionsToRectsCalculator::ComputeRotation(
return NormalizeRadians(rotation); return NormalizeRadians(rotation);
} }
REGISTER_CALCULATOR(DetectionsToRectsCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -0,0 +1,105 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_UTIL_DETECTIONS_TO_RECTS_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_UTIL_DETECTIONS_TO_RECTS_CALCULATOR_H_
#include <cmath>
#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_options.pb.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// A calculator that converts Detection proto to Rect proto.
//
// Detection is the format for encoding one or more detections in an image.
// The input can be a single Detection or std::vector<Detection>. The output can
// be either a single Rect or NormalizedRect, or std::vector<Rect> or
// std::vector<NormalizedRect>. If Rect is used, the LocationData format is
// expected to be BOUNDING_BOX, and if NormalizedRect is used it is expected to
// be RELATIVE_BOUNDING_BOX.
//
// When the input is std::vector<Detection> and the output is a Rect or
// NormalizedRect, only the first detection is converted. When the input is a
// single Detection and the output is a std::vector<Rect> or
// std::vector<NormalizedRect>, the output is a vector of size 1.
//
// Inputs:
//
// One of the following:
// DETECTION: A Detection proto.
// DETECTIONS: An std::vector<Detection>.
//
// IMAGE_SIZE (optional): A std::pair<int, int> represention image width and
// height. This is required only when rotation needs to be computed (see
// calculator options).
//
// Output:
// One of the following:
// RECT: A Rect proto.
// NORM_RECT: A NormalizedRect proto.
// RECTS: An std::vector<Rect>.
// NORM_RECTS: An std::vector<NormalizedRect>.
//
// Example config:
// node {
// calculator: "DetectionsToRectsCalculator"
// input_stream: "DETECTIONS:detections"
// input_stream: "IMAGE_SIZE:image_size"
// output_stream: "NORM_RECT:rect"
// options: {
// [mediapipe.DetectionsToRectCalculatorOptions.ext] {
// rotation_vector_start_keypoint_index: 0
// rotation_vector_end_keypoint_index: 2
// rotation_vector_target_angle_degrees: 90
// output_zero_rect_for_empty_detections: true
// }
// }
// }
class DetectionsToRectsCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
protected:
virtual float ComputeRotation(const ::mediapipe::Detection& detection,
const std::pair<int, int> image_size);
virtual ::mediapipe::Status DetectionToRect(
const ::mediapipe::Detection& detection, ::mediapipe::Rect* rect);
virtual ::mediapipe::Status DetectionToNormalizedRect(
const ::mediapipe::Detection& detection,
::mediapipe::NormalizedRect* rect);
static inline float NormalizeRadians(float angle) {
return angle - 2 * M_PI * std::floor((angle - (-M_PI)) / (2 * M_PI));
}
::mediapipe::DetectionsToRectsCalculatorOptions options_;
int start_keypoint_index_;
int end_keypoint_index_;
float target_angle_ = 0.0f; // In radians.
bool rotate_;
bool output_zero_rect_for_empty_detections_;
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_UTIL_DETECTIONS_TO_RECTS_CALCULATOR_H_

View File

@ -0,0 +1,181 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h>
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "absl/strings/str_cat.h"
#include "mediapipe/calculators/util/labels_to_render_data_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/video_stream_header.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/render_data.pb.h"
namespace mediapipe {
constexpr float kFontHeightScale = 1.25f;
// A calculator takes in pairs of labels and scores or classifications, outputs
// generates render data. Either both "LABELS" and "SCORES" or "CLASSIFICATIONS"
// must be present.
//
// Usage example:
// node {
// calculator: "LabelsToRenderDataCalculator"
// input_stream: "LABELS:labels"
// input_stream: "SCORES:scores"
// output_stream: "VIDEO_PRESTREAM:video_header"
// options {
// [LabelsToRenderDataCalculatorOptions.ext] {
// color { r: 255 g: 0 b: 0 }
// color { r: 0 g: 255 b: 0 }
// color { r: 0 g: 0 b: 255 }
// thickness: 2.0
// font_height_px: 20
// max_num_labels: 3
// font_face: 1
// location: TOP_LEFT
// }
// }
// }
class LabelsToRenderDataCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
LabelsToRenderDataCalculatorOptions options_;
int num_colors_ = 0;
int video_width_ = 0;
int video_height_ = 0;
int label_height_px_ = 0;
int label_left_px_ = 0;
};
REGISTER_CALCULATOR(LabelsToRenderDataCalculator);
::mediapipe::Status LabelsToRenderDataCalculator::GetContract(
CalculatorContract* cc) {
if (cc->Inputs().HasTag("CLASSIFICATIONS")) {
cc->Inputs().Tag("CLASSIFICATIONS").Set<ClassificationList>();
} else {
RET_CHECK(cc->Inputs().HasTag("LABELS"))
<< "Must provide input stream \"LABELS\"";
cc->Inputs().Tag("LABELS").Set<std::vector<std::string>>();
if (cc->Inputs().HasTag("SCORES")) {
cc->Inputs().Tag("SCORES").Set<std::vector<float>>();
}
}
if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) {
cc->Inputs().Tag("VIDEO_PRESTREAM").Set<VideoHeader>();
}
cc->Outputs().Tag("RENDER_DATA").Set<RenderData>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status LabelsToRenderDataCalculator::Open(CalculatorContext* cc) {
options_ = cc->Options<LabelsToRenderDataCalculatorOptions>();
num_colors_ = options_.color_size();
label_height_px_ = std::ceil(options_.font_height_px() * kFontHeightScale);
return ::mediapipe::OkStatus();
}
::mediapipe::Status LabelsToRenderDataCalculator::Process(
CalculatorContext* cc) {
if (cc->Inputs().HasTag("VIDEO_PRESTREAM") &&
cc->InputTimestamp() == Timestamp::PreStream()) {
const VideoHeader& video_header =
cc->Inputs().Tag("VIDEO_PRESTREAM").Get<VideoHeader>();
video_width_ = video_header.width;
video_height_ = video_header.height;
return ::mediapipe::OkStatus();
} else {
CHECK_EQ(options_.location(), LabelsToRenderDataCalculatorOptions::TOP_LEFT)
<< "Only TOP_LEFT is supported without VIDEO_PRESTREAM.";
}
std::vector<std::string> labels;
std::vector<float> scores;
if (cc->Inputs().HasTag("CLASSIFICATIONS")) {
const ClassificationList& classifications =
cc->Inputs().Tag("CLASSIFICATIONS").Get<ClassificationList>();
labels.resize(classifications.classification_size());
scores.resize(classifications.classification_size());
for (int i = 0; i < classifications.classification_size(); ++i) {
labels[i] = classifications.classification(i).label();
scores[i] = classifications.classification(i).score();
}
} else {
const std::vector<std::string>& label_vector =
cc->Inputs().Tag("LABELS").Get<std::vector<std::string>>();
std::vector<float> score_vector;
if (cc->Inputs().HasTag("SCORES")) {
score_vector = cc->Inputs().Tag("SCORES").Get<std::vector<float>>();
}
CHECK_EQ(label_vector.size(), score_vector.size());
labels.resize(label_vector.size());
scores.resize(label_vector.size());
for (int i = 0; i < label_vector.size(); ++i) {
labels[i] = label_vector[i];
scores[i] = score_vector[i];
}
}
RenderData render_data;
int num_label = std::min((int)labels.size(), options_.max_num_labels());
int label_baseline_px = options_.vertical_offset_px();
if (options_.location() == LabelsToRenderDataCalculatorOptions::TOP_LEFT) {
label_baseline_px += label_height_px_;
} else if (options_.location() ==
LabelsToRenderDataCalculatorOptions::BOTTOM_LEFT) {
label_baseline_px += video_height_ - label_height_px_ * (num_label - 1);
}
label_left_px_ = options_.horizontal_offset_px();
for (int i = 0; i < num_label; ++i) {
auto* label_annotation = render_data.add_render_annotations();
label_annotation->set_thickness(options_.thickness());
if (num_colors_ > 0) {
*(label_annotation->mutable_color()) = options_.color(i % num_colors_);
} else {
label_annotation->mutable_color()->set_r(255);
label_annotation->mutable_color()->set_g(0);
label_annotation->mutable_color()->set_b(0);
}
auto* text = label_annotation->mutable_text();
std::string display_text = labels[i];
if (cc->Inputs().HasTag("SCORES")) {
absl::StrAppend(&display_text, ":", scores[i]);
}
text->set_display_text(display_text);
text->set_font_height(options_.font_height_px());
text->set_left(label_left_px_);
text->set_baseline(label_baseline_px + i * label_height_px_);
text->set_font_face(options_.font_face());
}
cc->Outputs()
.Tag("RENDER_DATA")
.AddPacket(MakePacket<RenderData>(render_data).At(cc->InputTimestamp()));
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,62 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
import "mediapipe/util/color.proto";
message LabelsToRenderDataCalculatorOptions {
extend CalculatorOptions {
optional LabelsToRenderDataCalculatorOptions ext = 271660364;
}
// Colors for drawing the label(s).
repeated Color color = 1;
// Thickness for drawing the label(s).
optional double thickness = 2 [default = 2];
// The font height in absolute pixels.
optional int32 font_height_px = 3 [default = 50];
// The offset of the starting text in horizontal direction in absolute pixels.
optional int32 horizontal_offset_px = 7 [default = 0];
// The offset of the starting text in vertical direction in absolute pixels.
optional int32 vertical_offset_px = 8 [default = 0];
// The maximum number of labels to display.
optional int32 max_num_labels = 4 [default = 1];
// Specifies the font for the text. Font must be one of the following from
// OpenCV:
// cv::FONT_HERSHEY_SIMPLEX (0)
// cv::FONT_HERSHEY_PLAIN (1)
// cv::FONT_HERSHEY_DUPLEX (2)
// cv::FONT_HERSHEY_COMPLEX (3)
// cv::FONT_HERSHEY_TRIPLEX (4)
// cv::FONT_HERSHEY_COMPLEX_SMALL (5)
// cv::FONT_HERSHEY_SCRIPT_SIMPLEX (6)
// cv::FONT_HERSHEY_SCRIPT_COMPLEX (7)
optional int32 font_face = 5 [default = 0];
// Label location.
enum Location {
TOP_LEFT = 0;
BOTTOM_LEFT = 1;
}
optional Location location = 6 [default = TOP_LEFT];
}

View File

@ -0,0 +1,138 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cmath>
#include <vector>
#include "Eigen/Core"
#include "mediapipe/calculators/util/landmarks_to_floats_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/ret_check.h"
namespace mediapipe {
namespace {
constexpr char kLandmarksTag[] = "NORM_LANDMARKS";
constexpr char kFloatsTag[] = "FLOATS";
constexpr char kMatrixTag[] = "MATRIX";
} // namespace
// Converts a vector of landmarks to a vector of floats or a matrix.
// Input:
// NORM_LANDMARKS: An std::vector<NormalizedLandmark>.
//
// Output:
// FLOATS(optional): A vector of floats from flattened landmarks.
// MATRIX(optional): A matrix of floats of the landmarks.
//
// Usage example:
// node {
// calculator: "LandmarksToFloatsCalculator"
// input_stream: "NORM_LANDMARKS:landmarks"
// output_stream: "MATRIX:landmark_matrix"
// }
class LandmarksToFloatsCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag(kLandmarksTag).Set<std::vector<NormalizedLandmark>>();
RET_CHECK(cc->Outputs().HasTag(kFloatsTag) ||
cc->Outputs().HasTag(kMatrixTag));
if (cc->Outputs().HasTag(kFloatsTag)) {
cc->Outputs().Tag(kFloatsTag).Set<std::vector<float>>();
}
if (cc->Outputs().HasTag(kMatrixTag)) {
cc->Outputs().Tag(kMatrixTag).Set<Matrix>();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
const auto& options =
cc->Options<::mediapipe::LandmarksToFloatsCalculatorOptions>();
num_dimensions_ = options.num_dimensions();
// Currently number of dimensions must be within [1, 3].
RET_CHECK_GE(num_dimensions_, 1);
RET_CHECK_LE(num_dimensions_, 3);
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
// Only process if there's input landmarks.
if (cc->Inputs().Tag(kLandmarksTag).IsEmpty()) {
return ::mediapipe::OkStatus();
}
const auto& input_landmarks =
cc->Inputs().Tag(kLandmarksTag).Get<std::vector<NormalizedLandmark>>();
if (cc->Outputs().HasTag(kFloatsTag)) {
auto output_floats = absl::make_unique<std::vector<float>>();
for (const auto& landmark : input_landmarks) {
output_floats->emplace_back(landmark.x());
if (num_dimensions_ > 1) {
output_floats->emplace_back(landmark.y());
}
if (num_dimensions_ > 2) {
output_floats->emplace_back(landmark.z());
}
}
cc->Outputs()
.Tag(kFloatsTag)
.Add(output_floats.release(), cc->InputTimestamp());
} else {
auto output_matrix = absl::make_unique<Matrix>();
output_matrix->setZero(num_dimensions_, input_landmarks.size());
for (int i = 0; i < input_landmarks.size(); ++i) {
(*output_matrix)(0, i) = input_landmarks[i].x();
if (num_dimensions_ > 1) {
(*output_matrix)(1, i) = input_landmarks[i].y();
}
if (num_dimensions_ > 2) {
(*output_matrix)(2, i) = input_landmarks[i].z();
}
}
cc->Outputs()
.Tag(kMatrixTag)
.Add(output_matrix.release(), cc->InputTimestamp());
}
return ::mediapipe::OkStatus();
}
private:
int num_dimensions_ = 0;
};
REGISTER_CALCULATOR(LandmarksToFloatsCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,28 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message LandmarksToFloatsCalculatorOptions {
extend CalculatorOptions {
optional LandmarksToFloatsCalculatorOptions ext = 274035660;
}
// Number of dimensions to convert. Must within [1, 3].
optional int32 num_dimensions = 1 [default = 2];
}

View File

@ -0,0 +1,57 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// The calculator takes the path to the local file as an input side packet and
// outputs the contents of that file.
//
// Example config:
// node {
// calculator: "LocalFileContentsCalculator"
// input_side_packet: "FILE_PATH:file_path"
// output_side_packet: "CONTENTS:contents"
// }
class LocalFileContentsCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Tag("FILE_PATH").Set<std::string>();
cc->OutputSidePackets().Tag("CONTENTS").Set<std::string>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
std::string contents;
MP_RETURN_IF_ERROR(mediapipe::file::GetContents(
cc->InputSidePackets().Tag("FILE_PATH").Get<std::string>(), &contents));
cc->OutputSidePackets()
.Tag("CONTENTS")
.Set(MakePacket<std::string>(std::move(contents)));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(LocalFileContentsCalculator);
} // namespace mediapipe

View File

@ -23,13 +23,14 @@
#include "mediapipe/calculators/util/top_k_scores_calculator.pb.h" #include "mediapipe/calculators/util/top_k_scores_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/port/statusor.h"
#include "mediapipe/util/resource_util.h" #include "mediapipe/util/resource_util.h"
#if defined(MEDIAPIPE_LITE) || defined(__ANDROID__) || \ #if defined(MEDIAPIPE_LITE) || defined(__EMSCRIPTEN__) || \
(defined(__APPLE__) && !TARGET_OS_OSX) defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
#include "mediapipe/util/android/file/base/file.h" #include "mediapipe/util/android/file/base/file.h"
#include "mediapipe/util/android/file/base/helpers.h" #include "mediapipe/util/android/file/base/helpers.h"
#else #else
@ -37,8 +38,10 @@
#endif #endif
namespace mediapipe { namespace mediapipe {
// A calculator that takes a vector of scores and returns the indexes, scores, // A calculator that takes a vector of scores and returns the indexes, scores,
// labels of the top k elements. // labels of the top k elements, classification protos, and summary std::string
// (in csv format).
// //
// Usage example: // Usage example:
// node { // node {
@ -47,6 +50,8 @@ namespace mediapipe {
// output_stream: "TOP_K_INDEXES:top_k_indexes" // output_stream: "TOP_K_INDEXES:top_k_indexes"
// output_stream: "TOP_K_SCORES:top_k_scores" // output_stream: "TOP_K_SCORES:top_k_scores"
// output_stream: "TOP_K_LABELS:top_k_labels" // output_stream: "TOP_K_LABELS:top_k_labels"
// output_stream: "TOP_K_CLASSIFICATIONS:top_k_classes"
// output_stream: "SUMMARY:summary"
// options: { // options: {
// [mediapipe.TopKScoresCalculatorOptions.ext] { // [mediapipe.TopKScoresCalculatorOptions.ext] {
// top_k: 5 // top_k: 5
@ -69,6 +74,7 @@ class TopKScoresCalculator : public CalculatorBase {
int top_k_ = -1; int top_k_ = -1;
float threshold_ = 0.0; float threshold_ = 0.0;
std::unordered_map<int, std::string> label_map_; std::unordered_map<int, std::string> label_map_;
bool label_map_loaded_ = false;
}; };
REGISTER_CALCULATOR(TopKScoresCalculator); REGISTER_CALCULATOR(TopKScoresCalculator);
@ -84,6 +90,12 @@ REGISTER_CALCULATOR(TopKScoresCalculator);
if (cc->Outputs().HasTag("TOP_K_LABELS")) { if (cc->Outputs().HasTag("TOP_K_LABELS")) {
cc->Outputs().Tag("TOP_K_LABELS").Set<std::vector<std::string>>(); cc->Outputs().Tag("TOP_K_LABELS").Set<std::vector<std::string>>();
} }
if (cc->Outputs().HasTag("CLASSIFICATIONS")) {
cc->Outputs().Tag("CLASSIFICATIONS").Set<ClassificationList>();
}
if (cc->Outputs().HasTag("SUMMARY")) {
cc->Outputs().Tag("SUMMARY").Set<std::string>();
}
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -149,7 +161,7 @@ REGISTER_CALCULATOR(TopKScoresCalculator);
reverse(top_k_indexes.begin(), top_k_indexes.end()); reverse(top_k_indexes.begin(), top_k_indexes.end());
reverse(top_k_scores.begin(), top_k_scores.end()); reverse(top_k_scores.begin(), top_k_scores.end());
if (cc->Outputs().HasTag("TOP_K_LABELS")) { if (label_map_loaded_) {
for (int index : top_k_indexes) { for (int index : top_k_indexes) {
top_k_labels.push_back(label_map_[index]); top_k_labels.push_back(label_map_[index]);
} }
@ -172,6 +184,35 @@ REGISTER_CALCULATOR(TopKScoresCalculator);
.AddPacket(MakePacket<std::vector<std::string>>(top_k_labels) .AddPacket(MakePacket<std::vector<std::string>>(top_k_labels)
.At(cc->InputTimestamp())); .At(cc->InputTimestamp()));
} }
if (cc->Outputs().HasTag("SUMMARY")) {
std::vector<std::string> results;
for (int index = 0; index < top_k_indexes.size(); ++index) {
if (label_map_loaded_) {
results.push_back(
absl::StrCat(top_k_labels[index], ":", top_k_scores[index]));
} else {
results.push_back(
absl::StrCat(top_k_indexes[index], ":", top_k_scores[index]));
}
}
cc->Outputs().Tag("SUMMARY").AddPacket(
MakePacket<std::string>(absl::StrJoin(results, ","))
.At(cc->InputTimestamp()));
}
if (cc->Outputs().HasTag("TOP_K_CLASSIFICATION")) {
auto classification_list = absl::make_unique<ClassificationList>();
for (int index = 0; index < top_k_indexes.size(); ++index) {
Classification* classification =
classification_list->add_classification();
classification->set_index(top_k_indexes[index]);
classification->set_score(top_k_scores[index]);
if (label_map_loaded_) {
classification->set_label(top_k_labels[index]);
}
}
}
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -188,6 +229,7 @@ REGISTER_CALCULATOR(TopKScoresCalculator);
while (std::getline(stream, line)) { while (std::getline(stream, line)) {
label_map_[i++] = line; label_map_[i++] = line;
} }
label_map_loaded_ = true;
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }

View File

@ -0,0 +1,130 @@
## MediaPipe Android Archive Library
***Experimental Only***
The MediaPipe Android archive library is a convenient way to use MediaPipe with
Android Studio and Gradle. MediaPipe doesn't publish a general AAR that can be
used by all projects. Instead, developers need to add a mediapipe_aar() target
to generate a custom AAR file for their own projects. This is necessary in order
to include specific resources such as MediaPipe calculators needed for each
project.
### Steps to build a MediaPipe AAR
1. Create a mediapipe_aar() target.
In the MediaPipe directory, create a new mediapipe_aar() target in a BUILD
file. You need to figure out what calculators are used in the graph and
provide the calculator dependencies to the mediapipe_aar(). For example, to
build an AAR for [face detection gpu](./face_detection_mobile_gpu.md), you
can put the following code into
mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/BUILD.
```
load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_aar")
mediapipe_aar(
name = "mp_face_detection_aar",
calculators = ["//mediapipe/graphs/face_detection:mobile_calculators"],
)
```
2. Run the Bazel build command to generate the AAR.
```bash
bazel build -c opt --fat_apk_cpu=arm64-v8a,armeabi-v7a //path/to/the/aar/build/file:aar_name
```
For the face detection AAR target we made in the step 1, run:
```bash
bazel build -c opt --fat_apk_cpu=arm64-v8a,armeabi-v7a \
//mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mp_face_detection_aar
# It should print:
# Target //mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mp_face_detection_aar up-to-date:
# bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar
```
3. (Optional) Save the AAR to your preferred location.
```bash
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar
/absolute/path/to/your/preferred/location
```
### Steps to use a MediaPipe AAR in Android Studio with Gradle
1. Start Android Studio and go to your project.
2. Copy the AAR into app/libs.
```bash
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar
/path/to/your/app/libs/
```
![Screenshot](images/mobile/aar_location.png)
3. Make app/src/main/assets and copy assets (graph, model, and etc) into
app/src/main/assets.
Build the MediaPipe binary graph and copy the assets into
app/src/main/assets, e.g., for the face detection graph, you need to build
and copy
[the binary graph](https://github.com/google/mediapipe/blob/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/BUILD#L41),
[the tflite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_front.tflite),
and
[the label map](https://github.com/google/mediapipe/blob/master/mediapipe/models/face_detection_front_labelmap.txt).
```bash
bazel build -c opt mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu:binary_graph
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/facedetectiongpu.binarypb /path/to/your/app/src/main/assets/
cp mediapipe/models/face_detection_front.tflite /path/to/your/app/src/main/assets/
cp mediapipe/models/face_detection_front_labelmap.txt /path/to/your/app/src/main/assets/
```
![Screenshot](images/mobile/assets_location.png)
4. Make app/src/main/jniLibs and copy OpenCV JNI libraries into
app/src/main/jniLibs.
MediaPipe depends on OpenCV, you will need to copy the precompiled OpenCV so
files into app/src/main/jniLibs. You can download the official OpenCV
Android SDK from
[here](https://github.com/opencv/opencv/releases/download/4.1.0/opencv-4.1.0-android-sdk.zip)
and run:
```bash
cp -R ~/Downloads/OpenCV-android-sdk/sdk/native/libs/arm* /path/to/your/app/src/main/jniLibs/
```
![Screenshot](images/mobile/android_studio_opencv_location.png)
5. Modify app/build.gradle to add MediaPipe dependencies and MediaPipe AAR.
```
dependencies {
implementation fileTree(dir: 'libs', include: ['*.jar', '*.aar'])
implementation 'androidx.appcompat:appcompat:1.0.2'
implementation 'androidx.constraintlayout:constraintlayout:1.1.3'
testImplementation 'junit:junit:4.12'
androidTestImplementation 'androidx.test.ext:junit:1.1.0'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.1.1'
// MediaPipe deps
implementation 'com.google.flogger:flogger:0.3.1'
implementation 'com.google.flogger:flogger-system-backend:0.3.1'
implementation 'com.google.code.findbugs:jsr305:3.0.2'
implementation 'com.google.guava:guava:27.0.1-android'
implementation 'com.google.guava:guava:27.0.1-android'
// CameraX core library
def camerax_version = "1.0.0-alpha06"
implementation "androidx.camera:camera-core:$camerax_version"
implementation "androidx.camera:camera-camera2:$camerax_version"
}
```
6. Follow our Android app examples to use MediaPipe in Android Studio for your
use case. If you are looking for an example, a working face detection
example can be found
[here](https://github.com/jiuqiant/mediapipe_aar_example).

View File

@ -96,8 +96,9 @@ using the MediaPipe C++ APIs.
### Feature Extration for YouTube-8M Challenge ### Feature Extration for YouTube-8M Challenge
[Feature Extration for YouTube-8M Challenge](./youtube_8m.md) shows how to use [Feature Extration and Model Inference for YouTube-8M Challenge](./youtube_8m.md)
MediaPipe to prepare training data for the YouTube-8M Challenge. shows how to use MediaPipe to prepare training data for the YouTube-8M Challenge
and do the model inference with the baseline model.
### Preparing Data Sets with MediaSequence ### Preparing Data Sets with MediaSequence

View File

@ -36,10 +36,9 @@ $ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
# INFO: 711 processes: 710 linux-sandbox, 1 local. # INFO: 711 processes: 710 linux-sandbox, 1 local.
# INFO: Build completed successfully, 734 total actions # INFO: Build completed successfully, 734 total actions
$ export GLOG_logtostderr=1
# This will open up your webcam as long as it is connected and on # This will open up your webcam as long as it is connected and on
# Any errors is likely due to your webcam being not accessible # Any errors is likely due to your webcam being not accessible
$ bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_cpu \ $ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_cpu \
--calculator_graph_config_file=mediapipe/graphs/face_detection/face_detection_desktop_live.pbtxt --calculator_graph_config_file=mediapipe/graphs/face_detection/face_detection_desktop_live.pbtxt
``` ```
@ -60,11 +59,10 @@ $ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS \
# INFO: 711 processes: 710 linux-sandbox, 1 local. # INFO: 711 processes: 710 linux-sandbox, 1 local.
# INFO: Build completed successfully, 734 total actions # INFO: Build completed successfully, 734 total actions
$ export GLOG_logtostderr=1
# This will open up your webcam as long as it is connected and on # This will open up your webcam as long as it is connected and on
# Any errors is likely due to your webcam being not accessible, # Any errors is likely due to your webcam being not accessible,
# or GPU drivers not setup properly. # or GPU drivers not setup properly.
$ bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_gpu \ $ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_gpu \
--calculator_graph_config_file=mediapipe/graphs/face_detection/face_detection_mobile_gpu.pbtxt --calculator_graph_config_file=mediapipe/graphs/face_detection/face_detection_mobile_gpu.pbtxt
``` ```

View File

@ -35,11 +35,10 @@ $ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS \
#INFO: Streaming build results to: http://sponge2/37d5a184-293b-4e98-a43e-b22084db3142 #INFO: Streaming build results to: http://sponge2/37d5a184-293b-4e98-a43e-b22084db3142
#INFO: Build completed successfully, 12210 total actions #INFO: Build completed successfully, 12210 total actions
$ export GLOG_logtostderr=1
# This will open up your webcam as long as it is connected and on # This will open up your webcam as long as it is connected and on
# Any errors is likely due to your webcam being not accessible, # Any errors is likely due to your webcam being not accessible,
# or GPU drivers not setup properly. # or GPU drivers not setup properly.
$ bazel-bin/mediapipe/examples/desktop/hair_segmentation/hair_segmentation_gpu \ $ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hair_segmentation/hair_segmentation_gpu \
--calculator_graph_config_file=mediapipe/graphs/hair_segmentation/hair_segmentation_mobile_gpu.pbtxt --calculator_graph_config_file=mediapipe/graphs/hair_segmentation/hair_segmentation_mobile_gpu.pbtxt
``` ```

View File

@ -35,10 +35,9 @@ $ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
#INFO: Streaming build results to: http://sponge2/360196b9-33ab-44b1-84a7-1022b5043307 #INFO: Streaming build results to: http://sponge2/360196b9-33ab-44b1-84a7-1022b5043307
#INFO: Build completed successfully, 12517 total actions #INFO: Build completed successfully, 12517 total actions
$ export GLOG_logtostderr=1
# This will open up your webcam as long as it is connected and on # This will open up your webcam as long as it is connected and on
# Any errors is likely due to your webcam being not accessible # Any errors is likely due to your webcam being not accessible
$ bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_cpu \ $ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_cpu \
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop_live.pbtxt --calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop_live.pbtxt
``` ```
@ -59,11 +58,10 @@ $ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS \
#INFO: Streaming build results to: http://sponge2/00c7f95f-6fbc-432d-8978-f5d361efca3b #INFO: Streaming build results to: http://sponge2/00c7f95f-6fbc-432d-8978-f5d361efca3b
#INFO: Build completed successfully, 22455 total actions #INFO: Build completed successfully, 22455 total actions
$ export GLOG_logtostderr=1
# This will open up your webcam as long as it is connected and on # This will open up your webcam as long as it is connected and on
# Any errors is likely due to your webcam being not accessible, # Any errors is likely due to your webcam being not accessible,
# or GPU drivers not setup properly. # or GPU drivers not setup properly.
$ bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \ $ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt --calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt
``` ```

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

View File

@ -24,7 +24,8 @@ Choose your operating system:
To build and run Android apps: To build and run Android apps:
- [Setting up Android SDK and NDK](#setting-up-android-sdk-and-ndk) - [Setting up Android SDK and NDK](#setting-up-android-sdk-and-ndk)
- [Setting up Android Studio with MediaPipe](#setting-up-android-studio-with-mediapipe) - [Using MediaPipe with Gradle](#using-mediapipe-with-gradle)
- [Using MediaPipe with Bazel](#using-mediapipe-with-bazel)
To build and run iOS apps: To build and run iOS apps:
@ -41,19 +42,11 @@ To build and run iOS apps:
$ cd mediapipe $ cd mediapipe
``` ```
2. Install Bazel (0.24.1 and above required). 2. Install Bazel (version between 0.24.1 and 0.29.1).
Option 1. Use package manager tool to install the latest version of Bazel. Follow the official
```bash
$ sudo apt-get install bazel
# Run 'bazel version' to check version of bazel installed
```
Option 2. Follow Bazel's
[documentation](https://docs.bazel.build/versions/master/install-ubuntu.html) [documentation](https://docs.bazel.build/versions/master/install-ubuntu.html)
to install any version of Bazel manually. to install Bazel manually. Note that MediaPipe doesn't support Bazel 1.0.0+ yet.
3. Install OpenCV and FFmpeg. 3. Install OpenCV and FFmpeg.
@ -75,10 +68,10 @@ To build and run iOS apps:
[documentation](https://docs.opencv.org/3.4.6/d7/d9f/tutorial_linux_install.html) [documentation](https://docs.opencv.org/3.4.6/d7/d9f/tutorial_linux_install.html)
to manually build OpenCV from source code. to manually build OpenCV from source code.
Note: You may need to modify [`WORKSAPCE`] and [`opencv_linux.BUILD`] to Note: You may need to modify [`WORKSPACE`] and [`opencv_linux.BUILD`] to
point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed
in "/usr/local/", you need to update the "linux_opencv" new_local_repository in "/usr/local/", you need to update the "linux_opencv" new_local_repository
rule in [`WORKSAPCE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`] rule in [`WORKSPACE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
like the following: like the following:
```bash ```bash
@ -159,11 +152,11 @@ To build and run iOS apps:
$ cd mediapipe $ cd mediapipe
``` ```
2. Install Bazel (0.24.1 and above required). 2. Install Bazel (version between 0.24.1 and 0.29.1).
Follow Bazel's Follow the official
[documentation](https://docs.bazel.build/versions/master/install-redhat.html) [documentation](https://docs.bazel.build/versions/master/install-redhat.html)
to install Bazel manually. to install Bazel manually. Note that MediaPipe doesn't support Bazel 1.0.0+ yet.
3. Install OpenCV. 3. Install OpenCV.
@ -178,10 +171,10 @@ To build and run iOS apps:
Option 2. Build OpenCV from source code. Option 2. Build OpenCV from source code.
Note: You may need to modify [`WORKSAPCE`] and [`opencv_linux.BUILD`] to Note: You may need to modify [`WORKSPACE`] and [`opencv_linux.BUILD`] to
point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed
in "/usr/local/", you need to update the "linux_opencv" new_local_repository in "/usr/local/", you need to update the "linux_opencv" new_local_repository
rule in [`WORKSAPCE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`] rule in [`WORKSPACE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
like the following: like the following:
```bash ```bash
@ -237,7 +230,7 @@ To build and run iOS apps:
* Install [Homebrew](https://brew.sh). * Install [Homebrew](https://brew.sh).
* Install [Xcode](https://developer.apple.com/xcode/) and its Command Line * Install [Xcode](https://developer.apple.com/xcode/) and its Command Line
Tools. Tools by `xcode-select install`.
2. Checkout MediaPipe repository. 2. Checkout MediaPipe repository.
@ -247,19 +240,24 @@ To build and run iOS apps:
$ cd mediapipe $ cd mediapipe
``` ```
3. Install Bazel (0.24.1 and above required). 3. Install Bazel (version between 0.24.1 and 0.29.1).
Option 1. Use package manager tool to install the latest version of Bazel. Option 1. Use package manager tool to install Bazel 0.29.1
```bash ```bash
$ brew install bazel # If Bazel 1.0.0+ was installed.
$ brew uninstall bazel
# Install Bazel 0.29.1
$ brew install https://raw.githubusercontent.com/bazelbuild/homebrew-tap/223ffb570c21c0a2af251afc6df9dec0214c6e74/Formula/bazel.rb
$ brew link bazel
# Run 'bazel version' to check version of bazel installed # Run 'bazel version' to check version of bazel installed
``` ```
Option 2. Follow Bazel's Option 2. Follow the official
[documentation](https://docs.bazel.build/versions/master/install-os-x.html#install-with-installer-mac-os-x) [documentation](https://docs.bazel.build/versions/master/install-os-x.html#install-with-installer-mac-os-x)
to install any version of Bazel manually. to install Bazel manually. Note that MediaPipe doesn't support Bazel 1.0.0+ yet.
4. Install OpenCV and FFmpeg. 4. Install OpenCV and FFmpeg.
@ -281,7 +279,7 @@ To build and run iOS apps:
$ port install opencv $ port install opencv
``` ```
Note: when using MacPorts, please edit the [`WORKSAPCE`], Note: when using MacPorts, please edit the [`WORKSPACE`],
[`opencv_macos.BUILD`], and [`ffmpeg_macos.BUILD`] files like the following: [`opencv_macos.BUILD`], and [`ffmpeg_macos.BUILD`] files like the following:
```bash ```bash
@ -419,10 +417,10 @@ To build and run iOS apps:
[documentation](https://docs.opencv.org/3.4.6/d7/d9f/tutorial_linux_install.html) [documentation](https://docs.opencv.org/3.4.6/d7/d9f/tutorial_linux_install.html)
to manually build OpenCV from source code. to manually build OpenCV from source code.
Note: You may need to modify [`WORKSAPCE`] and [`opencv_linux.BUILD`] to Note: You may need to modify [`WORKSPACE`] and [`opencv_linux.BUILD`] to
point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed
in "/usr/local/", you need to update the "linux_opencv" new_local_repository in "/usr/local/", you need to update the "linux_opencv" new_local_repository
rule in [`WORKSAPCE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`] rule in [`WORKSPACE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
like the following: like the following:
```bash ```bash
@ -589,10 +587,20 @@ Please verify all the necessary packages are installed.
* Android SDK Tools 26.1.1 * Android SDK Tools 26.1.1
* Android NDK 17c or above * Android NDK 17c or above
### Setting up Android Studio with MediaPipe ### Using MediaPipe with Gradle
The steps below use Android Studio 3.5 to build and install a MediaPipe example MediaPipe can be used within an existing project, such as a Gradle project,
app. using the MediaPipe AAR target defined in mediapipe_aar.bzl. Please see the
separate [MediaPipe Android Archive Library](./android_archive_library.md)
documentation.
### Using MediaPipe with Bazel
The MediaPipe project can be imported to Android Studio using the Bazel plugins.
This allows the MediaPipe examples and demos to be built and modified in Android
Studio. To incorporate MediaPipe into an existing Android Studio project, see:
"Using MediaPipe with Gradle". The steps below use Android Studio 3.5 to build
and install a MediaPipe example app.
1. Install and launch Android Studio 3.5. 1. Install and launch Android Studio 3.5.
@ -682,7 +690,7 @@ app.
* Press the `[+]` button to add the new configuration. * Press the `[+]` button to add the new configuration.
* Select `Run` to run the example app on the connected Android device. * Select `Run` to run the example app on the connected Android device.
[`WORKSAPCE`]: https://github.com/google/mediapipe/tree/master/WORKSPACE [`WORKSPACE`]: https://github.com/google/mediapipe/tree/master/WORKSPACE
[`opencv_linux.BUILD`]: https://github.com/google/mediapipe/tree/master/third_party/opencv_linux.BUILD [`opencv_linux.BUILD`]: https://github.com/google/mediapipe/tree/master/third_party/opencv_linux.BUILD
[`opencv_macos.BUILD`]: https://github.com/google/mediapipe/tree/master/third_party/opencv_macos.BUILD [`opencv_macos.BUILD`]: https://github.com/google/mediapipe/tree/master/third_party/opencv_macos.BUILD
[`ffmpeg_macos.BUILD`]:https://github.com/google/mediapipe/tree/master/third_party/ffmpeg_macos.BUILD [`ffmpeg_macos.BUILD`]:https://github.com/google/mediapipe/tree/master/third_party/ffmpeg_macos.BUILD

View File

@ -35,10 +35,9 @@ $ bazel build -c opt \
# INFO: 2675 processes: 2673 linux-sandbox, 2 local. # INFO: 2675 processes: 2673 linux-sandbox, 2 local.
# INFO: Build completed successfully, 2807 total actions # INFO: Build completed successfully, 2807 total actions
$ export GLOG_logtostderr=1
# Replace <input video path> and <output video path>. # Replace <input video path> and <output video path>.
# You can find a test video in mediapipe/examples/desktop/object_detection. # You can find a test video in mediapipe/examples/desktop/object_detection.
$ bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tensorflow \ $ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tensorflow \
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tensorflow_graph.pbtxt \ --calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tensorflow_graph.pbtxt \
--input_side_packets=input_video_path=<input video path>,output_video_path=<output video path> --input_side_packets=input_video_path=<input video path>,output_video_path=<output video path>
``` ```
@ -200,10 +199,9 @@ $ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
# INFO: 711 processes: 710 linux-sandbox, 1 local. # INFO: 711 processes: 710 linux-sandbox, 1 local.
# INFO: Build completed successfully, 734 total actions # INFO: Build completed successfully, 734 total actions
$ export GLOG_logtostderr=1
# Replace <input video path> and <output video path>. # Replace <input video path> and <output video path>.
# You can find a test video in mediapipe/examples/desktop/object_detection. # You can find a test video in mediapipe/examples/desktop/object_detection.
$ bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tflite \ $ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tflite \
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tflite_graph.pbtxt \ --calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tflite_graph.pbtxt \
--input_side_packets=input_video_path=<input video path>,output_video_path=<output video path> --input_side_packets=input_video_path=<input video path>,output_video_path=<output video path>
``` ```
@ -224,10 +222,9 @@ $ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
#INFO: Streaming build results to: http://sponge2/1824d4cc-ba63-4350-bdc0-aacbd45b902b #INFO: Streaming build results to: http://sponge2/1824d4cc-ba63-4350-bdc0-aacbd45b902b
#INFO: Build completed successfully, 12154 total actions #INFO: Build completed successfully, 12154 total actions
$ export GLOG_logtostderr=1
# This will open up your webcam as long as it is connected and on # This will open up your webcam as long as it is connected and on
# Any errors is likely due to your webcam being not accessible # Any errors is likely due to your webcam being not accessible
$ bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_cpu \ $ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_cpu \
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_live.pbtxt --calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_live.pbtxt
``` ```

View File

@ -1,9 +1,11 @@
## Extracting Video Features for YouTube-8M Challenge # Feature Extration and Model Inference for YouTube-8M Challenge
MediaPipe is a useful and general framework for media processing that can assist MediaPipe is a useful and general framework for media processing that can assist
with research, development, and deployment of ML models. This example focuses on with research, development, and deployment of ML models. This example focuses on
model development by demonstrating how to prepare training data for the model development by demonstrating how to prepare training data and do model
YouTube-8M Challenge. inference for the YouTube-8M Challenge.
## Extracting Video Features for YouTube-8M Challenge
[Youtube-8M Challenge](https://www.kaggle.com/c/youtube8m-2019) is an annual [Youtube-8M Challenge](https://www.kaggle.com/c/youtube8m-2019) is an annual
video classification challenge hosted by Google. Over the last two years, the video classification challenge hosted by Google. Over the last two years, the
@ -29,14 +31,14 @@ videos.
### Steps to run the YouTube-8M feature extraction graph ### Steps to run the YouTube-8M feature extraction graph
1. Checkout the mediapipe repository 1. Checkout the mediapipe repository.
```bash ```bash
git clone https://github.com/google/mediapipe.git git clone https://github.com/google/mediapipe.git
cd mediapipe cd mediapipe
``` ```
2. Download the PCA and model data 2. Download the PCA and model data.
```bash ```bash
mkdir /tmp/mediapipe mkdir /tmp/mediapipe
@ -49,7 +51,7 @@ videos.
tar -xvf /tmp/mediapipe/inception-2015-12-05.tgz tar -xvf /tmp/mediapipe/inception-2015-12-05.tgz
``` ```
3. Get the VGGish frozen graph 3. Get the VGGish frozen graph.
Note: To run step 3 and step 4, you must have Python 2.7 or 3.5+ installed Note: To run step 3 and step 4, you must have Python 2.7 or 3.5+ installed
with the TensorFlow 1.14+ package installed. with the TensorFlow 1.14+ package installed.
@ -60,24 +62,103 @@ videos.
python -m mediapipe.examples.desktop.youtube8m.generate_vggish_frozen_graph python -m mediapipe.examples.desktop.youtube8m.generate_vggish_frozen_graph
``` ```
4. Generate a MediaSequence metadata from the input video 4. Generate a MediaSequence metadata from the input video.
Note: the output file is /tmp/mediapipe/metadata.tfrecord Note: the output file is /tmp/mediapipe/metadata.tfrecord
```bash ```bash
# change clip_end_time_sec to match the length of your video.
python -m mediapipe.examples.desktop.youtube8m.generate_input_sequence_example \ python -m mediapipe.examples.desktop.youtube8m.generate_input_sequence_example \
--path_to_input_video=/absolute/path/to/the/local/video/file --path_to_input_video=/absolute/path/to/the/local/video/file \
--clip_end_time_sec=120
``` ```
5. Run the MediaPipe binary to extract the features 5. Run the MediaPipe binary to extract the features.
```bash ```bash
bazel build -c opt \ bazel build -c opt \
--define MEDIAPIPE_DISABLE_GPU=1 --define no_aws_support=true \ --define MEDIAPIPE_DISABLE_GPU=1 --define no_aws_support=true \
mediapipe/examples/desktop/youtube8m:extract_yt8m_features mediapipe/examples/desktop/youtube8m:extract_yt8m_features
./bazel-bin/mediapipe/examples/desktop/youtube8m/extract_yt8m_features GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/youtube8m/extract_yt8m_features \
--calculator_graph_config_file=mediapipe/graphs/youtube8m/feature_extraction.pbtxt \ --calculator_graph_config_file=mediapipe/graphs/youtube8m/feature_extraction.pbtxt \
--input_side_packets=input_sequence_example=/tmp/mediapipe/metadata.tfrecord \ --input_side_packets=input_sequence_example=/tmp/mediapipe/metadata.tfrecord \
--output_side_packets=output_sequence_example=/tmp/mediapipe/output.tfrecord --output_side_packets=output_sequence_example=/tmp/mediapipe/output.tfrecord
``` ```
## Model Inference for YouTube-8M Challenge
MediaPipe can help you do model inference for YouTube-8M Challenge with both
local videos and the YouTube-8M dataset. To visualize
[the graph for local videos](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/youtube8m/local_video_model_inference.pbtxt)
and
[the graph for the YouTube-8M dataset](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/youtube8m/yt8m_dataset_model_inference.pbtxt),
copy the text specification of the graph and paste it into
[MediaPipe Visualizer](https://viz.mediapipe.dev/). We use the baseline model
[(model card)](https://drive.google.com/file/d/1xTCi9-Nm9dt2KIk8WR0dDFrIssWawyXy/view)
in our example. But, the model inference pipeline is highly customizable. You
are welcome to add new calculators or use your own machine learning models to do
the inference for both local videos and the dataset
### Steps to run the YouTube-8M model inference graph with Web Interface
1. Copy the baseline model
[(model card)](https://drive.google.com/file/d/1xTCi9-Nm9dt2KIk8WR0dDFrIssWawyXy/view)
to local.
```bash
curl -o /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz data.yt8m.org/models/baseline/saved_model.tar.gz
tar -xvf /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz -C /tmp/mediapipe
```
2. Build the inference binary.
```bash
bazel build -c opt --define='MEDIAPIPE_DISABLE_GPU=1' \
mediapipe/examples/desktop/youtube8m:model_inference
```
3. Run the python web server.
Note: pip install absl-py
```bash
python mediapipe/examples/desktop/youtube8m/viewer/server.py --root `pwd`
```
Navigate to localhost:8008 in a web browser.
[Here](https://drive.google.com/file/d/19GSvdAAuAlACpBhHOaqMWZ_9p8bLUYKh/view?usp=sharing)
is a demo video showing the steps to use this web application. Also please
read
[youtube8m/README.md](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/youtube8m/README.md)
if you prefer to run the underlying model_inference binary in command line.
### Steps to run the YouTube-8M model inference graph with a local video
1. Make sure you have the output tfrecord from the feature extraction pipeline.
2. Copy the baseline model
[(model card)](https://drive.google.com/file/d/1xTCi9-Nm9dt2KIk8WR0dDFrIssWawyXy/view)
to local.
```bash
curl -o /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz data.yt8m.org/models/baseline/saved_model.tar.gz
tar -xvf /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz -C /tmp/mediapipe
```
3. Build and run the inference binary.
```bash
bazel build -c opt --define='MEDIAPIPE_DISABLE_GPU=1' \
mediapipe/examples/desktop/youtube8m:model_inference
# segment_size is the number of seconds window of frames.
# overlap is the number of seconds adjacent segments share.
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/youtube8m/model_inference \
--calculator_graph_config_file=mediapipe/graphs/youtube8m/local_video_model_inference.pbtxt \
--input_side_packets=input_sequence_example_path=/tmp/mediapipe/output.tfrecord,input_video_path=/absolute/path/to/the/local/video/file,output_video_path=/tmp/mediapipe/annotated_video.mp4,segment_size=5,overlap=4
```
4. View the annotated video.

View File

@ -27,7 +27,9 @@ cc_library(
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:map_util", "//mediapipe/framework/port:map_util",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )

View File

@ -13,14 +13,23 @@
// limitations under the License. // limitations under the License.
// //
// A simple main function to run a MediaPipe graph. // A simple main function to run a MediaPipe graph.
#include <fstream>
#include <iostream>
#include <map>
#include <string>
#include <vector>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/map_util.h" #include "mediapipe/framework/port/map_util.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h"
DEFINE_string( DEFINE_string(
calculator_graph_config_file, "", calculator_graph_config_file, "",
@ -31,14 +40,72 @@ DEFINE_string(input_side_packets, "",
"for the CalculatorGraph. All values will be treated as the " "for the CalculatorGraph. All values will be treated as the "
"string type even if they represent doubles, floats, etc."); "string type even if they represent doubles, floats, etc.");
// Local file output flags.
// Output stream
DEFINE_string(output_stream, "",
"The output stream to output to the local file in csv format.");
DEFINE_string(output_stream_file, "",
"The name of the local file to output all packets sent to "
"the stream specified with --output_stream. ");
DEFINE_bool(strip_timestamps, false,
"If true, only the packet contents (without timestamps) will be "
"written into the local file.");
// Output side packets
DEFINE_string(output_side_packets, "",
"A CSV of output side packets to output to local file.");
DEFINE_string(output_side_packets_file, "",
"The name of the local file to output all side packets specified "
"with --output_side_packets. ");
::mediapipe::Status OutputStreamToLocalFile(
::mediapipe::OutputStreamPoller& poller) {
std::ofstream file;
file.open(FLAGS_output_stream_file);
::mediapipe::Packet packet;
while (poller.Next(&packet)) {
std::string output_data;
if (!FLAGS_strip_timestamps) {
absl::StrAppend(&output_data, packet.Timestamp().Value(), ",");
}
absl::StrAppend(&output_data, packet.Get<std::string>(), "\n");
file << output_data;
}
file.close();
return ::mediapipe::OkStatus();
}
::mediapipe::Status OutputSidePacketsToLocalFile(
::mediapipe::CalculatorGraph& graph) {
if (!FLAGS_output_side_packets.empty() &&
!FLAGS_output_side_packets_file.empty()) {
std::ofstream file;
file.open(FLAGS_output_side_packets_file);
std::vector<std::string> side_packet_names =
absl::StrSplit(FLAGS_output_side_packets, ',');
for (const std::string& side_packet_name : side_packet_names) {
ASSIGN_OR_RETURN(auto status_or_packet,
graph.GetOutputSidePacket(side_packet_name));
file << absl::StrCat(side_packet_name, ":",
status_or_packet.Get<std::string>(), "\n");
}
file.close();
} else {
RET_CHECK(FLAGS_output_side_packets.empty() &&
FLAGS_output_side_packets_file.empty())
<< "--output_side_packets and --output_side_packets_file should be "
"specified in pair.";
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status RunMPPGraph() { ::mediapipe::Status RunMPPGraph() {
std::string calculator_graph_config_contents; std::string calculator_graph_config_contents;
MP_RETURN_IF_ERROR(mediapipe::file::GetContents( MP_RETURN_IF_ERROR(::mediapipe::file::GetContents(
FLAGS_calculator_graph_config_file, &calculator_graph_config_contents)); FLAGS_calculator_graph_config_file, &calculator_graph_config_contents));
LOG(INFO) << "Get calculator graph config contents: " LOG(INFO) << "Get calculator graph config contents: "
<< calculator_graph_config_contents; << calculator_graph_config_contents;
mediapipe::CalculatorGraphConfig config = ::mediapipe::CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig>( ::mediapipe::ParseTextProtoOrDie<::mediapipe::CalculatorGraphConfig>(
calculator_graph_config_contents); calculator_graph_config_contents);
std::map<std::string, ::mediapipe::Packet> input_side_packets; std::map<std::string, ::mediapipe::Packet> input_side_packets;
std::vector<std::string> kv_pairs = std::vector<std::string> kv_pairs =
@ -51,10 +118,23 @@ DEFINE_string(input_side_packets, "",
::mediapipe::MakePacket<std::string>(name_and_value[1]); ::mediapipe::MakePacket<std::string>(name_and_value[1]);
} }
LOG(INFO) << "Initialize the calculator graph."; LOG(INFO) << "Initialize the calculator graph.";
mediapipe::CalculatorGraph graph; ::mediapipe::CalculatorGraph graph;
MP_RETURN_IF_ERROR(graph.Initialize(config, input_side_packets)); MP_RETURN_IF_ERROR(graph.Initialize(config, input_side_packets));
if (!FLAGS_output_stream.empty() && !FLAGS_output_stream_file.empty()) {
ASSIGN_OR_RETURN(auto poller,
graph.AddOutputStreamPoller(FLAGS_output_stream));
LOG(INFO) << "Start running the calculator graph."; LOG(INFO) << "Start running the calculator graph.";
return graph.Run(); MP_RETURN_IF_ERROR(graph.StartRun({}));
MP_RETURN_IF_ERROR(OutputStreamToLocalFile(poller));
} else {
RET_CHECK(FLAGS_output_stream.empty() && FLAGS_output_stream_file.empty())
<< "--output_stream and --output_stream_file should be specified in "
"pair.";
LOG(INFO) << "Start running the calculator graph.";
MP_RETURN_IF_ERROR(graph.StartRun({}));
}
MP_RETURN_IF_ERROR(graph.WaitUntilDone());
return OutputSidePacketsToLocalFile(graph);
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {

View File

@ -33,3 +33,14 @@ cc_binary(
"@org_tensorflow//tensorflow/core:direct_session", "@org_tensorflow//tensorflow/core:direct_session",
], ],
) )
cc_binary(
name = "model_inference",
deps = [
"//mediapipe/examples/desktop:simple_run_graph_main",
"//mediapipe/graphs/youtube8m:yt8m_inference_calculators_deps",
# TODO: Figure out the minimum set of the kernels needed by this example.
"@org_tensorflow//tensorflow/core:all_kernels",
"@org_tensorflow//tensorflow/core:direct_session",
],
)

View File

@ -1,13 +1,13 @@
### Steps to run the YouTube-8M feature extraction graph ### Steps to run the YouTube-8M feature extraction graph
1. Checkout the mediapipe repository 1. Checkout the mediapipe repository.
```bash ```bash
git clone https://github.com/google/mediapipe.git git clone https://github.com/google/mediapipe.git
cd mediapipe cd mediapipe
``` ```
2. Download the PCA and model data 2. Download the PCA and model data.
```bash ```bash
mkdir /tmp/mediapipe mkdir /tmp/mediapipe
@ -20,7 +20,7 @@
tar -xvf /tmp/mediapipe/inception-2015-12-05.tgz tar -xvf /tmp/mediapipe/inception-2015-12-05.tgz
``` ```
3. Get the VGGish frozen graph 3. Get the VGGish frozen graph.
Note: To run step 3 and step 4, you must have Python 2.7 or 3.5+ installed Note: To run step 3 and step 4, you must have Python 2.7 or 3.5+ installed
with the TensorFlow 1.14+ package installed. with the TensorFlow 1.14+ package installed.
@ -31,26 +31,114 @@
python -m mediapipe.examples.desktop.youtube8m.generate_vggish_frozen_graph python -m mediapipe.examples.desktop.youtube8m.generate_vggish_frozen_graph
``` ```
4. Generate a MediaSequence metadata from the input video 4. Generate a MediaSequence metadata from the input video.
Note: the output file is /tmp/mediapipe/metadata.tfrecord Note: the output file is /tmp/mediapipe/metadata.tfrecord
```bash ```bash
# change clip_end_time_sec to match the length of your video.
python -m mediapipe.examples.desktop.youtube8m.generate_input_sequence_example \ python -m mediapipe.examples.desktop.youtube8m.generate_input_sequence_example \
--path_to_input_video=/absolute/path/to/the/local/video/file \ --path_to_input_video=/absolute/path/to/the/local/video/file \
--clip_start_time_sec=0 \ --clip_end_time_sec=120
--clip_end_time_sec=10
``` ```
5. Run the MediaPipe binary to extract the features 5. Run the MediaPipe binary to extract the features.
```bash ```bash
bazel build -c opt \ bazel build -c opt \
--define MEDIAPIPE_DISABLE_GPU=1 --define no_aws_support=true \ --define MEDIAPIPE_DISABLE_GPU=1 --define no_aws_support=true \
mediapipe/examples/desktop/youtube8m:extract_yt8m_features mediapipe/examples/desktop/youtube8m:extract_yt8m_features
./bazel-bin/mediapipe/examples/desktop/youtube8m/extract_yt8m_features \ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/youtube8m/extract_yt8m_features \
--calculator_graph_config_file=mediapipe/graphs/youtube8m/feature_extraction.pbtxt \ --calculator_graph_config_file=mediapipe/graphs/youtube8m/feature_extraction.pbtxt \
--input_side_packets=input_sequence_example=/tmp/mediapipe/metadata.tfrecord \ --input_side_packets=input_sequence_example=/tmp/mediapipe/metadata.tfrecord \
--output_side_packets=output_sequence_example=/tmp/mediapipe/output.tfrecord --output_side_packets=output_sequence_example=/tmp/mediapipe/output.tfrecord
``` ```
### Steps to run the YouTube-8M inference graph with the YT8M dataset
1. Download the YT8M dataset
For example, download one shard of the training data:
```bash
curl http://us.data.yt8m.org/2/frame/train/trainpj.tfrecord --output /tmp/mediapipe/trainpj.tfrecord
```
2. Copy the baseline model [(model card)](https://drive.google.com/file/d/1xTCi9-Nm9dt2KIk8WR0dDFrIssWawyXy/view) to local.
```bash
curl -o /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz data.yt8m.org/models/baseline/saved_model.tar.gz
tar -xvf /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz -C /tmp/mediapipe
```
3. Build and run the inference binary.
```bash
bazel build -c opt --define='MEDIAPIPE_DISABLE_GPU=1' \
mediapipe/examples/desktop/youtube8m:model_inference
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/youtube8m/model_inference \
--calculator_graph_config_file=mediapipe/graphs/youtube8m/yt8m_dataset_model_inference.pbtxt \
--input_side_packets=tfrecord_path=/tmp/mediapipe/trainpj.tfrecord,record_index=0,desired_segment_size=5 \
--output_stream=annotation_summary \
--output_stream_file=/tmp/summary \
--output_side_packets=yt8m_id \
--output_side_packets_file=/tmp/yt8m_id
```
### Steps to run the YouTube-8M model inference graph with Web Interface
1. Copy the baseline model [(model card)](https://drive.google.com/file/d/1xTCi9-Nm9dt2KIk8WR0dDFrIssWawyXy/view) to local.
```bash
curl -o /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz data.yt8m.org/models/baseline/saved_model.tar.gz
tar -xvf /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz -C /tmp/mediapipe
```
2. Build the inference binary.
```bash
bazel build -c opt --define='MEDIAPIPE_DISABLE_GPU=1' \
mediapipe/examples/desktop/youtube8m:model_inference
```
3. Run the python web server.
Note: pip install absl-py
```bash
python mediapipe/examples/desktop/youtube8m/viewer/server.py --root `pwd`
```
Navigate to localhost:8008 in a web browser.
### Steps to run the YouTube-8M model inference graph with a local video
1. Make sure you have the output tfrecord from the feature extraction pipeline.
2. Copy the baseline model [(model card)](https://drive.google.com/file/d/1xTCi9-Nm9dt2KIk8WR0dDFrIssWawyXy/view) to local.
```bash
curl -o /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz data.yt8m.org/models/baseline/saved_model.tar.gz
tar -xvf /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz -C /tmp/mediapipe
```
3. Build and run the inference binary.
```bash
bazel build -c opt --define='MEDIAPIPE_DISABLE_GPU=1' \
mediapipe/examples/desktop/youtube8m:model_inference
# segment_size is the number of seconds window of frames.
# overlap is the number of seconds adjacent segments share.
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/youtube8m/model_inference \
--calculator_graph_config_file=mediapipe/graphs/youtube8m/local_video_model_inference.pbtxt \
--input_side_packets=input_sequence_example_path=/tmp/mediapipe/output.tfrecord,input_video_path=/absolute/path/to/the/local/video/file,output_video_path=/tmp/mediapipe/annotated_video.mp4,segment_size=5,overlap=4
```
4. View the annotated video.

View File

@ -0,0 +1,262 @@
"""Server for YouTube8M Model Inference Demo.
Serves up both the static files for the website and provides a service that
fetches the video id and timestamp based labels for a video analyzed in a
tfrecord files.
"""
from __future__ import print_function
import json
import os
import re
import socket
import subprocess
import sys
from absl import app
from absl import flags
import http.client
import http.server
from six.moves.urllib import parse
FLAGS = flags.FLAGS
flags.DEFINE_bool("show_label_at_center", False,
"Show labels at the center of the segment.")
flags.DEFINE_integer("port", 8008, "Port that the API is served over.")
flags.DEFINE_string("tmp_dir", "/tmp/mediapipe",
"Temporary asset storage location.")
flags.DEFINE_string("root", "", "MediaPipe root directory.")
# binary, pbtxt, label_map paths are relative to 'root' path
flags.DEFINE_string(
"binary",
"bazel-bin/mediapipe/examples/desktop/youtube8m/model_inference",
"Inference binary location.")
flags.DEFINE_string(
"pbtxt",
"mediapipe/graphs/youtube8m/yt8m_dataset_model_inference.pbtxt",
"Default pbtxt graph file.")
flags.DEFINE_string("label_map", "mediapipe/graphs/youtube8m/label_map.txt",
"Default label map text file.")
class HTTPServerV6(http.server.HTTPServer):
address_family = socket.AF_INET6
class Youtube8MRequestHandler(http.server.SimpleHTTPRequestHandler):
"""Static file server with /healthz support."""
def do_GET(self):
if self.path.startswith("/healthz"):
self.send_response(200)
self.send_header("Content-type", "text/plain")
self.send_header("Content-length", 2)
self.end_headers()
self.wfile.write("ok")
if self.path.startswith("/video"):
parsed_params = parse.urlparse(self.path)
url_params = parse.parse_qs(parsed_params.query)
tfrecord_path = ""
segment_size = 5
print(url_params)
if "file" in url_params:
tfrecord_path = url_params["file"][0]
if "segments" in url_params:
segment_size = int(url_params["segments"][0])
self.fetch(tfrecord_path, segment_size)
else:
if self.path == "/":
self.path = "/index.html"
# Default to serve up a local file
self.path = "/static" + self.path
http.server.SimpleHTTPRequestHandler.do_GET(self)
def report_error(self, msg):
"""Simplifies sending out a string as a 500 http response."""
self.send_response(500)
self.send_header("Content-type", "text/plain")
self.end_headers()
if sys.version_info[0] < 3:
self.wfile.write(str(msg).encode("utf-8"))
else:
self.wfile.write(bytes(msg, "utf-8"))
def report_missing_files(self, files):
"""Sends out 500 response with missing files."""
accumulate = ""
for file_path in files:
if not os.path.exists(file_path):
accumulate = "%s '%s'" % (accumulate, file_path)
if accumulate:
self.report_error("Could not find:%s" % accumulate)
return True
return False
def fetch(self, path, segment_size):
"""Returns the video id and labels for a tfrecord at a provided index."""
print("Received request. File=", path, "Segment Size =", segment_size)
if (self.report_missing_files([
"%s/%s" % (FLAGS.root, FLAGS.pbtxt),
"%s/%s" % (FLAGS.root, FLAGS.binary),
"%s/%s" % (FLAGS.root, FLAGS.label_map)
])):
return
# Parse the youtube video id off the end of the link or as a standalone id.
filename_match = re.match(
"(?:.*youtube.*v=)?([a-zA-Z-0-9_]{2})([a-zA-Z-0-9_]+)", path)
tfrecord_url = filename_match.expand(r"data.yt8m.org/2/j/r/\1/\1\2.js")
print("Trying to get tfrecord via", tfrecord_url)
connection = http.client.HTTPConnection("data.yt8m.org")
connection.request("GET", tfrecord_url)
response = connection.getresponse()
response_object = json.loads(response.read())
filename = response_object["filename_raw"]
index = response_object["index"]
print("TFRecord discovered: ", filename, ", index", index)
output_file = r"%s/%s" % (FLAGS.tmp_dir, filename)
tfrecord_url = r"http://us.data.yt8m.org/2/frame/train/%s" % filename
connection = http.client.HTTPConnection("us.data.yt8m.org")
connection.request("HEAD",
filename_match.expand(r"/2/frame/train/%s" % filename))
response = connection.getresponse()
if response.getheader("Content-Type") != "application/octet-stream":
self.report_error("Filename '%s' is invalid." % path)
print(output_file, "exists on yt8m.org. Did we fetch this before?")
if not os.path.exists(output_file):
print(output_file, "doesn't exist locally, download it now.")
return_code = subprocess.call(
["curl", "--output", output_file, tfrecord_url],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
if return_code:
self.report_error("Could not retrieve contents from %s" % tfrecord_url)
return
else:
print(output_file, "exist locally, reuse it.")
print("Run the graph...")
process = subprocess.Popen([
"%s/%s" % (FLAGS.root, FLAGS.binary),
"--calculator_graph_config_file=%s/%s" % (FLAGS.root, FLAGS.pbtxt),
"--input_side_packets=tfrecord_path=%s" % output_file +
",record_index=%d" % index + ",desired_segment_size=%d" % segment_size,
"--output_stream=annotation_summary",
"--output_stream_file=%s/labels" % FLAGS.tmp_dir,
"--output_side_packets=yt8m_id",
"--output_side_packets_file=%s/yt8m_id" % FLAGS.tmp_dir
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
stdout_str, stderr_str = process.communicate()
process.wait()
if stderr_str and "success" not in str(stderr_str).lower():
self.report_error("Error executing server binary: \n%s" % stderr_str)
return
f = open("%s/yt8m_id" % FLAGS.tmp_dir, "r")
contents = f.read()
print("yt8m_id is", contents[-5:-1])
curl_arg = "data.yt8m.org/2/j/i/%s/%s.js" % (contents[-5:-3],
contents[-5:-1])
print("Grab labels from", curl_arg)
process = subprocess.Popen(["curl", curl_arg],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
stdout = process.communicate()
process.wait()
stdout_str = stdout[0].decode("utf-8")
match = re.match(""".+"([^"]+)"[^"]+""", stdout_str)
final_results = {
"video_id": match.group(1),
"link": "https://www.youtube.com/watch?v=%s" % match.group(1),
"entries": []
}
f = open("%s/labels" % FLAGS.tmp_dir, "r")
lines = f.readlines()
show_at_center = FLAGS.show_label_at_center
print("%s/labels" % FLAGS.tmp_dir, "holds", len(lines), "entries")
for line in lines:
entry = {"labels": []}
final_results["entries"].append(entry)
first = True
for column in line.split(","):
if first:
subtract = segment_size / 2.0 if show_at_center else 0.0
entry["time"] = float(int(column)) / 1000000.0 - subtract
first = False
else:
label_score = re.match("(.+):([0-9.]+).*", column)
if label_score:
score = float(label_score.group(2))
entry["labels"].append({
"label": label_score.group(1),
"score": score
})
else:
print("empty score")
response_json = json.dumps(final_results, indent=2, separators=(",", ": "))
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
if sys.version_info[0] < 3:
self.wfile.write(str(response_json).encode("utf-8"))
else:
self.wfile.write(bytes(response_json, "utf-8"))
def update_pbtxt():
"""Update graph.pbtxt to use full path to label_map.txt."""
edited_line = ""
lines = []
with open("%s/%s" % (FLAGS.root, FLAGS.pbtxt), "r") as f:
lines = f.readlines()
for line in lines:
if "label_map_path" in line:
kv = line.split(":")
edited_line = kv[0] + (": \"%s/%s\"\n" % (FLAGS.root, FLAGS.label_map))
with open("%s/%s" % (FLAGS.root, FLAGS.pbtxt), "w") as f:
for line in lines:
if "label_map_path" in line:
f.write(edited_line)
else:
f.write(line)
def main(unused_args):
dname = os.path.dirname(os.path.abspath(__file__))
os.chdir(dname)
if not FLAGS.root:
print("Must specify MediaPipe root directory: --root `pwd`")
return
update_pbtxt()
port = FLAGS.port
print("Listening on port %s" % port) # pylint: disable=superfluous-parens
server = HTTPServerV6(("::", int(port)), Youtube8MRequestHandler)
server.serve_forever()
if __name__ == "__main__":
app.run(main)

View File

@ -0,0 +1,96 @@
<!doctype html>
<html lang="en">
<head>
<title>MediaPipe: YouTube8M Model Inference Demo</title>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
<base href="/">
<script src="main.js"></script>
<link rel="stylesheet"
href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css"
integrity="sha384-ggOyR0iXCbMQv3Xipma34MD+dH/1fQ784/j6cY/iJTQUOhcWr7x9JvoRxT2MZw1T"
crossorigin="anonymous">
</head>
<body>
<div class="container-fluid">
<h2>
MediaPipe: YouTube8M Model Inference Demo
</h2>
<form id="form">
<div class="row">
<div class="card m-2" style="width: 640px;">
<div>
<div style="position:relative;">
<iframe id="ytplayer" style="display:none;" type="text/html" width="640" height="320"
src="https://www.youtube.com/embed/M7lc1UVf-VE?enablejsapi=1" frameborder="0"
enablejsapi="1"></iframe>
<div id="cover" class="bg-warning"
style="width:640px; height:320px;">
</div>
<div id="spinner" class="bg-warning"
style="display: none; width:640px; height:320px;">
<div class="spinner-border" role="status"
style="position:relative; left:300px; top:130px;">
<span class="sr-only">Loading...</span>
</div>
</div>
</div>
</div>
<div class="card-body shadow">
<div class="row mb-2">
<ul class="nav">
<li class="nav-item">
<a class="nav-link"
href="https://research.google.com/youtube8m/explore.html"
target="_">Explore Videos</a>
</li>
</ul>
</div>
<div class="form-group">
<label for="file">YouTube video ID</label>
<input type="text" class="form-control" name="file" id="file"
placeholder="Enter a YouTube link or a YouTube ID">
<small class="form-text text-muted">
e.g., Both "https://youtube.com/watch?v=huGVGe3Afng" or "huGVGe3Afng" will work.
</small>
</div>
<div class="form-group">
<label id="segments_label" for="segments">Segment Size</label>
<input type="range" min="1" max="300" step="1" value="5"
class="form-control-range" name="segments" id="segments">
</div>
<button type="submit" class="btn btn-primary">Submit</button>
<div id="error_msg" style="visibility:hidden;" class="alert alert-danger mt-2"
role="alert"></div>
</div>
</div>
<div class="card m-2 shadow">
<div class="card-body">
<div class="form-group">
<label id="threshold_label" for="threshold">Score Threshold</label>
<input type="range" min="0" max="0.99" step="0.01" value="0.2"
class="form-control-range" name="threshold" id="threshold">
</div>
<h5>
Labels
</h5>
<textarea id="feedback" style="height:320px; width:500px;"></textarea>
</div>
</div>
</div>
</form>
</div>
</body>
<script async src="https://www.youtube.com/iframe_api"></script>
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.2.1/jquery.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.14.7/umd/popper.min.js"
integrity="sha384-UO2eT0CpHqdSJQ6hJty5KVphtPhzWj9WO1clHTMGa3JDZwrnQq4sF86dIHNDz0W1"
crossorigin="anonymous"></script>
<script src="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/js/bootstrap.min.js"
integrity="sha384-JjSmVgyd0p3pXB1rRibZUAYoIIy6OrQ6VrjIEaFf/nJGzIxFDsf4x0xIM+B07jRM"
crossorigin="anonymous"></script>
</html>

View File

@ -0,0 +1,217 @@
/**
* @license
* Copyright 2019 The MediaPipe Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
const STATE_PLAYER=0;
const STATE_COVER=1;
const STATE_SPINNER=2;
/**
* Looks up the value of a url parameter.
*
* @param {string} param The name of the parameter.
* @return {?string} The parameter value or null if there is no such parameter.
*/
var getUrlParameter = function(param) {
const url = decodeURIComponent(window.location.search.substring(1));
const url_parts = url.split('&');
for (var i = 0; i < url_parts.length; i++) {
const param_name = url_parts[i].split(/=(.*)/);
if (param_name[0] === param) {
return param_name[1] === undefined ? null : param_name[1];
}
}
};
/**
* Sets the fields in the form to match the values of the URL parameters.
*/
const updateFormFromURL = function() {
const form_elements = document.getElementById('form').elements;
const url = decodeURIComponent(window.location.search.substring(1));
const url_parts = url.split('&');
for (var i = 0; i < url_parts.length; i++) {
const p = url_parts[i].split(/=(.*)/);
if (p.length >= 2) {
if (form_elements[p[0]]) {
form_elements[p[0]].value = decodeURIComponent(p[1]);
}
}
}
};
let player = null;
let intervalID = undefined;
let entries = [];
/**
* Constructs the embedded YouTube player.
*/
window.onYouTubeIframeAPIReady = () => {
player = new YT.Player('ytplayer', {
events: {
'onReady': onPlayerReady,
'onStateChange': onStateChange
}
});
};
/**
* Listens for YouTube video events. When video is playing, periodically checks
* the time signature and updates the feedback with labels. When video stops,
* shuts off interval timer to save cycles.
* @param {!Event} event YouTube API Event.
*/
function onStateChange(event) {
if (event.data === 1) {
// Youtube switched to playing.
intervalID = setInterval(function(){
const currentTime = player.getCurrentTime();
let winner = undefined;
let first = undefined;
for (entry of entries) {
if (!first) {
first = entry.labels;
}
if (entry.time < currentTime) {
winner = entry.labels;
} else {
break;
}
}
if (!winner) {
winner = first;
}
const threshold =
document.getElementById('form').elements['threshold'].value;
let message = "";
for (var label of winner) {
if (label.score >= threshold) {
message = `${message}${label.label} (score: ${label.score})\n`;
}
}
$("textarea#feedback").val(message);
});
} else {
if (intervalID) {
clearInterval(intervalID);
}
}
}
/**
* Turns elements of the player on and off to reflect the state of the "app".
* @param {number} state One of STATE_COVER | STATE_SPINNER | STATE_PLAYER.
*/
function showState(state) {
switch(state) {
case STATE_COVER:
$('#cover').show();
$('#spinner').hide();
$('#ytplayer').hide();
break;
case STATE_SPINNER:
$('#cover').hide();
$('#spinner').show();
$('#ytplayer').hide();
break;
case STATE_PLAYER:
default:
$('#cover').hide();
$('#spinner').hide();
$('#ytplayer').show();
break;
}
}
/**
* Hide error field and clear its message.
*/
function hideError() {
$('#error_msg').css("visibility", "hidden").text('');
}
/**
* Set the error to visible and set its message.
* @param {string} msg Error message as a string.
*/
function showError(msg) {
$('#error_msg').css("visibility", "visible").text(msg);
}
/**
* Privides numeric feedback for the slider.
*/
function connectSlider() {
$('#threshold_label').text(
`Score Threshold (${$('#threshold')[0].value})`);
$('#threshold').on('input', () => {
$('#threshold_label').text(
`Score Threshold (${$('#threshold')[0].value})`);
});
$('#segments_label').text(
`Segment Size (${$('#segments')[0].value})`);
$('#segments').on('input', () => {
$('#segments_label').text(
`Segment Size (${$('#segments')[0].value})`);
});
}
/**
* Retrieve video information from backend.
* @param {string} filePath name of a tfrecord file.
* @param {number} segments desired number of segments (1-300)
*/
function fetchVideo(filePath, segments) {
const url = "/video?file=" + filePath + "&segments=" + segments;
$.ajax({
url: url,
success: function(result) {
const videoId = result["video_id"];
player.loadVideoById(videoId);
entries = result['entries'];
showState(STATE_PLAYER);
},
error: (err) => {
showState(STATE_COVER);
console.log(err);
showError(err.responseText);
},
datatype: "json"
});
}
/**
* Called when the embedded YouTube player has finished loading. It loads the
* requested video into the player and calls the golden6_viewer API to retrieve
* the frame-level data for that video.
*/
function onPlayerReady() {
const filePath = getUrlParameter('file') || "";
const segments = parseInt(getUrlParameter('segments')) || 0;
updateFormFromURL();
hideError();
connectSlider();
if (!filePath) {
return;
}
showState(STATE_SPINNER);
fetchVideo(filePath, segments);
}

View File

@ -688,6 +688,12 @@ cc_library(
cc_library( cc_library(
name = "demangle", name = "demangle",
hdrs = ["demangle.h"], hdrs = ["demangle.h"],
defines = select({
"//mediapipe/framework/profiler:android_release": [
"MEDIAPIPE_HAS_CXA_DEMANGLE=0",
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
@ -1713,3 +1719,10 @@ cc_test(
"//mediapipe/framework/tool/testdata:dub_quad_test_subgraph", "//mediapipe/framework/tool/testdata:dub_quad_test_subgraph",
], ],
) )
# Expose the proto source files for building mediapipe AAR.
filegroup(
name = "protos_src",
srcs = glob(["*.proto"]),
visibility = ["//mediapipe:__subpackages__"],
)

View File

@ -756,7 +756,7 @@ TEST(CalculatorGraphBoundsTest, BoundWithoutInputPackets) {
MP_ASSERT_OK(graph.WaitUntilDone()); MP_ASSERT_OK(graph.WaitUntilDone());
} }
// Shows that when fixed-size-input-stream-hanlder drops packets, // Shows that when fixed-size-input-stream-handler drops packets,
// no timetamp bounds are announced. // no timetamp bounds are announced.
TEST(CalculatorGraphBoundsTest, FixedSizeHandlerBounds) { TEST(CalculatorGraphBoundsTest, FixedSizeHandlerBounds) {
// LambdaCalculator with FixedSizeInputStreamHandler will drop packets // LambdaCalculator with FixedSizeInputStreamHandler will drop packets
@ -876,5 +876,93 @@ TEST(CalculatorGraphBoundsTest, FixedSizeHandlerBounds) {
MP_ASSERT_OK(graph.WaitUntilDone()); MP_ASSERT_OK(graph.WaitUntilDone());
} }
// A Calculator that outputs only the last packet from its input stream.
class LastPacketCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetAny();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
cc->Outputs().Index(0).SetNextTimestampBound(cc->InputTimestamp());
last_packet_ = cc->Inputs().Index(0).Value();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Close(CalculatorContext* cc) final {
cc->Outputs().Index(0).AddPacket(last_packet_);
return ::mediapipe::OkStatus();
}
private:
Packet last_packet_;
};
REGISTER_CALCULATOR(LastPacketCalculator);
// Shows that the last packet in an input stream can be detected.
TEST(CalculatorGraphBoundsTest, LastPacketCheck) {
// LastPacketCalculator emits only the last input stream packet.
// It emits a timestamp bound after the arrival of a successor input stream
// packet or input stream close. The output "last_output" shows the
// last packet, and "output" shows the timestamp bounds.
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: 'input'
output_stream: 'output'
output_stream: 'last_output'
node {
calculator: 'PassThroughCalculator'
input_stream: 'input'
output_stream: 'input_2'
}
node {
calculator: 'LastPacketCalculator'
input_stream: 'input_2'
output_stream: 'last_packet'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'input'
input_stream: 'last_packet'
output_stream: 'output'
output_stream: 'last_output'
}
)");
CalculatorGraph graph;
std::vector<Packet> output_packets;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.ObserveOutputStream("output", [&](const Packet& p) {
output_packets.push_back(p);
return ::mediapipe::OkStatus();
}));
std::vector<Packet> last_output_packets;
MP_ASSERT_OK(graph.ObserveOutputStream("last_output", [&](const Packet& p) {
last_output_packets.push_back(p);
return ::mediapipe::OkStatus();
}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Add four packets into the graph.
constexpr int kNumInputs = 4;
for (int i = 0; i < kNumInputs; ++i) {
Packet p = MakePacket<int>(33).At(Timestamp(i));
MP_ASSERT_OK(graph.AddPacketToInputStream("input", p));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_EQ(i, output_packets.size());
EXPECT_EQ(0, last_output_packets.size());
}
// Shutdown the graph.
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_EQ(kNumInputs, output_packets.size());
EXPECT_EQ(1, last_output_packets.size());
MP_ASSERT_OK(graph.WaitUntilDone());
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -743,5 +743,66 @@ TEST(CalculatorGraph, GetOutputSidePacket) {
} }
} }
typedef std::string HugeModel;
// Generates an output-side-packet once for each calculator-graph.
class OutputSidePacketCachedCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->OutputSidePackets().Index(0).Set<HugeModel>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
cc->OutputSidePackets().Index(0).Set(MakePacket<HugeModel>(
R"(An expensive side-packet created only once per graph)"));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
LOG(FATAL) << "Not reached.";
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(OutputSidePacketCachedCalculator);
// Returns true if two packets hold the same data.
bool Equals(Packet p1, Packet p2) {
return packet_internal::GetHolder(p1) == packet_internal::GetHolder(p2);
}
TEST(CalculatorGraph, OutputSidePacketCached) {
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
node {
calculator: "OutputSidePacketCachedCalculator"
output_side_packet: "model"
}
node {
calculator: "SidePacketToStreamPacketCalculator"
input_side_packet: "model"
output_stream: "output"
}
)");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
std::vector<Packet> output_packets;
MP_ASSERT_OK(graph.ObserveOutputStream(
"output", [&output_packets](const Packet& packet) {
output_packets.push_back(packet);
return ::mediapipe::OkStatus();
}));
// Run the graph three times.
for (int run = 0; run < 3; ++run) {
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilDone());
}
ASSERT_EQ(3, output_packets.size());
for (int run = 0; run < output_packets.size(); ++run) {
EXPECT_TRUE(Equals(output_packets[0], output_packets[run]));
}
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -391,6 +391,38 @@ void CalculatorNode::SetMaxInputStreamQueueSize(int max_queue_size) {
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
namespace {
// Returns the Packet sent to an OutputSidePacket, or an empty packet
// if none available.
const Packet GetPacket(const OutputSidePacket& out) {
auto impl = dynamic_cast<const OutputSidePacketImpl*>(&out);
return (impl == nullptr) ? Packet() : impl->GetPacket();
}
// Resends the output-side-packets from the previous graph run.
::mediapipe::Status ResendSidePackets(CalculatorContext* cc) {
auto& outs = cc->OutputSidePackets();
for (CollectionItemId id = outs.BeginId(); id < outs.EndId(); ++id) {
Packet packet = GetPacket(outs.Get(id));
if (!packet.IsEmpty()) {
// OutputSidePacket::Set re-announces the side-packet to its mirrors.
outs.Get(id).Set(packet);
}
}
return ::mediapipe::OkStatus();
}
} // namespace
bool CalculatorNode::OutputsAreConstant(CalculatorContext* cc) {
if (cc->Inputs().NumEntries() > 0 || cc->Outputs().NumEntries() > 0) {
return false;
}
if (input_side_packet_handler_.InputSidePacketsChanged()) {
return false;
}
return true;
}
::mediapipe::Status CalculatorNode::OpenNode() { ::mediapipe::Status CalculatorNode::OpenNode() {
VLOG(2) << "CalculatorNode::OpenNode() for " << DebugName(); VLOG(2) << "CalculatorNode::OpenNode() for " << DebugName();
@ -407,8 +439,9 @@ void CalculatorNode::SetMaxInputStreamQueueSize(int max_queue_size) {
default_context, Timestamp::Unstarted()); default_context, Timestamp::Unstarted());
::mediapipe::Status result; ::mediapipe::Status result;
if (OutputsAreConstant(default_context)) {
{ result = ResendSidePackets(default_context);
} else {
MEDIAPIPE_PROFILING(OPEN, default_context); MEDIAPIPE_PROFILING(OPEN, default_context);
LegacyCalculatorSupport::Scoped<CalculatorContext> s(default_context); LegacyCalculatorSupport::Scoped<CalculatorContext> s(default_context);
result = calculator_->Open(default_context); result = calculator_->Open(default_context);
@ -494,7 +527,10 @@ void CalculatorNode::CloseOutputStreams(OutputStreamShardSet* outputs) {
::mediapipe::Status result; ::mediapipe::Status result;
{ if (OutputsAreConstant(default_context)) {
// Do nothing.
result = ::mediapipe::OkStatus();
} else {
MEDIAPIPE_PROFILING(CLOSE, default_context); MEDIAPIPE_PROFILING(CLOSE, default_context);
LegacyCalculatorSupport::Scoped<CalculatorContext> s(default_context); LegacyCalculatorSupport::Scoped<CalculatorContext> s(default_context);
result = calculator_->Close(default_context); result = calculator_->Close(default_context);
@ -770,7 +806,10 @@ std::string CalculatorNode::DebugName() const {
VLOG(2) << "Calling Calculator::Process() for node: " << DebugName(); VLOG(2) << "Calling Calculator::Process() for node: " << DebugName();
{ if (OutputsAreConstant(calculator_context)) {
// Do nothing.
result = ::mediapipe::OkStatus();
} else {
MEDIAPIPE_PROFILING(PROCESS, calculator_context); MEDIAPIPE_PROFILING(PROCESS, calculator_context);
LegacyCalculatorSupport::Scoped<CalculatorContext> s( LegacyCalculatorSupport::Scoped<CalculatorContext> s(
calculator_context); calculator_context);

View File

@ -280,6 +280,9 @@ class CalculatorNode {
// Get a std::string describing the input streams. // Get a std::string describing the input streams.
std::string DebugInputStreamNames() const; std::string DebugInputStreamNames() const;
// Returns true if all outputs will be identical to the previous graph run.
bool OutputsAreConstant(CalculatorContext* cc);
// The calculator. // The calculator.
std::unique_ptr<CalculatorBase> calculator_; std::unique_ptr<CalculatorBase> calculator_;
// Keeps data which a Calculator subclass needs access to. // Keeps data which a Calculator subclass needs access to.

View File

@ -240,6 +240,22 @@ class Collection {
return tag_map_->EndId(tag); return tag_map_->EndId(tag);
} }
// Equal Collections contain equal mappings and equal elements.
bool operator==(const Collection<T>& other) const {
if (tag_map_->Mapping() != other.TagMap()->Mapping()) {
return false;
}
for (CollectionItemId id = BeginId(); id < EndId(); ++id) {
if (Get(id) != other.Get(id)) {
return false;
}
}
return true;
}
bool operator!=(const Collection<T>& other) const {
return !(*this == other);
}
private: private:
// An iterator which is identical to ItType** except that the // An iterator which is identical to ItType** except that the
// dereference operator (operator*) does a double dereference and // dereference operator (operator*) does a double dereference and

View File

@ -15,23 +15,25 @@
#ifndef MEDIAPIPE_FRAMEWORK_DEMANGLE_H_ #ifndef MEDIAPIPE_FRAMEWORK_DEMANGLE_H_
#define MEDIAPIPE_FRAMEWORK_DEMANGLE_H_ #define MEDIAPIPE_FRAMEWORK_DEMANGLE_H_
#ifndef MEDIAPIPE_HAS_CXA_DEMANGLE
// We only support some compilers that support __cxa_demangle. // We only support some compilers that support __cxa_demangle.
// TODO: Checks if Android NDK has fixed this issue or not. // TODO: Checks if Android NDK has fixed this issue or not.
#if defined(__ANDROID__) && (defined(__i386__) || defined(__x86_64__)) #if defined(__ANDROID__) && (defined(__i386__) || defined(__x86_64__))
#define HAS_CXA_DEMANGLE 0 #define MEDIAPIPE_HAS_CXA_DEMANGLE 0
#elif (__GNUC__ >= 4 || (__GNUC__ >= 3 && __GNUC_MINOR__ >= 4)) && \ #elif (__GNUC__ >= 4 || (__GNUC__ >= 3 && __GNUC_MINOR__ >= 4)) && \
!defined(__mips__) !defined(__mips__)
#define HAS_CXA_DEMANGLE 1 #define MEDIAPIPE_HAS_CXA_DEMANGLE 1
#elif defined(__clang__) && !defined(_MSC_VER) #elif defined(__clang__) && !defined(_MSC_VER)
#define HAS_CXA_DEMANGLE 1 #define MEDIAPIPE_HAS_CXA_DEMANGLE 1
#else #else
#define HAS_CXA_DEMANGLE 0 #define MEDIAPIPE_HAS_CXA_DEMANGLE 0
#endif
#endif #endif
#include <stdlib.h> #include <stdlib.h>
#include <string> #include <string>
#if HAS_CXA_DEMANGLE #if MEDIAPIPE_HAS_CXA_DEMANGLE
#include <cxxabi.h> #include <cxxabi.h>
#endif #endif
@ -65,7 +67,7 @@ namespace mediapipe {
inline std::string Demangle(const char* mangled) { inline std::string Demangle(const char* mangled) {
int status = 0; int status = 0;
char* demangled = nullptr; char* demangled = nullptr;
#if HAS_CXA_DEMANGLE #if MEDIAPIPE_HAS_CXA_DEMANGLE
demangled = abi::__cxa_demangle(mangled, nullptr, nullptr, &status); demangled = abi::__cxa_demangle(mangled, nullptr, nullptr, &status);
#endif #endif
std::string out; std::string out;

View File

@ -15,10 +15,9 @@
# Description: # Description:
# The dependencies of mediapipe. # The dependencies of mediapipe.
licenses(["notice"]) # Apache 2.0
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_py_proto_library")
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])

View File

@ -66,5 +66,9 @@ message ImageFormat {
// LAB, interleaved: one byte for L, then one byte for a, then one // LAB, interleaved: one byte for L, then one byte for a, then one
// byte for b for each pixel. // byte for b for each pixel.
LAB8 = 10; LAB8 = 10;
// sBGRA, interleaved: one byte for B, one byte for G, one byte for R,
// one byte for alpha or unused. This is the N32 format for Skia.
SBGRA = 11;
} }
} }

View File

@ -279,6 +279,8 @@ int ImageFrame::NumberOfChannelsForFormat(ImageFormat::Format format) {
return 1; return 1;
case ImageFormat::LAB8: case ImageFormat::LAB8:
return 3; return 3;
case ImageFormat::SBGRA:
return 4;
default: default:
LOG(FATAL) << InvalidFormatString(format); LOG(FATAL) << InvalidFormatString(format);
} }
@ -304,6 +306,8 @@ int ImageFrame::ChannelSizeForFormat(ImageFormat::Format format) {
return sizeof(float); return sizeof(float);
case ImageFormat::LAB8: case ImageFormat::LAB8:
return sizeof(uint8); return sizeof(uint8);
case ImageFormat::SBGRA:
return sizeof(uint8);
default: default:
LOG(FATAL) << InvalidFormatString(format); LOG(FATAL) << InvalidFormatString(format);
} }
@ -329,6 +333,8 @@ int ImageFrame::ByteDepthForFormat(ImageFormat::Format format) {
return 4; return 4;
case ImageFormat::LAB8: case ImageFormat::LAB8:
return 1; return 1;
case ImageFormat::SBGRA:
return 1;
default: default:
LOG(FATAL) << InvalidFormatString(format); LOG(FATAL) << InvalidFormatString(format);
} }

View File

@ -59,6 +59,9 @@ int GetMatType(const mediapipe::ImageFormat::Format format) {
case mediapipe::ImageFormat::LAB8: case mediapipe::ImageFormat::LAB8:
type = CV_8U; type = CV_8U;
break; break;
case mediapipe::ImageFormat::SBGRA:
type = CV_8U;
break;
default: default:
// Invalid or unknown; Default to uchar. // Invalid or unknown; Default to uchar.
type = CV_8U; type = CV_8U;

View File

@ -32,3 +32,8 @@ message NormalizedLandmark {
optional float y = 2; optional float y = 2;
optional float z = 3; optional float z = 3;
} }
// Group of NormalizedLandmark protos.
message NormalizedLandmarkList {
repeated NormalizedLandmark landmark = 1;
}

View File

@ -27,6 +27,7 @@ namespace mediapipe {
std::function<void()> input_side_packets_ready_callback, std::function<void()> input_side_packets_ready_callback,
std::function<void(::mediapipe::Status)> error_callback) { std::function<void(::mediapipe::Status)> error_callback) {
int missing_input_side_packet_count; int missing_input_side_packet_count;
prev_input_side_packets_ = std::move(input_side_packets_);
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
input_side_packets_, input_side_packets_,
tool::FillPacketSet(*input_side_packet_types, all_side_packets, tool::FillPacketSet(*input_side_packet_types, all_side_packets,
@ -41,6 +42,12 @@ namespace mediapipe {
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
bool InputSidePacketHandler::InputSidePacketsChanged() {
return prev_input_side_packets_ == nullptr ||
input_side_packets_ == nullptr ||
*input_side_packets_ != *prev_input_side_packets_;
}
void InputSidePacketHandler::Set(CollectionItemId id, const Packet& packet) { void InputSidePacketHandler::Set(CollectionItemId id, const Packet& packet) {
::mediapipe::Status status = SetInternal(id, packet); ::mediapipe::Status status = SetInternal(id, packet);
if (!status.ok()) { if (!status.ok()) {

View File

@ -52,6 +52,10 @@ class InputSidePacketHandler {
const PacketSet& InputSidePackets() const { return *input_side_packets_; } const PacketSet& InputSidePackets() const { return *input_side_packets_; }
// Returns true if the set of input-side-packets has changed since the
// previous run.
bool InputSidePacketsChanged();
// Returns the number of missing input side packets. // Returns the number of missing input side packets.
int MissingInputSidePacketCount() const { int MissingInputSidePacketCount() const {
return missing_input_side_packet_count_.load(std::memory_order_relaxed); return missing_input_side_packet_count_.load(std::memory_order_relaxed);
@ -68,6 +72,7 @@ class InputSidePacketHandler {
const PacketTypeSet* input_side_packet_types_; const PacketTypeSet* input_side_packet_types_;
std::unique_ptr<PacketSet> input_side_packets_; std::unique_ptr<PacketSet> input_side_packets_;
std::unique_ptr<PacketSet> prev_input_side_packets_;
std::atomic<int> missing_input_side_packet_count_{0}; std::atomic<int> missing_input_side_packet_count_{0};

View File

@ -30,7 +30,7 @@ namespace mediapipe {
void OutputSidePacketImpl::PrepareForRun( void OutputSidePacketImpl::PrepareForRun(
std::function<void(::mediapipe::Status)> error_callback) { std::function<void(::mediapipe::Status)> error_callback) {
error_callback_ = std::move(error_callback); error_callback_ = std::move(error_callback);
packet_ = Packet(); initialized_ = false;
} }
void OutputSidePacketImpl::Set(const Packet& packet) { void OutputSidePacketImpl::Set(const Packet& packet) {
@ -47,7 +47,7 @@ void OutputSidePacketImpl::AddMirror(
} }
::mediapipe::Status OutputSidePacketImpl::SetInternal(const Packet& packet) { ::mediapipe::Status OutputSidePacketImpl::SetInternal(const Packet& packet) {
if (!packet_.IsEmpty()) { if (initialized_) {
return ::mediapipe::AlreadyExistsErrorBuilder(MEDIAPIPE_LOC) return ::mediapipe::AlreadyExistsErrorBuilder(MEDIAPIPE_LOC)
<< "Output side packet \"" << name_ << "\" was already set."; << "Output side packet \"" << name_ << "\" was already set.";
} }
@ -72,6 +72,7 @@ void OutputSidePacketImpl::AddMirror(
} }
packet_ = packet; packet_ = packet;
initialized_ = true;
for (const auto& mirror : mirrors_) { for (const auto& mirror : mirrors_) {
mirror.input_side_packet_handler->Set(mirror.id, packet_); mirror.input_side_packet_handler->Set(mirror.id, packet_);
} }

View File

@ -80,6 +80,7 @@ class OutputSidePacketImpl : public OutputSidePacket {
const PacketType* packet_type_; const PacketType* packet_type_;
std::function<void(::mediapipe::Status)> error_callback_; std::function<void(::mediapipe::Status)> error_callback_;
Packet packet_; Packet packet_;
bool initialized_ = false;
std::vector<Mirror> mirrors_; std::vector<Mirror> mirrors_;
}; };

View File

@ -653,6 +653,14 @@ Packet PointToForeign(const T* ptr) {
return packet_internal::Create(new packet_internal::ForeignHolder<T>(ptr)); return packet_internal::Create(new packet_internal::ForeignHolder<T>(ptr));
} }
// Equal Packets refer to the same memory contents, like equal pointers.
inline bool operator==(const Packet& p1, const Packet& p2) {
return packet_internal::GetHolder(p1) == packet_internal::GetHolder(p2);
}
inline bool operator!=(const Packet& p1, const Packet& p2) {
return !(p1 == p2);
}
} // namespace mediapipe } // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_PACKET_H_ #endif // MEDIAPIPE_FRAMEWORK_PACKET_H_

View File

@ -28,4 +28,22 @@
#define MEDIAPIPE_MOBILE #define MEDIAPIPE_MOBILE
#endif #endif
#if !defined(MEDIAPIPE_ANDROID) && defined(__ANDROID__)
#define MEDIAPIPE_ANDROID
#endif
#if defined(__APPLE__)
#include "TargetConditionals.h" // for TARGET_OS_*
#if !defined(MEDIAPIPE_IOS) && !TARGET_OS_OSX
#define MEDIAPIPE_IOS
#endif
#endif
// These platforms do not support OpenGL ES Compute Shaders (v3.1 and up),
// but can still run OpenGL ES 3.0 and below.
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) && \
(defined(__APPLE__) || defined(__EMSCRIPTEN__))
#define MEDIAPIPE_DISABLE_GL_COMPUTE
#endif
#endif // MEDIAPIPE_FRAMEWORK_PORT_H_ #endif // MEDIAPIPE_FRAMEWORK_PORT_H_

View File

@ -247,25 +247,45 @@ TEST_F(GraphProfilerTestPeer, InitializeConfig) {
// Checks histogram_interval_size_usec and num_histogram_intervals. // Checks histogram_interval_size_usec and num_histogram_intervals.
CalculatorProfile actual = CalculatorProfile actual =
GetCalculatorProfilesMap()->find(kDummyTestCalculatorName)->second; GetCalculatorProfilesMap()->find(kDummyTestCalculatorName)->second;
ASSERT_EQ(actual.name(), kDummyTestCalculatorName); EXPECT_THAT(actual, EqualsProto(R"(
ASSERT_FALSE(actual.has_open_runtime()); name: "DummyTestCalculator"
ASSERT_FALSE(actual.has_close_runtime()); process_runtime {
total: 0
ASSERT_EQ(actual.process_runtime().interval_size_usec(), 1000); interval_size_usec: 1000
ASSERT_EQ(actual.process_runtime().num_intervals(), 3); num_intervals: 3
count: 0
ASSERT_EQ(actual.process_input_latency().interval_size_usec(), 1000); count: 0
ASSERT_EQ(actual.process_input_latency().num_intervals(), 3); count: 0
}
ASSERT_EQ(actual.process_output_latency().interval_size_usec(), 1000); process_input_latency {
ASSERT_EQ(actual.process_output_latency().num_intervals(), 3); total: 0
interval_size_usec: 1000
ASSERT_EQ(actual.input_stream_profiles().size(), 1); num_intervals: 3
ASSERT_EQ(actual.input_stream_profiles(0).name(), "input_stream"); count: 0
ASSERT_FALSE(actual.input_stream_profiles(0).back_edge()); count: 0
ASSERT_EQ(actual.input_stream_profiles(0).latency().interval_size_usec(), count: 0
1000); }
ASSERT_EQ(actual.input_stream_profiles(0).latency().num_intervals(), 3); process_output_latency {
total: 0
interval_size_usec: 1000
num_intervals: 3
count: 0
count: 0
count: 0
}
input_stream_profiles {
name: "input_stream"
back_edge: false
latency {
total: 0
interval_size_usec: 1000
num_intervals: 3
count: 0
count: 0
count: 0
}
}
)"));
} }
// Tests that Initialize() uses the ProfilerConfig in the graph definition. // Tests that Initialize() uses the ProfilerConfig in the graph definition.
@ -291,16 +311,17 @@ TEST_F(GraphProfilerTestPeer, InitializeConfigWithoutStreamLatency) {
// Checks histogram_interval_size_usec and num_histogram_intervals. // Checks histogram_interval_size_usec and num_histogram_intervals.
CalculatorProfile actual = CalculatorProfile actual =
GetCalculatorProfilesMap()->find(kDummyTestCalculatorName)->second; GetCalculatorProfilesMap()->find(kDummyTestCalculatorName)->second;
ASSERT_EQ(actual.name(), kDummyTestCalculatorName); EXPECT_THAT(actual, EqualsProto(R"(
ASSERT_FALSE(actual.has_open_runtime()); name: "DummyTestCalculator"
ASSERT_FALSE(actual.has_close_runtime()); process_runtime {
total: 0
ASSERT_EQ(actual.process_runtime().interval_size_usec(), 1000); interval_size_usec: 1000
ASSERT_EQ(actual.process_runtime().num_intervals(), 3); num_intervals: 3
count: 0
ASSERT_FALSE(actual.has_process_input_latency()); count: 0
ASSERT_FALSE(actual.has_process_output_latency()); count: 0
ASSERT_EQ(actual.input_stream_profiles().size(), 0); }
)"));
} }
// Tests that Initialize() reads all the configs defined in the graph // Tests that Initialize() reads all the configs defined in the graph
@ -633,10 +654,11 @@ TEST_F(GraphProfilerTestPeer, SetOpenRuntime) {
simulation_clock->ThreadFinish(); simulation_clock->ThreadFinish();
ASSERT_EQ(profiles.size(), 1); ASSERT_EQ(profiles.size(), 1);
ASSERT_EQ(profiles[0].open_runtime(), 100); EXPECT_THAT(profiles[0], Partially(EqualsProto(R"(
ASSERT_FALSE(profiles[0].has_close_runtime()); name: "DummyTestCalculator"
ASSERT_THAT(profiles[0].process_runtime(), open_runtime: 100
Partially(EqualsProto(CreateTimeHistogram(/*total=*/0, {0})))); process_runtime { total: 0 }
)")));
// Checks packets_info_ map hasn't changed. // Checks packets_info_ map hasn't changed.
ASSERT_EQ(GetPacketsInfoMap()->size(), 0); ASSERT_EQ(GetPacketsInfoMap()->size(), 0);
} }
@ -688,14 +710,29 @@ TEST_F(GraphProfilerTestPeer, SetOpenRuntimeWithStreamLatency) {
ASSERT_EQ(profiles.size(), 2); ASSERT_EQ(profiles.size(), 2);
CalculatorProfile source_profile = CalculatorProfile source_profile =
GetProfileWithName(profiles, "source_calc"); GetProfileWithName(profiles, "source_calc");
ASSERT_EQ(source_profile.open_runtime(), 150);
ASSERT_FALSE(source_profile.has_close_runtime()); EXPECT_THAT(source_profile, EqualsProto(R"(
ASSERT_THAT(source_profile.process_runtime(), name: "source_calc"
Partially(EqualsProto(CreateTimeHistogram(/*total=*/0, {0})))); open_runtime: 150
ASSERT_THAT(source_profile.process_input_latency(), process_runtime {
Partially(EqualsProto(CreateTimeHistogram(/*total=*/0, {0})))); total: 0
ASSERT_THAT(source_profile.process_output_latency(), interval_size_usec: 1000000
Partially(EqualsProto(CreateTimeHistogram(/*total=*/0, {0})))); num_intervals: 1
count: 0
}
process_input_latency {
total: 0
interval_size_usec: 1000000
num_intervals: 1
count: 0
}
process_output_latency {
total: 0
interval_size_usec: 1000000
num_intervals: 1
count: 0
}
)"));
// Check packets_info_ map has been updated. // Check packets_info_ map has been updated.
ASSERT_EQ(GetPacketsInfoMap()->size(), 1); ASSERT_EQ(GetPacketsInfoMap()->size(), 1);
@ -736,11 +773,16 @@ TEST_F(GraphProfilerTestPeer, SetCloseRuntime) {
std::vector<CalculatorProfile> profiles = Profiles(); std::vector<CalculatorProfile> profiles = Profiles();
simulation_clock->ThreadFinish(); simulation_clock->ThreadFinish();
ASSERT_EQ(profiles.size(), 1); EXPECT_THAT(profiles[0], EqualsProto(R"(
ASSERT_FALSE(profiles[0].open_runtime()); name: "DummyTestCalculator"
ASSERT_EQ(profiles[0].close_runtime(), 100); close_runtime: 100
ASSERT_THAT(profiles[0].process_runtime(), process_runtime {
Partially(EqualsProto(CreateTimeHistogram(/*total=*/0, {0})))); total: 0
interval_size_usec: 1000000
num_intervals: 1
count: 0
}
)"));
} }
// Tests that SetCloseRuntime() updates |close_runtime| and doesn't affect other // Tests that SetCloseRuntime() updates |close_runtime| and doesn't affect other
@ -789,11 +831,39 @@ TEST_F(GraphProfilerTestPeer, SetCloseRuntimeWithStreamLatency) {
ASSERT_EQ(profiles.size(), 2); ASSERT_EQ(profiles.size(), 2);
CalculatorProfile source_profile = CalculatorProfile source_profile =
GetProfileWithName(profiles, "source_calc"); GetProfileWithName(profiles, "source_calc");
ASSERT_FALSE(source_profile.open_runtime());
ASSERT_EQ(source_profile.close_runtime(), 100); EXPECT_THAT(source_profile, EqualsProto(R"(
ASSERT_THAT(source_profile.process_runtime(), name: "source_calc"
Partially(EqualsProto(CreateTimeHistogram(/*total=*/0, {0})))); close_runtime: 100
ASSERT_EQ(GetPacketsInfoMap()->size(), 1); process_runtime {
total: 0
interval_size_usec: 1000000
num_intervals: 1
count: 0
}
process_input_latency {
total: 0
interval_size_usec: 1000000
num_intervals: 1
count: 0
}
process_output_latency {
total: 0
interval_size_usec: 1000000
num_intervals: 1
count: 0
}
input_stream_profiles {
name: "input_stream"
back_edge: false
latency {
total: 0
interval_size_usec: 1000000
num_intervals: 1
count: 0
}
}
)"));
PacketInfo expected_packet_info = {0, PacketInfo expected_packet_info = {0,
/*production_time_usec=*/1000 + 100, /*production_time_usec=*/1000 + 100,
/*source_process_start_usec=*/1000 + 0}; /*source_process_start_usec=*/1000 + 0};
@ -933,10 +1003,15 @@ TEST_F(GraphProfilerTestPeer, AddProcessSample) {
simulation_clock->ThreadFinish(); simulation_clock->ThreadFinish();
ASSERT_EQ(profiles.size(), 1); ASSERT_EQ(profiles.size(), 1);
ASSERT_THAT(profiles[0].process_runtime(), EXPECT_THAT(profiles[0], EqualsProto(R"(
Partially(EqualsProto(CreateTimeHistogram(/*total=*/150, {1})))); name: "DummyTestCalculator"
ASSERT_FALSE(profiles[0].has_open_runtime()); process_runtime {
ASSERT_FALSE(profiles[0].has_close_runtime()); total: 150
interval_size_usec: 1000000
num_intervals: 1
count: 1
}
)"));
// Checks packets_info_ map hasn't changed. // Checks packets_info_ map hasn't changed.
ASSERT_EQ(GetPacketsInfoMap()->size(), 0); ASSERT_EQ(GetPacketsInfoMap()->size(), 0);
} }
@ -985,12 +1060,27 @@ TEST_F(GraphProfilerTestPeer, AddProcessSampleWithStreamLatency) {
ASSERT_EQ(profiles.size(), 2); ASSERT_EQ(profiles.size(), 2);
CalculatorProfile source_profile = CalculatorProfile source_profile =
GetProfileWithName(profiles, "source_calc"); GetProfileWithName(profiles, "source_calc");
ASSERT_THAT(source_profile.process_runtime(),
Partially(EqualsProto(CreateTimeHistogram(/*total=*/150, {1})))); EXPECT_THAT(profiles[0], Partially(EqualsProto(R"(
ASSERT_THAT(source_profile.process_input_latency(), process_runtime {
Partially(EqualsProto(CreateTimeHistogram(/*total=*/0, {1})))); total: 150
ASSERT_THAT(source_profile.process_output_latency(), interval_size_usec: 1000000
Partially(EqualsProto(CreateTimeHistogram(/*total=*/150, {1})))); num_intervals: 1
count: 1
}
process_input_latency {
total: 0
interval_size_usec: 1000000
num_intervals: 1
count: 1
}
process_output_latency {
total: 150
interval_size_usec: 1000000
num_intervals: 1
count: 1
}
)")));
// Check packets_info_ map has been updated. // Check packets_info_ map has been updated.
ASSERT_EQ(GetPacketsInfoMap()->size(), 1); ASSERT_EQ(GetPacketsInfoMap()->size(), 1);
@ -1019,22 +1109,24 @@ TEST_F(GraphProfilerTestPeer, AddProcessSampleWithStreamLatency) {
CalculatorProfile consumer_profile = CalculatorProfile consumer_profile =
GetProfileWithName(profiles, "consumer_calc"); GetProfileWithName(profiles, "consumer_calc");
ASSERT_THAT(consumer_profile.process_runtime(),
Partially(EqualsProto(CreateTimeHistogram(/*total=*/250, {1})))); // process input latency total = 2000 (end) - 1000 (when source started) =
ASSERT_THAT(consumer_profile.process_input_latency(), // 1000 process output latency total = 2000 (end) + 250 - 1000 (when source
Partially(EqualsProto(CreateTimeHistogram( // started) = 1250 For "stream_0" should have not changed since it was empty.
/*total=*/2000 - when_source_started, {1})))); // For "stream_1" = 2000 (end) - 1250 (when source finished) = 850
ASSERT_THAT(consumer_profile.process_output_latency(), EXPECT_THAT(consumer_profile, Partially(EqualsProto(R"(
Partially(EqualsProto(CreateTimeHistogram( name: "consumer_calc"
/*total=*/2000 + 250 - when_source_started, {1})))); process_input_latency { total: 1000 }
ASSERT_EQ(consumer_profile.input_stream_profiles().size(), 2); process_output_latency { total: 1250 }
// For "stream_0" should have not changed since it was empty. input_stream_profiles {
ASSERT_THAT(consumer_profile.input_stream_profiles(0).latency(), name: "stream_0"
Partially(EqualsProto(CreateTimeHistogram(/*total=*/0, {0})))); latency { total: 0 }
// For "stream_1" }
ASSERT_THAT(consumer_profile.input_stream_profiles(1).latency(), input_stream_profiles {
Partially(EqualsProto(CreateTimeHistogram( name: "stream_1"
/*total=*/2000 - when_source_finished, {1})))); latency { total: 850 }
}
)")));
// Check packets_info_ map for PacketId({"stream_1", 100}) should not yet be // Check packets_info_ map for PacketId({"stream_1", 100}) should not yet be
// garbage collected. // garbage collected.

View File

@ -39,9 +39,20 @@ inline const void* GetPacketDataId(const HolderBase* holder) {
struct TraceEvent { struct TraceEvent {
using EventType = GraphTrace::EventType; using EventType = GraphTrace::EventType;
// GraphTrace::EventType constants, repeated here to match GraphProfilerStub. // GraphTrace::EventType constants, repeated here to match GraphProfilerStub.
static const EventType UNKNOWN, OPEN, PROCESS, CLOSE, NOT_READY, static constexpr EventType UNKNOWN = GraphTrace::UNKNOWN;
READY_FOR_PROCESS, READY_FOR_CLOSE, THROTTLED, UNTHROTTLED, CPU_TASK_USER, static constexpr EventType OPEN = GraphTrace::OPEN;
CPU_TASK_SYSTEM, GPU_TASK, DSP_TASK, TPU_TASK; static constexpr EventType PROCESS = GraphTrace::PROCESS;
static constexpr EventType CLOSE = GraphTrace::CLOSE;
static constexpr EventType NOT_READY = GraphTrace::NOT_READY;
static constexpr EventType READY_FOR_PROCESS = GraphTrace::READY_FOR_PROCESS;
static constexpr EventType READY_FOR_CLOSE = GraphTrace::READY_FOR_CLOSE;
static constexpr EventType THROTTLED = GraphTrace::THROTTLED;
static constexpr EventType UNTHROTTLED = GraphTrace::UNTHROTTLED;
static constexpr EventType CPU_TASK_USER = GraphTrace::CPU_TASK_USER;
static constexpr EventType CPU_TASK_SYSTEM = GraphTrace::CPU_TASK_SYSTEM;
static constexpr EventType GPU_TASK = GraphTrace::GPU_TASK;
static constexpr EventType DSP_TASK = GraphTrace::DSP_TASK;
static constexpr EventType TPU_TASK = GraphTrace::TPU_TASK;
absl::Time event_time; absl::Time event_time;
EventType event_type = UNKNOWN; EventType event_type = UNKNOWN;
bool is_finish = false; bool is_finish = false;

View File

@ -385,21 +385,21 @@ void TraceBuilder::CreateLog(const TraceBuffer& buffer, absl::Time begin_time,
} }
void TraceBuilder::Clear() { impl_->Clear(); } void TraceBuilder::Clear() { impl_->Clear(); }
// Defined here since inline constants fail to link in android builds. // Defined here since constexpr requires out-of-class definition until C++17.
const TraceEvent::EventType // const TraceEvent::EventType //
TraceEvent::UNKNOWN = GraphTrace::UNKNOWN, TraceEvent::UNKNOWN, //
TraceEvent::OPEN = GraphTrace::OPEN, TraceEvent::OPEN, //
TraceEvent::PROCESS = GraphTrace::PROCESS, TraceEvent::PROCESS, //
TraceEvent::CLOSE = GraphTrace::CLOSE, TraceEvent::CLOSE, //
TraceEvent::NOT_READY = GraphTrace::NOT_READY, TraceEvent::NOT_READY, //
TraceEvent::READY_FOR_PROCESS = GraphTrace::READY_FOR_PROCESS, TraceEvent::READY_FOR_PROCESS, //
TraceEvent::READY_FOR_CLOSE = GraphTrace::READY_FOR_CLOSE, TraceEvent::READY_FOR_CLOSE, //
TraceEvent::THROTTLED = GraphTrace::THROTTLED, TraceEvent::THROTTLED, //
TraceEvent::UNTHROTTLED = GraphTrace::UNTHROTTLED, TraceEvent::UNTHROTTLED, //
TraceEvent::CPU_TASK_USER = GraphTrace::CPU_TASK_USER, TraceEvent::CPU_TASK_USER, //
TraceEvent::CPU_TASK_SYSTEM = GraphTrace::CPU_TASK_SYSTEM, TraceEvent::CPU_TASK_SYSTEM, //
TraceEvent::GPU_TASK = GraphTrace::GPU_TASK, TraceEvent::GPU_TASK, //
TraceEvent::DSP_TASK = GraphTrace::DSP_TASK, TraceEvent::DSP_TASK, //
TraceEvent::TPU_TASK = GraphTrace::TPU_TASK; TraceEvent::TPU_TASK;
} // namespace mediapipe } // namespace mediapipe

View File

@ -127,6 +127,11 @@ class TagMap {
std::vector<std::string> names_; std::vector<std::string> names_;
}; };
// Equal TagData structs define equal id ranges.
inline bool operator==(const TagMap::TagData& d1, const TagMap::TagData& d2) {
return d1.id == d2.id && d1.count == d2.count;
}
} // namespace tool } // namespace tool
} // namespace mediapipe } // namespace mediapipe

View File

@ -567,6 +567,10 @@ class TemplateExpanderImpl {
result = AsDict(args); result = AsDict(args);
} else if (expr.op() == "list") { } else if (expr.op() == "list") {
result = AsList(args); result = AsList(args);
} else if (expr.op() == "size") {
return AsArgument(static_cast<double>(
args[0].has_dict() ? args[0].mutable_dict()->arg_size()
: args[0].mutable_element()->size()));
} }
return result; return result;
} }

View File

@ -1318,8 +1318,8 @@ bool IsInfixOperator(const std::string& token) {
// A function-style operator, including a for or if expression. // A function-style operator, including a for or if expression.
bool IsFunctionOperator(const std::string& token) { bool IsFunctionOperator(const std::string& token) {
static auto kTokens = new std::set<std::string>{ static auto kTokens = new std::set<std::string>{
"min", "max", "for", "if", "!", "min", "max", "for", "if", "!", "concat",
"concat", "lowercase", "uppercase", "dict", "list", "lowercase", "uppercase", "size", "dict", "list",
}; };
return kTokens->count(token) > 0; return kTokens->count(token) > 0;
} }

View File

@ -101,6 +101,10 @@ static const GLfloat kBasicTextureVertices[] = {
1.0f, 1.0f, // top right 1.0f, 1.0f, // top right
}; };
// Places a texture on kBasicSquareVertices, flipped horizontally.
static const GLfloat kBasicTextureVerticesFlipX[] = {
V4(kBasicTextureVertices, 1, 0, 3, 2)};
// Places a texture on kBasicSquareVertices, flipped vertically. // Places a texture on kBasicSquareVertices, flipped vertically.
static const GLfloat kBasicTextureVerticesFlipY[] = { static const GLfloat kBasicTextureVerticesFlipY[] = {
V4(kBasicTextureVertices, 2, 3, 0, 1)}; V4(kBasicTextureVertices, 2, 3, 0, 1)};

View File

@ -44,3 +44,30 @@ cc_library(
"//mediapipe/calculators/video:opencv_video_decoder_calculator", "//mediapipe/calculators/video:opencv_video_decoder_calculator",
], ],
) )
cc_library(
name = "yt8m_inference_calculators_deps",
deps = [
"//mediapipe/calculators/core:concatenate_vector_calculator",
"//mediapipe/calculators/core:dequantize_byte_array_calculator",
"//mediapipe/calculators/core:packet_cloner_calculator",
"//mediapipe/calculators/core:side_packet_to_stream_calculator",
"//mediapipe/calculators/core:string_to_int_calculator",
"//mediapipe/calculators/tensorflow:lapped_tensor_buffer_calculator",
"//mediapipe/calculators/tensorflow:string_to_sequence_example_calculator",
"//mediapipe/calculators/tensorflow:tensor_to_vector_float_calculator",
"//mediapipe/calculators/tensorflow:tensorflow_inference_calculator",
"//mediapipe/calculators/tensorflow:tensorflow_session_from_saved_model_calculator",
"//mediapipe/calculators/tensorflow:tfrecord_reader_calculator",
"//mediapipe/calculators/tensorflow:unpack_media_sequence_calculator",
"//mediapipe/calculators/tensorflow:unpack_yt8m_sequence_example_calculator",
"//mediapipe/calculators/tensorflow:vector_float_to_tensor_calculator",
"//mediapipe/calculators/tensorflow:vector_int_to_tensor_calculator",
"//mediapipe/calculators/util:annotation_overlay_calculator",
"//mediapipe/calculators/util:labels_to_render_data_calculator",
"//mediapipe/calculators/util:local_file_contents_calculator",
"//mediapipe/calculators/util:top_k_scores_calculator",
"//mediapipe/calculators/video:opencv_video_decoder_calculator",
"//mediapipe/calculators/video:opencv_video_encoder_calculator",
],
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,178 @@
input_side_packet: "input_sequence_example_path"
input_side_packet: "input_video_path"
input_side_packet: "output_video_path"
input_side_packet: "segment_size"
input_side_packet: "overlap"
node {
calculator: "LocalFileContentsCalculator"
input_side_packet: "FILE_PATH:input_sequence_example_path"
output_side_packet: "CONTENTS:input_sequence_example"
}
node {
calculator: "StringToSequenceExampleCalculator"
input_side_packet: "STRING:input_sequence_example"
output_side_packet: "SEQUENCE_EXAMPLE:parsed_sequence_example"
}
node {
calculator: "UnpackMediaSequenceCalculator"
input_side_packet: "SEQUENCE_EXAMPLE:parsed_sequence_example"
output_stream: "FLOAT_FEATURE_RGB:rgb_feature_vector"
output_stream: "FLOAT_FEATURE_AUDIO:audio_feature_vector"
}
node {
calculator: "ConcatenateFloatVectorCalculator"
input_stream: "rgb_feature_vector"
input_stream: "audio_feature_vector"
output_stream: "feature_vector"
}
node {
calculator: "VectorFloatToTensorCalculator"
input_stream: "feature_vector"
output_stream: "feature_tensor"
}
node {
calculator: "StringToInt32Calculator"
input_side_packet: "segment_size"
output_side_packet: "segment_size_int"
}
node {
calculator: "StringToInt32Calculator"
input_side_packet: "overlap"
output_side_packet: "overlap_int"
}
node {
calculator: "LappedTensorBufferCalculator"
input_stream: "feature_tensor"
output_stream: "lapped_feature_tensor"
input_side_packet: "BUFFER_SIZE:segment_size_int"
input_side_packet: "OVERLAP:overlap_int"
node_options: {
[type.googleapis.com/mediapipe.LappedTensorBufferCalculatorOptions] {
add_batch_dim_to_tensors: true
}
}
}
node {
calculator: "SidePacketToStreamCalculator"
input_side_packet: "segment_size_int"
output_stream: "AT_ZERO:segment_size_int_stream"
}
node {
calculator: "VectorIntToTensorCalculator"
input_stream: "SINGLE_INT:segment_size_int_stream"
output_stream: "TENSOR_OUT:segment_size_tensor"
}
node {
calculator: "PacketClonerCalculator"
input_stream: "segment_size_tensor"
input_stream: "lapped_feature_tensor"
output_stream: "synced_segment_size_tensor"
}
node {
calculator: "TensorFlowSessionFromSavedModelCalculator"
output_side_packet: "SESSION:session"
node_options: {
[type.googleapis.com/mediapipe.TensorFlowSessionFromSavedModelCalculatorOptions]: {
saved_model_path: "/tmp/mediapipe/saved_model"
}
}
}
node: {
calculator: "TensorFlowInferenceCalculator"
input_side_packet: "SESSION:session"
input_stream: "NUM_FRAMES:synced_segment_size_tensor"
input_stream: "RGB_AND_AUDIO:lapped_feature_tensor"
output_stream: "PREDICTIONS:prediction_tensor"
node_options: {
[type.googleapis.com/mediapipe.TensorFlowInferenceCalculatorOptions]: {
batch_size: 32
}
}
}
node {
calculator: "TensorToVectorFloatCalculator"
input_stream: "prediction_tensor"
output_stream: "prediction_vector"
}
node {
calculator: "TopKScoresCalculator"
input_stream: "SCORES:prediction_vector"
output_stream: "TOP_K_INDEXES:top_k_indexes"
output_stream: "TOP_K_SCORES:top_k_scores"
output_stream: "TOP_K_LABELS:top_k_labels"
node_options: {
[type.googleapis.com/mediapipe.TopKScoresCalculatorOptions]: {
top_k: 3
label_map_path: "mediapipe/graphs/youtube8m/label_map.txt"
}
}
}
node {
calculator: "OpenCvVideoDecoderCalculator"
input_side_packet: "INPUT_FILE_PATH:input_video_path"
output_stream: "VIDEO:input_video"
output_stream: "VIDEO_PRESTREAM:input_video_header"
}
node {
calculator: "LabelsToRenderDataCalculator"
input_stream: "LABELS:top_k_labels"
input_stream: "SCORES:top_k_scores"
input_stream: "VIDEO_PRESTREAM:input_video_header"
output_stream: "RENDER_DATA:render_data"
node_options: {
[type.googleapis.com/mediapipe.LabelsToRenderDataCalculatorOptions]: {
color { r: 255 g: 0 b: 0 }
color { r: 0 g: 255 b: 0 }
color { r: 0 g: 0 b: 255 }
thickness: 2.0
font_height_px: 20
max_num_labels: 3
location: TOP_LEFT
}
}
}
node {
calculator: "PacketClonerCalculator"
input_stream: "render_data"
input_stream: "input_video"
output_stream: "synchronized_render_data"
}
node {
calculator: "AnnotationOverlayCalculator"
input_stream: "INPUT_FRAME:input_video"
input_stream: "synchronized_render_data"
output_stream: "OUTPUT_FRAME:output_video"
}
node {
calculator: "OpenCvVideoEncoderCalculator"
input_stream: "VIDEO:output_video"
input_stream: "VIDEO_PRESTREAM:input_video_header"
input_side_packet: "OUTPUT_FILE_PATH:output_video_path"
node_options: {
[type.googleapis.com/mediapipe.OpenCvVideoEncoderCalculatorOptions]: {
codec: "avc1"
video_format: "mp4"
}
}
}

View File

@ -0,0 +1,139 @@
input_side_packet: "desired_segment_size"
input_side_packet: "record_index"
input_side_packet: "tfrecord_path"
output_side_packet: "yt8m_id"
output_stream: "annotation_summary"
node {
calculator: "StringToInt32Calculator"
input_side_packet: "record_index"
output_side_packet: "record_index_int"
}
node {
calculator: "StringToInt32Calculator"
input_side_packet: "desired_segment_size"
output_side_packet: "desired_segment_size_int"
}
node {
calculator: "TFRecordReaderCalculator"
input_side_packet: "TFRECORD_PATH:tfrecord_path"
input_side_packet: "RECORD_INDEX:record_index_int"
output_side_packet: "SEQUENCE_EXAMPLE:yt8m_sequence_example"
}
node {
calculator: "UnpackYt8mSequenceExampleCalculator"
input_side_packet: "YT8M_SEQUENCE_EXAMPLE:yt8m_sequence_example"
input_side_packet: "DESIRED_SEGMENT_SIZE:desired_segment_size_int"
output_side_packet: "YT8M_ID:yt8m_id"
output_side_packet: "SEGMENT_SIZE:segment_size"
output_side_packet: "LAPPED_TENSOR_BUFFER_CALCULATOR_OPTIONS:lapped_tensor_buffer_calculator_options"
output_stream: "QUANTIZED_RGB_FEATURE:quantized_rgb_feature"
output_stream: "QUANTIZED_AUDIO_FEATURE:quantized_audio_feature"
}
node {
calculator: "DequantizeByteArrayCalculator"
input_stream: "ENCODED:quantized_rgb_feature"
output_stream: "FLOAT_VECTOR:rgb_feature_vector"
node_options: {
[type.googleapis.com/mediapipe.DequantizeByteArrayCalculatorOptions]: {
max_quantized_value: 2
min_quantized_value: -2
}
}
}
node {
calculator: "DequantizeByteArrayCalculator"
input_stream: "ENCODED:quantized_audio_feature"
output_stream: "FLOAT_VECTOR:audio_feature_vector"
node_options: {
[type.googleapis.com/mediapipe.DequantizeByteArrayCalculatorOptions]: {
max_quantized_value: 2
min_quantized_value: -2
}
}
}
node {
calculator: "ConcatenateFloatVectorCalculator"
input_stream: "rgb_feature_vector"
input_stream: "audio_feature_vector"
output_stream: "feature_vector"
}
node {
calculator: "VectorFloatToTensorCalculator"
input_stream: "feature_vector"
output_stream: "feature_tensor"
}
node {
calculator: "LappedTensorBufferCalculator"
input_stream: "feature_tensor"
input_side_packet: "CALCULATOR_OPTIONS:lapped_tensor_buffer_calculator_options"
output_stream: "lapped_feature_tensor"
}
node {
calculator: "SidePacketToStreamCalculator"
input_side_packet: "segment_size"
output_stream: "AT_ZERO:segment_size_int_stream"
}
node {
calculator: "VectorIntToTensorCalculator"
input_stream: "SINGLE_INT:segment_size_int_stream"
output_stream: "TENSOR_OUT:segment_size_tensor"
}
node {
calculator: "PacketClonerCalculator"
input_stream: "segment_size_tensor"
input_stream: "lapped_feature_tensor"
output_stream: "synced_segment_size_tensor"
}
node {
calculator: "TensorFlowSessionFromSavedModelCalculator"
output_side_packet: "SESSION:session"
node_options: {
[type.googleapis.com/mediapipe.TensorFlowSessionFromSavedModelCalculatorOptions]: {
saved_model_path: "/tmp/mediapipe/saved_model"
}
}
}
node: {
calculator: "TensorFlowInferenceCalculator"
input_side_packet: "SESSION:session"
input_stream: "NUM_FRAMES:synced_segment_size_tensor"
input_stream: "RGB_AND_AUDIO:lapped_feature_tensor"
output_stream: "PREDICTIONS:prediction_tensor"
node_options: {
[type.googleapis.com/mediapipe.TensorFlowInferenceCalculatorOptions]: {
batch_size: 32
}
}
}
node {
calculator: "TensorToVectorFloatCalculator"
input_stream: "prediction_tensor"
output_stream: "prediction_vector"
}
node {
calculator: "TopKScoresCalculator"
input_stream: "SCORES:prediction_vector"
output_stream: "SUMMARY:annotation_summary"
node_options: {
[type.googleapis.com/mediapipe.TopKScoresCalculatorOptions]: {
top_k: 9
label_map_path: "mediapipe/graphs/youtube8m/label_map.txt"
}
}
}

View File

@ -0,0 +1,15 @@
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"]) # Apache 2.0

View File

@ -68,3 +68,10 @@ android_library(
"@com_google_guava_android//jar", "@com_google_guava_android//jar",
], ],
) )
# Expose the java source files for building mediapipe AAR.
filegroup(
name = "java_src",
srcs = glob(["*.java"]),
visibility = ["//mediapipe:__subpackages__"],
)

View File

@ -150,6 +150,7 @@ public class ExternalTextureConverter implements TextureFrameProducer {
private ExternalTextureRenderer renderer = null; private ExternalTextureRenderer renderer = null;
private long timestampOffset = 0; private long timestampOffset = 0;
private long previousTimestamp = 0; private long previousTimestamp = 0;
private boolean previousTimestampValid = false;
protected int destinationWidth = 0; protected int destinationWidth = 0;
protected int destinationHeight = 0; protected int destinationHeight = 0;
@ -335,11 +336,12 @@ public class ExternalTextureConverter implements TextureFrameProducer {
// ensures that surface texture has the up-to-date timestamp. (Also adjust |timestampOffset| // ensures that surface texture has the up-to-date timestamp. (Also adjust |timestampOffset|
// to ensure that timestamps increase monotonically.) // to ensure that timestamps increase monotonically.)
long textureTimestamp = surfaceTexture.getTimestamp() / NANOS_PER_MICRO; long textureTimestamp = surfaceTexture.getTimestamp() / NANOS_PER_MICRO;
if (textureTimestamp + timestampOffset <= previousTimestamp) { if (previousTimestampValid && textureTimestamp + timestampOffset <= previousTimestamp) {
timestampOffset = previousTimestamp + 1 - textureTimestamp; timestampOffset = previousTimestamp + 1 - textureTimestamp;
} }
outputFrame.setTimestamp(textureTimestamp + timestampOffset); outputFrame.setTimestamp(textureTimestamp + timestampOffset);
previousTimestamp = outputFrame.getTimestamp(); previousTimestamp = outputFrame.getTimestamp();
previousTimestampValid = true;
} }
private void waitUntilReleased(AppTextureFrame frame) { private void waitUntilReleased(AppTextureFrame frame) {

View File

@ -82,3 +82,10 @@ android_library(
"@com_google_guava_android//jar", "@com_google_guava_android//jar",
], ],
) )
# Expose the java source files for building mediapipe AAR.
filegroup(
name = "java_src",
srcs = glob(["*.java"]),
visibility = ["//mediapipe:__subpackages__"],
)

View File

@ -30,3 +30,10 @@ android_library(
"@com_google_guava_android//jar", "@com_google_guava_android//jar",
], ],
) )
# Expose the java source files for building mediapipe AAR.
filegroup(
name = "java_src",
srcs = glob(["**/*.java"]),
visibility = ["//mediapipe:__subpackages__"],
)

View File

@ -0,0 +1,157 @@
"""Generate MediaPipe AAR including different variants of .so in jni folder.
Usage:
Create a new mediapipe_aar() target in a BUILD file. For example,
putting the following code into mediapipe/examples/android/aar_demo/BUILD.
```
load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_aar")
mediapipe_aar(
name = "my_aar",
calculators = ["//mediapipe/calculators/core:pass_through_calculator"],
)
```
Then, run the following Bazel command to generate the AAR.
```
$ bazel build -c opt --fat_apk_cpu=arm64-v8a,armeabi-v7a mediapipe/examples/android/aar_demo:my_aar
```
Finally, import the AAR into Android Studio.
"""
load("@build_bazel_rules_android//android:rules.bzl", "android_binary", "android_library")
def mediapipe_aar(name, calculators = []):
"""Generate MediaPipe AAR.
Args:
name: the name of the AAR.
calculators: the calculator libraries to be compiled into the .so.
"""
native.cc_binary(
name = "libmediapipe_jni.so",
linkshared = 1,
linkstatic = 1,
deps = [
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
] + calculators,
)
native.cc_library(
name = name + "_mediapipe_jni_lib",
srcs = [":libmediapipe_jni.so"],
alwayslink = 1,
)
native.genrule(
name = name + "_aar_manifest_generator",
outs = ["AndroidManifest.xml"],
cmd = """
cat > $(OUTS) <<EOF
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe">
<uses-sdk
android:minSdkVersion="21"
android:targetSdkVersion="27" />
<application />
</manifest>
""",
)
native.genrule(
name = name + "_calculator_proto_java_src_generator",
srcs = [
"//mediapipe/framework:protos_src",
"@com_google_protobuf_javalite//:well_known_protos",
],
outs = ["CalculatorProto.java"],
cmd = "$(location @com_google_protobuf_javalite//:protoc) " +
"--plugin=protoc-gen-javalite=$(location @com_google_protobuf_javalite//:protoc_gen_javalite) " +
"--proto_path=. --proto_path=$(GENDIR) " +
"--proto_path=$$(pwd)/external/com_google_protobuf_javalite/src " +
"--javalite_out=$$(dirname $(location CalculatorProto.java)) mediapipe/framework/calculator.proto && " +
"mv $$(dirname $(location CalculatorProto.java))/com/google/mediapipe/proto/CalculatorProto.java $$(dirname $(location CalculatorProto.java))",
tools = [
"@com_google_protobuf_javalite//:protoc",
"@com_google_protobuf_javalite//:protoc_gen_javalite",
],
)
android_library(
name = name + "_android_lib",
srcs = [
"//mediapipe/java/com/google/mediapipe/components:java_src",
"//mediapipe/java/com/google/mediapipe/framework:java_src",
"//mediapipe/java/com/google/mediapipe/glutil:java_src",
"CalculatorProto.java",
],
manifest = "AndroidManifest.xml",
proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"],
deps = [
":" + name + "_mediapipe_jni_lib",
"//mediapipe/framework:calculator_java_proto_lite",
"//mediapipe/framework:calculator_profile_java_proto_lite",
"//mediapipe/framework/tool:calculator_graph_template_java_proto_lite",
"//third_party:androidx_annotation",
"//third_party:androidx_appcompat",
"//third_party:androidx_core",
"//third_party:androidx_legacy_support_v4",
"//third_party:camerax_core",
"//third_party:camera2",
"@com_google_code_findbugs//jar",
"@com_google_common_flogger//jar",
"@com_google_common_flogger_system_backend//jar",
"@com_google_guava_android//jar",
"@androidx_lifecycle//jar",
],
)
_aar_with_jni(name, name + "_android_lib")
def _aar_with_jni(name, android_library):
# Generate dummy AndroidManifest.xml for dummy apk usage
# (dummy apk is generated by <name>_dummy_app target below)
native.genrule(
name = name + "_binary_manifest_generator",
outs = [name + "_generated_AndroidManifest.xml"],
cmd = """
cat > $(OUTS) <<EOF
<manifest
xmlns:android="http://schemas.android.com/apk/res/android"
package="dummy.package.for.so">
<uses-sdk android:minSdkVersion="21"/>
</manifest>
EOF
""",
)
# Generate dummy apk including .so files.
# We extract out .so files and throw away the apk.
android_binary(
name = name + "_dummy_app",
manifest = name + "_generated_AndroidManifest.xml",
custom_package = "dummy.package.for.so",
deps = [android_library],
)
native.genrule(
name = name,
srcs = [android_library + ".aar", name + "_dummy_app_unsigned.apk"],
outs = [name + ".aar"],
tags = ["manual"],
cmd = """
cp $(location {}.aar) $(location :{}.aar)
chmod +w $(location :{}.aar)
origdir=$$PWD
cd $$(mktemp -d)
unzip $$origdir/$(location :{}_dummy_app_unsigned.apk) "lib/*"
cp -r lib jni
zip -r $$origdir/$(location :{}.aar) jni/*/*.so
""".format(android_library, name, name, name, name),
)

View File

@ -466,6 +466,7 @@ tasks and tracking (or class) fields for tracking information.
|-----|------|------------------------|-------------| |-----|------|------------------------|-------------|
|`CLASS_SEGMENTATION/image/encoded`|feature list bytes|`add_class_segmentation_encoded` / `AddClassSegmentationEncoded`|The encoded image of class labels at each timestep.| |`CLASS_SEGMENTATION/image/encoded`|feature list bytes|`add_class_segmentation_encoded` / `AddClassSegmentationEncoded`|The encoded image of class labels at each timestep.|
|`CLASS_SEGMENTATION/image/timestamp`|feature list int|`add_class_segmentation_timestamp` / `AddClassSegmentationTimestamp`|The timestamp in microseconds for the class labels.| |`CLASS_SEGMENTATION/image/timestamp`|feature list int|`add_class_segmentation_timestamp` / `AddClassSegmentationTimestamp`|The timestamp in microseconds for the class labels.|
|`CLASS_SEGMENTATION/image/multi_encoded`|feature list bytes list|`add_class_segmentation_multi_encoded` / `AddClassSegmentationMultiEncoded`|Storing multiple segmentation masks in case they overlap.|
|`CLASS_SEGMENTATION/image/format`|context bytes|`set_class_segmentation_format` / `SetClassSegmentationFormat`|The encoding format of the class label images.| |`CLASS_SEGMENTATION/image/format`|context bytes|`set_class_segmentation_format` / `SetClassSegmentationFormat`|The encoding format of the class label images.|
|`CLASS_SEGMENTATION/image/height`|context int|`set_class_segmentation_height` / `SetClassSegmentationHeight`|The height of the image in pixels.| |`CLASS_SEGMENTATION/image/height`|context int|`set_class_segmentation_height` / `SetClassSegmentationHeight`|The height of the image in pixels.|
|`CLASS_SEGMENTATION/image/width`|context int|`set_class_segmentation_width` / `SetClassSegmentationWidth`|The width of the image in pixels.| |`CLASS_SEGMENTATION/image/width`|context int|`set_class_segmentation_width` / `SetClassSegmentationWidth`|The width of the image in pixels.|
@ -477,6 +478,7 @@ tasks and tracking (or class) fields for tracking information.
|-----|------|------------------------|-------------| |-----|------|------------------------|-------------|
|`INSTANCE_SEGMENTATION/image/ encoded`|feature list bytes|`add_instance_segmentation_encoded` / `AddInstanceSegmentationEncoded`|The encoded image of object instance labels at each timestep.| |`INSTANCE_SEGMENTATION/image/ encoded`|feature list bytes|`add_instance_segmentation_encoded` / `AddInstanceSegmentationEncoded`|The encoded image of object instance labels at each timestep.|
|`INSTANCE_SEGMENTATION/image/ timestamp`|feature list int|`add_instance_segmentation_timestamp` / `AddInstanceSegmentationTimestamp`|The timestamp in microseconds for the object instance labels.| |`INSTANCE_SEGMENTATION/image/ timestamp`|feature list int|`add_instance_segmentation_timestamp` / `AddInstanceSegmentationTimestamp`|The timestamp in microseconds for the object instance labels.|
|`INSTANCE_SEGMENTATION/image/multi_encoded`|feature list bytes list|`add_instance_segmentation_multi_encoded` / `AddInstanceSegmentationEncoded`|Storing multiple segmentation masks in case they overlap.|
|`INSTANCE_SEGMENTATION/image/ format`|context bytes|`set_instance_segmentation_format` / `SetInstanceSegmentationFormat`|The encoding format of the object instance labels.| |`INSTANCE_SEGMENTATION/image/ format`|context bytes|`set_instance_segmentation_format` / `SetInstanceSegmentationFormat`|The encoding format of the object instance labels.|
|`INSTANCE_SEGMENTATION/image/ height`|context int|`set_instance_segmentation_height` / `SetInstanceSegmentationHeight`|The height of the image in pixels.| |`INSTANCE_SEGMENTATION/image/ height`|context int|`set_instance_segmentation_height` / `SetInstanceSegmentationHeight`|The height of the image in pixels.|
|`INSTANCE_SEGMENTATION/image/ width`|context int|`set_instance_segmentation_width` / `SetInstanceSegmentationWidth`|The width of the image in pixels.| |`INSTANCE_SEGMENTATION/image/ width`|context int|`set_instance_segmentation_width` / `SetInstanceSegmentationWidth`|The width of the image in pixels.|

View File

@ -489,7 +489,9 @@ def _create_image_with_prefix(name, prefix):
prefix=prefix, module_dict=globals()) prefix=prefix, module_dict=globals())
msu.create_int_feature_list(name + "_timestamp", IMAGE_TIMESTAMP_KEY, msu.create_int_feature_list(name + "_timestamp", IMAGE_TIMESTAMP_KEY,
prefix=prefix, module_dict=globals()) prefix=prefix, module_dict=globals())
msu.create_bytes_list_feature_list(name + "_multi_encoded",
IMAGE_MULTI_ENCODED_KEY, prefix=prefix,
module_dict=globals())
FORWARD_FLOW_PREFIX = "FORWARD_FLOW" FORWARD_FLOW_PREFIX = "FORWARD_FLOW"
CLASS_SEGMENTATION_PREFIX = "CLASS_SEGMENTATION" CLASS_SEGMENTATION_PREFIX = "CLASS_SEGMENTATION"
INSTANCE_SEGMENTATION_PREFIX = "INSTANCE_SEGMENTATION" INSTANCE_SEGMENTATION_PREFIX = "INSTANCE_SEGMENTATION"

View File

@ -78,8 +78,10 @@ class MediaSequenceTest(tf.test.TestCase):
ms.set_bbox_parts((b"HEAD", b"TOE"), example) ms.set_bbox_parts((b"HEAD", b"TOE"), example)
# feature lists # feature lists
ms.add_image_encoded(b"test", example) ms.add_image_encoded(b"test", example)
ms.add_image_multi_encoded([b"test", b"test"], example)
ms.add_image_timestamp(47, example) ms.add_image_timestamp(47, example)
ms.add_forward_flow_encoded(b"test", example) ms.add_forward_flow_encoded(b"test", example)
ms.add_forward_flow_multi_encoded([b"test", b"test"], example)
ms.add_forward_flow_timestamp(47, example) ms.add_forward_flow_timestamp(47, example)
ms.add_bbox_ymin((0.47, 0.49), example) ms.add_bbox_ymin((0.47, 0.49), example)
ms.add_bbox_xmin((0.47, 0.49), example) ms.add_bbox_xmin((0.47, 0.49), example)
@ -109,7 +111,9 @@ class MediaSequenceTest(tf.test.TestCase):
ms.add_predicted_bbox_class_string((b"test", b"strings"), example) ms.add_predicted_bbox_class_string((b"test", b"strings"), example)
ms.add_predicted_bbox_timestamp(47, example) ms.add_predicted_bbox_timestamp(47, example)
ms.add_class_segmentation_encoded(b"test", example) ms.add_class_segmentation_encoded(b"test", example)
ms.add_class_segmentation_multi_encoded([b"test", b"test"], example)
ms.add_instance_segmentation_encoded(b"test", example) ms.add_instance_segmentation_encoded(b"test", example)
ms.add_instance_segmentation_multi_encoded([b"test", b"test"], example)
ms.add_class_segmentation_timestamp(47, example) ms.add_class_segmentation_timestamp(47, example)
ms.set_bbox_embedding_dimensions_per_region((47, 49), example) ms.set_bbox_embedding_dimensions_per_region((47, 49), example)
ms.set_bbox_embedding_format(b"test", example) ms.set_bbox_embedding_format(b"test", example)