Project import generated by Copybara.
GitOrigin-RevId: b137378673f7d66d41bcd46e4fc3a0d9ef254894
This commit is contained in:
parent
a2a63e3876
commit
259b48e082
|
@ -37,10 +37,15 @@ A web-based visualizer is hosted on [viz.mediapipe.dev](https://viz.mediapipe.de
|
|||
* [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General community discussion around MediaPipe
|
||||
|
||||
## Publications
|
||||
* [On-Device, Real-Time Hand Tracking with MediaPipe](https://ai.googleblog.com/2019/08/on-device-real-time-hand-tracking-with.html)
|
||||
* [MediaPipe: A Framework for Building Perception Pipelines](https://arxiv.org/abs/1906.08172)
|
||||
|
||||
## Events
|
||||
[Open sourced at CVPR 2019](https://sites.google.com/corp/view/perception-cv4arvr/mediapipe) on June 17~20 in Long Beach, CA
|
||||
* [ML Conference, Berlin 9-11 Dec 2019](https://mlconference.ai/machine-learning-advanced-development/mediapipe-building-real-time-cross-platform-mobile-web-edge-desktop-video-audio-ml-pipelines/)
|
||||
* [The 3rd Workshop on YouTube-8M Large Scale Video Understanding Workshop](https://research.google.com/youtube8m/workshop2019/index.html) Seoul, Korea ICCV 2019
|
||||
* [AI DevWorld 2019](https://aidevworld.com) on Oct 10 in San Jose, California
|
||||
* [Google Industry Workshop at ICIP 2019](http://2019.ieeeicip.org/?action=page4&id=14#Google) [Presentation](https://docs.google.com/presentation/d/e/2PACX-1vRIBBbO_LO9v2YmvbHHEt1cwyqH6EjDxiILjuT0foXy1E7g6uyh4CesB2DkkEwlRDO9_lWfuKMZx98T/pub?start=false&loop=false&delayms=3000&slide=id.g556cc1a659_0_5) on Sept 24 in Taipei, Taiwan
|
||||
* [Open sourced at CVPR 2019](https://sites.google.com/corp/view/perception-cv4arvr/mediapipe) on June 17~20 in Long Beach, CA
|
||||
|
||||
## Alpha Disclaimer
|
||||
MediaPipe is currently in alpha for v0.6. We are still making breaking API changes and expect to get to stable API by v1.0.
|
||||
|
|
|
@ -10,7 +10,8 @@ http_archive(
|
|||
sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e",
|
||||
)
|
||||
load("@bazel_skylib//lib:versions.bzl", "versions")
|
||||
versions.check(minimum_bazel_version = "0.24.1")
|
||||
versions.check(minimum_bazel_version = "0.24.1",
|
||||
maximum_bazel_version = "0.29.1")
|
||||
|
||||
# ABSL cpp library.
|
||||
http_archive(
|
||||
|
|
|
@ -26,6 +26,13 @@ proto_library(
|
|||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "dequantize_byte_array_calculator_proto",
|
||||
srcs = ["dequantize_byte_array_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "packet_cloner_calculator_proto",
|
||||
srcs = ["packet_cloner_calculator.proto"],
|
||||
|
@ -104,6 +111,14 @@ mediapipe_cc_proto_library(
|
|||
deps = [":concatenate_vector_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "dequantize_byte_array_calculator_cc_proto",
|
||||
srcs = ["dequantize_byte_array_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":dequantize_byte_array_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "quantize_float_vector_calculator_cc_proto",
|
||||
srcs = ["quantize_float_vector_calculator.proto"],
|
||||
|
@ -387,6 +402,32 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "string_to_int_calculator",
|
||||
srcs = ["string_to_int_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "side_packet_to_stream_calculator",
|
||||
srcs = ["side_packet_to_stream_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "immediate_mux_calculator_test",
|
||||
srcs = ["immediate_mux_calculator_test.cc"],
|
||||
|
@ -558,6 +599,32 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "dequantize_byte_array_calculator",
|
||||
srcs = ["dequantize_byte_array_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":dequantize_byte_array_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "dequantize_byte_array_calculator_test",
|
||||
srcs = ["dequantize_byte_array_calculator_test.cc"],
|
||||
deps = [
|
||||
":dequantize_byte_array_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "quantize_float_vector_calculator",
|
||||
srcs = ["quantize_float_vector_calculator.cc"],
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cfloat>
|
||||
|
||||
#include "mediapipe/calculators/core/dequantize_byte_array_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
// Dequantizes a byte array to a vector of floats.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "DequantizeByteArrayCalculator"
|
||||
// input_stream: "ENCODED:encoded"
|
||||
// output_stream: "FLOAT_VECTOR:float_vector"
|
||||
// options {
|
||||
// [mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
|
||||
// max_quantized_value: 2
|
||||
// min_quantized_value: -2
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
namespace mediapipe {
|
||||
|
||||
class DequantizeByteArrayCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag("ENCODED").Set<std::string>();
|
||||
cc->Outputs().Tag("FLOAT_VECTOR").Set<std::vector<float>>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) final {
|
||||
const auto options =
|
||||
cc->Options<::mediapipe::DequantizeByteArrayCalculatorOptions>();
|
||||
if (!options.has_max_quantized_value() ||
|
||||
!options.has_min_quantized_value()) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"Both max_quantized_value and min_quantized_value must be provided "
|
||||
"in DequantizeByteArrayCalculatorOptions.");
|
||||
}
|
||||
float max_quantized_value = options.max_quantized_value();
|
||||
float min_quantized_value = options.min_quantized_value();
|
||||
if (max_quantized_value < min_quantized_value + FLT_EPSILON) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"max_quantized_value must be greater than min_quantized_value.");
|
||||
}
|
||||
float range = max_quantized_value - min_quantized_value;
|
||||
scalar_ = range / 255.0;
|
||||
bias_ = (range / 512.0) + min_quantized_value;
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) final {
|
||||
const std::string& encoded =
|
||||
cc->Inputs().Tag("ENCODED").Value().Get<std::string>();
|
||||
std::vector<float> float_vector;
|
||||
float_vector.reserve(encoded.length());
|
||||
for (int i = 0; i < encoded.length(); ++i) {
|
||||
float_vector.push_back(
|
||||
static_cast<unsigned char>(encoded.at(i)) * scalar_ + bias_);
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag("FLOAT_VECTOR")
|
||||
.AddPacket(MakePacket<std::vector<float>>(float_vector)
|
||||
.At(cc->InputTimestamp()));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
float scalar_;
|
||||
float bias_;
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(DequantizeByteArrayCalculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message DequantizeByteArrayCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional DequantizeByteArrayCalculatorOptions ext = 272316343;
|
||||
}
|
||||
|
||||
optional float max_quantized_value = 1;
|
||||
optional float min_quantized_value = 2;
|
||||
}
|
|
@ -0,0 +1,137 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "DequantizeByteArrayCalculator"
|
||||
input_stream: "ENCODED:encoded"
|
||||
output_stream: "FLOAT_VECTOR:float_vector"
|
||||
options {
|
||||
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
|
||||
max_quantized_value: 2
|
||||
}
|
||||
}
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::string empty_string;
|
||||
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
|
||||
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr(
|
||||
"Both max_quantized_value and min_quantized_value must be provided"));
|
||||
}
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "DequantizeByteArrayCalculator"
|
||||
input_stream: "ENCODED:encoded"
|
||||
output_stream: "FLOAT_VECTOR:float_vector"
|
||||
options {
|
||||
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
|
||||
max_quantized_value: -2
|
||||
min_quantized_value: 2
|
||||
}
|
||||
}
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::string empty_string;
|
||||
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
|
||||
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr(
|
||||
"max_quantized_value must be greater than min_quantized_value"));
|
||||
}
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "DequantizeByteArrayCalculator"
|
||||
input_stream: "ENCODED:encoded"
|
||||
output_stream: "FLOAT_VECTOR:float_vector"
|
||||
options {
|
||||
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
|
||||
max_quantized_value: 1
|
||||
min_quantized_value: 1
|
||||
}
|
||||
}
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::string empty_string;
|
||||
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
|
||||
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr(
|
||||
"max_quantized_value must be greater than min_quantized_value"));
|
||||
}
|
||||
|
||||
TEST(DequantizeByteArrayCalculatorTest, TestDequantization) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||
calculator: "DequantizeByteArrayCalculator"
|
||||
input_stream: "ENCODED:encoded"
|
||||
output_stream: "FLOAT_VECTOR:float_vector"
|
||||
options {
|
||||
[mediapipe.DequantizeByteArrayCalculatorOptions.ext]: {
|
||||
max_quantized_value: 2
|
||||
min_quantized_value: -2
|
||||
}
|
||||
}
|
||||
)");
|
||||
CalculatorRunner runner(node_config);
|
||||
unsigned char input[4] = {0x7F, 0xFF, 0x00, 0x01};
|
||||
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
|
||||
MakePacket<std::string>(
|
||||
std::string(reinterpret_cast<char const*>(input), 4))
|
||||
.At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& outputs =
|
||||
runner.Outputs().Tag("FLOAT_VECTOR").packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
const std::vector<float>& result = outputs[0].Get<std::vector<float>>();
|
||||
ASSERT_FALSE(result.empty());
|
||||
EXPECT_EQ(4, result.size());
|
||||
EXPECT_NEAR(0, result[0], 0.01);
|
||||
EXPECT_NEAR(2, result[1], 0.01);
|
||||
EXPECT_NEAR(-2, result[2], 0.01);
|
||||
EXPECT_NEAR(-1.976, result[3], 0.01);
|
||||
|
||||
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -102,6 +102,12 @@ class PreviousLoopbackCalculator : public CalculatorBase {
|
|||
cc->Outputs().Get(loop_out_id_).AddPacket(std::move(previous_loopback));
|
||||
}
|
||||
}
|
||||
if (!main_ts_.empty()) {
|
||||
cc->Outputs().Get(loop_out_id_).SetNextTimestampBound(main_ts_.front());
|
||||
}
|
||||
if (cc->Inputs().Get(main_id_).IsDone() && main_ts_.empty()) {
|
||||
cc->Outputs().Get(loop_out_id_).Close();
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -107,5 +107,96 @@ TEST(PreviousLoopbackCalculator, CorrectTimestamps) {
|
|||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
// A Calculator that outputs a summary packet in CalculatorBase::Close().
|
||||
class PacketOnCloseCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<int>();
|
||||
cc->Outputs().Index(0).Set<int>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) final {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) final {
|
||||
sum_ += cc->Inputs().Index(0).Value().Get<int>();
|
||||
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Close(CalculatorContext* cc) final {
|
||||
cc->Outputs().Index(0).AddPacket(
|
||||
MakePacket<int>(sum_).At(Timestamp::Max()));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
int sum_ = 0;
|
||||
};
|
||||
REGISTER_CALCULATOR(PacketOnCloseCalculator);
|
||||
|
||||
// Demonstrates that all ouput and input streams in PreviousLoopbackCalculator
|
||||
// will close as expected when all graph input streams are closed.
|
||||
TEST(PreviousLoopbackCalculator, ClosesCorrectly) {
|
||||
std::vector<Packet> outputs;
|
||||
CalculatorGraphConfig graph_config_ =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: 'in'
|
||||
node {
|
||||
calculator: 'PreviousLoopbackCalculator'
|
||||
input_stream: 'MAIN:in'
|
||||
input_stream: 'LOOP:out'
|
||||
input_stream_info: { tag_index: 'LOOP' back_edge: true }
|
||||
output_stream: 'PREV_LOOP:previous'
|
||||
}
|
||||
# This calculator synchronizes its inputs as normal, so it is used
|
||||
# to check that both "in" and "previous" are ready.
|
||||
node {
|
||||
calculator: 'PassThroughCalculator'
|
||||
input_stream: 'in'
|
||||
input_stream: 'previous'
|
||||
output_stream: 'out'
|
||||
output_stream: 'previous2'
|
||||
}
|
||||
node {
|
||||
calculator: 'PacketOnCloseCalculator'
|
||||
input_stream: 'out'
|
||||
output_stream: 'close_out'
|
||||
}
|
||||
)");
|
||||
tool::AddVectorSink("close_out", &graph_config_, &outputs);
|
||||
|
||||
CalculatorGraph graph_;
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config_, {}));
|
||||
MP_ASSERT_OK(graph_.StartRun({}));
|
||||
|
||||
auto send_packet = [&graph_](const std::string& input_name, int n) {
|
||||
MP_EXPECT_OK(graph_.AddPacketToInputStream(
|
||||
input_name, MakePacket<int>(n).At(Timestamp(n))));
|
||||
};
|
||||
|
||||
send_packet("in", 1);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{1}));
|
||||
|
||||
send_packet("in", 5);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{1, 5}));
|
||||
|
||||
send_packet("in", 15);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{1, 5, 15}));
|
||||
|
||||
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
EXPECT_EQ(TimestampValues(outputs),
|
||||
(std::vector<int64>{1, 5, 15, Timestamp::Max().Value()}));
|
||||
|
||||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
using mediapipe::PacketTypeSet;
|
||||
using mediapipe::Timestamp;
|
||||
|
||||
namespace {
|
||||
static std::map<std::string, Timestamp>* kTimestampMap = []() {
|
||||
auto* res = new std::map<std::string, Timestamp>();
|
||||
res->emplace("AT_PRESTREAM", Timestamp::PreStream());
|
||||
res->emplace("AT_POSTSTREAM", Timestamp::PostStream());
|
||||
res->emplace("AT_ZERO", Timestamp(0));
|
||||
return res;
|
||||
}();
|
||||
|
||||
} // namespace
|
||||
|
||||
// Outputs the single input_side_packet at the timestamp specified in the
|
||||
// output_stream tag. Valid tags are AT_PRESTREAM, AT_POSTSTREAM and AT_ZERO.
|
||||
class SidePacketToStreamCalculator : public CalculatorBase {
|
||||
public:
|
||||
SidePacketToStreamCalculator() = default;
|
||||
~SidePacketToStreamCalculator() override = default;
|
||||
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Close(CalculatorContext* cc) override;
|
||||
};
|
||||
REGISTER_CALCULATOR(SidePacketToStreamCalculator);
|
||||
|
||||
::mediapipe::Status SidePacketToStreamCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
cc->InputSidePackets().Index(0).SetAny();
|
||||
|
||||
std::set<std::string> tags = cc->Outputs().GetTags();
|
||||
RET_CHECK_EQ(tags.size(), 1);
|
||||
|
||||
RET_CHECK_EQ(kTimestampMap->count(*tags.begin()), 1);
|
||||
cc->Outputs().Tag(*tags.begin()).SetAny();
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status SidePacketToStreamCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
return mediapipe::tool::StatusStop();
|
||||
}
|
||||
|
||||
::mediapipe::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) {
|
||||
std::set<std::string> tags = cc->Outputs().GetTags();
|
||||
RET_CHECK_EQ(tags.size(), 1);
|
||||
const std::string& tag = *tags.begin();
|
||||
RET_CHECK_EQ(kTimestampMap->count(tag), 1);
|
||||
cc->Outputs().Tag(tag).AddPacket(
|
||||
cc->InputSidePackets().Index(0).At(kTimestampMap->at(tag)));
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -34,7 +34,9 @@ namespace mediapipe {
|
|||
// SplitVectorCalculatorOptions. If the option "element_only" is set to true,
|
||||
// all ranges should be of size 1 and all outputs will be elements of type T. If
|
||||
// "element_only" is false, ranges can be non-zero in size and all outputs will
|
||||
// be of type std::vector<T>.
|
||||
// be of type std::vector<T>. If the option "combine_outputs" is set to true,
|
||||
// only one output stream can be specified and all ranges of elements will be
|
||||
// combined into one vector.
|
||||
// To use this class for a particular type T, register a calculator using
|
||||
// SplitVectorCalculator<T>.
|
||||
template <typename T>
|
||||
|
@ -49,28 +51,47 @@ class SplitVectorCalculator : public CalculatorBase {
|
|||
const auto& options =
|
||||
cc->Options<::mediapipe::SplitVectorCalculatorOptions>();
|
||||
|
||||
if (cc->Outputs().NumEntries() != options.ranges_size()) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"The number of output streams should match the number of ranges "
|
||||
"specified in the CalculatorOptions.");
|
||||
}
|
||||
|
||||
// Set the output types for each output stream.
|
||||
for (int i = 0; i < cc->Outputs().NumEntries(); ++i) {
|
||||
if (options.ranges(i).begin() < 0 || options.ranges(i).end() < 0 ||
|
||||
options.ranges(i).begin() >= options.ranges(i).end()) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"Indices should be non-negative and begin index should be less "
|
||||
"than the end index.");
|
||||
}
|
||||
if (options.element_only()) {
|
||||
if (options.ranges(i).end() - options.ranges(i).begin() != 1) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"Since element_only is true, all ranges should be of size 1.");
|
||||
if (options.combine_outputs()) {
|
||||
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1);
|
||||
cc->Outputs().Index(0).Set<std::vector<T>>();
|
||||
for (int i = 0; i < options.ranges_size() - 1; ++i) {
|
||||
for (int j = i + 1; j < options.ranges_size(); ++j) {
|
||||
const auto& range_0 = options.ranges(i);
|
||||
const auto& range_1 = options.ranges(j);
|
||||
if ((range_0.begin() >= range_1.begin() &&
|
||||
range_0.begin() < range_1.end()) ||
|
||||
(range_1.begin() >= range_0.begin() &&
|
||||
range_1.begin() < range_0.end())) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"Ranges must be non-overlapping when using combine_outputs "
|
||||
"option.");
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (cc->Outputs().NumEntries() != options.ranges_size()) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"The number of output streams should match the number of ranges "
|
||||
"specified in the CalculatorOptions.");
|
||||
}
|
||||
|
||||
// Set the output types for each output stream.
|
||||
for (int i = 0; i < cc->Outputs().NumEntries(); ++i) {
|
||||
if (options.ranges(i).begin() < 0 || options.ranges(i).end() < 0 ||
|
||||
options.ranges(i).begin() >= options.ranges(i).end()) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"Indices should be non-negative and begin index should be less "
|
||||
"than the end index.");
|
||||
}
|
||||
if (options.element_only()) {
|
||||
if (options.ranges(i).end() - options.ranges(i).begin() != 1) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"Since element_only is true, all ranges should be of size 1.");
|
||||
}
|
||||
cc->Outputs().Index(i).Set<T>();
|
||||
} else {
|
||||
cc->Outputs().Index(i).Set<std::vector<T>>();
|
||||
}
|
||||
cc->Outputs().Index(i).Set<T>();
|
||||
} else {
|
||||
cc->Outputs().Index(i).Set<std::vector<T>>();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -83,13 +104,15 @@ class SplitVectorCalculator : public CalculatorBase {
|
|||
const auto& options =
|
||||
cc->Options<::mediapipe::SplitVectorCalculatorOptions>();
|
||||
|
||||
element_only_ = options.element_only();
|
||||
combine_outputs_ = options.combine_outputs();
|
||||
|
||||
for (const auto& range : options.ranges()) {
|
||||
ranges_.push_back({range.begin(), range.end()});
|
||||
max_range_end_ = std::max(max_range_end_, range.end());
|
||||
total_elements_ += range.end() - range.begin();
|
||||
}
|
||||
|
||||
element_only_ = options.element_only();
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -97,17 +120,29 @@ class SplitVectorCalculator : public CalculatorBase {
|
|||
const auto& input = cc->Inputs().Index(0).Get<std::vector<T>>();
|
||||
RET_CHECK_GE(input.size(), max_range_end_);
|
||||
|
||||
if (element_only_) {
|
||||
if (combine_outputs_) {
|
||||
auto output = absl::make_unique<std::vector<T>>();
|
||||
output->reserve(total_elements_);
|
||||
for (int i = 0; i < ranges_.size(); ++i) {
|
||||
cc->Outputs().Index(i).AddPacket(
|
||||
MakePacket<T>(input[ranges_[i].first]).At(cc->InputTimestamp()));
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < ranges_.size(); ++i) {
|
||||
auto output = absl::make_unique<std::vector<T>>(
|
||||
auto elements = absl::make_unique<std::vector<T>>(
|
||||
input.begin() + ranges_[i].first,
|
||||
input.begin() + ranges_[i].second);
|
||||
cc->Outputs().Index(i).Add(output.release(), cc->InputTimestamp());
|
||||
output->insert(output->end(), elements->begin(), elements->end());
|
||||
}
|
||||
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||
} else {
|
||||
if (element_only_) {
|
||||
for (int i = 0; i < ranges_.size(); ++i) {
|
||||
cc->Outputs().Index(i).AddPacket(
|
||||
MakePacket<T>(input[ranges_[i].first]).At(cc->InputTimestamp()));
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < ranges_.size(); ++i) {
|
||||
auto output = absl::make_unique<std::vector<T>>(
|
||||
input.begin() + ranges_[i].first,
|
||||
input.begin() + ranges_[i].second);
|
||||
cc->Outputs().Index(i).Add(output.release(), cc->InputTimestamp());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -117,7 +152,9 @@ class SplitVectorCalculator : public CalculatorBase {
|
|||
private:
|
||||
std::vector<std::pair<int32, int32>> ranges_;
|
||||
int32 max_range_end_ = -1;
|
||||
int32 total_elements_ = 0;
|
||||
bool element_only_ = false;
|
||||
bool combine_outputs_ = false;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -37,4 +37,7 @@ message SplitVectorCalculatorOptions {
|
|||
// just element of type T. By default, if a range specifies only one element,
|
||||
// it is outputted as an std::vector<T>.
|
||||
optional bool element_only = 2 [default = false];
|
||||
|
||||
// Combines output elements to one vector.
|
||||
optional bool combine_outputs = 3 [default = false];
|
||||
}
|
||||
|
|
|
@ -105,6 +105,34 @@ class SplitTfLiteTensorVectorCalculatorTest : public ::testing::Test {
|
|||
}
|
||||
}
|
||||
|
||||
void ValidateCombinedVectorOutput(std::vector<Packet>& output_packets,
|
||||
int expected_elements,
|
||||
std::vector<int>& input_begin_indices,
|
||||
std::vector<int>& input_end_indices) {
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
ASSERT_EQ(input_begin_indices.size(), input_end_indices.size());
|
||||
const std::vector<TfLiteTensor>& output_vec =
|
||||
output_packets[0].Get<std::vector<TfLiteTensor>>();
|
||||
ASSERT_EQ(expected_elements, output_vec.size());
|
||||
const int num_ranges = input_begin_indices.size();
|
||||
|
||||
int element_id = 0;
|
||||
for (int range_id = 0; range_id < num_ranges; ++range_id) {
|
||||
for (int i = input_begin_indices[range_id];
|
||||
i < input_end_indices[range_id]; ++i) {
|
||||
const int expected_value = i;
|
||||
const TfLiteTensor* result = &output_vec[element_id];
|
||||
float* result_buffer = result->data.f;
|
||||
ASSERT_NE(result_buffer, nullptr);
|
||||
ASSERT_EQ(result_buffer, input_buffers_[i]);
|
||||
for (int j = 0; j < width * height * channels; ++j) {
|
||||
ASSERT_EQ(expected_value, result_buffer[j]);
|
||||
}
|
||||
element_id++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ValidateElementOutput(std::vector<Packet>& output_packets,
|
||||
int input_begin_index) {
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
|
@ -234,6 +262,65 @@ TEST_F(SplitTfLiteTensorVectorCalculatorTest, InvalidOutputStreamCountTest) {
|
|||
ASSERT_FALSE(graph.Initialize(graph_config).ok());
|
||||
}
|
||||
|
||||
TEST_F(SplitTfLiteTensorVectorCalculatorTest,
|
||||
InvalidCombineOutputsMultipleOutputsTest) {
|
||||
ASSERT_NE(interpreter_, nullptr);
|
||||
|
||||
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
|
||||
CalculatorGraphConfig graph_config =
|
||||
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"(
|
||||
input_stream: "tensor_in"
|
||||
node {
|
||||
calculator: "SplitTfLiteTensorVectorCalculator"
|
||||
input_stream: "tensor_in"
|
||||
output_stream: "range_0"
|
||||
output_stream: "range_1"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges: { begin: 0 end: 1 }
|
||||
ranges: { begin: 2 end: 3 }
|
||||
combine_outputs: true
|
||||
}
|
||||
}
|
||||
}
|
||||
)");
|
||||
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
// The graph should fail running because the number of output streams does not
|
||||
// match the number of range elements in the options.
|
||||
ASSERT_FALSE(graph.Initialize(graph_config).ok());
|
||||
}
|
||||
|
||||
TEST_F(SplitTfLiteTensorVectorCalculatorTest, InvalidOverlappingRangesTest) {
|
||||
ASSERT_NE(interpreter_, nullptr);
|
||||
|
||||
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
|
||||
CalculatorGraphConfig graph_config =
|
||||
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"(
|
||||
input_stream: "tensor_in"
|
||||
node {
|
||||
calculator: "SplitTfLiteTensorVectorCalculator"
|
||||
input_stream: "tensor_in"
|
||||
output_stream: "range_0"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges: { begin: 0 end: 3 }
|
||||
ranges: { begin: 1 end: 4 }
|
||||
combine_outputs: true
|
||||
}
|
||||
}
|
||||
}
|
||||
)");
|
||||
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
// The graph should fail running because there are overlapping ranges.
|
||||
ASSERT_FALSE(graph.Initialize(graph_config).ok());
|
||||
}
|
||||
|
||||
TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTestElementOnly) {
|
||||
ASSERT_NE(interpreter_, nullptr);
|
||||
|
||||
|
@ -289,6 +376,53 @@ TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTestElementOnly) {
|
|||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(SplitTfLiteTensorVectorCalculatorTest, SmokeTestCombiningOutputs) {
|
||||
ASSERT_NE(interpreter_, nullptr);
|
||||
|
||||
PrepareTfLiteTensorVector(/*vector_size=*/5);
|
||||
ASSERT_NE(input_vec_, nullptr);
|
||||
|
||||
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
|
||||
CalculatorGraphConfig graph_config =
|
||||
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"(
|
||||
input_stream: "tensor_in"
|
||||
node {
|
||||
calculator: "SplitTfLiteTensorVectorCalculator"
|
||||
input_stream: "tensor_in"
|
||||
output_stream: "range_0"
|
||||
options {
|
||||
[mediapipe.SplitVectorCalculatorOptions.ext] {
|
||||
ranges: { begin: 0 end: 1 }
|
||||
ranges: { begin: 2 end: 3 }
|
||||
ranges: { begin: 4 end: 5 }
|
||||
combine_outputs: true
|
||||
}
|
||||
}
|
||||
}
|
||||
)");
|
||||
std::vector<Packet> range_0_packets;
|
||||
tool::AddVectorSink("range_0", &graph_config, &range_0_packets);
|
||||
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"tensor_in", Adopt(input_vec_.release()).At(Timestamp(0))));
|
||||
// Wait until the calculator finishes processing.
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
std::vector<int> input_begin_indices = {0, 2, 4};
|
||||
std::vector<int> input_end_indices = {1, 3, 5};
|
||||
ValidateCombinedVectorOutput(range_0_packets, /*expected_elements=*/3,
|
||||
input_begin_indices, input_end_indices);
|
||||
|
||||
// Fully close the graph at the end.
|
||||
MP_ASSERT_OK(graph.CloseInputStream("tensor_in"));
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
TEST_F(SplitTfLiteTensorVectorCalculatorTest,
|
||||
ElementOnlyDisablesVectorOutputs) {
|
||||
// Prepare a graph to use the SplitTfLiteTensorVectorCalculator.
|
||||
|
|
79
mediapipe/calculators/core/string_to_int_calculator.cc
Normal file
79
mediapipe/calculators/core/string_to_int_calculator.cc
Normal file
|
@ -0,0 +1,79 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <sys/types.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Calculator that converts a std::string into an integer type, or fails if the
|
||||
// conversion is not possible.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "StringToIntCalculator"
|
||||
// input_side_packet: "string"
|
||||
// output_side_packet: "index"
|
||||
// }
|
||||
template <typename IntType>
|
||||
class StringToIntCalculatorTemplate : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->InputSidePackets().Index(0).Set<std::string>();
|
||||
cc->OutputSidePackets().Index(0).Set<IntType>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
IntType number;
|
||||
if (!absl::SimpleAtoi(cc->InputSidePackets().Index(0).Get<std::string>(),
|
||||
&number)) {
|
||||
return ::mediapipe::InvalidArgumentError(
|
||||
"The std::string could not be parsed as an integer.");
|
||||
}
|
||||
cc->OutputSidePackets().Index(0).Set(MakePacket<IntType>(number));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
using StringToIntCalculator = StringToIntCalculatorTemplate<int>;
|
||||
REGISTER_CALCULATOR(StringToIntCalculator);
|
||||
|
||||
using StringToUintCalculator = StringToIntCalculatorTemplate<uint>;
|
||||
REGISTER_CALCULATOR(StringToUintCalculator);
|
||||
|
||||
using StringToInt32Calculator = StringToIntCalculatorTemplate<int32>;
|
||||
REGISTER_CALCULATOR(StringToInt32Calculator);
|
||||
|
||||
using StringToUint32Calculator = StringToIntCalculatorTemplate<uint32>;
|
||||
REGISTER_CALCULATOR(StringToUint32Calculator);
|
||||
|
||||
using StringToInt64Calculator = StringToIntCalculatorTemplate<int64>;
|
||||
REGISTER_CALCULATOR(StringToInt64Calculator);
|
||||
|
||||
using StringToUint64Calculator = StringToIntCalculatorTemplate<uint64>;
|
||||
REGISTER_CALCULATOR(StringToUint64Calculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -104,6 +104,17 @@ proto_library(
|
|||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "unpack_media_sequence_calculator_proto",
|
||||
srcs = ["unpack_media_sequence_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:packet_resampler_calculator_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/util:audio_decoder_proto",
|
||||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "vector_float_to_tensor_calculator_options_proto",
|
||||
srcs = ["vector_float_to_tensor_calculator_options.proto"],
|
||||
|
@ -261,6 +272,17 @@ mediapipe_cc_proto_library(
|
|||
deps = [":unpack_media_sequence_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "vector_int_to_tensor_calculator_options_cc_proto",
|
||||
srcs = ["vector_int_to_tensor_calculator_options.proto"],
|
||||
cc_deps = [
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":vector_int_to_tensor_calculator_options_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "vector_float_to_tensor_calculator_options_cc_proto",
|
||||
srcs = ["vector_float_to_tensor_calculator_options.proto"],
|
||||
|
@ -621,6 +643,22 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfrecord_reader_calculator",
|
||||
srcs = ["tfrecord_reader_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@org_tensorflow//tensorflow/core:lib",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_to_vector_float_calculator",
|
||||
srcs = ["tensor_to_vector_float_calculator.cc"],
|
||||
|
@ -662,6 +700,20 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "vector_int_to_tensor_calculator",
|
||||
srcs = ["vector_int_to_tensor_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":vector_int_to_tensor_calculator_options_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@org_tensorflow//tensorflow/core:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "vector_float_to_tensor_calculator",
|
||||
srcs = ["vector_float_to_tensor_calculator.cc"],
|
||||
|
@ -676,6 +728,20 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "unpack_yt8m_sequence_example_calculator",
|
||||
srcs = ["unpack_yt8m_sequence_example_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensorflow:lapped_tensor_buffer_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "graph_tensors_packet_generator_test",
|
||||
srcs = ["graph_tensors_packet_generator_test.cc"],
|
||||
|
@ -980,6 +1046,20 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "vector_int_to_tensor_calculator_test",
|
||||
srcs = ["vector_int_to_tensor_calculator_test.cc"],
|
||||
deps = [
|
||||
":vector_int_to_tensor_calculator",
|
||||
":vector_int_to_tensor_calculator_options_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"@org_tensorflow//tensorflow/core:framework",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "vector_float_to_tensor_calculator_test",
|
||||
srcs = ["vector_float_to_tensor_calculator_test.cc"],
|
||||
|
|
|
@ -29,6 +29,11 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
const char kBufferSize[] = "BUFFER_SIZE";
|
||||
const char kOverlap[] = "OVERLAP";
|
||||
const char kTimestampOffset[] = "TIMESTAMP_OFFSET";
|
||||
const char kCalculatorOptions[] = "CALCULATOR_OPTIONS";
|
||||
|
||||
namespace tf = tensorflow;
|
||||
|
||||
// Given an input stream of tensors, concatenates the tensors over timesteps.
|
||||
|
@ -72,6 +77,9 @@ class LappedTensorBufferCalculator : public CalculatorBase {
|
|||
::mediapipe::Status AddBatchDimension(tf::Tensor* input_tensor);
|
||||
|
||||
int steps_until_output_;
|
||||
int buffer_size_;
|
||||
int overlap_;
|
||||
int timestamp_offset_;
|
||||
std::unique_ptr<CircularBuffer<Timestamp>> timestamp_buffer_;
|
||||
std::unique_ptr<CircularBuffer<tf::Tensor>> buffer_;
|
||||
LappedTensorBufferCalculatorOptions options_;
|
||||
|
@ -87,6 +95,21 @@ REGISTER_CALCULATOR(LappedTensorBufferCalculator);
|
|||
);
|
||||
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
|
||||
<< "Only one output stream is supported.";
|
||||
|
||||
if (cc->InputSidePackets().HasTag(kBufferSize)) {
|
||||
cc->InputSidePackets().Tag(kBufferSize).Set<int>();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag(kOverlap)) {
|
||||
cc->InputSidePackets().Tag(kOverlap).Set<int>();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag(kTimestampOffset)) {
|
||||
cc->InputSidePackets().Tag(kTimestampOffset).Set<int>();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag(kCalculatorOptions)) {
|
||||
cc->InputSidePackets()
|
||||
.Tag(kCalculatorOptions)
|
||||
.Set<LappedTensorBufferCalculatorOptions>();
|
||||
}
|
||||
cc->Outputs().Index(0).Set<tf::Tensor>(
|
||||
// Output tensorflow::Tensor stream with possibly overlapping steps.
|
||||
);
|
||||
|
@ -95,16 +118,33 @@ REGISTER_CALCULATOR(LappedTensorBufferCalculator);
|
|||
|
||||
::mediapipe::Status LappedTensorBufferCalculator::Open(CalculatorContext* cc) {
|
||||
options_ = cc->Options<LappedTensorBufferCalculatorOptions>();
|
||||
RET_CHECK_LT(options_.overlap(), options_.buffer_size());
|
||||
RET_CHECK_GE(options_.timestamp_offset(), 0)
|
||||
if (cc->InputSidePackets().HasTag(kCalculatorOptions)) {
|
||||
options_ = cc->InputSidePackets()
|
||||
.Tag(kCalculatorOptions)
|
||||
.Get<LappedTensorBufferCalculatorOptions>();
|
||||
}
|
||||
buffer_size_ = options_.buffer_size();
|
||||
if (cc->InputSidePackets().HasTag(kBufferSize)) {
|
||||
buffer_size_ = cc->InputSidePackets().Tag(kBufferSize).Get<int>();
|
||||
}
|
||||
overlap_ = options_.overlap();
|
||||
if (cc->InputSidePackets().HasTag(kOverlap)) {
|
||||
overlap_ = cc->InputSidePackets().Tag(kOverlap).Get<int>();
|
||||
}
|
||||
timestamp_offset_ = options_.timestamp_offset();
|
||||
if (cc->InputSidePackets().HasTag(kTimestampOffset)) {
|
||||
timestamp_offset_ = cc->InputSidePackets().Tag(kTimestampOffset).Get<int>();
|
||||
}
|
||||
|
||||
RET_CHECK_LT(overlap_, buffer_size_);
|
||||
RET_CHECK_GE(timestamp_offset_, 0)
|
||||
<< "Negative timestamp_offset is not allowed.";
|
||||
RET_CHECK_LT(options_.timestamp_offset(), options_.buffer_size())
|
||||
RET_CHECK_LT(timestamp_offset_, buffer_size_)
|
||||
<< "output_frame_num_offset has to be less than buffer_size.";
|
||||
timestamp_buffer_ =
|
||||
absl::make_unique<CircularBuffer<Timestamp>>(options_.buffer_size());
|
||||
buffer_ =
|
||||
absl::make_unique<CircularBuffer<tf::Tensor>>(options_.buffer_size());
|
||||
steps_until_output_ = options_.buffer_size();
|
||||
absl::make_unique<CircularBuffer<Timestamp>>(buffer_size_);
|
||||
buffer_ = absl::make_unique<CircularBuffer<tf::Tensor>>(buffer_size_);
|
||||
steps_until_output_ = buffer_size_;
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -128,11 +168,10 @@ REGISTER_CALCULATOR(LappedTensorBufferCalculator);
|
|||
concatenated.get());
|
||||
RET_CHECK(concat_status.ok()) << concat_status.ToString();
|
||||
|
||||
cc->Outputs().Index(0).Add(
|
||||
concatenated.release(),
|
||||
timestamp_buffer_->Get(options_.timestamp_offset()));
|
||||
cc->Outputs().Index(0).Add(concatenated.release(),
|
||||
timestamp_buffer_->Get(timestamp_offset_));
|
||||
|
||||
steps_until_output_ = options_.buffer_size() - options_.overlap();
|
||||
steps_until_output_ = buffer_size_ - overlap_;
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
|
126
mediapipe/calculators/tensorflow/tfrecord_reader_calculator.cc
Normal file
126
mediapipe/calculators/tensorflow/tfrecord_reader_calculator.cc
Normal file
|
@ -0,0 +1,126 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "tensorflow/core/example/example.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/io/record_reader.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/file_system.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
const char kTFRecordPath[] = "TFRECORD_PATH";
|
||||
const char kRecordIndex[] = "RECORD_INDEX";
|
||||
const char kExampleTag[] = "EXAMPLE";
|
||||
const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
|
||||
|
||||
// Reads a tensorflow example/sequence example from a tfrecord file.
|
||||
// If the "RECORD_INDEX" input side packet is provided, the calculator is going
|
||||
// to fetch the example/sequence example of the tfrecord file at the target
|
||||
// record index. Otherwise, the reader always reads the first example/sequence
|
||||
// example of the tfrecord file.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "TFRecordReaderCalculator"
|
||||
// input_side_packet: "TFRECORD_PATH:tfrecord_path"
|
||||
// input_side_packet: "RECORD_INDEX:record_index"
|
||||
// output_side_packet: "SEQUENCE_EXAMPLE:sequence_example"
|
||||
// }
|
||||
class TFRecordReaderCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
};
|
||||
|
||||
::mediapipe::Status TFRecordReaderCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
cc->InputSidePackets().Tag(kTFRecordPath).Set<std::string>();
|
||||
if (cc->InputSidePackets().HasTag(kRecordIndex)) {
|
||||
cc->InputSidePackets().Tag(kRecordIndex).Set<int>();
|
||||
}
|
||||
|
||||
RET_CHECK(cc->OutputSidePackets().HasTag(kExampleTag) ||
|
||||
cc->OutputSidePackets().HasTag(kSequenceExampleTag))
|
||||
<< "TFRecordReaderCalculator must output either Tensorflow example or "
|
||||
"sequence example.";
|
||||
if (cc->OutputSidePackets().HasTag(kExampleTag)) {
|
||||
cc->OutputSidePackets().Tag(kExampleTag).Set<tensorflow::Example>();
|
||||
} else {
|
||||
cc->OutputSidePackets()
|
||||
.Tag(kSequenceExampleTag)
|
||||
.Set<tensorflow::SequenceExample>();
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TFRecordReaderCalculator::Open(CalculatorContext* cc) {
|
||||
std::unique_ptr<tensorflow::RandomAccessFile> file;
|
||||
auto tf_status = tensorflow::Env::Default()->NewRandomAccessFile(
|
||||
cc->InputSidePackets().Tag(kTFRecordPath).Get<std::string>(), &file);
|
||||
RET_CHECK(tf_status.ok())
|
||||
<< "Failed to open tfrecord file: " << tf_status.error_message();
|
||||
tensorflow::io::RecordReader reader(file.get(),
|
||||
tensorflow::io::RecordReaderOptions());
|
||||
tensorflow::uint64 offset = 0;
|
||||
std::string example_str;
|
||||
const int target_idx =
|
||||
cc->InputSidePackets().HasTag(kRecordIndex)
|
||||
? cc->InputSidePackets().Tag(kRecordIndex).Get<int>()
|
||||
: 0;
|
||||
int current_idx = 0;
|
||||
while (current_idx <= target_idx) {
|
||||
tf_status = reader.ReadRecord(&offset, &example_str);
|
||||
RET_CHECK(tf_status.ok())
|
||||
<< "Failed to read tfrecord: " << tf_status.error_message();
|
||||
if (current_idx == target_idx) {
|
||||
if (cc->OutputSidePackets().HasTag(kExampleTag)) {
|
||||
tensorflow::Example tf_example;
|
||||
tf_example.ParseFromString(example_str);
|
||||
cc->OutputSidePackets()
|
||||
.Tag(kExampleTag)
|
||||
.Set(MakePacket<tensorflow::Example>(std::move(tf_example)));
|
||||
} else {
|
||||
tensorflow::SequenceExample tf_sequence_example;
|
||||
tf_sequence_example.ParseFromString(example_str);
|
||||
cc->OutputSidePackets()
|
||||
.Tag(kSequenceExampleTag)
|
||||
.Set(MakePacket<tensorflow::SequenceExample>(
|
||||
std::move(tf_sequence_example)));
|
||||
}
|
||||
}
|
||||
++current_idx;
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status TFRecordReaderCalculator::Process(CalculatorContext* cc) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
REGISTER_CALCULATOR(TFRecordReaderCalculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,192 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <iterator>
|
||||
|
||||
#include "mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "tensorflow/core/example/example.pb.h"
|
||||
#include "tensorflow/core/example/feature.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
const char kId[] = "id";
|
||||
const char kRgb[] = "rgb";
|
||||
const char kAudio[] = "audio";
|
||||
const char kDesiredSegmentSize[] = "DESIRED_SEGMENT_SIZE";
|
||||
const char kYt8mId[] = "YT8M_ID";
|
||||
const char kYt8mSequenceExample[] = "YT8M_SEQUENCE_EXAMPLE";
|
||||
const char kQuantizedRgbFeature[] = "QUANTIZED_RGB_FEATURE";
|
||||
const char kQuantizedAudioFeature[] = "QUANTIZED_AUDIO_FEATURE";
|
||||
const char kSegmentSize[] = "SEGMENT_SIZE";
|
||||
const char kLappedTensorBufferCalculatorOptions[] =
|
||||
"LAPPED_TENSOR_BUFFER_CALCULATOR_OPTIONS";
|
||||
|
||||
std::string GetQuantizedFeature(
|
||||
const tensorflow::SequenceExample& sequence_example, const std::string& key,
|
||||
int index) {
|
||||
const auto& bytes_list = sequence_example.feature_lists()
|
||||
.feature_list()
|
||||
.at(key)
|
||||
.feature()
|
||||
.Get(index)
|
||||
.bytes_list()
|
||||
.value();
|
||||
CHECK_EQ(1, bytes_list.size());
|
||||
return bytes_list.Get(0);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Unpacks YT8M Sequence Example. Note that the audio feature and rgb feature
|
||||
// output are quantized. DequantizeByteArrayCalculator can do the dequantization
|
||||
// for you.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "UnpackYt8mSequenceExampleCalculator"
|
||||
// input_side_packet: "YT8M_SEQUENCE_EXAMPLE:yt8m_sequence_example"
|
||||
// output_stream: "QUANTIZED_RGB_FEATURE:quantized_rgb_feature"
|
||||
// output_stream: "QUANTIZED_AUDIO_FEATURE:quantized_audio_feature"
|
||||
// }
|
||||
class UnpackYt8mSequenceExampleCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->InputSidePackets()
|
||||
.Tag(kYt8mSequenceExample)
|
||||
.Set<tensorflow::SequenceExample>();
|
||||
if (cc->InputSidePackets().HasTag(kDesiredSegmentSize)) {
|
||||
cc->InputSidePackets().Tag(kDesiredSegmentSize).Set<int>();
|
||||
}
|
||||
cc->Outputs().Tag(kQuantizedRgbFeature).Set<std::string>();
|
||||
cc->Outputs().Tag(kQuantizedAudioFeature).Set<std::string>();
|
||||
if (cc->OutputSidePackets().HasTag(kYt8mId)) {
|
||||
cc->OutputSidePackets().Tag(kYt8mId).Set<std::string>();
|
||||
}
|
||||
if (cc->OutputSidePackets().HasTag(kLappedTensorBufferCalculatorOptions)) {
|
||||
cc->OutputSidePackets()
|
||||
.Tag(kLappedTensorBufferCalculatorOptions)
|
||||
.Set<::mediapipe::LappedTensorBufferCalculatorOptions>();
|
||||
}
|
||||
if (cc->OutputSidePackets().HasTag(kSegmentSize)) {
|
||||
cc->OutputSidePackets().Tag(kSegmentSize).Set<int>();
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
const tensorflow::SequenceExample& sequence_example =
|
||||
cc->InputSidePackets()
|
||||
.Tag(kYt8mSequenceExample)
|
||||
.Get<tensorflow::SequenceExample>();
|
||||
const std::string& yt8m_id =
|
||||
sequence_example.context().feature().at(kId).bytes_list().value().Get(
|
||||
0);
|
||||
if (cc->OutputSidePackets().HasTag(kYt8mId)) {
|
||||
cc->OutputSidePackets().Tag(kYt8mId).Set(
|
||||
MakePacket<std::string>(yt8m_id));
|
||||
}
|
||||
|
||||
int rgb_feature_list_length =
|
||||
sequence_example.feature_lists().feature_list().at(kRgb).feature_size();
|
||||
int audio_feature_list_length = sequence_example.feature_lists()
|
||||
.feature_list()
|
||||
.at(kAudio)
|
||||
.feature_size();
|
||||
|
||||
if (rgb_feature_list_length != audio_feature_list_length) {
|
||||
return ::mediapipe::FailedPreconditionError(absl::StrCat(
|
||||
"Data corruption: the length of audio features and rgb features are "
|
||||
"not equal. Please check the sequence example that contains yt8m "
|
||||
"id: ",
|
||||
yt8m_id));
|
||||
}
|
||||
feature_list_length_ = rgb_feature_list_length;
|
||||
if (cc->OutputSidePackets().HasTag(kLappedTensorBufferCalculatorOptions) ||
|
||||
cc->OutputSidePackets().HasTag(kSegmentSize)) {
|
||||
// If the desired segment size is specified, take the min of the length of
|
||||
// the feature list and the desired size to be the output segment size.
|
||||
int segment_size = feature_list_length_;
|
||||
if (cc->InputSidePackets().HasTag(kDesiredSegmentSize)) {
|
||||
int desired_segment_size =
|
||||
cc->InputSidePackets().Tag(kDesiredSegmentSize).Get<int>();
|
||||
RET_CHECK(desired_segment_size > 0)
|
||||
<< "The desired segment size must be greater than zero.";
|
||||
segment_size = std::min(
|
||||
feature_list_length_,
|
||||
cc->InputSidePackets().Tag(kDesiredSegmentSize).Get<int>());
|
||||
}
|
||||
if (cc->OutputSidePackets().HasTag(
|
||||
kLappedTensorBufferCalculatorOptions)) {
|
||||
auto lapped_tensor_buffer_calculator_options = absl::make_unique<
|
||||
::mediapipe::LappedTensorBufferCalculatorOptions>();
|
||||
lapped_tensor_buffer_calculator_options->set_add_batch_dim_to_tensors(
|
||||
true);
|
||||
lapped_tensor_buffer_calculator_options->set_buffer_size(segment_size);
|
||||
lapped_tensor_buffer_calculator_options->set_overlap(segment_size - 1);
|
||||
lapped_tensor_buffer_calculator_options->set_timestamp_offset(
|
||||
segment_size - 1);
|
||||
cc->OutputSidePackets()
|
||||
.Tag(kLappedTensorBufferCalculatorOptions)
|
||||
.Set(Adopt(lapped_tensor_buffer_calculator_options.release()));
|
||||
}
|
||||
if (cc->OutputSidePackets().HasTag(kSegmentSize)) {
|
||||
cc->OutputSidePackets()
|
||||
.Tag(kSegmentSize)
|
||||
.Set(MakePacket<int>(segment_size));
|
||||
}
|
||||
}
|
||||
LOG(INFO) << "Reading the sequence example that contains yt8m id: "
|
||||
<< yt8m_id << ". Feature list length: " << feature_list_length_;
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
if (current_index_ >= feature_list_length_) {
|
||||
return ::mediapipe::tool::StatusStop();
|
||||
}
|
||||
const tensorflow::SequenceExample& sequence_example =
|
||||
cc->InputSidePackets()
|
||||
.Tag(kYt8mSequenceExample)
|
||||
.Get<tensorflow::SequenceExample>();
|
||||
|
||||
// Uses microsecond as the unit of time. In the YT8M dataset, each feature
|
||||
// represents a second.
|
||||
const Timestamp timestamp = Timestamp(current_index_ * 1000000);
|
||||
cc->Outputs()
|
||||
.Tag(kQuantizedRgbFeature)
|
||||
.AddPacket(
|
||||
MakePacket<std::string>(
|
||||
GetQuantizedFeature(sequence_example, kRgb, current_index_))
|
||||
.At(timestamp));
|
||||
cc->Outputs()
|
||||
.Tag(kQuantizedAudioFeature)
|
||||
.AddPacket(
|
||||
MakePacket<std::string>(
|
||||
GetQuantizedFeature(sequence_example, kAudio, current_index_))
|
||||
.At(timestamp));
|
||||
++current_index_;
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
int current_index_ = 0;
|
||||
int feature_list_length_ = 0;
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(UnpackYt8mSequenceExampleCalculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -23,10 +23,12 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace tf = ::tensorflow;
|
||||
|
||||
namespace {
|
||||
auto& INPUT_1D = VectorFloatToTensorCalculatorOptions::INPUT_1D;
|
||||
auto& INPUT_2D = VectorFloatToTensorCalculatorOptions::INPUT_2D;
|
||||
} // namespace
|
||||
|
||||
namespace tf = ::tensorflow;
|
||||
|
||||
// The calculator expects one input (a packet containing a vector<float> or
|
||||
// vector<vector<float>>) and generates one output (a packet containing a
|
||||
|
|
|
@ -0,0 +1,203 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Converts a single int or vector<int> or vector<vector<int>> to 1D (or 2D)
|
||||
// tf::Tensor.
|
||||
|
||||
#include "mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator_options.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
const char kVectorInt[] = "VECTOR_INT";
|
||||
const char kSingleInt[] = "SINGLE_INT";
|
||||
const char kTensorOut[] = "TENSOR_OUT";
|
||||
|
||||
namespace {
|
||||
auto& INPUT_1D = VectorIntToTensorCalculatorOptions::INPUT_1D;
|
||||
auto& INPUT_2D = VectorIntToTensorCalculatorOptions::INPUT_2D;
|
||||
} // namespace
|
||||
|
||||
namespace tf = ::tensorflow;
|
||||
|
||||
template <typename TensorType>
|
||||
void AssignMatrixValue(int r, int c, int value, tf::Tensor* output_tensor) {
|
||||
output_tensor->tensor<TensorType, 2>()(r, c) = value;
|
||||
}
|
||||
|
||||
// The calculator expects one input (a packet containing a single int or
|
||||
// vector<int> or vector<vector<int>>) and generates one output (a packet
|
||||
// containing a tf::Tensor containing the same data). The output tensor will be
|
||||
// either 1D or 2D with dimensions corresponding to the input vector int. It
|
||||
// will hold DT_INT32 or DT_UINT8 or DT_INT64 values.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "VectorIntToTensorCalculator"
|
||||
// input_stream: "SINGLE_INT:segment_size_int_stream"
|
||||
// output_stream: "TENSOR_OUT:segment_size_tensor"
|
||||
// }
|
||||
//
|
||||
// or
|
||||
//
|
||||
// node {
|
||||
// calculator: "VectorIntToTensorCalculator"
|
||||
// input_stream: "VECTOR_INT:vector_int_features"
|
||||
// output_stream: "TENSOR_OUT:tensor_features"
|
||||
// }
|
||||
class VectorIntToTensorCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
VectorIntToTensorCalculatorOptions options_;
|
||||
};
|
||||
REGISTER_CALCULATOR(VectorIntToTensorCalculator);
|
||||
|
||||
::mediapipe::Status VectorIntToTensorCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
const auto& options = cc->Options<VectorIntToTensorCalculatorOptions>();
|
||||
// Start with only one input packet.
|
||||
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
|
||||
<< "Only one input stream is supported.";
|
||||
if (options.input_size() == INPUT_2D) {
|
||||
cc->Inputs().Tag(kVectorInt).Set<std::vector<std::vector<int>>>();
|
||||
} else if (options.input_size() == INPUT_1D) {
|
||||
if (cc->Inputs().HasTag(kSingleInt)) {
|
||||
cc->Inputs().Tag(kSingleInt).Set<int>();
|
||||
} else {
|
||||
cc->Inputs().Tag(kVectorInt).Set<std::vector<int>>();
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "input size not supported";
|
||||
}
|
||||
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
|
||||
<< "Only one output stream is supported.";
|
||||
cc->Outputs().Tag(kTensorOut).Set<tf::Tensor>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status VectorIntToTensorCalculator::Open(CalculatorContext* cc) {
|
||||
options_ = cc->Options<VectorIntToTensorCalculatorOptions>();
|
||||
RET_CHECK(options_.tensor_data_type() == tf::DT_UINT8 ||
|
||||
options_.tensor_data_type() == tf::DT_INT32 ||
|
||||
options_.tensor_data_type() == tf::DT_INT64)
|
||||
<< "Output tensor data type is not supported.";
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status VectorIntToTensorCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
tf::TensorShape tensor_shape;
|
||||
if (options_.input_size() == INPUT_2D) {
|
||||
const std::vector<std::vector<int>>& input =
|
||||
cc->Inputs()
|
||||
.Tag(kVectorInt)
|
||||
.Value()
|
||||
.Get<std::vector<std::vector<int>>>();
|
||||
|
||||
const int32 rows = input.size();
|
||||
CHECK_GE(rows, 1);
|
||||
const int32 cols = input[0].size();
|
||||
CHECK_GE(cols, 1);
|
||||
for (int i = 1; i < rows; ++i) {
|
||||
CHECK_EQ(input[i].size(), cols);
|
||||
}
|
||||
if (options_.transpose()) {
|
||||
tensor_shape = tf::TensorShape({cols, rows});
|
||||
} else {
|
||||
tensor_shape = tf::TensorShape({rows, cols});
|
||||
}
|
||||
auto output = ::absl::make_unique<tf::Tensor>(options_.tensor_data_type(),
|
||||
tensor_shape);
|
||||
if (options_.transpose()) {
|
||||
for (int r = 0; r < rows; ++r) {
|
||||
for (int c = 0; c < cols; ++c) {
|
||||
switch (options_.tensor_data_type()) {
|
||||
case tf::DT_INT64:
|
||||
AssignMatrixValue<tf::int64>(c, r, input[r][c], output.get());
|
||||
break;
|
||||
case tf::DT_UINT8:
|
||||
AssignMatrixValue<uint8>(c, r, input[r][c], output.get());
|
||||
break;
|
||||
case tf::DT_INT32:
|
||||
AssignMatrixValue<int>(c, r, input[r][c], output.get());
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "tensor data type is not supported.";
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int r = 0; r < rows; ++r) {
|
||||
for (int c = 0; c < cols; ++c) {
|
||||
switch (options_.tensor_data_type()) {
|
||||
case tf::DT_INT64:
|
||||
AssignMatrixValue<tf::int64>(r, c, input[r][c], output.get());
|
||||
break;
|
||||
case tf::DT_UINT8:
|
||||
AssignMatrixValue<uint8>(r, c, input[r][c], output.get());
|
||||
break;
|
||||
case tf::DT_INT32:
|
||||
AssignMatrixValue<int>(r, c, input[r][c], output.get());
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "tensor data type is not supported.";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
cc->Outputs().Tag(kTensorOut).Add(output.release(), cc->InputTimestamp());
|
||||
} else if (options_.input_size() == INPUT_1D) {
|
||||
std::vector<int> input;
|
||||
if (cc->Inputs().HasTag(kSingleInt)) {
|
||||
input.push_back(cc->Inputs().Tag(kSingleInt).Get<int>());
|
||||
} else {
|
||||
input = cc->Inputs().Tag(kVectorInt).Value().Get<std::vector<int>>();
|
||||
}
|
||||
CHECK_GE(input.size(), 1);
|
||||
const int32 length = input.size();
|
||||
tensor_shape = tf::TensorShape({length});
|
||||
auto output = ::absl::make_unique<tf::Tensor>(options_.tensor_data_type(),
|
||||
tensor_shape);
|
||||
for (int i = 0; i < length; ++i) {
|
||||
switch (options_.tensor_data_type()) {
|
||||
case tf::DT_INT64:
|
||||
output->tensor<tf::int64, 1>()(i) = input.at(i);
|
||||
break;
|
||||
case tf::DT_UINT8:
|
||||
output->tensor<uint8, 1>()(i) = input.at(i);
|
||||
break;
|
||||
case tf::DT_INT32:
|
||||
output->tensor<int, 1>()(i) = input.at(i);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "tensor data type is not supported.";
|
||||
}
|
||||
}
|
||||
cc->Outputs().Tag(kTensorOut).Add(output.release(), cc->InputTimestamp());
|
||||
} else {
|
||||
LOG(FATAL) << "input size not supported";
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,43 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
|
||||
message VectorIntToTensorCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional VectorIntToTensorCalculatorOptions ext = 275364184;
|
||||
}
|
||||
enum InputSize {
|
||||
UNKNOWN = 0;
|
||||
INPUT_1D = 1;
|
||||
INPUT_2D = 2;
|
||||
}
|
||||
|
||||
// If input_size is INPUT_2D, unpack a vector<vector<int>> to a
|
||||
// 2d tensor (matrix). If INPUT_1D, convert a single int or vector<int>
|
||||
// into a 1d tensor (vector).
|
||||
optional InputSize input_size = 1 [default = INPUT_1D];
|
||||
|
||||
// If true, the output tensor is transposed.
|
||||
// Otherwise, the output tensor is not transposed.
|
||||
// It will be ignored if tensor_is_2d is INPUT_1D.
|
||||
optional bool transpose = 2 [default = false];
|
||||
|
||||
optional tensorflow.DataType tensor_data_type = 3 [default = DT_INT32];
|
||||
}
|
|
@ -0,0 +1,202 @@
|
|||
// Copyright 2018 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator_options.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
namespace tf = ::tensorflow;
|
||||
|
||||
class VectorIntToTensorCalculatorTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUpRunner(
|
||||
const VectorIntToTensorCalculatorOptions::InputSize input_size,
|
||||
const tensorflow::DataType tensor_data_type, const bool transpose,
|
||||
const bool single_value) {
|
||||
CalculatorGraphConfig::Node config;
|
||||
config.set_calculator("VectorIntToTensorCalculator");
|
||||
if (single_value) {
|
||||
config.add_input_stream("SINGLE_INT:input_int");
|
||||
} else {
|
||||
config.add_input_stream("VECTOR_INT:input_int");
|
||||
}
|
||||
config.add_output_stream("TENSOR_OUT:output_tensor");
|
||||
auto options = config.mutable_options()->MutableExtension(
|
||||
VectorIntToTensorCalculatorOptions::ext);
|
||||
options->set_input_size(input_size);
|
||||
options->set_transpose(transpose);
|
||||
options->set_tensor_data_type(tensor_data_type);
|
||||
runner_ = ::absl::make_unique<CalculatorRunner>(config);
|
||||
}
|
||||
|
||||
void TestConvertFromVectoVectorInt(const bool transpose) {
|
||||
SetUpRunner(VectorIntToTensorCalculatorOptions::INPUT_2D,
|
||||
tensorflow::DT_INT32, transpose, false);
|
||||
auto input = ::absl::make_unique<std::vector<std::vector<int>>>(
|
||||
2, std::vector<int>(2));
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
input->at(i).at(j) = i * 2 + j;
|
||||
}
|
||||
}
|
||||
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()
|
||||
->Tag("VECTOR_INT")
|
||||
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("TENSOR_OUT").packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
||||
|
||||
EXPECT_EQ(2, output_tensor.dims());
|
||||
EXPECT_EQ(tf::DT_INT32, output_tensor.dtype());
|
||||
const auto matrix = output_tensor.matrix<int>();
|
||||
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
if (!transpose) {
|
||||
EXPECT_EQ(i * 2 + j, matrix(i, j));
|
||||
} else {
|
||||
EXPECT_EQ(j * 2 + i, matrix(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<CalculatorRunner> runner_;
|
||||
};
|
||||
|
||||
TEST_F(VectorIntToTensorCalculatorTest, TestSingleValue) {
|
||||
SetUpRunner(VectorIntToTensorCalculatorOptions::INPUT_1D,
|
||||
tensorflow::DT_INT32, false, true);
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()
|
||||
->Tag("SINGLE_INT")
|
||||
.packets.push_back(MakePacket<int>(1).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("TENSOR_OUT").packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
||||
|
||||
EXPECT_EQ(1, output_tensor.dims());
|
||||
EXPECT_EQ(tf::DT_INT32, output_tensor.dtype());
|
||||
const auto vec = output_tensor.vec<int32>();
|
||||
EXPECT_EQ(1, vec(0));
|
||||
}
|
||||
|
||||
TEST_F(VectorIntToTensorCalculatorTest, TesOneDim) {
|
||||
SetUpRunner(VectorIntToTensorCalculatorOptions::INPUT_1D,
|
||||
tensorflow::DT_INT32, false, false);
|
||||
auto input = ::absl::make_unique<std::vector<int>>(5);
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
input->at(i) = i;
|
||||
}
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()
|
||||
->Tag("VECTOR_INT")
|
||||
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("TENSOR_OUT").packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
||||
|
||||
EXPECT_EQ(1, output_tensor.dims());
|
||||
EXPECT_EQ(tf::DT_INT32, output_tensor.dtype());
|
||||
const auto vec = output_tensor.vec<int32>();
|
||||
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
EXPECT_EQ(i, vec(i));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(VectorIntToTensorCalculatorTest, TestTwoDims) {
|
||||
for (bool transpose : {false, true}) {
|
||||
TestConvertFromVectoVectorInt(transpose);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(VectorIntToTensorCalculatorTest, TestInt64) {
|
||||
SetUpRunner(VectorIntToTensorCalculatorOptions::INPUT_1D,
|
||||
tensorflow::DT_INT64, false, true);
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()
|
||||
->Tag("SINGLE_INT")
|
||||
.packets.push_back(MakePacket<int>(2 ^ 31).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("TENSOR_OUT").packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
||||
|
||||
EXPECT_EQ(1, output_tensor.dims());
|
||||
EXPECT_EQ(tf::DT_INT64, output_tensor.dtype());
|
||||
const auto vec = output_tensor.vec<tf::int64>();
|
||||
EXPECT_EQ(2 ^ 31, vec(0));
|
||||
}
|
||||
|
||||
TEST_F(VectorIntToTensorCalculatorTest, TestUint8) {
|
||||
SetUpRunner(VectorIntToTensorCalculatorOptions::INPUT_1D,
|
||||
tensorflow::DT_UINT8, false, false);
|
||||
auto input = ::absl::make_unique<std::vector<int>>(5);
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
input->at(i) = i;
|
||||
}
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()
|
||||
->Tag("VECTOR_INT")
|
||||
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("TENSOR_OUT").packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
||||
|
||||
EXPECT_EQ(1, output_tensor.dims());
|
||||
EXPECT_EQ(tf::DT_UINT8, output_tensor.dtype());
|
||||
const auto vec = output_tensor.vec<uint8>();
|
||||
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
EXPECT_EQ(i, vec(i));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -25,7 +25,8 @@
|
|||
#include "tensorflow/lite/error_reporter.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
|
||||
|
@ -45,7 +46,8 @@
|
|||
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
|
||||
#endif // iOS
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
typedef id<MTLBuffer> GpuTensor;
|
||||
|
@ -67,7 +69,8 @@ typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
|
||||
using ::tflite::gpu::gl::GlProgram;
|
||||
using ::tflite::gpu::gl::GlShader;
|
||||
|
@ -146,7 +149,8 @@ class TfLiteConverterCalculator : public CalculatorBase {
|
|||
|
||||
std::unique_ptr<tflite::Interpreter> interpreter_ = nullptr;
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
std::unique_ptr<GPUData> gpu_data_out_;
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
|
@ -181,7 +185,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
|
||||
if (cc->Inputs().HasTag("IMAGE")) cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
|
||||
if (cc->Inputs().HasTag("MATRIX")) cc->Inputs().Tag("MATRIX").Set<Matrix>();
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
if (cc->Inputs().HasTag("IMAGE_GPU")) {
|
||||
cc->Inputs().Tag("IMAGE_GPU").Set<mediapipe::GpuBuffer>();
|
||||
use_gpu |= true;
|
||||
|
@ -190,7 +194,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
|
||||
if (cc->Outputs().HasTag("TENSORS"))
|
||||
cc->Outputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
if (cc->Outputs().HasTag("TENSORS_GPU")) {
|
||||
cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
|
||||
use_gpu |= true;
|
||||
|
@ -198,7 +202,8 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
if (use_gpu) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
||||
|
@ -218,7 +223,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
|
||||
if (cc->Inputs().HasTag("IMAGE_GPU") ||
|
||||
cc->Outputs().HasTag("IMAGE_OUT_GPU")) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
use_gpu_ = true;
|
||||
#else
|
||||
RET_CHECK_FAIL() << "GPU processing not enabled.";
|
||||
|
@ -231,7 +236,8 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
cc->Outputs().HasTag("TENSORS_GPU"));
|
||||
// Cannot use quantization.
|
||||
use_quantized_tensors_ = false;
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
||||
|
@ -264,7 +270,8 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
}
|
||||
|
||||
::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); });
|
||||
#endif
|
||||
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
|
@ -383,7 +390,8 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
|
||||
::mediapipe::Status TfLiteConverterCalculator::ProcessGPU(
|
||||
CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
// GpuBuffer to tflite::gpu::GlBuffer conversion.
|
||||
const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>();
|
||||
MP_RETURN_IF_ERROR(
|
||||
|
@ -468,7 +476,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
}
|
||||
|
||||
::mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
// Get input image sizes.
|
||||
const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>();
|
||||
mediapipe::ImageFormat::Format format =
|
||||
|
@ -485,7 +493,8 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
|
|||
RET_CHECK_FAIL() << "Num input channels is less than desired output.";
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &include_alpha, &input, &single_channel]() -> ::mediapipe::Status {
|
||||
// Device memory.
|
||||
|
|
|
@ -27,7 +27,8 @@
|
|||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
|
@ -52,7 +53,8 @@
|
|||
|
||||
namespace {
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
typedef id<MTLBuffer> GpuTensor;
|
||||
|
@ -68,13 +70,14 @@ size_t RoundUp(size_t n, size_t m) { return ((n + m - 1) / m) * m; } // NOLINT
|
|||
// * Aux
|
||||
namespace mediapipe {
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
using ::tflite::gpu::gl::CopyBuffer;
|
||||
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
|
||||
using ::tflite::gpu::gl::GlBuffer;
|
||||
#endif
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
struct GPUData {
|
||||
int elements = 1;
|
||||
GpuTensor buffer;
|
||||
|
@ -147,7 +150,8 @@ class TfLiteInferenceCalculator : public CalculatorBase {
|
|||
std::unique_ptr<tflite::FlatBufferModel> model_;
|
||||
TfLiteDelegate* delegate_ = nullptr;
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
std::unique_ptr<GPUData> gpu_data_in_;
|
||||
std::vector<std::unique_ptr<GPUData>> gpu_data_out_;
|
||||
|
@ -179,7 +183,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
|
||||
if (cc->Inputs().HasTag("TENSORS"))
|
||||
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
if (cc->Inputs().HasTag("TENSORS_GPU")) {
|
||||
cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
|
||||
use_gpu |= true;
|
||||
|
@ -188,7 +192,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
|
||||
if (cc->Outputs().HasTag("TENSORS"))
|
||||
cc->Outputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
if (cc->Outputs().HasTag("TENSORS_GPU")) {
|
||||
cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
|
||||
use_gpu |= true;
|
||||
|
@ -206,7 +210,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
use_gpu |= options.use_gpu();
|
||||
|
||||
if (use_gpu) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
||||
|
@ -225,7 +230,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
MP_RETURN_IF_ERROR(LoadOptions(cc));
|
||||
|
||||
if (cc->Inputs().HasTag("TENSORS_GPU")) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
gpu_input_ = true;
|
||||
gpu_inference_ = true; // Inference must be on GPU also.
|
||||
#else
|
||||
|
@ -235,7 +240,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
}
|
||||
|
||||
if (cc->Outputs().HasTag("TENSORS_GPU")) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
gpu_output_ = true;
|
||||
RET_CHECK(cc->Inputs().HasTag("TENSORS_GPU"))
|
||||
<< "GPU output must also have GPU Input.";
|
||||
|
@ -248,13 +253,15 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
MP_RETURN_IF_ERROR(LoadModel(cc));
|
||||
|
||||
if (gpu_inference_) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
||||
RET_CHECK(gpu_helper_);
|
||||
#endif
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &cc]() -> ::mediapipe::Status { return LoadDelegate(cc); }));
|
||||
#else
|
||||
|
@ -262,6 +269,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
#endif
|
||||
}
|
||||
|
||||
#if defined(__EMSCRIPTEN__)
|
||||
MP_RETURN_IF_ERROR(LoadDelegate(cc));
|
||||
#endif // __EMSCRIPTEN__
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -269,7 +280,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
// 1. Receive pre-processed tensor inputs.
|
||||
if (gpu_input_) {
|
||||
// Read GPU input into SSBO.
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>();
|
||||
RET_CHECK_EQ(input_tensors.size(), 1);
|
||||
|
@ -315,7 +327,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
|
||||
// 2. Run inference.
|
||||
if (gpu_inference_) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(
|
||||
gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status {
|
||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
|
@ -330,7 +343,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
|
||||
// 3. Output processed tensors.
|
||||
if (gpu_output_) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
// Output result tensors (GPU).
|
||||
auto output_tensors = absl::make_unique<std::vector<GpuTensor>>();
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
|
@ -392,7 +406,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
|
||||
::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) {
|
||||
if (delegate_) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status {
|
||||
TfLiteGpuDelegateDelete(delegate_);
|
||||
gpu_data_in_.reset();
|
||||
|
@ -456,6 +471,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
|
||||
RET_CHECK(interpreter_);
|
||||
|
||||
#if defined(__EMSCRIPTEN__)
|
||||
interpreter_->SetNumThreads(1);
|
||||
#endif // __EMSCRIPTEN__
|
||||
|
||||
if (gpu_output_) {
|
||||
use_quantized_tensors_ = false;
|
||||
} else {
|
||||
|
@ -471,7 +490,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
|||
|
||||
::mediapipe::Status TfLiteInferenceCalculator::LoadDelegate(
|
||||
CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
// Configure and create the delegate.
|
||||
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
|
||||
options.compile_options.precision_loss_allowed = 1;
|
||||
|
|
|
@ -24,7 +24,8 @@
|
|||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/util/resource_util.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
#if defined(__EMSCRIPTEN__) || defined(__ANDROID__) || \
|
||||
(defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
#include "mediapipe/util/android/file/base/file.h"
|
||||
#include "mediapipe/util/android/file/base/helpers.h"
|
||||
#else
|
||||
|
@ -66,8 +67,8 @@ class TfLiteTensorsToClassificationCalculator : public CalculatorBase {
|
|||
::mediapipe::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
::mediapipe::TfLiteTensorsToClassificationCalculatorOptions options_;
|
||||
int top_k_ = 0;
|
||||
double min_score_threshold_ = 0;
|
||||
std::unordered_map<int, std::string> label_map_;
|
||||
bool label_map_loaded_ = false;
|
||||
};
|
||||
|
@ -93,15 +94,14 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator);
|
|||
CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
auto options = cc->Options<
|
||||
options_ = cc->Options<
|
||||
::mediapipe::TfLiteTensorsToClassificationCalculatorOptions>();
|
||||
|
||||
top_k_ = options.top_k();
|
||||
min_score_threshold_ = options.min_score_threshold();
|
||||
if (options.has_label_map_path()) {
|
||||
top_k_ = options_.top_k();
|
||||
if (options_.has_label_map_path()) {
|
||||
std::string string_path;
|
||||
ASSIGN_OR_RETURN(string_path,
|
||||
PathToResourceAsFile(options.label_map_path()));
|
||||
PathToResourceAsFile(options_.label_map_path()));
|
||||
std::string label_map_string;
|
||||
MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string));
|
||||
|
||||
|
@ -125,9 +125,11 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator);
|
|||
RET_CHECK_EQ(input_tensors.size(), 1);
|
||||
|
||||
const TfLiteTensor* raw_score_tensor = &input_tensors[0];
|
||||
RET_CHECK_EQ(raw_score_tensor->dims->size, 2);
|
||||
RET_CHECK_EQ(raw_score_tensor->dims->data[0], 1);
|
||||
int num_classes = raw_score_tensor->dims->data[1];
|
||||
int num_classes = 1;
|
||||
for (int i = 0; i < raw_score_tensor->dims->size; ++i) {
|
||||
num_classes *= raw_score_tensor->dims->data[i];
|
||||
}
|
||||
|
||||
if (label_map_loaded_) {
|
||||
RET_CHECK_EQ(num_classes, label_map_.size());
|
||||
}
|
||||
|
@ -135,7 +137,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator);
|
|||
|
||||
auto classification_list = absl::make_unique<ClassificationList>();
|
||||
for (int i = 0; i < num_classes; ++i) {
|
||||
if (raw_scores[i] < min_score_threshold_) {
|
||||
if (options_.has_min_score_threshold() &&
|
||||
raw_scores[i] < options_.min_score_threshold()) {
|
||||
continue;
|
||||
}
|
||||
Classification* classification = classification_list->add_classification();
|
||||
|
@ -148,6 +151,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator);
|
|||
|
||||
// Note that partial_sort will raise error when top_k_ >
|
||||
// classification_list->classification_size().
|
||||
CHECK_GE(classification_list->classification_size(), top_k_);
|
||||
auto raw_classification_list = classification_list->mutable_classification();
|
||||
if (top_k_ > 0 && classification_list->classification_size() >= top_k_) {
|
||||
std::partial_sort(raw_classification_list->begin(),
|
||||
|
|
|
@ -27,7 +27,8 @@
|
|||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
|
||||
|
@ -55,12 +56,14 @@ constexpr int kNumCoordsPerBox = 4;
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
|
||||
using ::tflite::gpu::gl::GlShader;
|
||||
#endif
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
|
||||
typedef ::tflite::gpu::gl::GlProgram GpuProgram;
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
|
@ -70,7 +73,7 @@ typedef id<MTLComputePipelineState> GpuProgram;
|
|||
|
||||
namespace {
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
struct GPUData {
|
||||
GpuProgram decode_program;
|
||||
GpuProgram score_program;
|
||||
|
@ -169,18 +172,21 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase {
|
|||
const int* detection_classes, std::vector<Detection>* output_detections);
|
||||
Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax,
|
||||
float box_xmax, float score, int class_id,
|
||||
bool flip_vertically);
|
||||
int detection_id, bool flip_vertically);
|
||||
|
||||
int num_classes_ = 0;
|
||||
int num_boxes_ = 0;
|
||||
int num_coords_ = 0;
|
||||
// Unique detection ID per new detection.
|
||||
static int next_detection_id_;
|
||||
std::set<int> ignore_classes_;
|
||||
|
||||
::mediapipe::TfLiteTensorsToDetectionsCalculatorOptions options_;
|
||||
std::vector<Anchor> anchors_;
|
||||
bool side_packet_anchors_{};
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
std::unique_ptr<GPUData> gpu_data_;
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
|
@ -193,6 +199,10 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase {
|
|||
};
|
||||
REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
||||
|
||||
// Initialization of non-const static member should happen outside class
|
||||
// definition.
|
||||
int TfLiteTensorsToDetectionsCalculator::next_detection_id_ = 0;
|
||||
|
||||
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(!cc->Inputs().GetTags().empty());
|
||||
|
@ -204,7 +214,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
|
||||
}
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||
if (cc->Inputs().HasTag("TENSORS_GPU")) {
|
||||
cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
|
||||
use_gpu |= true;
|
||||
|
@ -222,7 +232,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
}
|
||||
|
||||
if (use_gpu) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
||||
|
@ -238,7 +249,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
|
||||
if (cc->Inputs().HasTag("TENSORS_GPU")) {
|
||||
gpu_input_ = true;
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
||||
|
@ -400,7 +412,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
}
|
||||
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU(
|
||||
CalculatorContext* cc, std::vector<Detection>* output_detections) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>();
|
||||
RET_CHECK_GE(input_tensors.size(), 2);
|
||||
|
@ -562,7 +575,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
|
||||
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Close(
|
||||
CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
gpu_helper_.RunInGlContext([this] { gpu_data_.reset(); });
|
||||
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
|
||||
gpu_data_.reset();
|
||||
|
@ -672,7 +686,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
Detection detection = ConvertToDetection(
|
||||
detection_boxes[box_offset + 0], detection_boxes[box_offset + 1],
|
||||
detection_boxes[box_offset + 2], detection_boxes[box_offset + 3],
|
||||
detection_scores[i], detection_classes[i], options_.flip_vertically());
|
||||
detection_scores[i], detection_classes[i], next_detection_id_,
|
||||
options_.flip_vertically());
|
||||
// Increment to get next unique detection ID.
|
||||
++next_detection_id_;
|
||||
// Add keypoints.
|
||||
if (options_.num_keypoints() > 0) {
|
||||
auto* location_data = detection.mutable_location_data();
|
||||
|
@ -695,10 +712,11 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
|
||||
Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection(
|
||||
float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score,
|
||||
int class_id, bool flip_vertically) {
|
||||
int class_id, int detection_id, bool flip_vertically) {
|
||||
Detection detection;
|
||||
detection.add_score(score);
|
||||
detection.add_label_id(class_id);
|
||||
detection.set_detection_id(detection_id);
|
||||
|
||||
LocationData* location_data = detection.mutable_location_data();
|
||||
location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX);
|
||||
|
@ -715,7 +733,8 @@ Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection(
|
|||
|
||||
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GpuInit(
|
||||
CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]()
|
||||
-> ::mediapipe::Status {
|
||||
gpu_data_ = absl::make_unique<GPUData>();
|
||||
|
|
|
@ -21,7 +21,8 @@
|
|||
namespace mediapipe {
|
||||
|
||||
// A calculator for converting TFLite tensors from regression models into
|
||||
// landmarks.
|
||||
// landmarks. Note that if the landmarks in the tensor has more than 3
|
||||
// dimensions, only the first 3 dimensions will be converted to x,y,z.
|
||||
//
|
||||
// Input:
|
||||
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32. Only the first
|
||||
|
@ -122,9 +123,6 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator);
|
|||
num_values *= raw_tensor->dims->data[i];
|
||||
}
|
||||
const int num_dimensions = num_values / num_landmarks_;
|
||||
// Landmarks must have less than 3 dimensions. Otherwise please consider
|
||||
// using matrix.
|
||||
CHECK_LE(num_dimensions, 3);
|
||||
CHECK_GT(num_dimensions, 0);
|
||||
|
||||
const float* raw_landmarks = raw_tensor->data.f;
|
||||
|
|
|
@ -28,7 +28,8 @@
|
|||
#include "mediapipe/util/resource_util.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "mediapipe/gpu/gl_simple_shaders.h"
|
||||
#include "mediapipe/gpu/shader_util.h"
|
||||
|
@ -53,7 +54,8 @@ float Clamp(float val, float min, float max) {
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
using ::tflite::gpu::gl::CopyBuffer;
|
||||
using ::tflite::gpu::gl::CreateReadWriteRgbaImageTexture;
|
||||
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
|
||||
|
@ -129,7 +131,8 @@ class TfLiteTensorsToSegmentationCalculator : public CalculatorBase {
|
|||
int tensor_channels_ = 0;
|
||||
|
||||
bool use_gpu_ = false;
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
std::unique_ptr<GlProgram> mask_program_with_prev_;
|
||||
std::unique_ptr<GlProgram> mask_program_no_prev_;
|
||||
|
@ -159,7 +162,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
}
|
||||
|
||||
// Inputs GPU.
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
if (cc->Inputs().HasTag("TENSORS_GPU")) {
|
||||
cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GlBuffer>>();
|
||||
use_gpu |= true;
|
||||
|
@ -178,7 +182,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
if (cc->Outputs().HasTag("MASK")) {
|
||||
cc->Outputs().Tag("MASK").Set<ImageFrame>();
|
||||
}
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
if (cc->Outputs().HasTag("MASK_GPU")) {
|
||||
cc->Outputs().Tag("MASK_GPU").Set<mediapipe::GpuBuffer>();
|
||||
use_gpu |= true;
|
||||
|
@ -186,7 +191,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
if (use_gpu) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
}
|
||||
|
@ -199,7 +205,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
|
||||
if (cc->Inputs().HasTag("TENSORS_GPU")) {
|
||||
use_gpu_ = true;
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
}
|
||||
|
@ -207,7 +214,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
MP_RETURN_IF_ERROR(LoadOptions(cc));
|
||||
|
||||
if (use_gpu_) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(
|
||||
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
|
||||
MP_RETURN_IF_ERROR(InitGpu(cc));
|
||||
|
@ -224,7 +232,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
if (use_gpu_) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(
|
||||
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
|
||||
MP_RETURN_IF_ERROR(ProcessGpu(cc));
|
||||
|
@ -240,7 +249,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
|
||||
::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Close(
|
||||
CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
gpu_helper_.RunInGlContext([this] {
|
||||
if (upsample_program_) glDeleteProgram(upsample_program_);
|
||||
upsample_program_ = 0;
|
||||
|
@ -367,7 +377,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
if (cc->Inputs().Tag("TENSORS_GPU").IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
// Get input streams.
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GlBuffer>>();
|
||||
|
@ -453,7 +464,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
|
|||
}
|
||||
|
||||
void TfLiteTensorsToSegmentationCalculator::GlRender() {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
static const GLfloat square_vertices[] = {
|
||||
-1.0f, -1.0f, // bottom left
|
||||
1.0f, -1.0f, // bottom right
|
||||
|
@ -525,7 +537,8 @@ void TfLiteTensorsToSegmentationCalculator::GlRender() {
|
|||
|
||||
::mediapipe::Status TfLiteTensorsToSegmentationCalculator::InitGpu(
|
||||
CalculatorContext* cc) {
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) && \
|
||||
!defined(__APPLE__)
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]()
|
||||
-> ::mediapipe::Status {
|
||||
// A shader to process a segmentation tensor into an output mask,
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
|
@ -234,6 +234,7 @@ cc_library(
|
|||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:vector",
|
||||
"//mediapipe/util:annotation_renderer",
|
||||
"//mediapipe/util:render_data_cc_proto",
|
||||
] + select({
|
||||
"//mediapipe/gpu:disable_gpu": [],
|
||||
"//conditions:default": [
|
||||
|
@ -360,6 +361,16 @@ mediapipe_cc_proto_library(
|
|||
deps = [":landmark_projection_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "landmarks_to_floats_calculator_cc_proto",
|
||||
srcs = ["landmarks_to_floats_calculator.proto"],
|
||||
cc_deps = [
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":landmarks_to_floats_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "rect_transformation_calculator_cc_proto",
|
||||
srcs = ["rect_transformation_calculator.proto"],
|
||||
|
@ -372,7 +383,12 @@ mediapipe_cc_proto_library(
|
|||
|
||||
cc_library(
|
||||
name = "detections_to_rects_calculator",
|
||||
srcs = ["detections_to_rects_calculator.cc"],
|
||||
srcs = [
|
||||
"detections_to_rects_calculator.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"detections_to_rects_calculator.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":detections_to_rects_calculator_cc_proto",
|
||||
|
@ -454,6 +470,17 @@ proto_library(
|
|||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "labels_to_render_data_calculator_proto",
|
||||
srcs = ["labels_to_render_data_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/util:color_proto",
|
||||
"//mediapipe/util:render_data_proto",
|
||||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "thresholding_calculator_proto",
|
||||
srcs = ["thresholding_calculator.proto"],
|
||||
|
@ -483,6 +510,15 @@ proto_library(
|
|||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "landmarks_to_floats_calculator_proto",
|
||||
srcs = ["landmarks_to_floats_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "rect_transformation_calculator_proto",
|
||||
srcs = ["rect_transformation_calculator.proto"],
|
||||
|
@ -577,6 +613,26 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "labels_to_render_data_calculator",
|
||||
srcs = ["labels_to_render_data_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":labels_to_render_data_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_options_cc_proto",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:video_stream_header",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
"//mediapipe/util:color_cc_proto",
|
||||
"//mediapipe/util:render_data_cc_proto",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rect_to_render_data_calculator",
|
||||
srcs = ["rect_to_render_data_calculator.cc"],
|
||||
|
@ -658,6 +714,22 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "landmarks_to_floats_calculator",
|
||||
srcs = ["landmarks_to_floats_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":landmarks_to_floats_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@eigen_archive//:eigen",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "detection_letterbox_removal_calculator_test",
|
||||
srcs = ["detection_letterbox_removal_calculator_test.cc"],
|
||||
|
@ -714,6 +786,7 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":top_k_scores_calculator_cc_proto",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
|
@ -750,3 +823,27 @@ cc_test(
|
|||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "labels_to_render_data_calculator_cc_proto",
|
||||
srcs = ["labels_to_render_data_calculator.proto"],
|
||||
cc_deps = [
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/util:color_cc_proto",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":labels_to_render_data_calculator_proto"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "local_file_contents_calculator",
|
||||
srcs = ["local_file_contents_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:file_helpers",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "mediapipe/framework/port/vector.h"
|
||||
#include "mediapipe/util/annotation_renderer.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
|
||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
|
@ -41,6 +42,8 @@ namespace {
|
|||
constexpr char kInputFrameTag[] = "INPUT_FRAME";
|
||||
constexpr char kOutputFrameTag[] = "OUTPUT_FRAME";
|
||||
|
||||
constexpr char kInputVectorTag[] = "VECTOR";
|
||||
|
||||
constexpr char kInputFrameTagGpu[] = "INPUT_FRAME_GPU";
|
||||
constexpr char kOutputFrameTagGpu[] = "OUTPUT_FRAME_GPU";
|
||||
|
||||
|
@ -65,6 +68,9 @@ constexpr int kAnnotationBackgroundColor[] = {100, 101, 102};
|
|||
// 2. RenderData proto on variable number of input streams. All the RenderData
|
||||
// at a particular timestamp is drawn on the image in the order of their
|
||||
// input streams. No tags required.
|
||||
// 3. std::vector<RenderData> on variable number of input streams. RenderData
|
||||
// objects at a particular timestamp are drawn on the image in order of the
|
||||
// input vector items. These input streams are tagged with "VECTOR".
|
||||
//
|
||||
// Output:
|
||||
// 1. OUTPUT_FRAME or OUTPUT_FRAME_GPU: A rendered ImageFrame (or GpuBuffer).
|
||||
|
@ -85,6 +91,8 @@ constexpr int kAnnotationBackgroundColor[] = {100, 101, 102};
|
|||
// input_stream: "render_data_1"
|
||||
// input_stream: "render_data_2"
|
||||
// input_stream: "render_data_3"
|
||||
// input_stream: "VECTOR:0:render_data_vec_0"
|
||||
// input_stream: "VECTOR:1:render_data_vec_1"
|
||||
// output_stream: "OUTPUT_FRAME:decorated_frames"
|
||||
// options {
|
||||
// [mediapipe.AnnotationOverlayCalculatorOptions.ext] {
|
||||
|
@ -99,6 +107,8 @@ constexpr int kAnnotationBackgroundColor[] = {100, 101, 102};
|
|||
// input_stream: "render_data_1"
|
||||
// input_stream: "render_data_2"
|
||||
// input_stream: "render_data_3"
|
||||
// input_stream: "VECTOR:0:render_data_vec_0"
|
||||
// input_stream: "VECTOR:1:render_data_vec_1"
|
||||
// output_stream: "OUTPUT_FRAME_GPU:decorated_frames"
|
||||
// options {
|
||||
// [mediapipe.AnnotationOverlayCalculatorOptions.ext] {
|
||||
|
@ -188,8 +198,16 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
|
|||
}
|
||||
|
||||
// Data streams to render.
|
||||
for (int i = 0; i < num_render_streams; ++i) {
|
||||
cc->Inputs().Index(i).Set<RenderData>();
|
||||
for (CollectionItemId id = cc->Inputs().BeginId(); id < cc->Inputs().EndId();
|
||||
++id) {
|
||||
auto tag_and_index = cc->Inputs().TagAndIndexFromId(id);
|
||||
std::string tag = tag_and_index.first;
|
||||
if (tag == kInputVectorTag) {
|
||||
cc->Inputs().Get(id).Set<std::vector<RenderData>>();
|
||||
} else if (tag.empty()) {
|
||||
// Empty tag defaults to accepting a single object of RenderData type.
|
||||
cc->Inputs().Get(id).Set<RenderData>();
|
||||
}
|
||||
}
|
||||
|
||||
// Rendered image.
|
||||
|
@ -285,12 +303,28 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
|
|||
renderer_->AdoptImage(image_mat.get());
|
||||
|
||||
// Render streams onto render target.
|
||||
for (int i = 0; i < num_render_streams_; ++i) {
|
||||
if (cc->Inputs().Index(i).IsEmpty()) {
|
||||
for (CollectionItemId id = cc->Inputs().BeginId(); id < cc->Inputs().EndId();
|
||||
++id) {
|
||||
auto tag_and_index = cc->Inputs().TagAndIndexFromId(id);
|
||||
std::string tag = tag_and_index.first;
|
||||
if (!tag.empty() && tag != kInputVectorTag) {
|
||||
continue;
|
||||
}
|
||||
const RenderData& render_data = cc->Inputs().Index(i).Get<RenderData>();
|
||||
renderer_->RenderDataOnImage(render_data);
|
||||
if (cc->Inputs().Get(id).IsEmpty()) {
|
||||
continue;
|
||||
}
|
||||
if (tag.empty()) {
|
||||
// Empty tag defaults to accepting a single object of RenderData type.
|
||||
const RenderData& render_data = cc->Inputs().Get(id).Get<RenderData>();
|
||||
renderer_->RenderDataOnImage(render_data);
|
||||
} else {
|
||||
RET_CHECK_EQ(kInputVectorTag, tag);
|
||||
const std::vector<RenderData>& render_data_vec =
|
||||
cc->Inputs().Get(id).Get<std::vector<RenderData>>();
|
||||
for (const RenderData& render_data : render_data_vec) {
|
||||
renderer_->RenderDataOnImage(render_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (use_gpu_) {
|
||||
|
|
|
@ -19,8 +19,8 @@
|
|||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/resource_util.h"
|
||||
|
||||
#if defined(MEDIAPIPE_LITE) || defined(__ANDROID__) || \
|
||||
(defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
#if defined(MEDIAPIPE_LITE) || defined(__EMSCRIPTEN__) || \
|
||||
defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
#include "mediapipe/util/android/file/base/file.h"
|
||||
#include "mediapipe/util/android/file/base/helpers.h"
|
||||
#else
|
||||
|
|
|
@ -11,6 +11,8 @@
|
|||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "mediapipe/calculators/util/detections_to_rects_calculator.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h"
|
||||
|
@ -24,8 +26,6 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
using mediapipe::DetectionsToRectsCalculatorOptions;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kDetectionTag[] = "DETECTION";
|
||||
|
@ -36,7 +36,10 @@ constexpr char kNormRectTag[] = "NORM_RECT";
|
|||
constexpr char kRectsTag[] = "RECTS";
|
||||
constexpr char kNormRectsTag[] = "NORM_RECTS";
|
||||
|
||||
::mediapipe::Status DetectionToRect(const Detection& detection, Rect* rect) {
|
||||
} // namespace
|
||||
|
||||
::mediapipe::Status DetectionsToRectsCalculator::DetectionToRect(
|
||||
const Detection& detection, Rect* rect) {
|
||||
const LocationData location_data = detection.location_data();
|
||||
RET_CHECK(location_data.format() == LocationData::BOUNDING_BOX)
|
||||
<< "Only Detection with formats of BOUNDING_BOX can be converted to Rect";
|
||||
|
@ -48,8 +51,8 @@ constexpr char kNormRectsTag[] = "NORM_RECTS";
|
|||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status DetectionToNormalizedRect(const Detection& detection,
|
||||
NormalizedRect* rect) {
|
||||
::mediapipe::Status DetectionsToRectsCalculator::DetectionToNormalizedRect(
|
||||
const Detection& detection, NormalizedRect* rect) {
|
||||
const LocationData location_data = detection.location_data();
|
||||
RET_CHECK(location_data.format() == LocationData::RELATIVE_BOUNDING_BOX)
|
||||
<< "Only Detection with formats of RELATIVE_BOUNDING_BOX can be "
|
||||
|
@ -63,79 +66,6 @@ constexpr char kNormRectsTag[] = "NORM_RECTS";
|
|||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
// Wraps around an angle in radians to within -M_PI and M_PI.
|
||||
inline float NormalizeRadians(float angle) {
|
||||
return angle - 2 * M_PI * std::floor((angle - (-M_PI)) / (2 * M_PI));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// A calculator that converts Detection proto to Rect proto.
|
||||
//
|
||||
// Detection is the format for encoding one or more detections in an image.
|
||||
// The input can be a single Detection or std::vector<Detection>. The output can
|
||||
// be either a single Rect or NormalizedRect, or std::vector<Rect> or
|
||||
// std::vector<NormalizedRect>. If Rect is used, the LocationData format is
|
||||
// expected to be BOUNDING_BOX, and if NormalizedRect is used it is expected to
|
||||
// be RELATIVE_BOUNDING_BOX.
|
||||
//
|
||||
// When the input is std::vector<Detection> and the output is a Rect or
|
||||
// NormalizedRect, only the first detection is converted. When the input is a
|
||||
// single Detection and the output is a std::vector<Rect> or
|
||||
// std::vector<NormalizedRect>, the output is a vector of size 1.
|
||||
//
|
||||
// Inputs:
|
||||
//
|
||||
// One of the following:
|
||||
// DETECTION: A Detection proto.
|
||||
// DETECTIONS: An std::vector<Detection>.
|
||||
//
|
||||
// IMAGE_SIZE (optional): A std::pair<int, int> represention image width and
|
||||
// height. This is required only when rotation needs to be computed (see
|
||||
// calculator options).
|
||||
//
|
||||
// Output:
|
||||
// One of the following:
|
||||
// RECT: A Rect proto.
|
||||
// NORM_RECT: A NormalizedRect proto.
|
||||
// RECTS: An std::vector<Rect>.
|
||||
// NORM_RECTS: An std::vector<NormalizedRect>.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "DetectionsToRectsCalculator"
|
||||
// input_stream: "DETECTIONS:detections"
|
||||
// input_stream: "IMAGE_SIZE:image_size"
|
||||
// output_stream: "NORM_RECT:rect"
|
||||
// options: {
|
||||
// [mediapipe.DetectionsToRectCalculatorOptions.ext] {
|
||||
// rotation_vector_start_keypoint_index: 0
|
||||
// rotation_vector_end_keypoint_index: 2
|
||||
// rotation_vector_target_angle_degrees: 90
|
||||
// output_zero_rect_for_empty_detections: true
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class DetectionsToRectsCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
float ComputeRotation(const Detection& detection,
|
||||
const std::pair<int, int> image_size);
|
||||
|
||||
DetectionsToRectsCalculatorOptions options_;
|
||||
int start_keypoint_index_;
|
||||
int end_keypoint_index_;
|
||||
float target_angle_; // In radians.
|
||||
bool rotate_;
|
||||
bool output_zero_rect_for_empty_detections_;
|
||||
};
|
||||
REGISTER_CALCULATOR(DetectionsToRectsCalculator);
|
||||
|
||||
::mediapipe::Status DetectionsToRectsCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag(kDetectionTag) ^
|
||||
|
@ -232,6 +162,13 @@ REGISTER_CALCULATOR(DetectionsToRectsCalculator);
|
|||
.Tag(kNormRectTag)
|
||||
.AddPacket(MakePacket<NormalizedRect>().At(cc->InputTimestamp()));
|
||||
}
|
||||
if (cc->Outputs().HasTag(kNormRectsTag)) {
|
||||
auto rect_vector = absl::make_unique<std::vector<NormalizedRect>>();
|
||||
rect_vector->emplace_back(NormalizedRect());
|
||||
cc->Outputs()
|
||||
.Tag(kNormRectsTag)
|
||||
.Add(rect_vector.release(), cc->InputTimestamp());
|
||||
}
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
@ -312,4 +249,6 @@ float DetectionsToRectsCalculator::ComputeRotation(
|
|||
return NormalizeRadians(rotation);
|
||||
}
|
||||
|
||||
REGISTER_CALCULATOR(DetectionsToRectsCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
105
mediapipe/calculators/util/detections_to_rects_calculator.h
Normal file
105
mediapipe/calculators/util/detections_to_rects_calculator.h
Normal file
|
@ -0,0 +1,105 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#ifndef MEDIAPIPE_CALCULATORS_UTIL_DETECTIONS_TO_RECTS_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_UTIL_DETECTIONS_TO_RECTS_CALCULATOR_H_
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_options.pb.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/location_data.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// A calculator that converts Detection proto to Rect proto.
|
||||
//
|
||||
// Detection is the format for encoding one or more detections in an image.
|
||||
// The input can be a single Detection or std::vector<Detection>. The output can
|
||||
// be either a single Rect or NormalizedRect, or std::vector<Rect> or
|
||||
// std::vector<NormalizedRect>. If Rect is used, the LocationData format is
|
||||
// expected to be BOUNDING_BOX, and if NormalizedRect is used it is expected to
|
||||
// be RELATIVE_BOUNDING_BOX.
|
||||
//
|
||||
// When the input is std::vector<Detection> and the output is a Rect or
|
||||
// NormalizedRect, only the first detection is converted. When the input is a
|
||||
// single Detection and the output is a std::vector<Rect> or
|
||||
// std::vector<NormalizedRect>, the output is a vector of size 1.
|
||||
//
|
||||
// Inputs:
|
||||
//
|
||||
// One of the following:
|
||||
// DETECTION: A Detection proto.
|
||||
// DETECTIONS: An std::vector<Detection>.
|
||||
//
|
||||
// IMAGE_SIZE (optional): A std::pair<int, int> represention image width and
|
||||
// height. This is required only when rotation needs to be computed (see
|
||||
// calculator options).
|
||||
//
|
||||
// Output:
|
||||
// One of the following:
|
||||
// RECT: A Rect proto.
|
||||
// NORM_RECT: A NormalizedRect proto.
|
||||
// RECTS: An std::vector<Rect>.
|
||||
// NORM_RECTS: An std::vector<NormalizedRect>.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "DetectionsToRectsCalculator"
|
||||
// input_stream: "DETECTIONS:detections"
|
||||
// input_stream: "IMAGE_SIZE:image_size"
|
||||
// output_stream: "NORM_RECT:rect"
|
||||
// options: {
|
||||
// [mediapipe.DetectionsToRectCalculatorOptions.ext] {
|
||||
// rotation_vector_start_keypoint_index: 0
|
||||
// rotation_vector_end_keypoint_index: 2
|
||||
// rotation_vector_target_angle_degrees: 90
|
||||
// output_zero_rect_for_empty_detections: true
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class DetectionsToRectsCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
protected:
|
||||
virtual float ComputeRotation(const ::mediapipe::Detection& detection,
|
||||
const std::pair<int, int> image_size);
|
||||
virtual ::mediapipe::Status DetectionToRect(
|
||||
const ::mediapipe::Detection& detection, ::mediapipe::Rect* rect);
|
||||
virtual ::mediapipe::Status DetectionToNormalizedRect(
|
||||
const ::mediapipe::Detection& detection,
|
||||
::mediapipe::NormalizedRect* rect);
|
||||
|
||||
static inline float NormalizeRadians(float angle) {
|
||||
return angle - 2 * M_PI * std::floor((angle - (-M_PI)) / (2 * M_PI));
|
||||
}
|
||||
|
||||
::mediapipe::DetectionsToRectsCalculatorOptions options_;
|
||||
int start_keypoint_index_;
|
||||
int end_keypoint_index_;
|
||||
float target_angle_ = 0.0f; // In radians.
|
||||
bool rotate_;
|
||||
bool output_zero_rect_for_empty_detections_;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
#endif // MEDIAPIPE_CALCULATORS_UTIL_DETECTIONS_TO_RECTS_CALCULATOR_H_
|
181
mediapipe/calculators/util/labels_to_render_data_calculator.cc
Normal file
181
mediapipe/calculators/util/labels_to_render_data_calculator.cc
Normal file
|
@ -0,0 +1,181 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/calculators/util/labels_to_render_data_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/video_stream_header.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr float kFontHeightScale = 1.25f;
|
||||
|
||||
// A calculator takes in pairs of labels and scores or classifications, outputs
|
||||
// generates render data. Either both "LABELS" and "SCORES" or "CLASSIFICATIONS"
|
||||
// must be present.
|
||||
//
|
||||
// Usage example:
|
||||
// node {
|
||||
// calculator: "LabelsToRenderDataCalculator"
|
||||
// input_stream: "LABELS:labels"
|
||||
// input_stream: "SCORES:scores"
|
||||
// output_stream: "VIDEO_PRESTREAM:video_header"
|
||||
// options {
|
||||
// [LabelsToRenderDataCalculatorOptions.ext] {
|
||||
// color { r: 255 g: 0 b: 0 }
|
||||
// color { r: 0 g: 255 b: 0 }
|
||||
// color { r: 0 g: 0 b: 255 }
|
||||
// thickness: 2.0
|
||||
// font_height_px: 20
|
||||
// max_num_labels: 3
|
||||
// font_face: 1
|
||||
// location: TOP_LEFT
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class LabelsToRenderDataCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
LabelsToRenderDataCalculatorOptions options_;
|
||||
int num_colors_ = 0;
|
||||
int video_width_ = 0;
|
||||
int video_height_ = 0;
|
||||
int label_height_px_ = 0;
|
||||
int label_left_px_ = 0;
|
||||
};
|
||||
REGISTER_CALCULATOR(LabelsToRenderDataCalculator);
|
||||
|
||||
::mediapipe::Status LabelsToRenderDataCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
if (cc->Inputs().HasTag("CLASSIFICATIONS")) {
|
||||
cc->Inputs().Tag("CLASSIFICATIONS").Set<ClassificationList>();
|
||||
} else {
|
||||
RET_CHECK(cc->Inputs().HasTag("LABELS"))
|
||||
<< "Must provide input stream \"LABELS\"";
|
||||
cc->Inputs().Tag("LABELS").Set<std::vector<std::string>>();
|
||||
if (cc->Inputs().HasTag("SCORES")) {
|
||||
cc->Inputs().Tag("SCORES").Set<std::vector<float>>();
|
||||
}
|
||||
}
|
||||
if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) {
|
||||
cc->Inputs().Tag("VIDEO_PRESTREAM").Set<VideoHeader>();
|
||||
}
|
||||
cc->Outputs().Tag("RENDER_DATA").Set<RenderData>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status LabelsToRenderDataCalculator::Open(CalculatorContext* cc) {
|
||||
options_ = cc->Options<LabelsToRenderDataCalculatorOptions>();
|
||||
num_colors_ = options_.color_size();
|
||||
label_height_px_ = std::ceil(options_.font_height_px() * kFontHeightScale);
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status LabelsToRenderDataCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
if (cc->Inputs().HasTag("VIDEO_PRESTREAM") &&
|
||||
cc->InputTimestamp() == Timestamp::PreStream()) {
|
||||
const VideoHeader& video_header =
|
||||
cc->Inputs().Tag("VIDEO_PRESTREAM").Get<VideoHeader>();
|
||||
video_width_ = video_header.width;
|
||||
video_height_ = video_header.height;
|
||||
return ::mediapipe::OkStatus();
|
||||
} else {
|
||||
CHECK_EQ(options_.location(), LabelsToRenderDataCalculatorOptions::TOP_LEFT)
|
||||
<< "Only TOP_LEFT is supported without VIDEO_PRESTREAM.";
|
||||
}
|
||||
|
||||
std::vector<std::string> labels;
|
||||
std::vector<float> scores;
|
||||
if (cc->Inputs().HasTag("CLASSIFICATIONS")) {
|
||||
const ClassificationList& classifications =
|
||||
cc->Inputs().Tag("CLASSIFICATIONS").Get<ClassificationList>();
|
||||
labels.resize(classifications.classification_size());
|
||||
scores.resize(classifications.classification_size());
|
||||
for (int i = 0; i < classifications.classification_size(); ++i) {
|
||||
labels[i] = classifications.classification(i).label();
|
||||
scores[i] = classifications.classification(i).score();
|
||||
}
|
||||
} else {
|
||||
const std::vector<std::string>& label_vector =
|
||||
cc->Inputs().Tag("LABELS").Get<std::vector<std::string>>();
|
||||
std::vector<float> score_vector;
|
||||
if (cc->Inputs().HasTag("SCORES")) {
|
||||
score_vector = cc->Inputs().Tag("SCORES").Get<std::vector<float>>();
|
||||
}
|
||||
CHECK_EQ(label_vector.size(), score_vector.size());
|
||||
labels.resize(label_vector.size());
|
||||
scores.resize(label_vector.size());
|
||||
for (int i = 0; i < label_vector.size(); ++i) {
|
||||
labels[i] = label_vector[i];
|
||||
scores[i] = score_vector[i];
|
||||
}
|
||||
}
|
||||
|
||||
RenderData render_data;
|
||||
int num_label = std::min((int)labels.size(), options_.max_num_labels());
|
||||
int label_baseline_px = options_.vertical_offset_px();
|
||||
if (options_.location() == LabelsToRenderDataCalculatorOptions::TOP_LEFT) {
|
||||
label_baseline_px += label_height_px_;
|
||||
} else if (options_.location() ==
|
||||
LabelsToRenderDataCalculatorOptions::BOTTOM_LEFT) {
|
||||
label_baseline_px += video_height_ - label_height_px_ * (num_label - 1);
|
||||
}
|
||||
label_left_px_ = options_.horizontal_offset_px();
|
||||
for (int i = 0; i < num_label; ++i) {
|
||||
auto* label_annotation = render_data.add_render_annotations();
|
||||
label_annotation->set_thickness(options_.thickness());
|
||||
if (num_colors_ > 0) {
|
||||
*(label_annotation->mutable_color()) = options_.color(i % num_colors_);
|
||||
} else {
|
||||
label_annotation->mutable_color()->set_r(255);
|
||||
label_annotation->mutable_color()->set_g(0);
|
||||
label_annotation->mutable_color()->set_b(0);
|
||||
}
|
||||
|
||||
auto* text = label_annotation->mutable_text();
|
||||
std::string display_text = labels[i];
|
||||
if (cc->Inputs().HasTag("SCORES")) {
|
||||
absl::StrAppend(&display_text, ":", scores[i]);
|
||||
}
|
||||
text->set_display_text(display_text);
|
||||
text->set_font_height(options_.font_height_px());
|
||||
text->set_left(label_left_px_);
|
||||
text->set_baseline(label_baseline_px + i * label_height_px_);
|
||||
text->set_font_face(options_.font_face());
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag("RENDER_DATA")
|
||||
.AddPacket(MakePacket<RenderData>(render_data).At(cc->InputTimestamp()));
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,62 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/util/color.proto";
|
||||
|
||||
message LabelsToRenderDataCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional LabelsToRenderDataCalculatorOptions ext = 271660364;
|
||||
}
|
||||
|
||||
// Colors for drawing the label(s).
|
||||
repeated Color color = 1;
|
||||
|
||||
// Thickness for drawing the label(s).
|
||||
optional double thickness = 2 [default = 2];
|
||||
|
||||
// The font height in absolute pixels.
|
||||
optional int32 font_height_px = 3 [default = 50];
|
||||
|
||||
// The offset of the starting text in horizontal direction in absolute pixels.
|
||||
optional int32 horizontal_offset_px = 7 [default = 0];
|
||||
// The offset of the starting text in vertical direction in absolute pixels.
|
||||
optional int32 vertical_offset_px = 8 [default = 0];
|
||||
|
||||
// The maximum number of labels to display.
|
||||
optional int32 max_num_labels = 4 [default = 1];
|
||||
|
||||
// Specifies the font for the text. Font must be one of the following from
|
||||
// OpenCV:
|
||||
// cv::FONT_HERSHEY_SIMPLEX (0)
|
||||
// cv::FONT_HERSHEY_PLAIN (1)
|
||||
// cv::FONT_HERSHEY_DUPLEX (2)
|
||||
// cv::FONT_HERSHEY_COMPLEX (3)
|
||||
// cv::FONT_HERSHEY_TRIPLEX (4)
|
||||
// cv::FONT_HERSHEY_COMPLEX_SMALL (5)
|
||||
// cv::FONT_HERSHEY_SCRIPT_SIMPLEX (6)
|
||||
// cv::FONT_HERSHEY_SCRIPT_COMPLEX (7)
|
||||
optional int32 font_face = 5 [default = 0];
|
||||
|
||||
// Label location.
|
||||
enum Location {
|
||||
TOP_LEFT = 0;
|
||||
BOTTOM_LEFT = 1;
|
||||
}
|
||||
optional Location location = 6 [default = TOP_LEFT];
|
||||
}
|
138
mediapipe/calculators/util/landmarks_to_floats_calculator.cc
Normal file
138
mediapipe/calculators/util/landmarks_to_floats_calculator.cc
Normal file
|
@ -0,0 +1,138 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "mediapipe/calculators/util/landmarks_to_floats_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kLandmarksTag[] = "NORM_LANDMARKS";
|
||||
constexpr char kFloatsTag[] = "FLOATS";
|
||||
constexpr char kMatrixTag[] = "MATRIX";
|
||||
|
||||
} // namespace
|
||||
|
||||
// Converts a vector of landmarks to a vector of floats or a matrix.
|
||||
// Input:
|
||||
// NORM_LANDMARKS: An std::vector<NormalizedLandmark>.
|
||||
//
|
||||
// Output:
|
||||
// FLOATS(optional): A vector of floats from flattened landmarks.
|
||||
// MATRIX(optional): A matrix of floats of the landmarks.
|
||||
//
|
||||
// Usage example:
|
||||
// node {
|
||||
// calculator: "LandmarksToFloatsCalculator"
|
||||
// input_stream: "NORM_LANDMARKS:landmarks"
|
||||
// output_stream: "MATRIX:landmark_matrix"
|
||||
// }
|
||||
class LandmarksToFloatsCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag(kLandmarksTag).Set<std::vector<NormalizedLandmark>>();
|
||||
RET_CHECK(cc->Outputs().HasTag(kFloatsTag) ||
|
||||
cc->Outputs().HasTag(kMatrixTag));
|
||||
if (cc->Outputs().HasTag(kFloatsTag)) {
|
||||
cc->Outputs().Tag(kFloatsTag).Set<std::vector<float>>();
|
||||
}
|
||||
if (cc->Outputs().HasTag(kMatrixTag)) {
|
||||
cc->Outputs().Tag(kMatrixTag).Set<Matrix>();
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
const auto& options =
|
||||
cc->Options<::mediapipe::LandmarksToFloatsCalculatorOptions>();
|
||||
num_dimensions_ = options.num_dimensions();
|
||||
// Currently number of dimensions must be within [1, 3].
|
||||
RET_CHECK_GE(num_dimensions_, 1);
|
||||
RET_CHECK_LE(num_dimensions_, 3);
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
// Only process if there's input landmarks.
|
||||
if (cc->Inputs().Tag(kLandmarksTag).IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
const auto& input_landmarks =
|
||||
cc->Inputs().Tag(kLandmarksTag).Get<std::vector<NormalizedLandmark>>();
|
||||
|
||||
if (cc->Outputs().HasTag(kFloatsTag)) {
|
||||
auto output_floats = absl::make_unique<std::vector<float>>();
|
||||
for (const auto& landmark : input_landmarks) {
|
||||
output_floats->emplace_back(landmark.x());
|
||||
if (num_dimensions_ > 1) {
|
||||
output_floats->emplace_back(landmark.y());
|
||||
}
|
||||
if (num_dimensions_ > 2) {
|
||||
output_floats->emplace_back(landmark.z());
|
||||
}
|
||||
}
|
||||
|
||||
cc->Outputs()
|
||||
.Tag(kFloatsTag)
|
||||
.Add(output_floats.release(), cc->InputTimestamp());
|
||||
} else {
|
||||
auto output_matrix = absl::make_unique<Matrix>();
|
||||
output_matrix->setZero(num_dimensions_, input_landmarks.size());
|
||||
for (int i = 0; i < input_landmarks.size(); ++i) {
|
||||
(*output_matrix)(0, i) = input_landmarks[i].x();
|
||||
if (num_dimensions_ > 1) {
|
||||
(*output_matrix)(1, i) = input_landmarks[i].y();
|
||||
}
|
||||
if (num_dimensions_ > 2) {
|
||||
(*output_matrix)(2, i) = input_landmarks[i].z();
|
||||
}
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag(kMatrixTag)
|
||||
.Add(output_matrix.release(), cc->InputTimestamp());
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
int num_dimensions_ = 0;
|
||||
};
|
||||
REGISTER_CALCULATOR(LandmarksToFloatsCalculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message LandmarksToFloatsCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional LandmarksToFloatsCalculatorOptions ext = 274035660;
|
||||
}
|
||||
|
||||
// Number of dimensions to convert. Must within [1, 3].
|
||||
optional int32 num_dimensions = 1 [default = 2];
|
||||
}
|
57
mediapipe/calculators/util/local_file_contents_calculator.cc
Normal file
57
mediapipe/calculators/util/local_file_contents_calculator.cc
Normal file
|
@ -0,0 +1,57 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
// The calculator takes the path to the local file as an input side packet and
|
||||
// outputs the contents of that file.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "LocalFileContentsCalculator"
|
||||
// input_side_packet: "FILE_PATH:file_path"
|
||||
// output_side_packet: "CONTENTS:contents"
|
||||
// }
|
||||
class LocalFileContentsCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->InputSidePackets().Tag("FILE_PATH").Set<std::string>();
|
||||
cc->OutputSidePackets().Tag("CONTENTS").Set<std::string>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||
std::string contents;
|
||||
MP_RETURN_IF_ERROR(mediapipe::file::GetContents(
|
||||
cc->InputSidePackets().Tag("FILE_PATH").Get<std::string>(), &contents));
|
||||
cc->OutputSidePackets()
|
||||
.Tag("CONTENTS")
|
||||
.Set(MakePacket<std::string>(std::move(contents)));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(LocalFileContentsCalculator);
|
||||
|
||||
} // namespace mediapipe
|
|
@ -23,13 +23,14 @@
|
|||
|
||||
#include "mediapipe/calculators/util/top_k_scores_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
#include "mediapipe/util/resource_util.h"
|
||||
|
||||
#if defined(MEDIAPIPE_LITE) || defined(__ANDROID__) || \
|
||||
(defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
#if defined(MEDIAPIPE_LITE) || defined(__EMSCRIPTEN__) || \
|
||||
defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
|
||||
#include "mediapipe/util/android/file/base/file.h"
|
||||
#include "mediapipe/util/android/file/base/helpers.h"
|
||||
#else
|
||||
|
@ -37,8 +38,10 @@
|
|||
#endif
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// A calculator that takes a vector of scores and returns the indexes, scores,
|
||||
// labels of the top k elements.
|
||||
// labels of the top k elements, classification protos, and summary std::string
|
||||
// (in csv format).
|
||||
//
|
||||
// Usage example:
|
||||
// node {
|
||||
|
@ -47,6 +50,8 @@ namespace mediapipe {
|
|||
// output_stream: "TOP_K_INDEXES:top_k_indexes"
|
||||
// output_stream: "TOP_K_SCORES:top_k_scores"
|
||||
// output_stream: "TOP_K_LABELS:top_k_labels"
|
||||
// output_stream: "TOP_K_CLASSIFICATIONS:top_k_classes"
|
||||
// output_stream: "SUMMARY:summary"
|
||||
// options: {
|
||||
// [mediapipe.TopKScoresCalculatorOptions.ext] {
|
||||
// top_k: 5
|
||||
|
@ -69,6 +74,7 @@ class TopKScoresCalculator : public CalculatorBase {
|
|||
int top_k_ = -1;
|
||||
float threshold_ = 0.0;
|
||||
std::unordered_map<int, std::string> label_map_;
|
||||
bool label_map_loaded_ = false;
|
||||
};
|
||||
REGISTER_CALCULATOR(TopKScoresCalculator);
|
||||
|
||||
|
@ -84,6 +90,12 @@ REGISTER_CALCULATOR(TopKScoresCalculator);
|
|||
if (cc->Outputs().HasTag("TOP_K_LABELS")) {
|
||||
cc->Outputs().Tag("TOP_K_LABELS").Set<std::vector<std::string>>();
|
||||
}
|
||||
if (cc->Outputs().HasTag("CLASSIFICATIONS")) {
|
||||
cc->Outputs().Tag("CLASSIFICATIONS").Set<ClassificationList>();
|
||||
}
|
||||
if (cc->Outputs().HasTag("SUMMARY")) {
|
||||
cc->Outputs().Tag("SUMMARY").Set<std::string>();
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -149,7 +161,7 @@ REGISTER_CALCULATOR(TopKScoresCalculator);
|
|||
reverse(top_k_indexes.begin(), top_k_indexes.end());
|
||||
reverse(top_k_scores.begin(), top_k_scores.end());
|
||||
|
||||
if (cc->Outputs().HasTag("TOP_K_LABELS")) {
|
||||
if (label_map_loaded_) {
|
||||
for (int index : top_k_indexes) {
|
||||
top_k_labels.push_back(label_map_[index]);
|
||||
}
|
||||
|
@ -172,6 +184,35 @@ REGISTER_CALCULATOR(TopKScoresCalculator);
|
|||
.AddPacket(MakePacket<std::vector<std::string>>(top_k_labels)
|
||||
.At(cc->InputTimestamp()));
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("SUMMARY")) {
|
||||
std::vector<std::string> results;
|
||||
for (int index = 0; index < top_k_indexes.size(); ++index) {
|
||||
if (label_map_loaded_) {
|
||||
results.push_back(
|
||||
absl::StrCat(top_k_labels[index], ":", top_k_scores[index]));
|
||||
} else {
|
||||
results.push_back(
|
||||
absl::StrCat(top_k_indexes[index], ":", top_k_scores[index]));
|
||||
}
|
||||
}
|
||||
cc->Outputs().Tag("SUMMARY").AddPacket(
|
||||
MakePacket<std::string>(absl::StrJoin(results, ","))
|
||||
.At(cc->InputTimestamp()));
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("TOP_K_CLASSIFICATION")) {
|
||||
auto classification_list = absl::make_unique<ClassificationList>();
|
||||
for (int index = 0; index < top_k_indexes.size(); ++index) {
|
||||
Classification* classification =
|
||||
classification_list->add_classification();
|
||||
classification->set_index(top_k_indexes[index]);
|
||||
classification->set_score(top_k_scores[index]);
|
||||
if (label_map_loaded_) {
|
||||
classification->set_label(top_k_labels[index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -188,6 +229,7 @@ REGISTER_CALCULATOR(TopKScoresCalculator);
|
|||
while (std::getline(stream, line)) {
|
||||
label_map_[i++] = line;
|
||||
}
|
||||
label_map_loaded_ = true;
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
|
|
130
mediapipe/docs/android_archive_library.md
Normal file
130
mediapipe/docs/android_archive_library.md
Normal file
|
@ -0,0 +1,130 @@
|
|||
## MediaPipe Android Archive Library
|
||||
|
||||
***Experimental Only***
|
||||
|
||||
The MediaPipe Android archive library is a convenient way to use MediaPipe with
|
||||
Android Studio and Gradle. MediaPipe doesn't publish a general AAR that can be
|
||||
used by all projects. Instead, developers need to add a mediapipe_aar() target
|
||||
to generate a custom AAR file for their own projects. This is necessary in order
|
||||
to include specific resources such as MediaPipe calculators needed for each
|
||||
project.
|
||||
|
||||
### Steps to build a MediaPipe AAR
|
||||
|
||||
1. Create a mediapipe_aar() target.
|
||||
|
||||
In the MediaPipe directory, create a new mediapipe_aar() target in a BUILD
|
||||
file. You need to figure out what calculators are used in the graph and
|
||||
provide the calculator dependencies to the mediapipe_aar(). For example, to
|
||||
build an AAR for [face detection gpu](./face_detection_mobile_gpu.md), you
|
||||
can put the following code into
|
||||
mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/BUILD.
|
||||
|
||||
```
|
||||
load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_aar")
|
||||
|
||||
mediapipe_aar(
|
||||
name = "mp_face_detection_aar",
|
||||
calculators = ["//mediapipe/graphs/face_detection:mobile_calculators"],
|
||||
)
|
||||
```
|
||||
|
||||
2. Run the Bazel build command to generate the AAR.
|
||||
|
||||
```bash
|
||||
bazel build -c opt --fat_apk_cpu=arm64-v8a,armeabi-v7a //path/to/the/aar/build/file:aar_name
|
||||
```
|
||||
|
||||
For the face detection AAR target we made in the step 1, run:
|
||||
|
||||
```bash
|
||||
bazel build -c opt --fat_apk_cpu=arm64-v8a,armeabi-v7a \
|
||||
//mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mp_face_detection_aar
|
||||
|
||||
# It should print:
|
||||
# Target //mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mp_face_detection_aar up-to-date:
|
||||
# bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar
|
||||
```
|
||||
|
||||
3. (Optional) Save the AAR to your preferred location.
|
||||
|
||||
```bash
|
||||
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar
|
||||
/absolute/path/to/your/preferred/location
|
||||
```
|
||||
|
||||
### Steps to use a MediaPipe AAR in Android Studio with Gradle
|
||||
|
||||
1. Start Android Studio and go to your project.
|
||||
|
||||
2. Copy the AAR into app/libs.
|
||||
|
||||
```bash
|
||||
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example/mp_face_detection_aar.aar
|
||||
/path/to/your/app/libs/
|
||||
```
|
||||
|
||||
![Screenshot](images/mobile/aar_location.png)
|
||||
|
||||
3. Make app/src/main/assets and copy assets (graph, model, and etc) into
|
||||
app/src/main/assets.
|
||||
|
||||
Build the MediaPipe binary graph and copy the assets into
|
||||
app/src/main/assets, e.g., for the face detection graph, you need to build
|
||||
and copy
|
||||
[the binary graph](https://github.com/google/mediapipe/blob/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/BUILD#L41),
|
||||
[the tflite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_front.tflite),
|
||||
and
|
||||
[the label map](https://github.com/google/mediapipe/blob/master/mediapipe/models/face_detection_front_labelmap.txt).
|
||||
|
||||
```bash
|
||||
bazel build -c opt mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu:binary_graph
|
||||
cp bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/facedetectiongpu.binarypb /path/to/your/app/src/main/assets/
|
||||
cp mediapipe/models/face_detection_front.tflite /path/to/your/app/src/main/assets/
|
||||
cp mediapipe/models/face_detection_front_labelmap.txt /path/to/your/app/src/main/assets/
|
||||
```
|
||||
|
||||
![Screenshot](images/mobile/assets_location.png)
|
||||
|
||||
4. Make app/src/main/jniLibs and copy OpenCV JNI libraries into
|
||||
app/src/main/jniLibs.
|
||||
|
||||
MediaPipe depends on OpenCV, you will need to copy the precompiled OpenCV so
|
||||
files into app/src/main/jniLibs. You can download the official OpenCV
|
||||
Android SDK from
|
||||
[here](https://github.com/opencv/opencv/releases/download/4.1.0/opencv-4.1.0-android-sdk.zip)
|
||||
and run:
|
||||
|
||||
```bash
|
||||
cp -R ~/Downloads/OpenCV-android-sdk/sdk/native/libs/arm* /path/to/your/app/src/main/jniLibs/
|
||||
```
|
||||
|
||||
![Screenshot](images/mobile/android_studio_opencv_location.png)
|
||||
|
||||
5. Modify app/build.gradle to add MediaPipe dependencies and MediaPipe AAR.
|
||||
|
||||
```
|
||||
dependencies {
|
||||
implementation fileTree(dir: 'libs', include: ['*.jar', '*.aar'])
|
||||
implementation 'androidx.appcompat:appcompat:1.0.2'
|
||||
implementation 'androidx.constraintlayout:constraintlayout:1.1.3'
|
||||
testImplementation 'junit:junit:4.12'
|
||||
androidTestImplementation 'androidx.test.ext:junit:1.1.0'
|
||||
androidTestImplementation 'androidx.test.espresso:espresso-core:3.1.1'
|
||||
// MediaPipe deps
|
||||
implementation 'com.google.flogger:flogger:0.3.1'
|
||||
implementation 'com.google.flogger:flogger-system-backend:0.3.1'
|
||||
implementation 'com.google.code.findbugs:jsr305:3.0.2'
|
||||
implementation 'com.google.guava:guava:27.0.1-android'
|
||||
implementation 'com.google.guava:guava:27.0.1-android'
|
||||
// CameraX core library
|
||||
def camerax_version = "1.0.0-alpha06"
|
||||
implementation "androidx.camera:camera-core:$camerax_version"
|
||||
implementation "androidx.camera:camera-camera2:$camerax_version"
|
||||
}
|
||||
```
|
||||
|
||||
6. Follow our Android app examples to use MediaPipe in Android Studio for your
|
||||
use case. If you are looking for an example, a working face detection
|
||||
example can be found
|
||||
[here](https://github.com/jiuqiant/mediapipe_aar_example).
|
|
@ -96,8 +96,9 @@ using the MediaPipe C++ APIs.
|
|||
|
||||
### Feature Extration for YouTube-8M Challenge
|
||||
|
||||
[Feature Extration for YouTube-8M Challenge](./youtube_8m.md) shows how to use
|
||||
MediaPipe to prepare training data for the YouTube-8M Challenge.
|
||||
[Feature Extration and Model Inference for YouTube-8M Challenge](./youtube_8m.md)
|
||||
shows how to use MediaPipe to prepare training data for the YouTube-8M Challenge
|
||||
and do the model inference with the baseline model.
|
||||
|
||||
### Preparing Data Sets with MediaSequence
|
||||
|
||||
|
|
|
@ -36,10 +36,9 @@ $ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
|
|||
# INFO: 711 processes: 710 linux-sandbox, 1 local.
|
||||
# INFO: Build completed successfully, 734 total actions
|
||||
|
||||
$ export GLOG_logtostderr=1
|
||||
# This will open up your webcam as long as it is connected and on
|
||||
# Any errors is likely due to your webcam being not accessible
|
||||
$ bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_cpu \
|
||||
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_cpu \
|
||||
--calculator_graph_config_file=mediapipe/graphs/face_detection/face_detection_desktop_live.pbtxt
|
||||
```
|
||||
|
||||
|
@ -60,11 +59,10 @@ $ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS \
|
|||
# INFO: 711 processes: 710 linux-sandbox, 1 local.
|
||||
# INFO: Build completed successfully, 734 total actions
|
||||
|
||||
$ export GLOG_logtostderr=1
|
||||
# This will open up your webcam as long as it is connected and on
|
||||
# Any errors is likely due to your webcam being not accessible,
|
||||
# or GPU drivers not setup properly.
|
||||
$ bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_gpu \
|
||||
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_gpu \
|
||||
--calculator_graph_config_file=mediapipe/graphs/face_detection/face_detection_mobile_gpu.pbtxt
|
||||
```
|
||||
|
||||
|
|
|
@ -35,11 +35,10 @@ $ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS \
|
|||
#INFO: Streaming build results to: http://sponge2/37d5a184-293b-4e98-a43e-b22084db3142
|
||||
#INFO: Build completed successfully, 12210 total actions
|
||||
|
||||
$ export GLOG_logtostderr=1
|
||||
# This will open up your webcam as long as it is connected and on
|
||||
# Any errors is likely due to your webcam being not accessible,
|
||||
# or GPU drivers not setup properly.
|
||||
$ bazel-bin/mediapipe/examples/desktop/hair_segmentation/hair_segmentation_gpu \
|
||||
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hair_segmentation/hair_segmentation_gpu \
|
||||
--calculator_graph_config_file=mediapipe/graphs/hair_segmentation/hair_segmentation_mobile_gpu.pbtxt
|
||||
```
|
||||
|
||||
|
|
|
@ -35,10 +35,9 @@ $ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
|
|||
#INFO: Streaming build results to: http://sponge2/360196b9-33ab-44b1-84a7-1022b5043307
|
||||
#INFO: Build completed successfully, 12517 total actions
|
||||
|
||||
$ export GLOG_logtostderr=1
|
||||
# This will open up your webcam as long as it is connected and on
|
||||
# Any errors is likely due to your webcam being not accessible
|
||||
$ bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_cpu \
|
||||
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_cpu \
|
||||
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop_live.pbtxt
|
||||
```
|
||||
|
||||
|
@ -59,11 +58,10 @@ $ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS \
|
|||
#INFO: Streaming build results to: http://sponge2/00c7f95f-6fbc-432d-8978-f5d361efca3b
|
||||
#INFO: Build completed successfully, 22455 total actions
|
||||
|
||||
$ export GLOG_logtostderr=1
|
||||
# This will open up your webcam as long as it is connected and on
|
||||
# Any errors is likely due to your webcam being not accessible,
|
||||
# or GPU drivers not setup properly.
|
||||
$ bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \
|
||||
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \
|
||||
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt
|
||||
```
|
||||
|
||||
|
|
BIN
mediapipe/docs/images/mobile/aar_location.png
Normal file
BIN
mediapipe/docs/images/mobile/aar_location.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 35 KiB |
BIN
mediapipe/docs/images/mobile/android_studio_opencv_location.png
Normal file
BIN
mediapipe/docs/images/mobile/android_studio_opencv_location.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 75 KiB |
BIN
mediapipe/docs/images/mobile/assets_location.png
Normal file
BIN
mediapipe/docs/images/mobile/assets_location.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 56 KiB |
|
@ -24,7 +24,8 @@ Choose your operating system:
|
|||
To build and run Android apps:
|
||||
|
||||
- [Setting up Android SDK and NDK](#setting-up-android-sdk-and-ndk)
|
||||
- [Setting up Android Studio with MediaPipe](#setting-up-android-studio-with-mediapipe)
|
||||
- [Using MediaPipe with Gradle](#using-mediapipe-with-gradle)
|
||||
- [Using MediaPipe with Bazel](#using-mediapipe-with-bazel)
|
||||
|
||||
To build and run iOS apps:
|
||||
|
||||
|
@ -41,19 +42,11 @@ To build and run iOS apps:
|
|||
$ cd mediapipe
|
||||
```
|
||||
|
||||
2. Install Bazel (0.24.1 and above required).
|
||||
2. Install Bazel (version between 0.24.1 and 0.29.1).
|
||||
|
||||
Option 1. Use package manager tool to install the latest version of Bazel.
|
||||
|
||||
```bash
|
||||
$ sudo apt-get install bazel
|
||||
|
||||
# Run 'bazel version' to check version of bazel installed
|
||||
```
|
||||
|
||||
Option 2. Follow Bazel's
|
||||
Follow the official
|
||||
[documentation](https://docs.bazel.build/versions/master/install-ubuntu.html)
|
||||
to install any version of Bazel manually.
|
||||
to install Bazel manually. Note that MediaPipe doesn't support Bazel 1.0.0+ yet.
|
||||
|
||||
3. Install OpenCV and FFmpeg.
|
||||
|
||||
|
@ -75,10 +68,10 @@ To build and run iOS apps:
|
|||
[documentation](https://docs.opencv.org/3.4.6/d7/d9f/tutorial_linux_install.html)
|
||||
to manually build OpenCV from source code.
|
||||
|
||||
Note: You may need to modify [`WORKSAPCE`] and [`opencv_linux.BUILD`] to
|
||||
Note: You may need to modify [`WORKSPACE`] and [`opencv_linux.BUILD`] to
|
||||
point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed
|
||||
in "/usr/local/", you need to update the "linux_opencv" new_local_repository
|
||||
rule in [`WORKSAPCE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
|
||||
rule in [`WORKSPACE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
|
||||
like the following:
|
||||
|
||||
```bash
|
||||
|
@ -159,11 +152,11 @@ To build and run iOS apps:
|
|||
$ cd mediapipe
|
||||
```
|
||||
|
||||
2. Install Bazel (0.24.1 and above required).
|
||||
2. Install Bazel (version between 0.24.1 and 0.29.1).
|
||||
|
||||
Follow Bazel's
|
||||
Follow the official
|
||||
[documentation](https://docs.bazel.build/versions/master/install-redhat.html)
|
||||
to install Bazel manually.
|
||||
to install Bazel manually. Note that MediaPipe doesn't support Bazel 1.0.0+ yet.
|
||||
|
||||
3. Install OpenCV.
|
||||
|
||||
|
@ -178,10 +171,10 @@ To build and run iOS apps:
|
|||
|
||||
Option 2. Build OpenCV from source code.
|
||||
|
||||
Note: You may need to modify [`WORKSAPCE`] and [`opencv_linux.BUILD`] to
|
||||
Note: You may need to modify [`WORKSPACE`] and [`opencv_linux.BUILD`] to
|
||||
point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed
|
||||
in "/usr/local/", you need to update the "linux_opencv" new_local_repository
|
||||
rule in [`WORKSAPCE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
|
||||
rule in [`WORKSPACE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
|
||||
like the following:
|
||||
|
||||
```bash
|
||||
|
@ -237,7 +230,7 @@ To build and run iOS apps:
|
|||
|
||||
* Install [Homebrew](https://brew.sh).
|
||||
* Install [Xcode](https://developer.apple.com/xcode/) and its Command Line
|
||||
Tools.
|
||||
Tools by `xcode-select install`.
|
||||
|
||||
2. Checkout MediaPipe repository.
|
||||
|
||||
|
@ -247,19 +240,24 @@ To build and run iOS apps:
|
|||
$ cd mediapipe
|
||||
```
|
||||
|
||||
3. Install Bazel (0.24.1 and above required).
|
||||
3. Install Bazel (version between 0.24.1 and 0.29.1).
|
||||
|
||||
Option 1. Use package manager tool to install the latest version of Bazel.
|
||||
Option 1. Use package manager tool to install Bazel 0.29.1
|
||||
|
||||
```bash
|
||||
$ brew install bazel
|
||||
# If Bazel 1.0.0+ was installed.
|
||||
$ brew uninstall bazel
|
||||
|
||||
# Install Bazel 0.29.1
|
||||
$ brew install https://raw.githubusercontent.com/bazelbuild/homebrew-tap/223ffb570c21c0a2af251afc6df9dec0214c6e74/Formula/bazel.rb
|
||||
$ brew link bazel
|
||||
|
||||
# Run 'bazel version' to check version of bazel installed
|
||||
```
|
||||
|
||||
Option 2. Follow Bazel's
|
||||
Option 2. Follow the official
|
||||
[documentation](https://docs.bazel.build/versions/master/install-os-x.html#install-with-installer-mac-os-x)
|
||||
to install any version of Bazel manually.
|
||||
to install Bazel manually. Note that MediaPipe doesn't support Bazel 1.0.0+ yet.
|
||||
|
||||
4. Install OpenCV and FFmpeg.
|
||||
|
||||
|
@ -281,7 +279,7 @@ To build and run iOS apps:
|
|||
$ port install opencv
|
||||
```
|
||||
|
||||
Note: when using MacPorts, please edit the [`WORKSAPCE`],
|
||||
Note: when using MacPorts, please edit the [`WORKSPACE`],
|
||||
[`opencv_macos.BUILD`], and [`ffmpeg_macos.BUILD`] files like the following:
|
||||
|
||||
```bash
|
||||
|
@ -419,10 +417,10 @@ To build and run iOS apps:
|
|||
[documentation](https://docs.opencv.org/3.4.6/d7/d9f/tutorial_linux_install.html)
|
||||
to manually build OpenCV from source code.
|
||||
|
||||
Note: You may need to modify [`WORKSAPCE`] and [`opencv_linux.BUILD`] to
|
||||
Note: You may need to modify [`WORKSPACE`] and [`opencv_linux.BUILD`] to
|
||||
point MediaPipe to your own OpenCV libraries, e.g., if OpenCV 4 is installed
|
||||
in "/usr/local/", you need to update the "linux_opencv" new_local_repository
|
||||
rule in [`WORKSAPCE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
|
||||
rule in [`WORKSPACE`] and "opencv" cc_library rule in [`opencv_linux.BUILD`]
|
||||
like the following:
|
||||
|
||||
```bash
|
||||
|
@ -589,10 +587,20 @@ Please verify all the necessary packages are installed.
|
|||
* Android SDK Tools 26.1.1
|
||||
* Android NDK 17c or above
|
||||
|
||||
### Setting up Android Studio with MediaPipe
|
||||
### Using MediaPipe with Gradle
|
||||
|
||||
The steps below use Android Studio 3.5 to build and install a MediaPipe example
|
||||
app.
|
||||
MediaPipe can be used within an existing project, such as a Gradle project,
|
||||
using the MediaPipe AAR target defined in mediapipe_aar.bzl. Please see the
|
||||
separate [MediaPipe Android Archive Library](./android_archive_library.md)
|
||||
documentation.
|
||||
|
||||
### Using MediaPipe with Bazel
|
||||
|
||||
The MediaPipe project can be imported to Android Studio using the Bazel plugins.
|
||||
This allows the MediaPipe examples and demos to be built and modified in Android
|
||||
Studio. To incorporate MediaPipe into an existing Android Studio project, see:
|
||||
"Using MediaPipe with Gradle". The steps below use Android Studio 3.5 to build
|
||||
and install a MediaPipe example app.
|
||||
|
||||
1. Install and launch Android Studio 3.5.
|
||||
|
||||
|
@ -682,7 +690,7 @@ app.
|
|||
* Press the `[+]` button to add the new configuration.
|
||||
* Select `Run` to run the example app on the connected Android device.
|
||||
|
||||
[`WORKSAPCE`]: https://github.com/google/mediapipe/tree/master/WORKSPACE
|
||||
[`WORKSPACE`]: https://github.com/google/mediapipe/tree/master/WORKSPACE
|
||||
[`opencv_linux.BUILD`]: https://github.com/google/mediapipe/tree/master/third_party/opencv_linux.BUILD
|
||||
[`opencv_macos.BUILD`]: https://github.com/google/mediapipe/tree/master/third_party/opencv_macos.BUILD
|
||||
[`ffmpeg_macos.BUILD`]:https://github.com/google/mediapipe/tree/master/third_party/ffmpeg_macos.BUILD
|
||||
|
|
|
@ -35,10 +35,9 @@ $ bazel build -c opt \
|
|||
# INFO: 2675 processes: 2673 linux-sandbox, 2 local.
|
||||
# INFO: Build completed successfully, 2807 total actions
|
||||
|
||||
$ export GLOG_logtostderr=1
|
||||
# Replace <input video path> and <output video path>.
|
||||
# You can find a test video in mediapipe/examples/desktop/object_detection.
|
||||
$ bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tensorflow \
|
||||
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tensorflow \
|
||||
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tensorflow_graph.pbtxt \
|
||||
--input_side_packets=input_video_path=<input video path>,output_video_path=<output video path>
|
||||
```
|
||||
|
@ -200,10 +199,9 @@ $ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
|
|||
# INFO: 711 processes: 710 linux-sandbox, 1 local.
|
||||
# INFO: Build completed successfully, 734 total actions
|
||||
|
||||
$ export GLOG_logtostderr=1
|
||||
# Replace <input video path> and <output video path>.
|
||||
# You can find a test video in mediapipe/examples/desktop/object_detection.
|
||||
$ bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tflite \
|
||||
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tflite \
|
||||
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tflite_graph.pbtxt \
|
||||
--input_side_packets=input_video_path=<input video path>,output_video_path=<output video path>
|
||||
```
|
||||
|
@ -224,10 +222,9 @@ $ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
|
|||
#INFO: Streaming build results to: http://sponge2/1824d4cc-ba63-4350-bdc0-aacbd45b902b
|
||||
#INFO: Build completed successfully, 12154 total actions
|
||||
|
||||
$ export GLOG_logtostderr=1
|
||||
# This will open up your webcam as long as it is connected and on
|
||||
# Any errors is likely due to your webcam being not accessible
|
||||
$ bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_cpu \
|
||||
$ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_cpu \
|
||||
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_live.pbtxt
|
||||
```
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
## Extracting Video Features for YouTube-8M Challenge
|
||||
# Feature Extration and Model Inference for YouTube-8M Challenge
|
||||
|
||||
MediaPipe is a useful and general framework for media processing that can assist
|
||||
with research, development, and deployment of ML models. This example focuses on
|
||||
model development by demonstrating how to prepare training data for the
|
||||
YouTube-8M Challenge.
|
||||
model development by demonstrating how to prepare training data and do model
|
||||
inference for the YouTube-8M Challenge.
|
||||
|
||||
## Extracting Video Features for YouTube-8M Challenge
|
||||
|
||||
[Youtube-8M Challenge](https://www.kaggle.com/c/youtube8m-2019) is an annual
|
||||
video classification challenge hosted by Google. Over the last two years, the
|
||||
|
@ -29,14 +31,14 @@ videos.
|
|||
|
||||
### Steps to run the YouTube-8M feature extraction graph
|
||||
|
||||
1. Checkout the mediapipe repository
|
||||
1. Checkout the mediapipe repository.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/google/mediapipe.git
|
||||
cd mediapipe
|
||||
```
|
||||
|
||||
2. Download the PCA and model data
|
||||
2. Download the PCA and model data.
|
||||
|
||||
```bash
|
||||
mkdir /tmp/mediapipe
|
||||
|
@ -49,7 +51,7 @@ videos.
|
|||
tar -xvf /tmp/mediapipe/inception-2015-12-05.tgz
|
||||
```
|
||||
|
||||
3. Get the VGGish frozen graph
|
||||
3. Get the VGGish frozen graph.
|
||||
|
||||
Note: To run step 3 and step 4, you must have Python 2.7 or 3.5+ installed
|
||||
with the TensorFlow 1.14+ package installed.
|
||||
|
@ -60,24 +62,103 @@ videos.
|
|||
python -m mediapipe.examples.desktop.youtube8m.generate_vggish_frozen_graph
|
||||
```
|
||||
|
||||
4. Generate a MediaSequence metadata from the input video
|
||||
4. Generate a MediaSequence metadata from the input video.
|
||||
|
||||
Note: the output file is /tmp/mediapipe/metadata.tfrecord
|
||||
|
||||
```bash
|
||||
# change clip_end_time_sec to match the length of your video.
|
||||
python -m mediapipe.examples.desktop.youtube8m.generate_input_sequence_example \
|
||||
--path_to_input_video=/absolute/path/to/the/local/video/file
|
||||
--path_to_input_video=/absolute/path/to/the/local/video/file \
|
||||
--clip_end_time_sec=120
|
||||
```
|
||||
|
||||
5. Run the MediaPipe binary to extract the features
|
||||
5. Run the MediaPipe binary to extract the features.
|
||||
|
||||
```bash
|
||||
bazel build -c opt \
|
||||
--define MEDIAPIPE_DISABLE_GPU=1 --define no_aws_support=true \
|
||||
mediapipe/examples/desktop/youtube8m:extract_yt8m_features
|
||||
|
||||
./bazel-bin/mediapipe/examples/desktop/youtube8m/extract_yt8m_features
|
||||
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/youtube8m/extract_yt8m_features \
|
||||
--calculator_graph_config_file=mediapipe/graphs/youtube8m/feature_extraction.pbtxt \
|
||||
--input_side_packets=input_sequence_example=/tmp/mediapipe/metadata.tfrecord \
|
||||
--output_side_packets=output_sequence_example=/tmp/mediapipe/output.tfrecord
|
||||
```
|
||||
|
||||
## Model Inference for YouTube-8M Challenge
|
||||
|
||||
MediaPipe can help you do model inference for YouTube-8M Challenge with both
|
||||
local videos and the YouTube-8M dataset. To visualize
|
||||
[the graph for local videos](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/youtube8m/local_video_model_inference.pbtxt)
|
||||
and
|
||||
[the graph for the YouTube-8M dataset](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/youtube8m/yt8m_dataset_model_inference.pbtxt),
|
||||
copy the text specification of the graph and paste it into
|
||||
[MediaPipe Visualizer](https://viz.mediapipe.dev/). We use the baseline model
|
||||
[(model card)](https://drive.google.com/file/d/1xTCi9-Nm9dt2KIk8WR0dDFrIssWawyXy/view)
|
||||
in our example. But, the model inference pipeline is highly customizable. You
|
||||
are welcome to add new calculators or use your own machine learning models to do
|
||||
the inference for both local videos and the dataset
|
||||
|
||||
### Steps to run the YouTube-8M model inference graph with Web Interface
|
||||
|
||||
1. Copy the baseline model
|
||||
[(model card)](https://drive.google.com/file/d/1xTCi9-Nm9dt2KIk8WR0dDFrIssWawyXy/view)
|
||||
to local.
|
||||
|
||||
```bash
|
||||
curl -o /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz data.yt8m.org/models/baseline/saved_model.tar.gz
|
||||
|
||||
tar -xvf /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz -C /tmp/mediapipe
|
||||
```
|
||||
|
||||
2. Build the inference binary.
|
||||
|
||||
```bash
|
||||
bazel build -c opt --define='MEDIAPIPE_DISABLE_GPU=1' \
|
||||
mediapipe/examples/desktop/youtube8m:model_inference
|
||||
```
|
||||
|
||||
3. Run the python web server.
|
||||
|
||||
Note: pip install absl-py
|
||||
|
||||
```bash
|
||||
python mediapipe/examples/desktop/youtube8m/viewer/server.py --root `pwd`
|
||||
```
|
||||
|
||||
Navigate to localhost:8008 in a web browser.
|
||||
[Here](https://drive.google.com/file/d/19GSvdAAuAlACpBhHOaqMWZ_9p8bLUYKh/view?usp=sharing)
|
||||
is a demo video showing the steps to use this web application. Also please
|
||||
read
|
||||
[youtube8m/README.md](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/youtube8m/README.md)
|
||||
if you prefer to run the underlying model_inference binary in command line.
|
||||
|
||||
### Steps to run the YouTube-8M model inference graph with a local video
|
||||
|
||||
1. Make sure you have the output tfrecord from the feature extraction pipeline.
|
||||
|
||||
2. Copy the baseline model
|
||||
[(model card)](https://drive.google.com/file/d/1xTCi9-Nm9dt2KIk8WR0dDFrIssWawyXy/view)
|
||||
to local.
|
||||
|
||||
```bash
|
||||
curl -o /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz data.yt8m.org/models/baseline/saved_model.tar.gz
|
||||
|
||||
tar -xvf /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz -C /tmp/mediapipe
|
||||
```
|
||||
|
||||
3. Build and run the inference binary.
|
||||
|
||||
```bash
|
||||
bazel build -c opt --define='MEDIAPIPE_DISABLE_GPU=1' \
|
||||
mediapipe/examples/desktop/youtube8m:model_inference
|
||||
|
||||
# segment_size is the number of seconds window of frames.
|
||||
# overlap is the number of seconds adjacent segments share.
|
||||
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/youtube8m/model_inference \
|
||||
--calculator_graph_config_file=mediapipe/graphs/youtube8m/local_video_model_inference.pbtxt \
|
||||
--input_side_packets=input_sequence_example_path=/tmp/mediapipe/output.tfrecord,input_video_path=/absolute/path/to/the/local/video/file,output_video_path=/tmp/mediapipe/annotated_video.mp4,segment_size=5,overlap=4
|
||||
```
|
||||
|
||||
4. View the annotated video.
|
||||
|
|
|
@ -27,7 +27,9 @@ cc_library(
|
|||
"//mediapipe/framework/port:file_helpers",
|
||||
"//mediapipe/framework/port:map_util",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -13,14 +13,23 @@
|
|||
// limitations under the License.
|
||||
//
|
||||
// A simple main function to run a MediaPipe graph.
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/commandlineflags.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/map_util.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
|
||||
DEFINE_string(
|
||||
calculator_graph_config_file, "",
|
||||
|
@ -31,14 +40,72 @@ DEFINE_string(input_side_packets, "",
|
|||
"for the CalculatorGraph. All values will be treated as the "
|
||||
"string type even if they represent doubles, floats, etc.");
|
||||
|
||||
// Local file output flags.
|
||||
// Output stream
|
||||
DEFINE_string(output_stream, "",
|
||||
"The output stream to output to the local file in csv format.");
|
||||
DEFINE_string(output_stream_file, "",
|
||||
"The name of the local file to output all packets sent to "
|
||||
"the stream specified with --output_stream. ");
|
||||
DEFINE_bool(strip_timestamps, false,
|
||||
"If true, only the packet contents (without timestamps) will be "
|
||||
"written into the local file.");
|
||||
// Output side packets
|
||||
DEFINE_string(output_side_packets, "",
|
||||
"A CSV of output side packets to output to local file.");
|
||||
DEFINE_string(output_side_packets_file, "",
|
||||
"The name of the local file to output all side packets specified "
|
||||
"with --output_side_packets. ");
|
||||
|
||||
::mediapipe::Status OutputStreamToLocalFile(
|
||||
::mediapipe::OutputStreamPoller& poller) {
|
||||
std::ofstream file;
|
||||
file.open(FLAGS_output_stream_file);
|
||||
::mediapipe::Packet packet;
|
||||
while (poller.Next(&packet)) {
|
||||
std::string output_data;
|
||||
if (!FLAGS_strip_timestamps) {
|
||||
absl::StrAppend(&output_data, packet.Timestamp().Value(), ",");
|
||||
}
|
||||
absl::StrAppend(&output_data, packet.Get<std::string>(), "\n");
|
||||
file << output_data;
|
||||
}
|
||||
file.close();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status OutputSidePacketsToLocalFile(
|
||||
::mediapipe::CalculatorGraph& graph) {
|
||||
if (!FLAGS_output_side_packets.empty() &&
|
||||
!FLAGS_output_side_packets_file.empty()) {
|
||||
std::ofstream file;
|
||||
file.open(FLAGS_output_side_packets_file);
|
||||
std::vector<std::string> side_packet_names =
|
||||
absl::StrSplit(FLAGS_output_side_packets, ',');
|
||||
for (const std::string& side_packet_name : side_packet_names) {
|
||||
ASSIGN_OR_RETURN(auto status_or_packet,
|
||||
graph.GetOutputSidePacket(side_packet_name));
|
||||
file << absl::StrCat(side_packet_name, ":",
|
||||
status_or_packet.Get<std::string>(), "\n");
|
||||
}
|
||||
file.close();
|
||||
} else {
|
||||
RET_CHECK(FLAGS_output_side_packets.empty() &&
|
||||
FLAGS_output_side_packets_file.empty())
|
||||
<< "--output_side_packets and --output_side_packets_file should be "
|
||||
"specified in pair.";
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status RunMPPGraph() {
|
||||
std::string calculator_graph_config_contents;
|
||||
MP_RETURN_IF_ERROR(mediapipe::file::GetContents(
|
||||
MP_RETURN_IF_ERROR(::mediapipe::file::GetContents(
|
||||
FLAGS_calculator_graph_config_file, &calculator_graph_config_contents));
|
||||
LOG(INFO) << "Get calculator graph config contents: "
|
||||
<< calculator_graph_config_contents;
|
||||
mediapipe::CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig>(
|
||||
::mediapipe::CalculatorGraphConfig config =
|
||||
::mediapipe::ParseTextProtoOrDie<::mediapipe::CalculatorGraphConfig>(
|
||||
calculator_graph_config_contents);
|
||||
std::map<std::string, ::mediapipe::Packet> input_side_packets;
|
||||
std::vector<std::string> kv_pairs =
|
||||
|
@ -51,10 +118,23 @@ DEFINE_string(input_side_packets, "",
|
|||
::mediapipe::MakePacket<std::string>(name_and_value[1]);
|
||||
}
|
||||
LOG(INFO) << "Initialize the calculator graph.";
|
||||
mediapipe::CalculatorGraph graph;
|
||||
::mediapipe::CalculatorGraph graph;
|
||||
MP_RETURN_IF_ERROR(graph.Initialize(config, input_side_packets));
|
||||
LOG(INFO) << "Start running the calculator graph.";
|
||||
return graph.Run();
|
||||
if (!FLAGS_output_stream.empty() && !FLAGS_output_stream_file.empty()) {
|
||||
ASSIGN_OR_RETURN(auto poller,
|
||||
graph.AddOutputStreamPoller(FLAGS_output_stream));
|
||||
LOG(INFO) << "Start running the calculator graph.";
|
||||
MP_RETURN_IF_ERROR(graph.StartRun({}));
|
||||
MP_RETURN_IF_ERROR(OutputStreamToLocalFile(poller));
|
||||
} else {
|
||||
RET_CHECK(FLAGS_output_stream.empty() && FLAGS_output_stream_file.empty())
|
||||
<< "--output_stream and --output_stream_file should be specified in "
|
||||
"pair.";
|
||||
LOG(INFO) << "Start running the calculator graph.";
|
||||
MP_RETURN_IF_ERROR(graph.StartRun({}));
|
||||
}
|
||||
MP_RETURN_IF_ERROR(graph.WaitUntilDone());
|
||||
return OutputSidePacketsToLocalFile(graph);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
|
|
|
@ -33,3 +33,14 @@ cc_binary(
|
|||
"@org_tensorflow//tensorflow/core:direct_session",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "model_inference",
|
||||
deps = [
|
||||
"//mediapipe/examples/desktop:simple_run_graph_main",
|
||||
"//mediapipe/graphs/youtube8m:yt8m_inference_calculators_deps",
|
||||
# TODO: Figure out the minimum set of the kernels needed by this example.
|
||||
"@org_tensorflow//tensorflow/core:all_kernels",
|
||||
"@org_tensorflow//tensorflow/core:direct_session",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
### Steps to run the YouTube-8M feature extraction graph
|
||||
|
||||
1. Checkout the mediapipe repository
|
||||
1. Checkout the mediapipe repository.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/google/mediapipe.git
|
||||
cd mediapipe
|
||||
```
|
||||
|
||||
2. Download the PCA and model data
|
||||
2. Download the PCA and model data.
|
||||
|
||||
```bash
|
||||
mkdir /tmp/mediapipe
|
||||
|
@ -20,7 +20,7 @@
|
|||
tar -xvf /tmp/mediapipe/inception-2015-12-05.tgz
|
||||
```
|
||||
|
||||
3. Get the VGGish frozen graph
|
||||
3. Get the VGGish frozen graph.
|
||||
|
||||
Note: To run step 3 and step 4, you must have Python 2.7 or 3.5+ installed
|
||||
with the TensorFlow 1.14+ package installed.
|
||||
|
@ -31,26 +31,114 @@
|
|||
python -m mediapipe.examples.desktop.youtube8m.generate_vggish_frozen_graph
|
||||
```
|
||||
|
||||
4. Generate a MediaSequence metadata from the input video
|
||||
4. Generate a MediaSequence metadata from the input video.
|
||||
|
||||
Note: the output file is /tmp/mediapipe/metadata.tfrecord
|
||||
|
||||
```bash
|
||||
# change clip_end_time_sec to match the length of your video.
|
||||
python -m mediapipe.examples.desktop.youtube8m.generate_input_sequence_example \
|
||||
--path_to_input_video=/absolute/path/to/the/local/video/file \
|
||||
--clip_start_time_sec=0 \
|
||||
--clip_end_time_sec=10
|
||||
--clip_end_time_sec=120
|
||||
```
|
||||
|
||||
5. Run the MediaPipe binary to extract the features
|
||||
5. Run the MediaPipe binary to extract the features.
|
||||
|
||||
```bash
|
||||
bazel build -c opt \
|
||||
--define MEDIAPIPE_DISABLE_GPU=1 --define no_aws_support=true \
|
||||
mediapipe/examples/desktop/youtube8m:extract_yt8m_features
|
||||
|
||||
./bazel-bin/mediapipe/examples/desktop/youtube8m/extract_yt8m_features \
|
||||
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/youtube8m/extract_yt8m_features \
|
||||
--calculator_graph_config_file=mediapipe/graphs/youtube8m/feature_extraction.pbtxt \
|
||||
--input_side_packets=input_sequence_example=/tmp/mediapipe/metadata.tfrecord \
|
||||
--output_side_packets=output_sequence_example=/tmp/mediapipe/output.tfrecord
|
||||
```
|
||||
|
||||
### Steps to run the YouTube-8M inference graph with the YT8M dataset
|
||||
|
||||
1. Download the YT8M dataset
|
||||
|
||||
For example, download one shard of the training data:
|
||||
|
||||
```bash
|
||||
curl http://us.data.yt8m.org/2/frame/train/trainpj.tfrecord --output /tmp/mediapipe/trainpj.tfrecord
|
||||
```
|
||||
|
||||
2. Copy the baseline model [(model card)](https://drive.google.com/file/d/1xTCi9-Nm9dt2KIk8WR0dDFrIssWawyXy/view) to local.
|
||||
|
||||
```bash
|
||||
curl -o /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz data.yt8m.org/models/baseline/saved_model.tar.gz
|
||||
|
||||
tar -xvf /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz -C /tmp/mediapipe
|
||||
```
|
||||
|
||||
3. Build and run the inference binary.
|
||||
|
||||
```bash
|
||||
bazel build -c opt --define='MEDIAPIPE_DISABLE_GPU=1' \
|
||||
mediapipe/examples/desktop/youtube8m:model_inference
|
||||
|
||||
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/youtube8m/model_inference \
|
||||
--calculator_graph_config_file=mediapipe/graphs/youtube8m/yt8m_dataset_model_inference.pbtxt \
|
||||
--input_side_packets=tfrecord_path=/tmp/mediapipe/trainpj.tfrecord,record_index=0,desired_segment_size=5 \
|
||||
--output_stream=annotation_summary \
|
||||
--output_stream_file=/tmp/summary \
|
||||
--output_side_packets=yt8m_id \
|
||||
--output_side_packets_file=/tmp/yt8m_id
|
||||
```
|
||||
|
||||
### Steps to run the YouTube-8M model inference graph with Web Interface
|
||||
|
||||
1. Copy the baseline model [(model card)](https://drive.google.com/file/d/1xTCi9-Nm9dt2KIk8WR0dDFrIssWawyXy/view) to local.
|
||||
|
||||
|
||||
```bash
|
||||
curl -o /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz data.yt8m.org/models/baseline/saved_model.tar.gz
|
||||
|
||||
tar -xvf /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz -C /tmp/mediapipe
|
||||
```
|
||||
|
||||
2. Build the inference binary.
|
||||
|
||||
```bash
|
||||
bazel build -c opt --define='MEDIAPIPE_DISABLE_GPU=1' \
|
||||
mediapipe/examples/desktop/youtube8m:model_inference
|
||||
```
|
||||
|
||||
3. Run the python web server.
|
||||
|
||||
Note: pip install absl-py
|
||||
|
||||
```bash
|
||||
python mediapipe/examples/desktop/youtube8m/viewer/server.py --root `pwd`
|
||||
```
|
||||
|
||||
Navigate to localhost:8008 in a web browser.
|
||||
|
||||
### Steps to run the YouTube-8M model inference graph with a local video
|
||||
|
||||
1. Make sure you have the output tfrecord from the feature extraction pipeline.
|
||||
|
||||
2. Copy the baseline model [(model card)](https://drive.google.com/file/d/1xTCi9-Nm9dt2KIk8WR0dDFrIssWawyXy/view) to local.
|
||||
|
||||
```bash
|
||||
curl -o /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz data.yt8m.org/models/baseline/saved_model.tar.gz
|
||||
|
||||
tar -xvf /tmp/mediapipe/yt8m_baseline_saved_model.tar.gz -C /tmp/mediapipe
|
||||
```
|
||||
|
||||
3. Build and run the inference binary.
|
||||
|
||||
```bash
|
||||
bazel build -c opt --define='MEDIAPIPE_DISABLE_GPU=1' \
|
||||
mediapipe/examples/desktop/youtube8m:model_inference
|
||||
|
||||
# segment_size is the number of seconds window of frames.
|
||||
# overlap is the number of seconds adjacent segments share.
|
||||
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/youtube8m/model_inference \
|
||||
--calculator_graph_config_file=mediapipe/graphs/youtube8m/local_video_model_inference.pbtxt \
|
||||
--input_side_packets=input_sequence_example_path=/tmp/mediapipe/output.tfrecord,input_video_path=/absolute/path/to/the/local/video/file,output_video_path=/tmp/mediapipe/annotated_video.mp4,segment_size=5,overlap=4
|
||||
```
|
||||
|
||||
4. View the annotated video.
|
||||
|
|
262
mediapipe/examples/desktop/youtube8m/viewer/server.py
Normal file
262
mediapipe/examples/desktop/youtube8m/viewer/server.py
Normal file
|
@ -0,0 +1,262 @@
|
|||
"""Server for YouTube8M Model Inference Demo.
|
||||
|
||||
Serves up both the static files for the website and provides a service that
|
||||
fetches the video id and timestamp based labels for a video analyzed in a
|
||||
tfrecord files.
|
||||
|
||||
"""
|
||||
from __future__ import print_function
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import http.client
|
||||
import http.server
|
||||
from six.moves.urllib import parse
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_bool("show_label_at_center", False,
|
||||
"Show labels at the center of the segment.")
|
||||
flags.DEFINE_integer("port", 8008, "Port that the API is served over.")
|
||||
flags.DEFINE_string("tmp_dir", "/tmp/mediapipe",
|
||||
"Temporary asset storage location.")
|
||||
flags.DEFINE_string("root", "", "MediaPipe root directory.")
|
||||
# binary, pbtxt, label_map paths are relative to 'root' path
|
||||
flags.DEFINE_string(
|
||||
"binary",
|
||||
"bazel-bin/mediapipe/examples/desktop/youtube8m/model_inference",
|
||||
"Inference binary location.")
|
||||
flags.DEFINE_string(
|
||||
"pbtxt",
|
||||
"mediapipe/graphs/youtube8m/yt8m_dataset_model_inference.pbtxt",
|
||||
"Default pbtxt graph file.")
|
||||
flags.DEFINE_string("label_map", "mediapipe/graphs/youtube8m/label_map.txt",
|
||||
"Default label map text file.")
|
||||
|
||||
|
||||
class HTTPServerV6(http.server.HTTPServer):
|
||||
address_family = socket.AF_INET6
|
||||
|
||||
|
||||
class Youtube8MRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||
"""Static file server with /healthz support."""
|
||||
|
||||
def do_GET(self):
|
||||
if self.path.startswith("/healthz"):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/plain")
|
||||
self.send_header("Content-length", 2)
|
||||
self.end_headers()
|
||||
self.wfile.write("ok")
|
||||
if self.path.startswith("/video"):
|
||||
parsed_params = parse.urlparse(self.path)
|
||||
url_params = parse.parse_qs(parsed_params.query)
|
||||
|
||||
tfrecord_path = ""
|
||||
segment_size = 5
|
||||
|
||||
print(url_params)
|
||||
if "file" in url_params:
|
||||
tfrecord_path = url_params["file"][0]
|
||||
if "segments" in url_params:
|
||||
segment_size = int(url_params["segments"][0])
|
||||
|
||||
self.fetch(tfrecord_path, segment_size)
|
||||
|
||||
else:
|
||||
if self.path == "/":
|
||||
self.path = "/index.html"
|
||||
# Default to serve up a local file
|
||||
self.path = "/static" + self.path
|
||||
http.server.SimpleHTTPRequestHandler.do_GET(self)
|
||||
|
||||
def report_error(self, msg):
|
||||
"""Simplifies sending out a string as a 500 http response."""
|
||||
self.send_response(500)
|
||||
self.send_header("Content-type", "text/plain")
|
||||
self.end_headers()
|
||||
if sys.version_info[0] < 3:
|
||||
self.wfile.write(str(msg).encode("utf-8"))
|
||||
else:
|
||||
self.wfile.write(bytes(msg, "utf-8"))
|
||||
|
||||
def report_missing_files(self, files):
|
||||
"""Sends out 500 response with missing files."""
|
||||
accumulate = ""
|
||||
for file_path in files:
|
||||
if not os.path.exists(file_path):
|
||||
accumulate = "%s '%s'" % (accumulate, file_path)
|
||||
|
||||
if accumulate:
|
||||
self.report_error("Could not find:%s" % accumulate)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def fetch(self, path, segment_size):
|
||||
"""Returns the video id and labels for a tfrecord at a provided index."""
|
||||
|
||||
print("Received request. File=", path, "Segment Size =", segment_size)
|
||||
|
||||
if (self.report_missing_files([
|
||||
"%s/%s" % (FLAGS.root, FLAGS.pbtxt),
|
||||
"%s/%s" % (FLAGS.root, FLAGS.binary),
|
||||
"%s/%s" % (FLAGS.root, FLAGS.label_map)
|
||||
])):
|
||||
return
|
||||
|
||||
# Parse the youtube video id off the end of the link or as a standalone id.
|
||||
filename_match = re.match(
|
||||
"(?:.*youtube.*v=)?([a-zA-Z-0-9_]{2})([a-zA-Z-0-9_]+)", path)
|
||||
tfrecord_url = filename_match.expand(r"data.yt8m.org/2/j/r/\1/\1\2.js")
|
||||
|
||||
print("Trying to get tfrecord via", tfrecord_url)
|
||||
|
||||
connection = http.client.HTTPConnection("data.yt8m.org")
|
||||
connection.request("GET", tfrecord_url)
|
||||
response = connection.getresponse()
|
||||
|
||||
response_object = json.loads(response.read())
|
||||
filename = response_object["filename_raw"]
|
||||
index = response_object["index"]
|
||||
|
||||
print("TFRecord discovered: ", filename, ", index", index)
|
||||
|
||||
output_file = r"%s/%s" % (FLAGS.tmp_dir, filename)
|
||||
tfrecord_url = r"http://us.data.yt8m.org/2/frame/train/%s" % filename
|
||||
|
||||
connection = http.client.HTTPConnection("us.data.yt8m.org")
|
||||
connection.request("HEAD",
|
||||
filename_match.expand(r"/2/frame/train/%s" % filename))
|
||||
response = connection.getresponse()
|
||||
if response.getheader("Content-Type") != "application/octet-stream":
|
||||
self.report_error("Filename '%s' is invalid." % path)
|
||||
|
||||
print(output_file, "exists on yt8m.org. Did we fetch this before?")
|
||||
|
||||
if not os.path.exists(output_file):
|
||||
print(output_file, "doesn't exist locally, download it now.")
|
||||
return_code = subprocess.call(
|
||||
["curl", "--output", output_file, tfrecord_url],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
if return_code:
|
||||
self.report_error("Could not retrieve contents from %s" % tfrecord_url)
|
||||
return
|
||||
else:
|
||||
print(output_file, "exist locally, reuse it.")
|
||||
|
||||
print("Run the graph...")
|
||||
process = subprocess.Popen([
|
||||
"%s/%s" % (FLAGS.root, FLAGS.binary),
|
||||
"--calculator_graph_config_file=%s/%s" % (FLAGS.root, FLAGS.pbtxt),
|
||||
"--input_side_packets=tfrecord_path=%s" % output_file +
|
||||
",record_index=%d" % index + ",desired_segment_size=%d" % segment_size,
|
||||
"--output_stream=annotation_summary",
|
||||
"--output_stream_file=%s/labels" % FLAGS.tmp_dir,
|
||||
"--output_side_packets=yt8m_id",
|
||||
"--output_side_packets_file=%s/yt8m_id" % FLAGS.tmp_dir
|
||||
],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
stdout_str, stderr_str = process.communicate()
|
||||
process.wait()
|
||||
|
||||
if stderr_str and "success" not in str(stderr_str).lower():
|
||||
self.report_error("Error executing server binary: \n%s" % stderr_str)
|
||||
return
|
||||
|
||||
f = open("%s/yt8m_id" % FLAGS.tmp_dir, "r")
|
||||
contents = f.read()
|
||||
print("yt8m_id is", contents[-5:-1])
|
||||
|
||||
curl_arg = "data.yt8m.org/2/j/i/%s/%s.js" % (contents[-5:-3],
|
||||
contents[-5:-1])
|
||||
print("Grab labels from", curl_arg)
|
||||
process = subprocess.Popen(["curl", curl_arg],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
stdout = process.communicate()
|
||||
process.wait()
|
||||
|
||||
stdout_str = stdout[0].decode("utf-8")
|
||||
|
||||
match = re.match(""".+"([^"]+)"[^"]+""", stdout_str)
|
||||
final_results = {
|
||||
"video_id": match.group(1),
|
||||
"link": "https://www.youtube.com/watch?v=%s" % match.group(1),
|
||||
"entries": []
|
||||
}
|
||||
f = open("%s/labels" % FLAGS.tmp_dir, "r")
|
||||
lines = f.readlines()
|
||||
show_at_center = FLAGS.show_label_at_center
|
||||
|
||||
print("%s/labels" % FLAGS.tmp_dir, "holds", len(lines), "entries")
|
||||
for line in lines:
|
||||
entry = {"labels": []}
|
||||
final_results["entries"].append(entry)
|
||||
first = True
|
||||
for column in line.split(","):
|
||||
if first:
|
||||
subtract = segment_size / 2.0 if show_at_center else 0.0
|
||||
entry["time"] = float(int(column)) / 1000000.0 - subtract
|
||||
first = False
|
||||
else:
|
||||
label_score = re.match("(.+):([0-9.]+).*", column)
|
||||
if label_score:
|
||||
score = float(label_score.group(2))
|
||||
entry["labels"].append({
|
||||
"label": label_score.group(1),
|
||||
"score": score
|
||||
})
|
||||
else:
|
||||
print("empty score")
|
||||
|
||||
response_json = json.dumps(final_results, indent=2, separators=(",", ": "))
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
if sys.version_info[0] < 3:
|
||||
self.wfile.write(str(response_json).encode("utf-8"))
|
||||
else:
|
||||
self.wfile.write(bytes(response_json, "utf-8"))
|
||||
|
||||
|
||||
def update_pbtxt():
|
||||
"""Update graph.pbtxt to use full path to label_map.txt."""
|
||||
edited_line = ""
|
||||
lines = []
|
||||
with open("%s/%s" % (FLAGS.root, FLAGS.pbtxt), "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
if "label_map_path" in line:
|
||||
kv = line.split(":")
|
||||
edited_line = kv[0] + (": \"%s/%s\"\n" % (FLAGS.root, FLAGS.label_map))
|
||||
with open("%s/%s" % (FLAGS.root, FLAGS.pbtxt), "w") as f:
|
||||
for line in lines:
|
||||
if "label_map_path" in line:
|
||||
f.write(edited_line)
|
||||
else:
|
||||
f.write(line)
|
||||
|
||||
|
||||
def main(unused_args):
|
||||
dname = os.path.dirname(os.path.abspath(__file__))
|
||||
os.chdir(dname)
|
||||
if not FLAGS.root:
|
||||
print("Must specify MediaPipe root directory: --root `pwd`")
|
||||
return
|
||||
update_pbtxt()
|
||||
port = FLAGS.port
|
||||
print("Listening on port %s" % port) # pylint: disable=superfluous-parens
|
||||
server = HTTPServerV6(("::", int(port)), Youtube8MRequestHandler)
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
|
@ -0,0 +1,96 @@
|
|||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>MediaPipe: YouTube8M Model Inference Demo</title>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
|
||||
<base href="/">
|
||||
<script src="main.js"></script>
|
||||
<link rel="stylesheet"
|
||||
href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css"
|
||||
integrity="sha384-ggOyR0iXCbMQv3Xipma34MD+dH/1fQ784/j6cY/iJTQUOhcWr7x9JvoRxT2MZw1T"
|
||||
crossorigin="anonymous">
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div class="container-fluid">
|
||||
<h2>
|
||||
MediaPipe: YouTube8M Model Inference Demo
|
||||
</h2>
|
||||
<form id="form">
|
||||
<div class="row">
|
||||
<div class="card m-2" style="width: 640px;">
|
||||
<div>
|
||||
<div style="position:relative;">
|
||||
<iframe id="ytplayer" style="display:none;" type="text/html" width="640" height="320"
|
||||
src="https://www.youtube.com/embed/M7lc1UVf-VE?enablejsapi=1" frameborder="0"
|
||||
enablejsapi="1"></iframe>
|
||||
<div id="cover" class="bg-warning"
|
||||
style="width:640px; height:320px;">
|
||||
</div>
|
||||
<div id="spinner" class="bg-warning"
|
||||
style="display: none; width:640px; height:320px;">
|
||||
<div class="spinner-border" role="status"
|
||||
style="position:relative; left:300px; top:130px;">
|
||||
<span class="sr-only">Loading...</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="card-body shadow">
|
||||
<div class="row mb-2">
|
||||
<ul class="nav">
|
||||
<li class="nav-item">
|
||||
<a class="nav-link"
|
||||
href="https://research.google.com/youtube8m/explore.html"
|
||||
target="_">Explore Videos</a>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="file">YouTube video ID</label>
|
||||
<input type="text" class="form-control" name="file" id="file"
|
||||
placeholder="Enter a YouTube link or a YouTube ID">
|
||||
<small class="form-text text-muted">
|
||||
e.g., Both "https://youtube.com/watch?v=huGVGe3Afng" or "huGVGe3Afng" will work.
|
||||
</small>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label id="segments_label" for="segments">Segment Size</label>
|
||||
<input type="range" min="1" max="300" step="1" value="5"
|
||||
class="form-control-range" name="segments" id="segments">
|
||||
</div>
|
||||
<button type="submit" class="btn btn-primary">Submit</button>
|
||||
<div id="error_msg" style="visibility:hidden;" class="alert alert-danger mt-2"
|
||||
role="alert"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="card m-2 shadow">
|
||||
<div class="card-body">
|
||||
<div class="form-group">
|
||||
<label id="threshold_label" for="threshold">Score Threshold</label>
|
||||
<input type="range" min="0" max="0.99" step="0.01" value="0.2"
|
||||
class="form-control-range" name="threshold" id="threshold">
|
||||
</div>
|
||||
<h5>
|
||||
Labels
|
||||
</h5>
|
||||
<textarea id="feedback" style="height:320px; width:500px;"></textarea>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</body>
|
||||
<script async src="https://www.youtube.com/iframe_api"></script>
|
||||
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.2.1/jquery.min.js"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.14.7/umd/popper.min.js"
|
||||
integrity="sha384-UO2eT0CpHqdSJQ6hJty5KVphtPhzWj9WO1clHTMGa3JDZwrnQq4sF86dIHNDz0W1"
|
||||
crossorigin="anonymous"></script>
|
||||
<script src="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/js/bootstrap.min.js"
|
||||
integrity="sha384-JjSmVgyd0p3pXB1rRibZUAYoIIy6OrQ6VrjIEaFf/nJGzIxFDsf4x0xIM+B07jRM"
|
||||
crossorigin="anonymous"></script>
|
||||
</html>
|
||||
|
||||
|
217
mediapipe/examples/desktop/youtube8m/viewer/static/main.js
Normal file
217
mediapipe/examples/desktop/youtube8m/viewer/static/main.js
Normal file
|
@ -0,0 +1,217 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2019 The MediaPipe Authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
const STATE_PLAYER=0;
|
||||
const STATE_COVER=1;
|
||||
const STATE_SPINNER=2;
|
||||
|
||||
/**
|
||||
* Looks up the value of a url parameter.
|
||||
*
|
||||
* @param {string} param The name of the parameter.
|
||||
* @return {?string} The parameter value or null if there is no such parameter.
|
||||
*/
|
||||
var getUrlParameter = function(param) {
|
||||
const url = decodeURIComponent(window.location.search.substring(1));
|
||||
const url_parts = url.split('&');
|
||||
for (var i = 0; i < url_parts.length; i++) {
|
||||
const param_name = url_parts[i].split(/=(.*)/);
|
||||
if (param_name[0] === param) {
|
||||
return param_name[1] === undefined ? null : param_name[1];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Sets the fields in the form to match the values of the URL parameters.
|
||||
*/
|
||||
const updateFormFromURL = function() {
|
||||
const form_elements = document.getElementById('form').elements;
|
||||
const url = decodeURIComponent(window.location.search.substring(1));
|
||||
const url_parts = url.split('&');
|
||||
for (var i = 0; i < url_parts.length; i++) {
|
||||
const p = url_parts[i].split(/=(.*)/);
|
||||
if (p.length >= 2) {
|
||||
if (form_elements[p[0]]) {
|
||||
form_elements[p[0]].value = decodeURIComponent(p[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let player = null;
|
||||
let intervalID = undefined;
|
||||
let entries = [];
|
||||
|
||||
/**
|
||||
* Constructs the embedded YouTube player.
|
||||
*/
|
||||
window.onYouTubeIframeAPIReady = () => {
|
||||
player = new YT.Player('ytplayer', {
|
||||
events: {
|
||||
'onReady': onPlayerReady,
|
||||
'onStateChange': onStateChange
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* Listens for YouTube video events. When video is playing, periodically checks
|
||||
* the time signature and updates the feedback with labels. When video stops,
|
||||
* shuts off interval timer to save cycles.
|
||||
* @param {!Event} event YouTube API Event.
|
||||
*/
|
||||
function onStateChange(event) {
|
||||
if (event.data === 1) {
|
||||
// Youtube switched to playing.
|
||||
intervalID = setInterval(function(){
|
||||
const currentTime = player.getCurrentTime();
|
||||
let winner = undefined;
|
||||
let first = undefined;
|
||||
for (entry of entries) {
|
||||
if (!first) {
|
||||
first = entry.labels;
|
||||
}
|
||||
if (entry.time < currentTime) {
|
||||
winner = entry.labels;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!winner) {
|
||||
winner = first;
|
||||
}
|
||||
const threshold =
|
||||
document.getElementById('form').elements['threshold'].value;
|
||||
let message = "";
|
||||
for (var label of winner) {
|
||||
if (label.score >= threshold) {
|
||||
message = `${message}${label.label} (score: ${label.score})\n`;
|
||||
}
|
||||
}
|
||||
$("textarea#feedback").val(message);
|
||||
});
|
||||
} else {
|
||||
if (intervalID) {
|
||||
clearInterval(intervalID);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Turns elements of the player on and off to reflect the state of the "app".
|
||||
* @param {number} state One of STATE_COVER | STATE_SPINNER | STATE_PLAYER.
|
||||
*/
|
||||
function showState(state) {
|
||||
switch(state) {
|
||||
case STATE_COVER:
|
||||
$('#cover').show();
|
||||
$('#spinner').hide();
|
||||
$('#ytplayer').hide();
|
||||
break;
|
||||
case STATE_SPINNER:
|
||||
$('#cover').hide();
|
||||
$('#spinner').show();
|
||||
$('#ytplayer').hide();
|
||||
break;
|
||||
case STATE_PLAYER:
|
||||
default:
|
||||
$('#cover').hide();
|
||||
$('#spinner').hide();
|
||||
$('#ytplayer').show();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Hide error field and clear its message.
|
||||
*/
|
||||
function hideError() {
|
||||
$('#error_msg').css("visibility", "hidden").text('');
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the error to visible and set its message.
|
||||
* @param {string} msg Error message as a string.
|
||||
*/
|
||||
function showError(msg) {
|
||||
$('#error_msg').css("visibility", "visible").text(msg);
|
||||
}
|
||||
|
||||
/**
|
||||
* Privides numeric feedback for the slider.
|
||||
*/
|
||||
function connectSlider() {
|
||||
$('#threshold_label').text(
|
||||
`Score Threshold (${$('#threshold')[0].value})`);
|
||||
$('#threshold').on('input', () => {
|
||||
$('#threshold_label').text(
|
||||
`Score Threshold (${$('#threshold')[0].value})`);
|
||||
});
|
||||
$('#segments_label').text(
|
||||
`Segment Size (${$('#segments')[0].value})`);
|
||||
$('#segments').on('input', () => {
|
||||
$('#segments_label').text(
|
||||
`Segment Size (${$('#segments')[0].value})`);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieve video information from backend.
|
||||
* @param {string} filePath name of a tfrecord file.
|
||||
* @param {number} segments desired number of segments (1-300)
|
||||
*/
|
||||
function fetchVideo(filePath, segments) {
|
||||
const url = "/video?file=" + filePath + "&segments=" + segments;
|
||||
$.ajax({
|
||||
url: url,
|
||||
success: function(result) {
|
||||
const videoId = result["video_id"];
|
||||
player.loadVideoById(videoId);
|
||||
entries = result['entries'];
|
||||
showState(STATE_PLAYER);
|
||||
},
|
||||
error: (err) => {
|
||||
showState(STATE_COVER);
|
||||
console.log(err);
|
||||
showError(err.responseText);
|
||||
},
|
||||
datatype: "json"
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Called when the embedded YouTube player has finished loading. It loads the
|
||||
* requested video into the player and calls the golden6_viewer API to retrieve
|
||||
* the frame-level data for that video.
|
||||
*/
|
||||
function onPlayerReady() {
|
||||
const filePath = getUrlParameter('file') || "";
|
||||
const segments = parseInt(getUrlParameter('segments')) || 0;
|
||||
|
||||
updateFormFromURL();
|
||||
hideError();
|
||||
connectSlider();
|
||||
|
||||
if (!filePath) {
|
||||
return;
|
||||
}
|
||||
|
||||
showState(STATE_SPINNER);
|
||||
fetchVideo(filePath, segments);
|
||||
}
|
|
@ -688,6 +688,12 @@ cc_library(
|
|||
cc_library(
|
||||
name = "demangle",
|
||||
hdrs = ["demangle.h"],
|
||||
defines = select({
|
||||
"//mediapipe/framework/profiler:android_release": [
|
||||
"MEDIAPIPE_HAS_CXA_DEMANGLE=0",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
|
@ -1713,3 +1719,10 @@ cc_test(
|
|||
"//mediapipe/framework/tool/testdata:dub_quad_test_subgraph",
|
||||
],
|
||||
)
|
||||
|
||||
# Expose the proto source files for building mediapipe AAR.
|
||||
filegroup(
|
||||
name = "protos_src",
|
||||
srcs = glob(["*.proto"]),
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
|
|
@ -756,7 +756,7 @@ TEST(CalculatorGraphBoundsTest, BoundWithoutInputPackets) {
|
|||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
// Shows that when fixed-size-input-stream-hanlder drops packets,
|
||||
// Shows that when fixed-size-input-stream-handler drops packets,
|
||||
// no timetamp bounds are announced.
|
||||
TEST(CalculatorGraphBoundsTest, FixedSizeHandlerBounds) {
|
||||
// LambdaCalculator with FixedSizeInputStreamHandler will drop packets
|
||||
|
@ -876,5 +876,93 @@ TEST(CalculatorGraphBoundsTest, FixedSizeHandlerBounds) {
|
|||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
// A Calculator that outputs only the last packet from its input stream.
|
||||
class LastPacketCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).SetAny();
|
||||
cc->Outputs().Index(0).SetAny();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
::mediapipe::Status Open(CalculatorContext* cc) final {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
::mediapipe::Status Process(CalculatorContext* cc) final {
|
||||
cc->Outputs().Index(0).SetNextTimestampBound(cc->InputTimestamp());
|
||||
last_packet_ = cc->Inputs().Index(0).Value();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
::mediapipe::Status Close(CalculatorContext* cc) final {
|
||||
cc->Outputs().Index(0).AddPacket(last_packet_);
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
Packet last_packet_;
|
||||
};
|
||||
REGISTER_CALCULATOR(LastPacketCalculator);
|
||||
|
||||
// Shows that the last packet in an input stream can be detected.
|
||||
TEST(CalculatorGraphBoundsTest, LastPacketCheck) {
|
||||
// LastPacketCalculator emits only the last input stream packet.
|
||||
// It emits a timestamp bound after the arrival of a successor input stream
|
||||
// packet or input stream close. The output "last_output" shows the
|
||||
// last packet, and "output" shows the timestamp bounds.
|
||||
CalculatorGraphConfig config =
|
||||
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: 'input'
|
||||
output_stream: 'output'
|
||||
output_stream: 'last_output'
|
||||
node {
|
||||
calculator: 'PassThroughCalculator'
|
||||
input_stream: 'input'
|
||||
output_stream: 'input_2'
|
||||
}
|
||||
node {
|
||||
calculator: 'LastPacketCalculator'
|
||||
input_stream: 'input_2'
|
||||
output_stream: 'last_packet'
|
||||
}
|
||||
node {
|
||||
calculator: 'PassThroughCalculator'
|
||||
input_stream: 'input'
|
||||
input_stream: 'last_packet'
|
||||
output_stream: 'output'
|
||||
output_stream: 'last_output'
|
||||
}
|
||||
)");
|
||||
CalculatorGraph graph;
|
||||
std::vector<Packet> output_packets;
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
MP_ASSERT_OK(graph.ObserveOutputStream("output", [&](const Packet& p) {
|
||||
output_packets.push_back(p);
|
||||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
std::vector<Packet> last_output_packets;
|
||||
MP_ASSERT_OK(graph.ObserveOutputStream("last_output", [&](const Packet& p) {
|
||||
last_output_packets.push_back(p);
|
||||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Add four packets into the graph.
|
||||
constexpr int kNumInputs = 4;
|
||||
for (int i = 0; i < kNumInputs; ++i) {
|
||||
Packet p = MakePacket<int>(33).At(Timestamp(i));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream("input", p));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_EQ(i, output_packets.size());
|
||||
EXPECT_EQ(0, last_output_packets.size());
|
||||
}
|
||||
|
||||
// Shutdown the graph.
|
||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_EQ(kNumInputs, output_packets.size());
|
||||
EXPECT_EQ(1, last_output_packets.size());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -743,5 +743,66 @@ TEST(CalculatorGraph, GetOutputSidePacket) {
|
|||
}
|
||||
}
|
||||
|
||||
typedef std::string HugeModel;
|
||||
|
||||
// Generates an output-side-packet once for each calculator-graph.
|
||||
class OutputSidePacketCachedCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||
cc->OutputSidePackets().Index(0).Set<HugeModel>();
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) final {
|
||||
cc->OutputSidePackets().Index(0).Set(MakePacket<HugeModel>(
|
||||
R"(An expensive side-packet created only once per graph)"));
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status Process(CalculatorContext* cc) final {
|
||||
LOG(FATAL) << "Not reached.";
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
};
|
||||
REGISTER_CALCULATOR(OutputSidePacketCachedCalculator);
|
||||
|
||||
// Returns true if two packets hold the same data.
|
||||
bool Equals(Packet p1, Packet p2) {
|
||||
return packet_internal::GetHolder(p1) == packet_internal::GetHolder(p2);
|
||||
}
|
||||
|
||||
TEST(CalculatorGraph, OutputSidePacketCached) {
|
||||
CalculatorGraphConfig config =
|
||||
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
node {
|
||||
calculator: "OutputSidePacketCachedCalculator"
|
||||
output_side_packet: "model"
|
||||
}
|
||||
node {
|
||||
calculator: "SidePacketToStreamPacketCalculator"
|
||||
input_side_packet: "model"
|
||||
output_stream: "output"
|
||||
}
|
||||
)");
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
std::vector<Packet> output_packets;
|
||||
MP_ASSERT_OK(graph.ObserveOutputStream(
|
||||
"output", [&output_packets](const Packet& packet) {
|
||||
output_packets.push_back(packet);
|
||||
return ::mediapipe::OkStatus();
|
||||
}));
|
||||
|
||||
// Run the graph three times.
|
||||
for (int run = 0; run < 3; ++run) {
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
ASSERT_EQ(3, output_packets.size());
|
||||
for (int run = 0; run < output_packets.size(); ++run) {
|
||||
EXPECT_TRUE(Equals(output_packets[0], output_packets[run]));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -391,6 +391,38 @@ void CalculatorNode::SetMaxInputStreamQueueSize(int max_queue_size) {
|
|||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Returns the Packet sent to an OutputSidePacket, or an empty packet
|
||||
// if none available.
|
||||
const Packet GetPacket(const OutputSidePacket& out) {
|
||||
auto impl = dynamic_cast<const OutputSidePacketImpl*>(&out);
|
||||
return (impl == nullptr) ? Packet() : impl->GetPacket();
|
||||
}
|
||||
|
||||
// Resends the output-side-packets from the previous graph run.
|
||||
::mediapipe::Status ResendSidePackets(CalculatorContext* cc) {
|
||||
auto& outs = cc->OutputSidePackets();
|
||||
for (CollectionItemId id = outs.BeginId(); id < outs.EndId(); ++id) {
|
||||
Packet packet = GetPacket(outs.Get(id));
|
||||
if (!packet.IsEmpty()) {
|
||||
// OutputSidePacket::Set re-announces the side-packet to its mirrors.
|
||||
outs.Get(id).Set(packet);
|
||||
}
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool CalculatorNode::OutputsAreConstant(CalculatorContext* cc) {
|
||||
if (cc->Inputs().NumEntries() > 0 || cc->Outputs().NumEntries() > 0) {
|
||||
return false;
|
||||
}
|
||||
if (input_side_packet_handler_.InputSidePacketsChanged()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
::mediapipe::Status CalculatorNode::OpenNode() {
|
||||
VLOG(2) << "CalculatorNode::OpenNode() for " << DebugName();
|
||||
|
||||
|
@ -407,8 +439,9 @@ void CalculatorNode::SetMaxInputStreamQueueSize(int max_queue_size) {
|
|||
default_context, Timestamp::Unstarted());
|
||||
|
||||
::mediapipe::Status result;
|
||||
|
||||
{
|
||||
if (OutputsAreConstant(default_context)) {
|
||||
result = ResendSidePackets(default_context);
|
||||
} else {
|
||||
MEDIAPIPE_PROFILING(OPEN, default_context);
|
||||
LegacyCalculatorSupport::Scoped<CalculatorContext> s(default_context);
|
||||
result = calculator_->Open(default_context);
|
||||
|
@ -494,7 +527,10 @@ void CalculatorNode::CloseOutputStreams(OutputStreamShardSet* outputs) {
|
|||
|
||||
::mediapipe::Status result;
|
||||
|
||||
{
|
||||
if (OutputsAreConstant(default_context)) {
|
||||
// Do nothing.
|
||||
result = ::mediapipe::OkStatus();
|
||||
} else {
|
||||
MEDIAPIPE_PROFILING(CLOSE, default_context);
|
||||
LegacyCalculatorSupport::Scoped<CalculatorContext> s(default_context);
|
||||
result = calculator_->Close(default_context);
|
||||
|
@ -770,7 +806,10 @@ std::string CalculatorNode::DebugName() const {
|
|||
|
||||
VLOG(2) << "Calling Calculator::Process() for node: " << DebugName();
|
||||
|
||||
{
|
||||
if (OutputsAreConstant(calculator_context)) {
|
||||
// Do nothing.
|
||||
result = ::mediapipe::OkStatus();
|
||||
} else {
|
||||
MEDIAPIPE_PROFILING(PROCESS, calculator_context);
|
||||
LegacyCalculatorSupport::Scoped<CalculatorContext> s(
|
||||
calculator_context);
|
||||
|
|
|
@ -280,6 +280,9 @@ class CalculatorNode {
|
|||
// Get a std::string describing the input streams.
|
||||
std::string DebugInputStreamNames() const;
|
||||
|
||||
// Returns true if all outputs will be identical to the previous graph run.
|
||||
bool OutputsAreConstant(CalculatorContext* cc);
|
||||
|
||||
// The calculator.
|
||||
std::unique_ptr<CalculatorBase> calculator_;
|
||||
// Keeps data which a Calculator subclass needs access to.
|
||||
|
|
|
@ -240,6 +240,22 @@ class Collection {
|
|||
return tag_map_->EndId(tag);
|
||||
}
|
||||
|
||||
// Equal Collections contain equal mappings and equal elements.
|
||||
bool operator==(const Collection<T>& other) const {
|
||||
if (tag_map_->Mapping() != other.TagMap()->Mapping()) {
|
||||
return false;
|
||||
}
|
||||
for (CollectionItemId id = BeginId(); id < EndId(); ++id) {
|
||||
if (Get(id) != other.Get(id)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool operator!=(const Collection<T>& other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
private:
|
||||
// An iterator which is identical to ItType** except that the
|
||||
// dereference operator (operator*) does a double dereference and
|
||||
|
|
|
@ -15,23 +15,25 @@
|
|||
#ifndef MEDIAPIPE_FRAMEWORK_DEMANGLE_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_DEMANGLE_H_
|
||||
|
||||
#ifndef MEDIAPIPE_HAS_CXA_DEMANGLE
|
||||
// We only support some compilers that support __cxa_demangle.
|
||||
// TODO: Checks if Android NDK has fixed this issue or not.
|
||||
#if defined(__ANDROID__) && (defined(__i386__) || defined(__x86_64__))
|
||||
#define HAS_CXA_DEMANGLE 0
|
||||
#define MEDIAPIPE_HAS_CXA_DEMANGLE 0
|
||||
#elif (__GNUC__ >= 4 || (__GNUC__ >= 3 && __GNUC_MINOR__ >= 4)) && \
|
||||
!defined(__mips__)
|
||||
#define HAS_CXA_DEMANGLE 1
|
||||
#define MEDIAPIPE_HAS_CXA_DEMANGLE 1
|
||||
#elif defined(__clang__) && !defined(_MSC_VER)
|
||||
#define HAS_CXA_DEMANGLE 1
|
||||
#define MEDIAPIPE_HAS_CXA_DEMANGLE 1
|
||||
#else
|
||||
#define HAS_CXA_DEMANGLE 0
|
||||
#define MEDIAPIPE_HAS_CXA_DEMANGLE 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <string>
|
||||
#if HAS_CXA_DEMANGLE
|
||||
#if MEDIAPIPE_HAS_CXA_DEMANGLE
|
||||
#include <cxxabi.h>
|
||||
#endif
|
||||
|
||||
|
@ -65,7 +67,7 @@ namespace mediapipe {
|
|||
inline std::string Demangle(const char* mangled) {
|
||||
int status = 0;
|
||||
char* demangled = nullptr;
|
||||
#if HAS_CXA_DEMANGLE
|
||||
#if MEDIAPIPE_HAS_CXA_DEMANGLE
|
||||
demangled = abi::__cxa_demangle(mangled, nullptr, nullptr, &status);
|
||||
#endif
|
||||
std::string out;
|
||||
|
|
|
@ -15,10 +15,9 @@
|
|||
# Description:
|
||||
# The dependencies of mediapipe.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_py_proto_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
|
|
|
@ -66,5 +66,9 @@ message ImageFormat {
|
|||
// LAB, interleaved: one byte for L, then one byte for a, then one
|
||||
// byte for b for each pixel.
|
||||
LAB8 = 10;
|
||||
|
||||
// sBGRA, interleaved: one byte for B, one byte for G, one byte for R,
|
||||
// one byte for alpha or unused. This is the N32 format for Skia.
|
||||
SBGRA = 11;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -279,6 +279,8 @@ int ImageFrame::NumberOfChannelsForFormat(ImageFormat::Format format) {
|
|||
return 1;
|
||||
case ImageFormat::LAB8:
|
||||
return 3;
|
||||
case ImageFormat::SBGRA:
|
||||
return 4;
|
||||
default:
|
||||
LOG(FATAL) << InvalidFormatString(format);
|
||||
}
|
||||
|
@ -304,6 +306,8 @@ int ImageFrame::ChannelSizeForFormat(ImageFormat::Format format) {
|
|||
return sizeof(float);
|
||||
case ImageFormat::LAB8:
|
||||
return sizeof(uint8);
|
||||
case ImageFormat::SBGRA:
|
||||
return sizeof(uint8);
|
||||
default:
|
||||
LOG(FATAL) << InvalidFormatString(format);
|
||||
}
|
||||
|
@ -329,6 +333,8 @@ int ImageFrame::ByteDepthForFormat(ImageFormat::Format format) {
|
|||
return 4;
|
||||
case ImageFormat::LAB8:
|
||||
return 1;
|
||||
case ImageFormat::SBGRA:
|
||||
return 1;
|
||||
default:
|
||||
LOG(FATAL) << InvalidFormatString(format);
|
||||
}
|
||||
|
|
|
@ -59,6 +59,9 @@ int GetMatType(const mediapipe::ImageFormat::Format format) {
|
|||
case mediapipe::ImageFormat::LAB8:
|
||||
type = CV_8U;
|
||||
break;
|
||||
case mediapipe::ImageFormat::SBGRA:
|
||||
type = CV_8U;
|
||||
break;
|
||||
default:
|
||||
// Invalid or unknown; Default to uchar.
|
||||
type = CV_8U;
|
||||
|
|
|
@ -32,3 +32,8 @@ message NormalizedLandmark {
|
|||
optional float y = 2;
|
||||
optional float z = 3;
|
||||
}
|
||||
|
||||
// Group of NormalizedLandmark protos.
|
||||
message NormalizedLandmarkList {
|
||||
repeated NormalizedLandmark landmark = 1;
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ namespace mediapipe {
|
|||
std::function<void()> input_side_packets_ready_callback,
|
||||
std::function<void(::mediapipe::Status)> error_callback) {
|
||||
int missing_input_side_packet_count;
|
||||
prev_input_side_packets_ = std::move(input_side_packets_);
|
||||
ASSIGN_OR_RETURN(
|
||||
input_side_packets_,
|
||||
tool::FillPacketSet(*input_side_packet_types, all_side_packets,
|
||||
|
@ -41,6 +42,12 @@ namespace mediapipe {
|
|||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
bool InputSidePacketHandler::InputSidePacketsChanged() {
|
||||
return prev_input_side_packets_ == nullptr ||
|
||||
input_side_packets_ == nullptr ||
|
||||
*input_side_packets_ != *prev_input_side_packets_;
|
||||
}
|
||||
|
||||
void InputSidePacketHandler::Set(CollectionItemId id, const Packet& packet) {
|
||||
::mediapipe::Status status = SetInternal(id, packet);
|
||||
if (!status.ok()) {
|
||||
|
|
|
@ -52,6 +52,10 @@ class InputSidePacketHandler {
|
|||
|
||||
const PacketSet& InputSidePackets() const { return *input_side_packets_; }
|
||||
|
||||
// Returns true if the set of input-side-packets has changed since the
|
||||
// previous run.
|
||||
bool InputSidePacketsChanged();
|
||||
|
||||
// Returns the number of missing input side packets.
|
||||
int MissingInputSidePacketCount() const {
|
||||
return missing_input_side_packet_count_.load(std::memory_order_relaxed);
|
||||
|
@ -68,6 +72,7 @@ class InputSidePacketHandler {
|
|||
const PacketTypeSet* input_side_packet_types_;
|
||||
|
||||
std::unique_ptr<PacketSet> input_side_packets_;
|
||||
std::unique_ptr<PacketSet> prev_input_side_packets_;
|
||||
|
||||
std::atomic<int> missing_input_side_packet_count_{0};
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace mediapipe {
|
|||
void OutputSidePacketImpl::PrepareForRun(
|
||||
std::function<void(::mediapipe::Status)> error_callback) {
|
||||
error_callback_ = std::move(error_callback);
|
||||
packet_ = Packet();
|
||||
initialized_ = false;
|
||||
}
|
||||
|
||||
void OutputSidePacketImpl::Set(const Packet& packet) {
|
||||
|
@ -47,7 +47,7 @@ void OutputSidePacketImpl::AddMirror(
|
|||
}
|
||||
|
||||
::mediapipe::Status OutputSidePacketImpl::SetInternal(const Packet& packet) {
|
||||
if (!packet_.IsEmpty()) {
|
||||
if (initialized_) {
|
||||
return ::mediapipe::AlreadyExistsErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "Output side packet \"" << name_ << "\" was already set.";
|
||||
}
|
||||
|
@ -72,6 +72,7 @@ void OutputSidePacketImpl::AddMirror(
|
|||
}
|
||||
|
||||
packet_ = packet;
|
||||
initialized_ = true;
|
||||
for (const auto& mirror : mirrors_) {
|
||||
mirror.input_side_packet_handler->Set(mirror.id, packet_);
|
||||
}
|
||||
|
|
|
@ -80,6 +80,7 @@ class OutputSidePacketImpl : public OutputSidePacket {
|
|||
const PacketType* packet_type_;
|
||||
std::function<void(::mediapipe::Status)> error_callback_;
|
||||
Packet packet_;
|
||||
bool initialized_ = false;
|
||||
|
||||
std::vector<Mirror> mirrors_;
|
||||
};
|
||||
|
|
|
@ -653,6 +653,14 @@ Packet PointToForeign(const T* ptr) {
|
|||
return packet_internal::Create(new packet_internal::ForeignHolder<T>(ptr));
|
||||
}
|
||||
|
||||
// Equal Packets refer to the same memory contents, like equal pointers.
|
||||
inline bool operator==(const Packet& p1, const Packet& p2) {
|
||||
return packet_internal::GetHolder(p1) == packet_internal::GetHolder(p2);
|
||||
}
|
||||
inline bool operator!=(const Packet& p1, const Packet& p2) {
|
||||
return !(p1 == p2);
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_PACKET_H_
|
||||
|
|
|
@ -28,4 +28,22 @@
|
|||
#define MEDIAPIPE_MOBILE
|
||||
#endif
|
||||
|
||||
#if !defined(MEDIAPIPE_ANDROID) && defined(__ANDROID__)
|
||||
#define MEDIAPIPE_ANDROID
|
||||
#endif
|
||||
|
||||
#if defined(__APPLE__)
|
||||
#include "TargetConditionals.h" // for TARGET_OS_*
|
||||
#if !defined(MEDIAPIPE_IOS) && !TARGET_OS_OSX
|
||||
#define MEDIAPIPE_IOS
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// These platforms do not support OpenGL ES Compute Shaders (v3.1 and up),
|
||||
// but can still run OpenGL ES 3.0 and below.
|
||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) && \
|
||||
(defined(__APPLE__) || defined(__EMSCRIPTEN__))
|
||||
#define MEDIAPIPE_DISABLE_GL_COMPUTE
|
||||
#endif
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_PORT_H_
|
||||
|
|
|
@ -247,25 +247,45 @@ TEST_F(GraphProfilerTestPeer, InitializeConfig) {
|
|||
// Checks histogram_interval_size_usec and num_histogram_intervals.
|
||||
CalculatorProfile actual =
|
||||
GetCalculatorProfilesMap()->find(kDummyTestCalculatorName)->second;
|
||||
ASSERT_EQ(actual.name(), kDummyTestCalculatorName);
|
||||
ASSERT_FALSE(actual.has_open_runtime());
|
||||
ASSERT_FALSE(actual.has_close_runtime());
|
||||
|
||||
ASSERT_EQ(actual.process_runtime().interval_size_usec(), 1000);
|
||||
ASSERT_EQ(actual.process_runtime().num_intervals(), 3);
|
||||
|
||||
ASSERT_EQ(actual.process_input_latency().interval_size_usec(), 1000);
|
||||
ASSERT_EQ(actual.process_input_latency().num_intervals(), 3);
|
||||
|
||||
ASSERT_EQ(actual.process_output_latency().interval_size_usec(), 1000);
|
||||
ASSERT_EQ(actual.process_output_latency().num_intervals(), 3);
|
||||
|
||||
ASSERT_EQ(actual.input_stream_profiles().size(), 1);
|
||||
ASSERT_EQ(actual.input_stream_profiles(0).name(), "input_stream");
|
||||
ASSERT_FALSE(actual.input_stream_profiles(0).back_edge());
|
||||
ASSERT_EQ(actual.input_stream_profiles(0).latency().interval_size_usec(),
|
||||
1000);
|
||||
ASSERT_EQ(actual.input_stream_profiles(0).latency().num_intervals(), 3);
|
||||
EXPECT_THAT(actual, EqualsProto(R"(
|
||||
name: "DummyTestCalculator"
|
||||
process_runtime {
|
||||
total: 0
|
||||
interval_size_usec: 1000
|
||||
num_intervals: 3
|
||||
count: 0
|
||||
count: 0
|
||||
count: 0
|
||||
}
|
||||
process_input_latency {
|
||||
total: 0
|
||||
interval_size_usec: 1000
|
||||
num_intervals: 3
|
||||
count: 0
|
||||
count: 0
|
||||
count: 0
|
||||
}
|
||||
process_output_latency {
|
||||
total: 0
|
||||
interval_size_usec: 1000
|
||||
num_intervals: 3
|
||||
count: 0
|
||||
count: 0
|
||||
count: 0
|
||||
}
|
||||
input_stream_profiles {
|
||||
name: "input_stream"
|
||||
back_edge: false
|
||||
latency {
|
||||
total: 0
|
||||
interval_size_usec: 1000
|
||||
num_intervals: 3
|
||||
count: 0
|
||||
count: 0
|
||||
count: 0
|
||||
}
|
||||
}
|
||||
)"));
|
||||
}
|
||||
|
||||
// Tests that Initialize() uses the ProfilerConfig in the graph definition.
|
||||
|
@ -291,16 +311,17 @@ TEST_F(GraphProfilerTestPeer, InitializeConfigWithoutStreamLatency) {
|
|||
// Checks histogram_interval_size_usec and num_histogram_intervals.
|
||||
CalculatorProfile actual =
|
||||
GetCalculatorProfilesMap()->find(kDummyTestCalculatorName)->second;
|
||||
ASSERT_EQ(actual.name(), kDummyTestCalculatorName);
|
||||
ASSERT_FALSE(actual.has_open_runtime());
|
||||
ASSERT_FALSE(actual.has_close_runtime());
|
||||
|
||||
ASSERT_EQ(actual.process_runtime().interval_size_usec(), 1000);
|
||||
ASSERT_EQ(actual.process_runtime().num_intervals(), 3);
|
||||
|
||||
ASSERT_FALSE(actual.has_process_input_latency());
|
||||
ASSERT_FALSE(actual.has_process_output_latency());
|
||||
ASSERT_EQ(actual.input_stream_profiles().size(), 0);
|
||||
EXPECT_THAT(actual, EqualsProto(R"(
|
||||
name: "DummyTestCalculator"
|
||||
process_runtime {
|
||||
total: 0
|
||||
interval_size_usec: 1000
|
||||
num_intervals: 3
|
||||
count: 0
|
||||
count: 0
|
||||
count: 0
|
||||
}
|
||||
)"));
|
||||
}
|
||||
|
||||
// Tests that Initialize() reads all the configs defined in the graph
|
||||
|
@ -633,10 +654,11 @@ TEST_F(GraphProfilerTestPeer, SetOpenRuntime) {
|
|||
simulation_clock->ThreadFinish();
|
||||
|
||||
ASSERT_EQ(profiles.size(), 1);
|
||||
ASSERT_EQ(profiles[0].open_runtime(), 100);
|
||||
ASSERT_FALSE(profiles[0].has_close_runtime());
|
||||
ASSERT_THAT(profiles[0].process_runtime(),
|
||||
Partially(EqualsProto(CreateTimeHistogram(/*total=*/0, {0}))));
|
||||
EXPECT_THAT(profiles[0], Partially(EqualsProto(R"(
|
||||
name: "DummyTestCalculator"
|
||||
open_runtime: 100
|
||||
process_runtime { total: 0 }
|
||||
)")));
|
||||
// Checks packets_info_ map hasn't changed.
|
||||
ASSERT_EQ(GetPacketsInfoMap()->size(), 0);
|
||||
}
|
||||
|
@ -688,14 +710,29 @@ TEST_F(GraphProfilerTestPeer, SetOpenRuntimeWithStreamLatency) {
|
|||
ASSERT_EQ(profiles.size(), 2);
|
||||
CalculatorProfile source_profile =
|
||||
GetProfileWithName(profiles, "source_calc");
|
||||
ASSERT_EQ(source_profile.open_runtime(), 150);
|
||||
ASSERT_FALSE(source_profile.has_close_runtime());
|
||||
ASSERT_THAT(source_profile.process_runtime(),
|
||||
Partially(EqualsProto(CreateTimeHistogram(/*total=*/0, {0}))));
|
||||
ASSERT_THAT(source_profile.process_input_latency(),
|
||||
Partially(EqualsProto(CreateTimeHistogram(/*total=*/0, {0}))));
|
||||
ASSERT_THAT(source_profile.process_output_latency(),
|
||||
Partially(EqualsProto(CreateTimeHistogram(/*total=*/0, {0}))));
|
||||
|
||||
EXPECT_THAT(source_profile, EqualsProto(R"(
|
||||
name: "source_calc"
|
||||
open_runtime: 150
|
||||
process_runtime {
|
||||
total: 0
|
||||
interval_size_usec: 1000000
|
||||
num_intervals: 1
|
||||
count: 0
|
||||
}
|
||||
process_input_latency {
|
||||
total: 0
|
||||
interval_size_usec: 1000000
|
||||
num_intervals: 1
|
||||
count: 0
|
||||
}
|
||||
process_output_latency {
|
||||
total: 0
|
||||
interval_size_usec: 1000000
|
||||
num_intervals: 1
|
||||
count: 0
|
||||
}
|
||||
)"));
|
||||
|
||||
// Check packets_info_ map has been updated.
|
||||
ASSERT_EQ(GetPacketsInfoMap()->size(), 1);
|
||||
|
@ -736,11 +773,16 @@ TEST_F(GraphProfilerTestPeer, SetCloseRuntime) {
|
|||
std::vector<CalculatorProfile> profiles = Profiles();
|
||||
simulation_clock->ThreadFinish();
|
||||
|
||||
ASSERT_EQ(profiles.size(), 1);
|
||||
ASSERT_FALSE(profiles[0].open_runtime());
|
||||
ASSERT_EQ(profiles[0].close_runtime(), 100);
|
||||
ASSERT_THAT(profiles[0].process_runtime(),
|
||||
Partially(EqualsProto(CreateTimeHistogram(/*total=*/0, {0}))));
|
||||
EXPECT_THAT(profiles[0], EqualsProto(R"(
|
||||
name: "DummyTestCalculator"
|
||||
close_runtime: 100
|
||||
process_runtime {
|
||||
total: 0
|
||||
interval_size_usec: 1000000
|
||||
num_intervals: 1
|
||||
count: 0
|
||||
}
|
||||
)"));
|
||||
}
|
||||
|
||||
// Tests that SetCloseRuntime() updates |close_runtime| and doesn't affect other
|
||||
|
@ -789,11 +831,39 @@ TEST_F(GraphProfilerTestPeer, SetCloseRuntimeWithStreamLatency) {
|
|||
ASSERT_EQ(profiles.size(), 2);
|
||||
CalculatorProfile source_profile =
|
||||
GetProfileWithName(profiles, "source_calc");
|
||||
ASSERT_FALSE(source_profile.open_runtime());
|
||||
ASSERT_EQ(source_profile.close_runtime(), 100);
|
||||
ASSERT_THAT(source_profile.process_runtime(),
|
||||
Partially(EqualsProto(CreateTimeHistogram(/*total=*/0, {0}))));
|
||||
ASSERT_EQ(GetPacketsInfoMap()->size(), 1);
|
||||
|
||||
EXPECT_THAT(source_profile, EqualsProto(R"(
|
||||
name: "source_calc"
|
||||
close_runtime: 100
|
||||
process_runtime {
|
||||
total: 0
|
||||
interval_size_usec: 1000000
|
||||
num_intervals: 1
|
||||
count: 0
|
||||
}
|
||||
process_input_latency {
|
||||
total: 0
|
||||
interval_size_usec: 1000000
|
||||
num_intervals: 1
|
||||
count: 0
|
||||
}
|
||||
process_output_latency {
|
||||
total: 0
|
||||
interval_size_usec: 1000000
|
||||
num_intervals: 1
|
||||
count: 0
|
||||
}
|
||||
input_stream_profiles {
|
||||
name: "input_stream"
|
||||
back_edge: false
|
||||
latency {
|
||||
total: 0
|
||||
interval_size_usec: 1000000
|
||||
num_intervals: 1
|
||||
count: 0
|
||||
}
|
||||
}
|
||||
)"));
|
||||
PacketInfo expected_packet_info = {0,
|
||||
/*production_time_usec=*/1000 + 100,
|
||||
/*source_process_start_usec=*/1000 + 0};
|
||||
|
@ -933,10 +1003,15 @@ TEST_F(GraphProfilerTestPeer, AddProcessSample) {
|
|||
simulation_clock->ThreadFinish();
|
||||
|
||||
ASSERT_EQ(profiles.size(), 1);
|
||||
ASSERT_THAT(profiles[0].process_runtime(),
|
||||
Partially(EqualsProto(CreateTimeHistogram(/*total=*/150, {1}))));
|
||||
ASSERT_FALSE(profiles[0].has_open_runtime());
|
||||
ASSERT_FALSE(profiles[0].has_close_runtime());
|
||||
EXPECT_THAT(profiles[0], EqualsProto(R"(
|
||||
name: "DummyTestCalculator"
|
||||
process_runtime {
|
||||
total: 150
|
||||
interval_size_usec: 1000000
|
||||
num_intervals: 1
|
||||
count: 1
|
||||
}
|
||||
)"));
|
||||
// Checks packets_info_ map hasn't changed.
|
||||
ASSERT_EQ(GetPacketsInfoMap()->size(), 0);
|
||||
}
|
||||
|
@ -985,12 +1060,27 @@ TEST_F(GraphProfilerTestPeer, AddProcessSampleWithStreamLatency) {
|
|||
ASSERT_EQ(profiles.size(), 2);
|
||||
CalculatorProfile source_profile =
|
||||
GetProfileWithName(profiles, "source_calc");
|
||||
ASSERT_THAT(source_profile.process_runtime(),
|
||||
Partially(EqualsProto(CreateTimeHistogram(/*total=*/150, {1}))));
|
||||
ASSERT_THAT(source_profile.process_input_latency(),
|
||||
Partially(EqualsProto(CreateTimeHistogram(/*total=*/0, {1}))));
|
||||
ASSERT_THAT(source_profile.process_output_latency(),
|
||||
Partially(EqualsProto(CreateTimeHistogram(/*total=*/150, {1}))));
|
||||
|
||||
EXPECT_THAT(profiles[0], Partially(EqualsProto(R"(
|
||||
process_runtime {
|
||||
total: 150
|
||||
interval_size_usec: 1000000
|
||||
num_intervals: 1
|
||||
count: 1
|
||||
}
|
||||
process_input_latency {
|
||||
total: 0
|
||||
interval_size_usec: 1000000
|
||||
num_intervals: 1
|
||||
count: 1
|
||||
}
|
||||
process_output_latency {
|
||||
total: 150
|
||||
interval_size_usec: 1000000
|
||||
num_intervals: 1
|
||||
count: 1
|
||||
}
|
||||
)")));
|
||||
|
||||
// Check packets_info_ map has been updated.
|
||||
ASSERT_EQ(GetPacketsInfoMap()->size(), 1);
|
||||
|
@ -1019,22 +1109,24 @@ TEST_F(GraphProfilerTestPeer, AddProcessSampleWithStreamLatency) {
|
|||
|
||||
CalculatorProfile consumer_profile =
|
||||
GetProfileWithName(profiles, "consumer_calc");
|
||||
ASSERT_THAT(consumer_profile.process_runtime(),
|
||||
Partially(EqualsProto(CreateTimeHistogram(/*total=*/250, {1}))));
|
||||
ASSERT_THAT(consumer_profile.process_input_latency(),
|
||||
Partially(EqualsProto(CreateTimeHistogram(
|
||||
/*total=*/2000 - when_source_started, {1}))));
|
||||
ASSERT_THAT(consumer_profile.process_output_latency(),
|
||||
Partially(EqualsProto(CreateTimeHistogram(
|
||||
/*total=*/2000 + 250 - when_source_started, {1}))));
|
||||
ASSERT_EQ(consumer_profile.input_stream_profiles().size(), 2);
|
||||
// For "stream_0" should have not changed since it was empty.
|
||||
ASSERT_THAT(consumer_profile.input_stream_profiles(0).latency(),
|
||||
Partially(EqualsProto(CreateTimeHistogram(/*total=*/0, {0}))));
|
||||
// For "stream_1"
|
||||
ASSERT_THAT(consumer_profile.input_stream_profiles(1).latency(),
|
||||
Partially(EqualsProto(CreateTimeHistogram(
|
||||
/*total=*/2000 - when_source_finished, {1}))));
|
||||
|
||||
// process input latency total = 2000 (end) - 1000 (when source started) =
|
||||
// 1000 process output latency total = 2000 (end) + 250 - 1000 (when source
|
||||
// started) = 1250 For "stream_0" should have not changed since it was empty.
|
||||
// For "stream_1" = 2000 (end) - 1250 (when source finished) = 850
|
||||
EXPECT_THAT(consumer_profile, Partially(EqualsProto(R"(
|
||||
name: "consumer_calc"
|
||||
process_input_latency { total: 1000 }
|
||||
process_output_latency { total: 1250 }
|
||||
input_stream_profiles {
|
||||
name: "stream_0"
|
||||
latency { total: 0 }
|
||||
}
|
||||
input_stream_profiles {
|
||||
name: "stream_1"
|
||||
latency { total: 850 }
|
||||
}
|
||||
)")));
|
||||
|
||||
// Check packets_info_ map for PacketId({"stream_1", 100}) should not yet be
|
||||
// garbage collected.
|
||||
|
|
|
@ -39,9 +39,20 @@ inline const void* GetPacketDataId(const HolderBase* holder) {
|
|||
struct TraceEvent {
|
||||
using EventType = GraphTrace::EventType;
|
||||
// GraphTrace::EventType constants, repeated here to match GraphProfilerStub.
|
||||
static const EventType UNKNOWN, OPEN, PROCESS, CLOSE, NOT_READY,
|
||||
READY_FOR_PROCESS, READY_FOR_CLOSE, THROTTLED, UNTHROTTLED, CPU_TASK_USER,
|
||||
CPU_TASK_SYSTEM, GPU_TASK, DSP_TASK, TPU_TASK;
|
||||
static constexpr EventType UNKNOWN = GraphTrace::UNKNOWN;
|
||||
static constexpr EventType OPEN = GraphTrace::OPEN;
|
||||
static constexpr EventType PROCESS = GraphTrace::PROCESS;
|
||||
static constexpr EventType CLOSE = GraphTrace::CLOSE;
|
||||
static constexpr EventType NOT_READY = GraphTrace::NOT_READY;
|
||||
static constexpr EventType READY_FOR_PROCESS = GraphTrace::READY_FOR_PROCESS;
|
||||
static constexpr EventType READY_FOR_CLOSE = GraphTrace::READY_FOR_CLOSE;
|
||||
static constexpr EventType THROTTLED = GraphTrace::THROTTLED;
|
||||
static constexpr EventType UNTHROTTLED = GraphTrace::UNTHROTTLED;
|
||||
static constexpr EventType CPU_TASK_USER = GraphTrace::CPU_TASK_USER;
|
||||
static constexpr EventType CPU_TASK_SYSTEM = GraphTrace::CPU_TASK_SYSTEM;
|
||||
static constexpr EventType GPU_TASK = GraphTrace::GPU_TASK;
|
||||
static constexpr EventType DSP_TASK = GraphTrace::DSP_TASK;
|
||||
static constexpr EventType TPU_TASK = GraphTrace::TPU_TASK;
|
||||
absl::Time event_time;
|
||||
EventType event_type = UNKNOWN;
|
||||
bool is_finish = false;
|
||||
|
|
|
@ -385,21 +385,21 @@ void TraceBuilder::CreateLog(const TraceBuffer& buffer, absl::Time begin_time,
|
|||
}
|
||||
void TraceBuilder::Clear() { impl_->Clear(); }
|
||||
|
||||
// Defined here since inline constants fail to link in android builds.
|
||||
const TraceEvent::EventType //
|
||||
TraceEvent::UNKNOWN = GraphTrace::UNKNOWN,
|
||||
TraceEvent::OPEN = GraphTrace::OPEN,
|
||||
TraceEvent::PROCESS = GraphTrace::PROCESS,
|
||||
TraceEvent::CLOSE = GraphTrace::CLOSE,
|
||||
TraceEvent::NOT_READY = GraphTrace::NOT_READY,
|
||||
TraceEvent::READY_FOR_PROCESS = GraphTrace::READY_FOR_PROCESS,
|
||||
TraceEvent::READY_FOR_CLOSE = GraphTrace::READY_FOR_CLOSE,
|
||||
TraceEvent::THROTTLED = GraphTrace::THROTTLED,
|
||||
TraceEvent::UNTHROTTLED = GraphTrace::UNTHROTTLED,
|
||||
TraceEvent::CPU_TASK_USER = GraphTrace::CPU_TASK_USER,
|
||||
TraceEvent::CPU_TASK_SYSTEM = GraphTrace::CPU_TASK_SYSTEM,
|
||||
TraceEvent::GPU_TASK = GraphTrace::GPU_TASK,
|
||||
TraceEvent::DSP_TASK = GraphTrace::DSP_TASK,
|
||||
TraceEvent::TPU_TASK = GraphTrace::TPU_TASK;
|
||||
// Defined here since constexpr requires out-of-class definition until C++17.
|
||||
const TraceEvent::EventType //
|
||||
TraceEvent::UNKNOWN, //
|
||||
TraceEvent::OPEN, //
|
||||
TraceEvent::PROCESS, //
|
||||
TraceEvent::CLOSE, //
|
||||
TraceEvent::NOT_READY, //
|
||||
TraceEvent::READY_FOR_PROCESS, //
|
||||
TraceEvent::READY_FOR_CLOSE, //
|
||||
TraceEvent::THROTTLED, //
|
||||
TraceEvent::UNTHROTTLED, //
|
||||
TraceEvent::CPU_TASK_USER, //
|
||||
TraceEvent::CPU_TASK_SYSTEM, //
|
||||
TraceEvent::GPU_TASK, //
|
||||
TraceEvent::DSP_TASK, //
|
||||
TraceEvent::TPU_TASK;
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -127,6 +127,11 @@ class TagMap {
|
|||
std::vector<std::string> names_;
|
||||
};
|
||||
|
||||
// Equal TagData structs define equal id ranges.
|
||||
inline bool operator==(const TagMap::TagData& d1, const TagMap::TagData& d2) {
|
||||
return d1.id == d2.id && d1.count == d2.count;
|
||||
}
|
||||
|
||||
} // namespace tool
|
||||
} // namespace mediapipe
|
||||
|
||||
|
|
|
@ -567,6 +567,10 @@ class TemplateExpanderImpl {
|
|||
result = AsDict(args);
|
||||
} else if (expr.op() == "list") {
|
||||
result = AsList(args);
|
||||
} else if (expr.op() == "size") {
|
||||
return AsArgument(static_cast<double>(
|
||||
args[0].has_dict() ? args[0].mutable_dict()->arg_size()
|
||||
: args[0].mutable_element()->size()));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -1318,8 +1318,8 @@ bool IsInfixOperator(const std::string& token) {
|
|||
// A function-style operator, including a for or if expression.
|
||||
bool IsFunctionOperator(const std::string& token) {
|
||||
static auto kTokens = new std::set<std::string>{
|
||||
"min", "max", "for", "if", "!",
|
||||
"concat", "lowercase", "uppercase", "dict", "list",
|
||||
"min", "max", "for", "if", "!", "concat",
|
||||
"lowercase", "uppercase", "size", "dict", "list",
|
||||
};
|
||||
return kTokens->count(token) > 0;
|
||||
}
|
||||
|
|
|
@ -101,6 +101,10 @@ static const GLfloat kBasicTextureVertices[] = {
|
|||
1.0f, 1.0f, // top right
|
||||
};
|
||||
|
||||
// Places a texture on kBasicSquareVertices, flipped horizontally.
|
||||
static const GLfloat kBasicTextureVerticesFlipX[] = {
|
||||
V4(kBasicTextureVertices, 1, 0, 3, 2)};
|
||||
|
||||
// Places a texture on kBasicSquareVertices, flipped vertically.
|
||||
static const GLfloat kBasicTextureVerticesFlipY[] = {
|
||||
V4(kBasicTextureVertices, 2, 3, 0, 1)};
|
||||
|
|
|
@ -44,3 +44,30 @@ cc_library(
|
|||
"//mediapipe/calculators/video:opencv_video_decoder_calculator",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "yt8m_inference_calculators_deps",
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:concatenate_vector_calculator",
|
||||
"//mediapipe/calculators/core:dequantize_byte_array_calculator",
|
||||
"//mediapipe/calculators/core:packet_cloner_calculator",
|
||||
"//mediapipe/calculators/core:side_packet_to_stream_calculator",
|
||||
"//mediapipe/calculators/core:string_to_int_calculator",
|
||||
"//mediapipe/calculators/tensorflow:lapped_tensor_buffer_calculator",
|
||||
"//mediapipe/calculators/tensorflow:string_to_sequence_example_calculator",
|
||||
"//mediapipe/calculators/tensorflow:tensor_to_vector_float_calculator",
|
||||
"//mediapipe/calculators/tensorflow:tensorflow_inference_calculator",
|
||||
"//mediapipe/calculators/tensorflow:tensorflow_session_from_saved_model_calculator",
|
||||
"//mediapipe/calculators/tensorflow:tfrecord_reader_calculator",
|
||||
"//mediapipe/calculators/tensorflow:unpack_media_sequence_calculator",
|
||||
"//mediapipe/calculators/tensorflow:unpack_yt8m_sequence_example_calculator",
|
||||
"//mediapipe/calculators/tensorflow:vector_float_to_tensor_calculator",
|
||||
"//mediapipe/calculators/tensorflow:vector_int_to_tensor_calculator",
|
||||
"//mediapipe/calculators/util:annotation_overlay_calculator",
|
||||
"//mediapipe/calculators/util:labels_to_render_data_calculator",
|
||||
"//mediapipe/calculators/util:local_file_contents_calculator",
|
||||
"//mediapipe/calculators/util:top_k_scores_calculator",
|
||||
"//mediapipe/calculators/video:opencv_video_decoder_calculator",
|
||||
"//mediapipe/calculators/video:opencv_video_encoder_calculator",
|
||||
],
|
||||
)
|
||||
|
|
3862
mediapipe/graphs/youtube8m/label_map.txt
Normal file
3862
mediapipe/graphs/youtube8m/label_map.txt
Normal file
File diff suppressed because it is too large
Load Diff
178
mediapipe/graphs/youtube8m/local_video_model_inference.pbtxt
Normal file
178
mediapipe/graphs/youtube8m/local_video_model_inference.pbtxt
Normal file
|
@ -0,0 +1,178 @@
|
|||
input_side_packet: "input_sequence_example_path"
|
||||
input_side_packet: "input_video_path"
|
||||
input_side_packet: "output_video_path"
|
||||
input_side_packet: "segment_size"
|
||||
input_side_packet: "overlap"
|
||||
|
||||
node {
|
||||
calculator: "LocalFileContentsCalculator"
|
||||
input_side_packet: "FILE_PATH:input_sequence_example_path"
|
||||
output_side_packet: "CONTENTS:input_sequence_example"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "StringToSequenceExampleCalculator"
|
||||
input_side_packet: "STRING:input_sequence_example"
|
||||
output_side_packet: "SEQUENCE_EXAMPLE:parsed_sequence_example"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "UnpackMediaSequenceCalculator"
|
||||
input_side_packet: "SEQUENCE_EXAMPLE:parsed_sequence_example"
|
||||
output_stream: "FLOAT_FEATURE_RGB:rgb_feature_vector"
|
||||
output_stream: "FLOAT_FEATURE_AUDIO:audio_feature_vector"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "ConcatenateFloatVectorCalculator"
|
||||
input_stream: "rgb_feature_vector"
|
||||
input_stream: "audio_feature_vector"
|
||||
output_stream: "feature_vector"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "VectorFloatToTensorCalculator"
|
||||
input_stream: "feature_vector"
|
||||
output_stream: "feature_tensor"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "StringToInt32Calculator"
|
||||
input_side_packet: "segment_size"
|
||||
output_side_packet: "segment_size_int"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "StringToInt32Calculator"
|
||||
input_side_packet: "overlap"
|
||||
output_side_packet: "overlap_int"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "LappedTensorBufferCalculator"
|
||||
input_stream: "feature_tensor"
|
||||
output_stream: "lapped_feature_tensor"
|
||||
input_side_packet: "BUFFER_SIZE:segment_size_int"
|
||||
input_side_packet: "OVERLAP:overlap_int"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.LappedTensorBufferCalculatorOptions] {
|
||||
add_batch_dim_to_tensors: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "SidePacketToStreamCalculator"
|
||||
input_side_packet: "segment_size_int"
|
||||
output_stream: "AT_ZERO:segment_size_int_stream"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "VectorIntToTensorCalculator"
|
||||
input_stream: "SINGLE_INT:segment_size_int_stream"
|
||||
output_stream: "TENSOR_OUT:segment_size_tensor"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "PacketClonerCalculator"
|
||||
input_stream: "segment_size_tensor"
|
||||
input_stream: "lapped_feature_tensor"
|
||||
output_stream: "synced_segment_size_tensor"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "TensorFlowSessionFromSavedModelCalculator"
|
||||
output_side_packet: "SESSION:session"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.TensorFlowSessionFromSavedModelCalculatorOptions]: {
|
||||
saved_model_path: "/tmp/mediapipe/saved_model"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node: {
|
||||
calculator: "TensorFlowInferenceCalculator"
|
||||
input_side_packet: "SESSION:session"
|
||||
input_stream: "NUM_FRAMES:synced_segment_size_tensor"
|
||||
input_stream: "RGB_AND_AUDIO:lapped_feature_tensor"
|
||||
output_stream: "PREDICTIONS:prediction_tensor"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.TensorFlowInferenceCalculatorOptions]: {
|
||||
batch_size: 32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "TensorToVectorFloatCalculator"
|
||||
input_stream: "prediction_tensor"
|
||||
output_stream: "prediction_vector"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "TopKScoresCalculator"
|
||||
input_stream: "SCORES:prediction_vector"
|
||||
output_stream: "TOP_K_INDEXES:top_k_indexes"
|
||||
output_stream: "TOP_K_SCORES:top_k_scores"
|
||||
output_stream: "TOP_K_LABELS:top_k_labels"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.TopKScoresCalculatorOptions]: {
|
||||
top_k: 3
|
||||
label_map_path: "mediapipe/graphs/youtube8m/label_map.txt"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "OpenCvVideoDecoderCalculator"
|
||||
input_side_packet: "INPUT_FILE_PATH:input_video_path"
|
||||
output_stream: "VIDEO:input_video"
|
||||
output_stream: "VIDEO_PRESTREAM:input_video_header"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "LabelsToRenderDataCalculator"
|
||||
input_stream: "LABELS:top_k_labels"
|
||||
input_stream: "SCORES:top_k_scores"
|
||||
input_stream: "VIDEO_PRESTREAM:input_video_header"
|
||||
output_stream: "RENDER_DATA:render_data"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.LabelsToRenderDataCalculatorOptions]: {
|
||||
color { r: 255 g: 0 b: 0 }
|
||||
color { r: 0 g: 255 b: 0 }
|
||||
color { r: 0 g: 0 b: 255 }
|
||||
thickness: 2.0
|
||||
font_height_px: 20
|
||||
max_num_labels: 3
|
||||
location: TOP_LEFT
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "PacketClonerCalculator"
|
||||
input_stream: "render_data"
|
||||
input_stream: "input_video"
|
||||
output_stream: "synchronized_render_data"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "AnnotationOverlayCalculator"
|
||||
input_stream: "INPUT_FRAME:input_video"
|
||||
input_stream: "synchronized_render_data"
|
||||
output_stream: "OUTPUT_FRAME:output_video"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "OpenCvVideoEncoderCalculator"
|
||||
input_stream: "VIDEO:output_video"
|
||||
input_stream: "VIDEO_PRESTREAM:input_video_header"
|
||||
input_side_packet: "OUTPUT_FILE_PATH:output_video_path"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.OpenCvVideoEncoderCalculatorOptions]: {
|
||||
codec: "avc1"
|
||||
video_format: "mp4"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
139
mediapipe/graphs/youtube8m/yt8m_dataset_model_inference.pbtxt
Normal file
139
mediapipe/graphs/youtube8m/yt8m_dataset_model_inference.pbtxt
Normal file
|
@ -0,0 +1,139 @@
|
|||
input_side_packet: "desired_segment_size"
|
||||
input_side_packet: "record_index"
|
||||
input_side_packet: "tfrecord_path"
|
||||
output_side_packet: "yt8m_id"
|
||||
output_stream: "annotation_summary"
|
||||
|
||||
node {
|
||||
calculator: "StringToInt32Calculator"
|
||||
input_side_packet: "record_index"
|
||||
output_side_packet: "record_index_int"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "StringToInt32Calculator"
|
||||
input_side_packet: "desired_segment_size"
|
||||
output_side_packet: "desired_segment_size_int"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "TFRecordReaderCalculator"
|
||||
input_side_packet: "TFRECORD_PATH:tfrecord_path"
|
||||
input_side_packet: "RECORD_INDEX:record_index_int"
|
||||
output_side_packet: "SEQUENCE_EXAMPLE:yt8m_sequence_example"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "UnpackYt8mSequenceExampleCalculator"
|
||||
input_side_packet: "YT8M_SEQUENCE_EXAMPLE:yt8m_sequence_example"
|
||||
input_side_packet: "DESIRED_SEGMENT_SIZE:desired_segment_size_int"
|
||||
output_side_packet: "YT8M_ID:yt8m_id"
|
||||
output_side_packet: "SEGMENT_SIZE:segment_size"
|
||||
output_side_packet: "LAPPED_TENSOR_BUFFER_CALCULATOR_OPTIONS:lapped_tensor_buffer_calculator_options"
|
||||
output_stream: "QUANTIZED_RGB_FEATURE:quantized_rgb_feature"
|
||||
output_stream: "QUANTIZED_AUDIO_FEATURE:quantized_audio_feature"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "DequantizeByteArrayCalculator"
|
||||
input_stream: "ENCODED:quantized_rgb_feature"
|
||||
output_stream: "FLOAT_VECTOR:rgb_feature_vector"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.DequantizeByteArrayCalculatorOptions]: {
|
||||
max_quantized_value: 2
|
||||
min_quantized_value: -2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "DequantizeByteArrayCalculator"
|
||||
input_stream: "ENCODED:quantized_audio_feature"
|
||||
output_stream: "FLOAT_VECTOR:audio_feature_vector"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.DequantizeByteArrayCalculatorOptions]: {
|
||||
max_quantized_value: 2
|
||||
min_quantized_value: -2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "ConcatenateFloatVectorCalculator"
|
||||
input_stream: "rgb_feature_vector"
|
||||
input_stream: "audio_feature_vector"
|
||||
output_stream: "feature_vector"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "VectorFloatToTensorCalculator"
|
||||
input_stream: "feature_vector"
|
||||
output_stream: "feature_tensor"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "LappedTensorBufferCalculator"
|
||||
input_stream: "feature_tensor"
|
||||
input_side_packet: "CALCULATOR_OPTIONS:lapped_tensor_buffer_calculator_options"
|
||||
output_stream: "lapped_feature_tensor"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "SidePacketToStreamCalculator"
|
||||
input_side_packet: "segment_size"
|
||||
output_stream: "AT_ZERO:segment_size_int_stream"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "VectorIntToTensorCalculator"
|
||||
input_stream: "SINGLE_INT:segment_size_int_stream"
|
||||
output_stream: "TENSOR_OUT:segment_size_tensor"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "PacketClonerCalculator"
|
||||
input_stream: "segment_size_tensor"
|
||||
input_stream: "lapped_feature_tensor"
|
||||
output_stream: "synced_segment_size_tensor"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "TensorFlowSessionFromSavedModelCalculator"
|
||||
output_side_packet: "SESSION:session"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.TensorFlowSessionFromSavedModelCalculatorOptions]: {
|
||||
saved_model_path: "/tmp/mediapipe/saved_model"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node: {
|
||||
calculator: "TensorFlowInferenceCalculator"
|
||||
input_side_packet: "SESSION:session"
|
||||
input_stream: "NUM_FRAMES:synced_segment_size_tensor"
|
||||
input_stream: "RGB_AND_AUDIO:lapped_feature_tensor"
|
||||
output_stream: "PREDICTIONS:prediction_tensor"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.TensorFlowInferenceCalculatorOptions]: {
|
||||
batch_size: 32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "TensorToVectorFloatCalculator"
|
||||
input_stream: "prediction_tensor"
|
||||
output_stream: "prediction_vector"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "TopKScoresCalculator"
|
||||
input_stream: "SCORES:prediction_vector"
|
||||
output_stream: "SUMMARY:annotation_summary"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.TopKScoresCalculatorOptions]: {
|
||||
top_k: 9
|
||||
label_map_path: "mediapipe/graphs/youtube8m/label_map.txt"
|
||||
}
|
||||
}
|
||||
}
|
15
mediapipe/java/com/google/mediapipe/BUILD
Normal file
15
mediapipe/java/com/google/mediapipe/BUILD
Normal file
|
@ -0,0 +1,15 @@
|
|||
# Copyright 2019 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http:#www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
|
@ -68,3 +68,10 @@ android_library(
|
|||
"@com_google_guava_android//jar",
|
||||
],
|
||||
)
|
||||
|
||||
# Expose the java source files for building mediapipe AAR.
|
||||
filegroup(
|
||||
name = "java_src",
|
||||
srcs = glob(["*.java"]),
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
|
|
@ -150,6 +150,7 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
|||
private ExternalTextureRenderer renderer = null;
|
||||
private long timestampOffset = 0;
|
||||
private long previousTimestamp = 0;
|
||||
private boolean previousTimestampValid = false;
|
||||
|
||||
protected int destinationWidth = 0;
|
||||
protected int destinationHeight = 0;
|
||||
|
@ -335,11 +336,12 @@ public class ExternalTextureConverter implements TextureFrameProducer {
|
|||
// ensures that surface texture has the up-to-date timestamp. (Also adjust |timestampOffset|
|
||||
// to ensure that timestamps increase monotonically.)
|
||||
long textureTimestamp = surfaceTexture.getTimestamp() / NANOS_PER_MICRO;
|
||||
if (textureTimestamp + timestampOffset <= previousTimestamp) {
|
||||
if (previousTimestampValid && textureTimestamp + timestampOffset <= previousTimestamp) {
|
||||
timestampOffset = previousTimestamp + 1 - textureTimestamp;
|
||||
}
|
||||
outputFrame.setTimestamp(textureTimestamp + timestampOffset);
|
||||
previousTimestamp = outputFrame.getTimestamp();
|
||||
previousTimestampValid = true;
|
||||
}
|
||||
|
||||
private void waitUntilReleased(AppTextureFrame frame) {
|
||||
|
|
|
@ -82,3 +82,10 @@ android_library(
|
|||
"@com_google_guava_android//jar",
|
||||
],
|
||||
)
|
||||
|
||||
# Expose the java source files for building mediapipe AAR.
|
||||
filegroup(
|
||||
name = "java_src",
|
||||
srcs = glob(["*.java"]),
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
|
|
@ -30,3 +30,10 @@ android_library(
|
|||
"@com_google_guava_android//jar",
|
||||
],
|
||||
)
|
||||
|
||||
# Expose the java source files for building mediapipe AAR.
|
||||
filegroup(
|
||||
name = "java_src",
|
||||
srcs = glob(["**/*.java"]),
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
|
157
mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl
Normal file
157
mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl
Normal file
|
@ -0,0 +1,157 @@
|
|||
"""Generate MediaPipe AAR including different variants of .so in jni folder.
|
||||
|
||||
Usage:
|
||||
|
||||
Create a new mediapipe_aar() target in a BUILD file. For example,
|
||||
putting the following code into mediapipe/examples/android/aar_demo/BUILD.
|
||||
|
||||
```
|
||||
load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_aar")
|
||||
|
||||
mediapipe_aar(
|
||||
name = "my_aar",
|
||||
calculators = ["//mediapipe/calculators/core:pass_through_calculator"],
|
||||
)
|
||||
```
|
||||
|
||||
Then, run the following Bazel command to generate the AAR.
|
||||
|
||||
```
|
||||
$ bazel build -c opt --fat_apk_cpu=arm64-v8a,armeabi-v7a mediapipe/examples/android/aar_demo:my_aar
|
||||
```
|
||||
|
||||
Finally, import the AAR into Android Studio.
|
||||
|
||||
"""
|
||||
|
||||
load("@build_bazel_rules_android//android:rules.bzl", "android_binary", "android_library")
|
||||
|
||||
def mediapipe_aar(name, calculators = []):
|
||||
"""Generate MediaPipe AAR.
|
||||
|
||||
Args:
|
||||
name: the name of the AAR.
|
||||
calculators: the calculator libraries to be compiled into the .so.
|
||||
"""
|
||||
native.cc_binary(
|
||||
name = "libmediapipe_jni.so",
|
||||
linkshared = 1,
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
||||
] + calculators,
|
||||
)
|
||||
|
||||
native.cc_library(
|
||||
name = name + "_mediapipe_jni_lib",
|
||||
srcs = [":libmediapipe_jni.so"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
native.genrule(
|
||||
name = name + "_aar_manifest_generator",
|
||||
outs = ["AndroidManifest.xml"],
|
||||
cmd = """
|
||||
cat > $(OUTS) <<EOF
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe">
|
||||
<uses-sdk
|
||||
android:minSdkVersion="21"
|
||||
android:targetSdkVersion="27" />
|
||||
<application />
|
||||
</manifest>
|
||||
""",
|
||||
)
|
||||
|
||||
native.genrule(
|
||||
name = name + "_calculator_proto_java_src_generator",
|
||||
srcs = [
|
||||
"//mediapipe/framework:protos_src",
|
||||
"@com_google_protobuf_javalite//:well_known_protos",
|
||||
],
|
||||
outs = ["CalculatorProto.java"],
|
||||
cmd = "$(location @com_google_protobuf_javalite//:protoc) " +
|
||||
"--plugin=protoc-gen-javalite=$(location @com_google_protobuf_javalite//:protoc_gen_javalite) " +
|
||||
"--proto_path=. --proto_path=$(GENDIR) " +
|
||||
"--proto_path=$$(pwd)/external/com_google_protobuf_javalite/src " +
|
||||
"--javalite_out=$$(dirname $(location CalculatorProto.java)) mediapipe/framework/calculator.proto && " +
|
||||
"mv $$(dirname $(location CalculatorProto.java))/com/google/mediapipe/proto/CalculatorProto.java $$(dirname $(location CalculatorProto.java))",
|
||||
tools = [
|
||||
"@com_google_protobuf_javalite//:protoc",
|
||||
"@com_google_protobuf_javalite//:protoc_gen_javalite",
|
||||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = name + "_android_lib",
|
||||
srcs = [
|
||||
"//mediapipe/java/com/google/mediapipe/components:java_src",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:java_src",
|
||||
"//mediapipe/java/com/google/mediapipe/glutil:java_src",
|
||||
"CalculatorProto.java",
|
||||
],
|
||||
manifest = "AndroidManifest.xml",
|
||||
proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"],
|
||||
deps = [
|
||||
":" + name + "_mediapipe_jni_lib",
|
||||
"//mediapipe/framework:calculator_java_proto_lite",
|
||||
"//mediapipe/framework:calculator_profile_java_proto_lite",
|
||||
"//mediapipe/framework/tool:calculator_graph_template_java_proto_lite",
|
||||
"//third_party:androidx_annotation",
|
||||
"//third_party:androidx_appcompat",
|
||||
"//third_party:androidx_core",
|
||||
"//third_party:androidx_legacy_support_v4",
|
||||
"//third_party:camerax_core",
|
||||
"//third_party:camera2",
|
||||
"@com_google_code_findbugs//jar",
|
||||
"@com_google_common_flogger//jar",
|
||||
"@com_google_common_flogger_system_backend//jar",
|
||||
"@com_google_guava_android//jar",
|
||||
"@androidx_lifecycle//jar",
|
||||
],
|
||||
)
|
||||
|
||||
_aar_with_jni(name, name + "_android_lib")
|
||||
|
||||
def _aar_with_jni(name, android_library):
|
||||
# Generate dummy AndroidManifest.xml for dummy apk usage
|
||||
# (dummy apk is generated by <name>_dummy_app target below)
|
||||
native.genrule(
|
||||
name = name + "_binary_manifest_generator",
|
||||
outs = [name + "_generated_AndroidManifest.xml"],
|
||||
cmd = """
|
||||
cat > $(OUTS) <<EOF
|
||||
<manifest
|
||||
xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="dummy.package.for.so">
|
||||
<uses-sdk android:minSdkVersion="21"/>
|
||||
</manifest>
|
||||
EOF
|
||||
""",
|
||||
)
|
||||
|
||||
# Generate dummy apk including .so files.
|
||||
# We extract out .so files and throw away the apk.
|
||||
android_binary(
|
||||
name = name + "_dummy_app",
|
||||
manifest = name + "_generated_AndroidManifest.xml",
|
||||
custom_package = "dummy.package.for.so",
|
||||
deps = [android_library],
|
||||
)
|
||||
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = [android_library + ".aar", name + "_dummy_app_unsigned.apk"],
|
||||
outs = [name + ".aar"],
|
||||
tags = ["manual"],
|
||||
cmd = """
|
||||
cp $(location {}.aar) $(location :{}.aar)
|
||||
chmod +w $(location :{}.aar)
|
||||
origdir=$$PWD
|
||||
cd $$(mktemp -d)
|
||||
unzip $$origdir/$(location :{}_dummy_app_unsigned.apk) "lib/*"
|
||||
cp -r lib jni
|
||||
zip -r $$origdir/$(location :{}.aar) jni/*/*.so
|
||||
""".format(android_library, name, name, name, name),
|
||||
)
|
|
@ -466,6 +466,7 @@ tasks and tracking (or class) fields for tracking information.
|
|||
|-----|------|------------------------|-------------|
|
||||
|`CLASS_SEGMENTATION/image/encoded`|feature list bytes|`add_class_segmentation_encoded` / `AddClassSegmentationEncoded`|The encoded image of class labels at each timestep.|
|
||||
|`CLASS_SEGMENTATION/image/timestamp`|feature list int|`add_class_segmentation_timestamp` / `AddClassSegmentationTimestamp`|The timestamp in microseconds for the class labels.|
|
||||
|`CLASS_SEGMENTATION/image/multi_encoded`|feature list bytes list|`add_class_segmentation_multi_encoded` / `AddClassSegmentationMultiEncoded`|Storing multiple segmentation masks in case they overlap.|
|
||||
|`CLASS_SEGMENTATION/image/format`|context bytes|`set_class_segmentation_format` / `SetClassSegmentationFormat`|The encoding format of the class label images.|
|
||||
|`CLASS_SEGMENTATION/image/height`|context int|`set_class_segmentation_height` / `SetClassSegmentationHeight`|The height of the image in pixels.|
|
||||
|`CLASS_SEGMENTATION/image/width`|context int|`set_class_segmentation_width` / `SetClassSegmentationWidth`|The width of the image in pixels.|
|
||||
|
@ -477,6 +478,7 @@ tasks and tracking (or class) fields for tracking information.
|
|||
|-----|------|------------------------|-------------|
|
||||
|`INSTANCE_SEGMENTATION/image/ encoded`|feature list bytes|`add_instance_segmentation_encoded` / `AddInstanceSegmentationEncoded`|The encoded image of object instance labels at each timestep.|
|
||||
|`INSTANCE_SEGMENTATION/image/ timestamp`|feature list int|`add_instance_segmentation_timestamp` / `AddInstanceSegmentationTimestamp`|The timestamp in microseconds for the object instance labels.|
|
||||
|`INSTANCE_SEGMENTATION/image/multi_encoded`|feature list bytes list|`add_instance_segmentation_multi_encoded` / `AddInstanceSegmentationEncoded`|Storing multiple segmentation masks in case they overlap.|
|
||||
|`INSTANCE_SEGMENTATION/image/ format`|context bytes|`set_instance_segmentation_format` / `SetInstanceSegmentationFormat`|The encoding format of the object instance labels.|
|
||||
|`INSTANCE_SEGMENTATION/image/ height`|context int|`set_instance_segmentation_height` / `SetInstanceSegmentationHeight`|The height of the image in pixels.|
|
||||
|`INSTANCE_SEGMENTATION/image/ width`|context int|`set_instance_segmentation_width` / `SetInstanceSegmentationWidth`|The width of the image in pixels.|
|
||||
|
|
|
@ -489,7 +489,9 @@ def _create_image_with_prefix(name, prefix):
|
|||
prefix=prefix, module_dict=globals())
|
||||
msu.create_int_feature_list(name + "_timestamp", IMAGE_TIMESTAMP_KEY,
|
||||
prefix=prefix, module_dict=globals())
|
||||
|
||||
msu.create_bytes_list_feature_list(name + "_multi_encoded",
|
||||
IMAGE_MULTI_ENCODED_KEY, prefix=prefix,
|
||||
module_dict=globals())
|
||||
FORWARD_FLOW_PREFIX = "FORWARD_FLOW"
|
||||
CLASS_SEGMENTATION_PREFIX = "CLASS_SEGMENTATION"
|
||||
INSTANCE_SEGMENTATION_PREFIX = "INSTANCE_SEGMENTATION"
|
||||
|
|
|
@ -78,8 +78,10 @@ class MediaSequenceTest(tf.test.TestCase):
|
|||
ms.set_bbox_parts((b"HEAD", b"TOE"), example)
|
||||
# feature lists
|
||||
ms.add_image_encoded(b"test", example)
|
||||
ms.add_image_multi_encoded([b"test", b"test"], example)
|
||||
ms.add_image_timestamp(47, example)
|
||||
ms.add_forward_flow_encoded(b"test", example)
|
||||
ms.add_forward_flow_multi_encoded([b"test", b"test"], example)
|
||||
ms.add_forward_flow_timestamp(47, example)
|
||||
ms.add_bbox_ymin((0.47, 0.49), example)
|
||||
ms.add_bbox_xmin((0.47, 0.49), example)
|
||||
|
@ -109,7 +111,9 @@ class MediaSequenceTest(tf.test.TestCase):
|
|||
ms.add_predicted_bbox_class_string((b"test", b"strings"), example)
|
||||
ms.add_predicted_bbox_timestamp(47, example)
|
||||
ms.add_class_segmentation_encoded(b"test", example)
|
||||
ms.add_class_segmentation_multi_encoded([b"test", b"test"], example)
|
||||
ms.add_instance_segmentation_encoded(b"test", example)
|
||||
ms.add_instance_segmentation_multi_encoded([b"test", b"test"], example)
|
||||
ms.add_class_segmentation_timestamp(47, example)
|
||||
ms.set_bbox_embedding_dimensions_per_region((47, 49), example)
|
||||
ms.set_bbox_embedding_format(b"test", example)
|
||||
|
|
Loading…
Reference in New Issue
Block a user