Merge branch 'ios-task-files' into ios-task

This commit is contained in:
Prianka Liz Kariat 2022-12-23 16:10:57 +05:30
commit 7f7776ef80
89 changed files with 1189 additions and 891 deletions

View File

@ -1,6 +1,6 @@
---
name: "Solution (legacy) Issue"
about: Use this template for assistance with a specific Mediapipe solution (google.github.io/mediapipe/solutions), such as "Pose" or "Iris", including inference model usage/training, solution-specific calculators, etc.
about: Use this template for assistance with a specific Mediapipe solution (google.github.io/mediapipe/solutions) such as "Pose", including inference model usage/training, solution-specific calculators etc.
labels: type:support
---

View File

@ -259,6 +259,7 @@ mp_holistic = mp.solutions.holistic
# For static images:
IMAGE_FILES = []
BG_COLOR = (192, 192, 192) # gray
with mp_holistic.Holistic(
static_image_mode=True,
model_complexity=2,

View File

@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
licenses(["notice"])
package(default_visibility = ["//visibility:private"])
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
proto_library(
name = "mfcc_mel_calculators_proto",
srcs = ["mfcc_mel_calculators.proto"],

View File

@ -567,7 +567,7 @@ cc_library(
name = "packet_thinner_calculator",
srcs = ["packet_thinner_calculator.cc"],
deps = [
"//mediapipe/calculators/core:packet_thinner_calculator_cc_proto",
":packet_thinner_calculator_cc_proto",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:video_stream_header",
@ -584,7 +584,7 @@ cc_test(
srcs = ["packet_thinner_calculator_test.cc"],
deps = [
":packet_thinner_calculator",
"//mediapipe/calculators/core:packet_thinner_calculator_cc_proto",
":packet_thinner_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:video_stream_header",
@ -762,7 +762,7 @@ cc_library(
srcs = ["packet_resampler_calculator.cc"],
hdrs = ["packet_resampler_calculator.h"],
deps = [
"//mediapipe/calculators/core:packet_resampler_calculator_cc_proto",
":packet_resampler_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework/deps:mathutil",
@ -786,7 +786,7 @@ cc_test(
],
deps = [
":packet_resampler_calculator",
"//mediapipe/calculators/core:packet_resampler_calculator_cc_proto",
":packet_resampler_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:video_stream_header",
@ -852,10 +852,10 @@ cc_test(
name = "flow_limiter_calculator_test",
srcs = ["flow_limiter_calculator_test.cc"],
deps = [
":counting_source_calculator",
":flow_limiter_calculator",
":flow_limiter_calculator_cc_proto",
"//mediapipe/calculators/core:counting_source_calculator",
"//mediapipe/calculators/core:pass_through_calculator",
":pass_through_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:test_calculators",
@ -1302,7 +1302,7 @@ cc_test(
srcs = ["packet_sequencer_calculator_test.cc"],
deps = [
":packet_sequencer_calculator",
"//mediapipe/calculators/core:pass_through_calculator",
":pass_through_calculator",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:subgraph",

View File

@ -47,7 +47,7 @@ namespace api2 {
// calculator: "Get{SpecificType}VectorItemCalculator"
// input_stream: "VECTOR:vector"
// input_stream: "INDEX:index"
// input_stream: "ITEM:item"
// output_stream: "ITEM:item"
// options {
// [mediapipe.GetVectorItemCalculatorOptions.ext] {
// item_index: 5

View File

@ -76,7 +76,11 @@ constexpr char kMaxInFlightTag[] = "MAX_IN_FLIGHT";
// }
// output_stream: "gated_frames"
// }
class RealTimeFlowLimiterCalculator : public CalculatorBase {
//
// Please use FlowLimiterCalculator, which replaces this calculator and
// defines a few additional configuration options.
class ABSL_DEPRECATED("Use FlowLimiterCalculator instead.")
RealTimeFlowLimiterCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
int num_data_streams = cc->Inputs().NumEntries("");

View File

@ -66,12 +66,16 @@ class SequenceShiftCalculator : public Node {
// The number of packets or timestamps we need to store to output packet[i] at
// the timestamp of packet[i + packet_offset]; equal to abs(packet_offset).
int cache_size_;
bool emit_empty_packets_before_first_packet_ = false;
};
MEDIAPIPE_REGISTER_NODE(SequenceShiftCalculator);
absl::Status SequenceShiftCalculator::Open(CalculatorContext* cc) {
packet_offset_ = kOffset(cc).GetOr(
cc->Options<mediapipe::SequenceShiftCalculatorOptions>().packet_offset());
emit_empty_packets_before_first_packet_ =
cc->Options<mediapipe::SequenceShiftCalculatorOptions>()
.emit_empty_packets_before_first_packet();
cache_size_ = abs(packet_offset_);
// An offset of zero is a no-op, but someone might still request it.
if (packet_offset_ == 0) {
@ -96,6 +100,8 @@ void SequenceShiftCalculator::ProcessPositiveOffset(CalculatorContext* cc) {
// Ready to output oldest packet with current timestamp.
kOut(cc).Send(packet_cache_.front().At(cc->InputTimestamp()));
packet_cache_.pop_front();
} else if (emit_empty_packets_before_first_packet_) {
LOG(FATAL) << "Not supported yet";
}
// Store current packet for later output.
packet_cache_.push_back(kIn(cc).packet());

View File

@ -23,4 +23,8 @@ message SequenceShiftCalculatorOptions {
optional SequenceShiftCalculatorOptions ext = 107633927;
}
optional int32 packet_offset = 1 [default = -1];
// Emits empty packets before the first delayed packet is emitted. Takes
// effect only when packet offset is set to positive.
optional bool emit_empty_packets_before_first_packet = 2 [default = false];
}

View File

@ -378,8 +378,8 @@ cc_library(
name = "scale_image_calculator",
srcs = ["scale_image_calculator.cc"],
deps = [
":scale_image_calculator_cc_proto",
":scale_image_utils",
"//mediapipe/calculators/image:scale_image_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/formats:image_frame",
@ -747,8 +747,8 @@ cc_test(
tags = ["desktop_only_test"],
deps = [
":affine_transformation",
":image_transformation_calculator",
":warp_affine_calculator",
"//mediapipe/calculators/image:image_transformation_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_converter",
"//mediapipe/calculators/tensor:image_to_tensor_utils",
"//mediapipe/calculators/util:from_image_calculator",

View File

@ -92,8 +92,8 @@ class GlTextureWarpAffineRunner
constexpr GLchar kVertShader[] = R"(
in vec4 position;
in mediump vec4 texture_coordinate;
out mediump vec2 sample_coordinate;
in highp vec4 texture_coordinate;
out highp vec2 sample_coordinate;
uniform mat4 transform_matrix;
void main() {
@ -104,7 +104,7 @@ class GlTextureWarpAffineRunner
)";
constexpr GLchar kFragShader[] = R"(
DEFAULT_PRECISION(mediump, float)
DEFAULT_PRECISION(highp, float)
in vec2 sample_coordinate;
uniform sampler2D input_texture;

View File

@ -38,6 +38,7 @@ void SetColorChannel(int channel, uint8 value, cv::Mat* mat) {
constexpr char kRgbaInTag[] = "RGBA_IN";
constexpr char kRgbInTag[] = "RGB_IN";
constexpr char kBgrInTag[] = "BGR_IN";
constexpr char kBgraInTag[] = "BGRA_IN";
constexpr char kGrayInTag[] = "GRAY_IN";
constexpr char kRgbaOutTag[] = "RGBA_OUT";
@ -57,6 +58,7 @@ constexpr char kGrayOutTag[] = "GRAY_OUT";
// RGB -> RGBA
// RGBA -> BGRA
// BGRA -> RGBA
// BGR -> RGB
//
// This calculator only supports a single input stream and output stream at a
// time. If more than one input stream or output stream is present, the
@ -69,6 +71,7 @@ constexpr char kGrayOutTag[] = "GRAY_OUT";
// RGB_IN: The input video stream (ImageFrame, SRGB).
// BGRA_IN: The input video stream (ImageFrame, SBGRA).
// GRAY_IN: The input video stream (ImageFrame, GRAY8).
// BGR_IN: The input video stream (ImageFrame, SBGR).
//
// Output streams:
// RGBA_OUT: The output video stream (ImageFrame, SRGBA).
@ -122,6 +125,10 @@ absl::Status ColorConvertCalculator::GetContract(CalculatorContract* cc) {
cc->Inputs().Tag(kBgraInTag).Set<ImageFrame>();
}
if (cc->Inputs().HasTag(kBgrInTag)) {
cc->Inputs().Tag(kBgrInTag).Set<ImageFrame>();
}
if (cc->Outputs().HasTag(kRgbOutTag)) {
cc->Outputs().Tag(kRgbOutTag).Set<ImageFrame>();
}
@ -194,6 +201,11 @@ absl::Status ColorConvertCalculator::Process(CalculatorContext* cc) {
return ConvertAndOutput(kRgbaInTag, kBgraOutTag, ImageFormat::SBGRA,
cv::COLOR_RGBA2BGRA, cc);
}
// BGR -> RGB
if (cc->Inputs().HasTag(kBgrInTag) && cc->Outputs().HasTag(kRgbOutTag)) {
return ConvertAndOutput(kBgrInTag, kRgbOutTag, ImageFormat::SRGB,
cv::COLOR_BGR2RGB, cc);
}
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Unsupported image format conversion.";

View File

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
licenses(["notice"])
package(default_visibility = ["//visibility:private"])
proto_library(

View File

@ -68,8 +68,8 @@ class GlProcessor : public ImageToTensorConverter {
constexpr GLchar kExtractSubRectVertexShader[] = R"(
in vec4 position;
in mediump vec4 texture_coordinate;
out mediump vec2 sample_coordinate;
in highp vec4 texture_coordinate;
out highp vec2 sample_coordinate;
uniform mat4 transform_matrix;
void main() {
@ -86,7 +86,7 @@ class GlProcessor : public ImageToTensorConverter {
)";
constexpr GLchar kExtractSubRectFragBody[] = R"(
DEFAULT_PRECISION(mediump, float)
DEFAULT_PRECISION(highp, float)
// Provided by kExtractSubRectVertexShader.
in vec2 sample_coordinate;

View File

@ -22,8 +22,8 @@ cc_library(
name = "alignment_points_to_rects_calculator",
srcs = ["alignment_points_to_rects_calculator.cc"],
deps = [
":detections_to_rects_calculator",
":detections_to_rects_calculator_cc_proto",
"//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_options_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto",

View File

@ -1,4 +1,3 @@
#
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -227,13 +226,13 @@ cc_library(
":mediapipe_internal",
],
deps = [
":calculator_cc_proto",
":graph_service",
":mediapipe_options_cc_proto",
":packet_generator_cc_proto",
":packet_type",
":port",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:mediapipe_options_cc_proto",
"//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework:status_handler_cc_proto",
":status_handler_cc_proto",
"//mediapipe/framework/port:any_proto",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:options_map",
@ -329,10 +328,10 @@ cc_library(
":thread_pool_executor",
":timestamp",
":validated_graph_config",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework:status_handler_cc_proto",
"//mediapipe/framework:thread_pool_executor_cc_proto",
":calculator_cc_proto",
":packet_generator_cc_proto",
":status_handler_cc_proto",
":thread_pool_executor_cc_proto",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/container:flat_hash_map",
@ -370,7 +369,7 @@ cc_library(
visibility = [":mediapipe_internal"],
deps = [
":graph_service",
"//mediapipe/framework:packet",
":packet",
"@com_google_absl//absl/status",
],
)
@ -380,7 +379,7 @@ cc_test(
srcs = ["graph_service_manager_test.cc"],
deps = [
":graph_service_manager",
"//mediapipe/framework:packet",
":packet",
"//mediapipe/framework/port:gtest_main",
],
)
@ -392,6 +391,7 @@ cc_library(
visibility = [":mediapipe_internal"],
deps = [
":calculator_base",
":calculator_cc_proto",
":calculator_context",
":calculator_context_manager",
":calculator_state",
@ -408,10 +408,9 @@ cc_library(
":packet_set",
":packet_type",
":port",
":stream_handler_cc_proto",
":timestamp",
":validated_graph_config",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:stream_handler_cc_proto",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
@ -467,6 +466,7 @@ cc_library(
hdrs = ["calculator_state.h"],
visibility = [":mediapipe_internal"],
deps = [
":calculator_cc_proto",
":counter",
":counter_factory",
":graph_service",
@ -476,7 +476,6 @@ cc_library(
":packet",
":packet_set",
":port",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/port:any_proto",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/tool:options_map",
@ -584,7 +583,7 @@ cc_library(
hdrs = ["executor.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:mediapipe_options_cc_proto",
":mediapipe_options_cc_proto",
"//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
@ -671,11 +670,11 @@ cc_library(
":collection_item_id",
":input_stream_manager",
":input_stream_shard",
":mediapipe_options_cc_proto",
":mediapipe_profiling",
":packet",
":packet_set",
":packet_type",
"//mediapipe/framework:mediapipe_options_cc_proto",
"//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
@ -785,12 +784,12 @@ cc_library(
":calculator_context_manager",
":collection",
":collection_item_id",
":mediapipe_options_cc_proto",
":output_stream_manager",
":output_stream_shard",
":packet_set",
":packet_type",
":timestamp",
"//mediapipe/framework:mediapipe_options_cc_proto",
"//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:status",
@ -876,10 +875,10 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":packet",
":packet_generator_cc_proto",
":packet_set",
":packet_type",
":port",
"//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:status",
@ -897,13 +896,13 @@ cc_library(
":delegating_executor",
":executor",
":packet",
":packet_factory_cc_proto",
":packet_generator",
":packet_generator_cc_proto",
":packet_type",
":port",
":thread_pool_executor",
":validated_graph_config",
"//mediapipe/framework:packet_factory_cc_proto",
"//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
@ -1020,10 +1019,10 @@ cc_library(
hdrs = ["status_handler.h"],
visibility = ["//visibility:public"],
deps = [
":mediapipe_options_cc_proto",
":packet_set",
":packet_type",
":port",
"//mediapipe/framework:mediapipe_options_cc_proto",
"//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/memory",
@ -1036,10 +1035,10 @@ cc_library(
hdrs = ["subgraph.h"],
visibility = ["//visibility:public"],
deps = [
":calculator_cc_proto",
":graph_service",
":graph_service_manager",
":port",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
@ -1061,7 +1060,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":calculator_framework",
"//mediapipe/framework:test_calculators_cc_proto",
":test_calculators_cc_proto",
"//mediapipe/framework/deps:mathutil",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:integral_types",
@ -1098,7 +1097,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":executor",
"//mediapipe/framework:thread_pool_executor_cc_proto",
":thread_pool_executor_cc_proto",
"//mediapipe/framework/deps:thread_options",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:status",
@ -1163,22 +1162,22 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":calculator_base",
":calculator_cc_proto",
":calculator_contract",
":graph_service_manager",
":legacy_calculator_support",
":packet",
":packet_generator",
":packet_generator_cc_proto",
":packet_set",
":packet_type",
":port",
":status_handler",
":status_handler_cc_proto",
":stream_handler_cc_proto",
":subgraph",
":thread_pool_executor_cc_proto",
":timestamp",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework:status_handler_cc_proto",
"//mediapipe/framework:stream_handler_cc_proto",
"//mediapipe/framework:thread_pool_executor_cc_proto",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
@ -1203,11 +1202,11 @@ cc_test(
name = "validated_graph_config_test",
srcs = ["validated_graph_config_test.cc"],
deps = [
":calculator_cc_proto",
":calculator_framework",
":graph_service",
":graph_service_manager",
":validated_graph_config",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/port:gtest_main",
@ -1234,6 +1233,7 @@ cc_test(
linkstatic = 1,
deps = [
":calculator_base",
":calculator_cc_proto",
":calculator_context",
":calculator_context_manager",
":calculator_registry",
@ -1243,7 +1243,6 @@ cc_test(
":output_stream_shard",
":packet_set",
":packet_type",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:status_util",
@ -1257,11 +1256,11 @@ cc_test(
srcs = ["calculator_contract_test.cc"],
linkstatic = 1,
deps = [
":calculator_cc_proto",
":calculator_contract",
":calculator_contract_test_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework:status_handler_cc_proto",
":packet_generator_cc_proto",
":status_handler_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
],
@ -1369,6 +1368,7 @@ cc_test(
srcs = ["calculator_context_test.cc"],
linkstatic = 1,
deps = [
":calculator_cc_proto",
":calculator_context",
":calculator_context_manager",
":calculator_state",
@ -1377,7 +1377,6 @@ cc_test(
":output_stream_shard",
":packet_set",
":packet_type",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
@ -1404,6 +1403,7 @@ cc_test(
":executor",
":input_stream_handler",
":lifetime_tracker",
":mediapipe_options_cc_proto",
":output_stream_poller",
":packet_set",
":packet_type",
@ -1411,13 +1411,12 @@ cc_test(
":subgraph",
":test_calculators",
":thread_pool_executor",
":thread_pool_executor_cc_proto",
":timestamp",
":type_map",
"//mediapipe/calculators/core:counting_source_calculator",
"//mediapipe/calculators/core:mux_calculator",
"//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/framework:mediapipe_options_cc_proto",
"//mediapipe/framework:thread_pool_executor_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:parse_text_proto",
@ -1482,12 +1481,12 @@ cc_test(
],
visibility = ["//visibility:public"],
deps = [
":calculator_cc_proto",
":calculator_framework",
":test_calculators",
"//mediapipe/calculators/core:counting_source_calculator",
"//mediapipe/calculators/core:mux_calculator",
"//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:parse_text_proto",
@ -1631,8 +1630,8 @@ cc_test(
srcs = ["packet_generator_test.cc"],
deps = [
":packet_generator",
":packet_generator_cc_proto",
":packet_type",
"//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/tool:validate_type",
"@com_google_absl//absl/strings",

View File

@ -15,12 +15,17 @@
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace api2 {
namespace test {
namespace mediapipe::api2::builder {
namespace {
using ::mediapipe::api2::test::Bar;
using ::mediapipe::api2::test::FloatAdder;
using ::mediapipe::api2::test::Foo;
using ::mediapipe::api2::test::Foo2;
using ::mediapipe::api2::test::FooBar1;
TEST(BuilderTest, BuildGraph) {
builder::Graph graph;
Graph graph;
auto& foo = graph.AddNode("Foo");
auto& bar = graph.AddNode("Bar");
graph.In("IN").SetName("base") >> foo.In("BASE");
@ -49,22 +54,19 @@ TEST(BuilderTest, BuildGraph) {
}
TEST(BuilderTest, CopyableSource) {
builder::Graph graph;
builder::Source<int> a = graph[Input<int>("A")];
a.SetName("a");
builder::Source<int> b = graph[Input<int>("B")];
b.SetName("b");
builder::SideSource<float> side_a = graph[SideInput<float>("SIDE_A")];
side_a.SetName("side_a");
builder::SideSource<float> side_b = graph[SideInput<float>("SIDE_B")];
side_b.SetName("side_b");
builder::Destination<int> out = graph[Output<int>("OUT")];
builder::SideDestination<float> side_out =
graph[SideOutput<float>("SIDE_OUT")];
Graph graph;
Source<int> a = graph.In("A").SetName("a").Cast<int>();
Source<int> b = graph.In("B").SetName("b").Cast<int>();
SideSource<float> side_a =
graph.SideIn("SIDE_A").SetName("side_a").Cast<float>();
SideSource<float> side_b =
graph.SideIn("SIDE_B").SetName("side_b").Cast<float>();
Destination<int> out = graph.Out("OUT").Cast<int>();
SideDestination<float> side_out = graph.SideOut("SIDE_OUT").Cast<float>();
builder::Source<int> input = a;
Source<int> input = a;
input = b;
builder::SideSource<float> side_input = side_b;
SideSource<float> side_input = side_b;
side_input = side_a;
input >> out;
@ -83,31 +85,27 @@ TEST(BuilderTest, CopyableSource) {
}
TEST(BuilderTest, BuildGraphWithFunctions) {
builder::Graph graph;
Graph graph;
builder::Source<int> base = graph[Input<int>("IN")];
base.SetName("base");
builder::SideSource<float> side = graph[SideInput<float>("SIDE")];
side.SetName("side");
Source<int> base = graph.In("IN").SetName("base").Cast<int>();
SideSource<float> side = graph.SideIn("SIDE").SetName("side").Cast<float>();
auto foo_fn = [](builder::Source<int> base, builder::SideSource<float> side,
builder::Graph& graph) {
auto foo_fn = [](Source<int> base, SideSource<float> side, Graph& graph) {
auto& foo = graph.AddNode("Foo");
base >> foo[Input<int>("BASE")];
side >> foo[SideInput<float>("SIDE")];
return foo[Output<double>("OUT")];
base >> foo.In("BASE");
side >> foo.SideIn("SIDE");
return foo.Out("OUT")[0].Cast<double>();
};
builder::Source<double> foo_out = foo_fn(base, side, graph);
Source<double> foo_out = foo_fn(base, side, graph);
auto bar_fn = [](builder::Source<double> in, builder::Graph& graph) {
auto bar_fn = [](Source<double> in, Graph& graph) {
auto& bar = graph.AddNode("Bar");
in >> bar[Input<double>("IN")];
return bar[Output<double>("OUT")];
in >> bar.In("IN");
return bar.Out("OUT")[0].Cast<double>();
};
builder::Source<double> bar_out = bar_fn(foo_out, graph);
bar_out.SetName("out");
Source<double> bar_out = bar_fn(foo_out, graph);
bar_out >> graph[Output<double>("OUT")];
bar_out.SetName("out") >> graph.Out("OUT");
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
@ -131,7 +129,7 @@ TEST(BuilderTest, BuildGraphWithFunctions) {
template <class FooT>
void BuildGraphTypedTest() {
builder::Graph graph;
Graph graph;
auto& foo = graph.AddNode<FooT>();
auto& bar = graph.AddNode<Bar>();
graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE"));
@ -161,12 +159,12 @@ void BuildGraphTypedTest() {
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
TEST(BuilderTest, BuildGraphTyped) { BuildGraphTypedTest<Foo>(); }
TEST(BuilderTest, BuildGraphTyped) { BuildGraphTypedTest<test::Foo>(); }
TEST(BuilderTest, BuildGraphTyped2) { BuildGraphTypedTest<Foo2>(); }
TEST(BuilderTest, BuildGraphTyped2) { BuildGraphTypedTest<test::Foo2>(); }
TEST(BuilderTest, FanOut) {
builder::Graph graph;
Graph graph;
auto& foo = graph.AddNode("Foo");
auto& adder = graph.AddNode("FloatAdder");
graph.In("IN").SetName("base") >> foo.In("BASE");
@ -194,9 +192,9 @@ TEST(BuilderTest, FanOut) {
}
TEST(BuilderTest, TypedMultiple) {
builder::Graph graph;
auto& foo = graph.AddNode<Foo>();
auto& adder = graph.AddNode<FloatAdder>();
Graph graph;
auto& foo = graph.AddNode<test::Foo>();
auto& adder = graph.AddNode<test::FloatAdder>();
graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE"));
foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[0];
foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[1];
@ -222,14 +220,14 @@ TEST(BuilderTest, TypedMultiple) {
}
TEST(BuilderTest, TypedByPorts) {
builder::Graph graph;
auto& foo = graph.AddNode<Foo>();
Graph graph;
auto& foo = graph.AddNode<test::Foo>();
auto& adder = graph.AddNode<FloatAdder>();
graph[FooBar1::kIn].SetName("base") >> foo[Foo::kBase];
graph.In(FooBar1::kIn).SetName("base") >> foo[Foo::kBase];
foo[Foo::kOut] >> adder[FloatAdder::kIn][0];
foo[Foo::kOut] >> adder[FloatAdder::kIn][1];
adder[FloatAdder::kOut].SetName("out") >> graph[FooBar1::kOut];
adder[FloatAdder::kOut].SetName("out") >> graph.Out(FooBar1::kOut);
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
@ -251,7 +249,7 @@ TEST(BuilderTest, TypedByPorts) {
}
TEST(BuilderTest, PacketGenerator) {
builder::Graph graph;
Graph graph;
auto& generator = graph.AddPacketGenerator("FloatGenerator");
graph.SideIn("IN") >> generator.SideIn("IN");
generator.SideOut("OUT") >> graph.SideOut("OUT");
@ -270,7 +268,7 @@ TEST(BuilderTest, PacketGenerator) {
}
TEST(BuilderTest, EmptyTag) {
builder::Graph graph;
Graph graph;
auto& foo = graph.AddNode("Foo");
graph.In("A").SetName("a") >> foo.In("")[0];
graph.In("C").SetName("c") >> foo.In("")[2];
@ -302,7 +300,7 @@ TEST(BuilderTest, StringLikeTags) {
const std::string kB = "B";
constexpr absl::string_view kC = "C";
builder::Graph graph;
Graph graph;
auto& foo = graph.AddNode("Foo");
graph.In(kA).SetName("a") >> foo.In(kA);
graph.In(kB).SetName("b") >> foo.In(kB);
@ -324,7 +322,7 @@ TEST(BuilderTest, StringLikeTags) {
}
TEST(BuilderTest, GraphIndexes) {
builder::Graph graph;
Graph graph;
auto& foo = graph.AddNode("Foo");
graph.In(0).SetName("a") >> foo.In("")[0];
graph.In(1).SetName("c") >> foo.In("")[2];
@ -376,28 +374,27 @@ class AnyAndSameTypeCalculator : public NodeIntf {
};
TEST(BuilderTest, AnyAndSameTypeHandledProperly) {
builder::Graph graph;
builder::Source<AnyType> any_input = graph[Input<AnyType>{"GRAPH_ANY_INPUT"}];
builder::Source<int> int_input = graph[Input<int>{"GRAPH_INT_INPUT"}];
Graph graph;
Source<AnyType> any_input = graph.In("GRAPH_ANY_INPUT");
Source<int> int_input = graph.In("GRAPH_INT_INPUT").Cast<int>();
auto& node = graph.AddNode("AnyAndSameTypeCalculator");
any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput];
int_input >> node[AnyAndSameTypeCalculator::kIntInput];
builder::Source<AnyType> any_type_output =
Source<AnyType> any_type_output =
node[AnyAndSameTypeCalculator::kAnyTypeOutput];
any_type_output.SetName("any_type_output");
builder::Source<AnyType> same_type_output =
Source<AnyType> same_type_output =
node[AnyAndSameTypeCalculator::kSameTypeOutput];
same_type_output.SetName("same_type_output");
builder::Source<AnyType> recursive_same_type_output =
Source<AnyType> recursive_same_type_output =
node[AnyAndSameTypeCalculator::kRecursiveSameTypeOutput];
recursive_same_type_output.SetName("recursive_same_type_output");
builder::Source<int> same_int_output =
node[AnyAndSameTypeCalculator::kSameIntOutput];
Source<int> same_int_output = node[AnyAndSameTypeCalculator::kSameIntOutput];
same_int_output.SetName("same_int_output");
builder::Source<int> recursive_same_int_type_output =
Source<int> recursive_same_int_type_output =
node[AnyAndSameTypeCalculator::kRecursiveSameIntOutput];
recursive_same_int_type_output.SetName("recursive_same_int_type_output");
@ -420,15 +417,16 @@ TEST(BuilderTest, AnyAndSameTypeHandledProperly) {
}
TEST(BuilderTest, AnyTypeCanBeCast) {
builder::Graph graph;
builder::Source<std::string> any_input =
Graph graph;
Source<std::string> any_input =
graph.In("GRAPH_ANY_INPUT").Cast<std::string>();
auto& node = graph.AddNode("AnyAndSameTypeCalculator");
any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput];
builder::Source<double> any_type_output =
node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>();
any_type_output.SetName("any_type_output");
Source<double> any_type_output =
node[AnyAndSameTypeCalculator::kAnyTypeOutput]
.SetName("any_type_output")
.Cast<double>();
any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast<double>();
@ -446,11 +444,11 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
}
TEST(BuilderTest, MultiPortIsCastToMultiPort) {
builder::Graph graph;
builder::MultiSource<AnyType> any_input = graph.In("ANY_INPUT");
builder::MultiSource<int> int_input = any_input.Cast<int>();
builder::MultiDestination<AnyType> any_output = graph.Out("ANY_OUTPUT");
builder::MultiDestination<int> int_output = any_output.Cast<int>();
Graph graph;
MultiSource<AnyType> any_input = graph.In("ANY_INPUT");
MultiSource<int> int_input = any_input.Cast<int>();
MultiDestination<AnyType> any_output = graph.Out("ANY_OUTPUT");
MultiDestination<int> int_output = any_output.Cast<int>();
int_input >> int_output;
CalculatorGraphConfig expected =
@ -462,11 +460,11 @@ TEST(BuilderTest, MultiPortIsCastToMultiPort) {
}
TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) {
builder::Graph graph;
builder::MultiSource<AnyType> any_multi_input = graph.In("ANY_INPUT");
builder::Source<AnyType> any_input = any_multi_input;
builder::MultiDestination<AnyType> any_multi_output = graph.Out("ANY_OUTPUT");
builder::Destination<AnyType> any_output = any_multi_output;
Graph graph;
MultiSource<AnyType> any_multi_input = graph.In("ANY_INPUT");
Source<AnyType> any_input = any_multi_input;
MultiDestination<AnyType> any_multi_output = graph.Out("ANY_OUTPUT");
Destination<AnyType> any_output = any_multi_output;
any_input >> any_output;
CalculatorGraphConfig expected =
@ -478,11 +476,11 @@ TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) {
}
TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) {
builder::Graph graph;
builder::Source<int> int_input = graph.In("INT_INPUT").Cast<int>();
builder::Source<AnyType> any_input = graph.In("ANY_OUTPUT");
builder::Destination<int> int_output = graph.Out("INT_OUTPUT").Cast<int>();
builder::Destination<AnyType> any_output = graph.Out("ANY_OUTPUT");
Graph graph;
Source<int> int_input = graph.In("INT_INPUT").Cast<int>();
Source<AnyType> any_input = graph.In("ANY_OUTPUT");
Destination<int> int_output = graph.Out("INT_OUTPUT").Cast<int>();
Destination<AnyType> any_output = graph.Out("ANY_OUTPUT");
int_input >> int_output;
any_input >> any_output;
@ -496,6 +494,5 @@ TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) {
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
} // namespace test
} // namespace api2
} // namespace mediapipe
} // namespace
} // namespace mediapipe::api2::builder

View File

@ -557,8 +557,8 @@ class OutputSidePacketAccess {
if (output_) output_->Set(ToOldPacket(std::move(packet)));
}
void Set(const T& payload) { Set(MakePacket<T>(payload)); }
void Set(T&& payload) { Set(MakePacket<T>(std::move(payload))); }
void Set(const T& payload) { Set(api2::MakePacket<T>(payload)); }
void Set(T&& payload) { Set(api2::MakePacket<T>(std::move(payload))); }
private:
OutputSidePacketAccess(OutputSidePacket* output) : output_(output) {}

View File

@ -20,9 +20,14 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
licenses(["notice"])
package(default_visibility = [
"//mediapipe:__subpackages__",
])
package_group(
name = "mediapipe_internal",
packages = [
"//mediapipe/...",
],
)
package(default_visibility = ["mediapipe_internal"])
bzl_library(
name = "expand_template_bzl",
@ -214,6 +219,9 @@ cc_library(
name = "registration",
srcs = ["registration.cc"],
hdrs = ["registration.h"],
visibility = [
"mediapipe_internal",
],
deps = [
":registration_token",
"//mediapipe/framework/port:logging",

View File

@ -26,7 +26,7 @@ licenses(["notice"])
mediapipe_proto_library(
name = "detection_proto",
srcs = ["detection.proto"],
deps = ["//mediapipe/framework/formats:location_data_proto"],
deps = [":location_data_proto"],
)
mediapipe_register_type(
@ -38,7 +38,7 @@ mediapipe_register_type(
"::std::vector<::mediapipe::Detection>",
"::std::vector<::mediapipe::DetectionList>",
],
deps = ["//mediapipe/framework/formats:detection_cc_proto"],
deps = [":detection_cc_proto"],
)
mediapipe_proto_library(
@ -105,8 +105,8 @@ cc_library(
srcs = ["matrix.cc"],
hdrs = ["matrix.h"],
deps = [
":matrix_data_cc_proto",
"//mediapipe/framework:port",
"//mediapipe/framework/formats:matrix_data_cc_proto",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
@ -142,7 +142,7 @@ cc_library(
srcs = ["image_frame.cc"],
hdrs = ["image_frame.h"],
deps = [
"//mediapipe/framework/formats:image_format_cc_proto",
":image_format_cc_proto",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
@ -166,8 +166,8 @@ cc_library(
srcs = ["image_frame_opencv.cc"],
hdrs = ["image_frame_opencv.h"],
deps = [
":image_format_cc_proto",
":image_frame",
"//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/port:opencv_core",
],
)
@ -194,7 +194,7 @@ cc_library(
deps = [
"@com_google_protobuf//:protobuf",
"//mediapipe/framework/formats/annotation:locus_cc_proto",
"//mediapipe/framework/formats:location_data_cc_proto",
":location_data_cc_proto",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@ -245,7 +245,7 @@ cc_library(
name = "video_stream_header",
hdrs = ["video_stream_header.h"],
deps = [
"//mediapipe/framework/formats:image_format_cc_proto",
":image_format_cc_proto",
],
)
@ -263,9 +263,9 @@ cc_test(
size = "small",
srcs = ["image_frame_opencv_test.cc"],
deps = [
":image_format_cc_proto",
":image_frame",
":image_frame_opencv",
"//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
@ -324,8 +324,8 @@ cc_library(
"//conditions:default": [],
}),
deps = [
"//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/formats:image_frame",
":image_format_cc_proto",
":image_frame",
"@com_google_absl//absl/synchronization",
"//mediapipe/framework:port",
"//mediapipe/framework:type_map",
@ -354,7 +354,7 @@ cc_library(
hdrs = ["image_multi_pool.h"],
deps = [
":image",
"//mediapipe/framework/formats:image_frame_pool",
":image_frame_pool",
"//mediapipe/framework:port",
"//mediapipe/framework/port:logging",
"@com_google_absl//absl/memory",
@ -390,7 +390,7 @@ cc_library(
],
deps = [
":image",
"//mediapipe/framework/formats:image_format_cc_proto",
":image_format_cc_proto",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:statusor",
@ -428,7 +428,10 @@ cc_library(
"tensor.cc",
"tensor_ahwb.cc",
],
hdrs = ["tensor.h"],
hdrs = [
"tensor.h",
"tensor_internal.h",
],
copts = select({
"//mediapipe:apple": [
"-x objective-c++",
@ -452,6 +455,7 @@ cc_library(
],
}),
deps = [
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
"//mediapipe/framework:port",

View File

@ -38,11 +38,11 @@ cc_library(
srcs = ["optical_flow_field.cc"],
hdrs = ["optical_flow_field.h"],
deps = [
":optical_flow_field_data_cc_proto",
"//mediapipe/framework:type_map",
"//mediapipe/framework/deps:mathutil",
"//mediapipe/framework/formats:location",
"//mediapipe/framework/formats:location_opencv",
"//mediapipe/framework/formats/motion:optical_flow_field_data_cc_proto",
"//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",

View File

@ -246,10 +246,10 @@ Tensor::OpenGlTexture2dView::GetLayoutDimensions(const Tensor::Shape& shape,
return Tensor::OpenGlTexture2dView::Layout::kAligned;
}
}
// The best performance of a compute shader can be achived with textures'
// The best performance of a compute shader can be achieved with textures'
// width multiple of 256. Making minimum fixed width of 256 waste memory for
// small tensors. The optimal balance memory-vs-performance is power of 2.
// The texture width and height are choosen to be closer to square.
// The texture width and height are chosen to be closer to square.
float power = std::log2(std::sqrt(static_cast<float>(num_pixels)));
w = 1 << static_cast<int>(power);
int h = (num_pixels + w - 1) / w;
@ -326,7 +326,7 @@ Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const {
auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_));
AllocateOpenGlBuffer();
if (!(valid_ & kValidOpenGlBuffer)) {
// If the call succeds then AHWB -> SSBO are synchronized so any usage of
// If the call succeeds then AHWB -> SSBO are synchronized so any usage of
// the SSBO is correct after this call.
if (!InsertAhwbToSsboFence()) {
glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_);
@ -348,8 +348,10 @@ Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const {
};
}
Tensor::OpenGlBufferView Tensor::GetOpenGlBufferWriteView() const {
Tensor::OpenGlBufferView Tensor::GetOpenGlBufferWriteView(
uint64_t source_location_hash) const {
auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_));
TrackAhwbUsage(source_location_hash);
AllocateOpenGlBuffer();
valid_ = kValidOpenGlBuffer;
return {opengl_buffer_, std::move(lock), nullptr};
@ -385,6 +387,7 @@ void Tensor::Move(Tensor* src) {
src->element_type_ = ElementType::kNone; // Mark as invalidated.
cpu_buffer_ = src->cpu_buffer_;
src->cpu_buffer_ = nullptr;
ahwb_tracking_key_ = src->ahwb_tracking_key_;
#if MEDIAPIPE_METAL_ENABLED
device_ = src->device_;
src->device_ = nil;
@ -589,8 +592,10 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const {
return {cpu_buffer_, std::move(lock)};
}
Tensor::CpuWriteView Tensor::GetCpuWriteView() const {
Tensor::CpuWriteView Tensor::GetCpuWriteView(
uint64_t source_location_hash) const {
auto lock = absl::make_unique<absl::MutexLock>(&view_mutex_);
TrackAhwbUsage(source_location_hash);
AllocateCpuBuffer();
valid_ = kValidCpu;
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
@ -620,24 +625,4 @@ void Tensor::AllocateCpuBuffer() const {
}
}
void Tensor::SetPreferredStorageType(StorageType type) {
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
if (__builtin_available(android 26, *)) {
use_ahwb_ = type == StorageType::kAhwb;
VLOG(4) << "Tensor: use of AHardwareBuffer is "
<< (use_ahwb_ ? "allowed" : "not allowed");
}
#else
VLOG(4) << "Tensor: use of AHardwareBuffer is not allowed";
#endif // MEDIAPIPE_TENSOR_USE_AHWB
}
Tensor::StorageType Tensor::GetPreferredStorageType() {
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
return use_ahwb_ ? StorageType::kAhwb : StorageType::kDefault;
#else
return StorageType::kDefault;
#endif // MEDIAPIPE_TENSOR_USE_AHWB
}
} // namespace mediapipe

View File

@ -24,8 +24,9 @@
#include <utility>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/container/flat_hash_set.h"
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/formats/tensor_internal.h"
#include "mediapipe/framework/port.h"
#if MEDIAPIPE_METAL_ENABLED
@ -48,6 +49,22 @@
#include "mediapipe/gpu/gl_context.h"
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
#if defined __has_builtin
#if __has_builtin(__builtin_LINE)
#define builtin_LINE __builtin_LINE
#endif
#if __has_builtin(__builtin_FILE)
#define builtin_FILE __builtin_FILE
#endif
#endif
#ifndef builtin_LINE
#define builtin_LINE() 0
#endif
#ifndef builtin_FILE
#define builtin_FILE() ""
#endif
namespace mediapipe {
// Tensor is a container of multi-dimensional data that supports sharing the
@ -65,7 +82,7 @@ namespace mediapipe {
// GLuint buffer = view.buffer();
// Then the buffer can be bound to the GPU command buffer.
// ...binding the buffer to the command buffer...
// ...commiting command buffer and releasing the view...
// ...committing command buffer and releasing the view...
//
// The following request for the CPU view will be blocked until the GPU view is
// released and the GPU task is finished.
@ -161,7 +178,9 @@ class Tensor {
using CpuReadView = CpuView<const void>;
CpuReadView GetCpuReadView() const;
using CpuWriteView = CpuView<void>;
CpuWriteView GetCpuWriteView() const;
CpuWriteView GetCpuWriteView(
uint64_t source_location_hash =
tensor_internal::FnvHash64(builtin_FILE(), builtin_LINE())) const;
#if MEDIAPIPE_METAL_ENABLED
// TODO: id<MTLBuffer> vs. MtlBufferView.
@ -305,7 +324,9 @@ class Tensor {
// A valid OpenGL context must be bound to the calling thread due to possible
// GPU resource allocation.
OpenGlBufferView GetOpenGlBufferReadView() const;
OpenGlBufferView GetOpenGlBufferWriteView() const;
OpenGlBufferView GetOpenGlBufferWriteView(
uint64_t source_location_hash =
tensor_internal::FnvHash64(builtin_FILE(), builtin_LINE())) const;
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
const Shape& shape() const { return shape_; }
@ -408,9 +429,13 @@ class Tensor {
mutable std::function<void()> release_callback_;
bool AllocateAHardwareBuffer(int size_alignment = 0) const;
void CreateEglSyncAndFd() const;
// Use Ahwb for other views: OpenGL / CPU buffer.
#endif // MEDIAPIPE_TENSOR_USE_AHWB
static inline bool use_ahwb_ = false;
// Use Ahwb for other views: OpenGL / CPU buffer.
mutable bool use_ahwb_ = false;
mutable uint64_t ahwb_tracking_key_ = 0;
// TODO: Tracks all unique tensors. Can grow to a large number. LRU
// can be more predicted.
static inline absl::flat_hash_set<uint64_t> ahwb_usage_track_;
// Expects the target SSBO to be already bound.
bool AllocateAhwbMapToSsbo() const;
bool InsertAhwbToSsboFence() const;
@ -419,6 +444,8 @@ class Tensor {
void* MapAhwbToCpuRead() const;
void* MapAhwbToCpuWrite() const;
void MoveCpuOrSsboToAhwb() const;
// Set current tracking key, set "use ahwb" if the key is already marked.
void TrackAhwbUsage(uint64_t key) const;
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
mutable std::shared_ptr<mediapipe::GlContext> gl_context_;

View File

@ -212,9 +212,6 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const {
CHECK(!(valid_ & kValidOpenGlTexture2d))
<< "Tensor conversion between OpenGL texture and AHardwareBuffer is not "
"supported.";
CHECK(ahwb_ || !(valid_ & kValidOpenGlBuffer))
<< "Interoperability bettween OpenGL buffer and AHardwareBuffer is not "
"supported on target system.";
bool transfer = !ahwb_;
CHECK(AllocateAHardwareBuffer())
<< "AHardwareBuffer is not supported on the target system.";
@ -268,6 +265,10 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView(
}
bool Tensor::AllocateAHardwareBuffer(int size_alignment) const {
// Mark current tracking key as Ahwb-use.
ahwb_usage_track_.insert(ahwb_tracking_key_);
use_ahwb_ = true;
if (__builtin_available(android 26, *)) {
if (ahwb_ == nullptr) {
AHardwareBuffer_Desc desc = {};
@ -315,7 +316,13 @@ void Tensor::MoveCpuOrSsboToAhwb() const {
ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, -1, nullptr, &dest);
CHECK(error == 0) << "AHardwareBuffer_lock " << error;
}
if (valid_ & kValidOpenGlBuffer) {
if (valid_ & kValidCpu) {
std::memcpy(dest, cpu_buffer_, bytes());
// Free CPU memory because next time AHWB is mapped instead.
free(cpu_buffer_);
cpu_buffer_ = nullptr;
valid_ &= ~kValidCpu;
} else if (valid_ & kValidOpenGlBuffer) {
gl_context_->Run([this, dest]() {
glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_);
const void* src = glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(),
@ -326,11 +333,9 @@ void Tensor::MoveCpuOrSsboToAhwb() const {
});
opengl_buffer_ = GL_INVALID_INDEX;
gl_context_ = nullptr;
} else if (valid_ & kValidCpu) {
std::memcpy(dest, cpu_buffer_, bytes());
// Free CPU memory because next time AHWB is mapped instead.
free(cpu_buffer_);
cpu_buffer_ = nullptr;
// Reset OpenGL Buffer validness. The OpenGL buffer will be allocated on top
// of the Ahwb at the next request to the OpenGlBufferView.
valid_ &= ~kValidOpenGlBuffer;
} else {
LOG(FATAL) << "Can't convert tensor with mask " << valid_ << " into AHWB.";
}
@ -446,6 +451,16 @@ void* Tensor::MapAhwbToCpuWrite() const {
return nullptr;
}
void Tensor::TrackAhwbUsage(uint64_t source_location_hash) const {
if (ahwb_tracking_key_ == 0) {
ahwb_tracking_key_ = source_location_hash;
for (int dim : shape_.dims) {
ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim);
}
}
use_ahwb_ = ahwb_usage_track_.contains(ahwb_tracking_key_);
}
#else // MEDIAPIPE_TENSOR_USE_AHWB
bool Tensor::AllocateAhwbMapToSsbo() const { return false; }
@ -454,6 +469,7 @@ void Tensor::MoveAhwbStuff(Tensor* src) {}
void Tensor::ReleaseAhwbStuff() {}
void* Tensor::MapAhwbToCpuRead() const { return nullptr; }
void* Tensor::MapAhwbToCpuWrite() const { return nullptr; }
void Tensor::TrackAhwbUsage(uint64_t key) const {}
#endif // MEDIAPIPE_TENSOR_USE_AHWB

View File

@ -152,6 +152,36 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) {
{
auto view = tensor.GetAHardwareBufferReadView();
EXPECT_NE(view.handle(), nullptr);
view.SetReadingFinishedFunc([](bool) { return true; });
}
auto ptr = tensor.GetCpuReadView().buffer<float>();
EXPECT_NE(ptr, nullptr);
std::vector<float> reference;
reference.resize(num_elements);
for (int i = 0; i < num_elements; i++) {
reference[i] = static_cast<float>(i) / 10.0f;
}
EXPECT_THAT(absl::Span<const float>(ptr, num_elements),
testing::Pointwise(testing::FloatEq(), reference));
}
TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) {
// Request the GPU view to get the ssbo allocated internally.
// Request Ahwb view then to transform the storage into Ahwb.
Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault);
constexpr size_t num_elements = 20;
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
RunInGlContext([&tensor] {
auto ssbo_view = tensor.GetOpenGlBufferWriteView();
auto ssbo_name = ssbo_view.name();
EXPECT_GT(ssbo_name, 0);
FillGpuBuffer(ssbo_name, tensor.shape().num_elements(),
tensor.element_type());
});
{
auto view = tensor.GetAHardwareBufferReadView();
EXPECT_NE(view.handle(), nullptr);
view.SetReadingFinishedFunc([](bool) { return true; });
}
auto ptr = tensor.GetCpuReadView().buffer<float>();
EXPECT_NE(ptr, nullptr);

View File

@ -1,71 +0,0 @@
#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_
#define MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_
#if !defined(MEDIAPIPE_NO_JNI) && \
(__ANDROID_API__ >= 26 || \
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
#include <android/hardware_buffer.h>
#include <cstdint>
#include "mediapipe/framework/formats/tensor_buffer.h"
#include "mediapipe/framework/formats/tensor_internal.h"
#include "mediapipe/framework/formats/tensor_v2.h"
namespace mediapipe {
// Supports:
// - float 16 and 32 bits
// - signed / unsigned integers 8,16,32 bits
class TensorHardwareBufferView;
struct TensorHardwareBufferViewDescriptor : public Tensor::ViewDescriptor {
using ViewT = TensorHardwareBufferView;
TensorBufferDescriptor buffer;
};
class TensorHardwareBufferView : public Tensor::View {
public:
TENSOR_UNIQUE_VIEW_TYPE_ID();
~TensorHardwareBufferView() = default;
const TensorHardwareBufferViewDescriptor& descriptor() const override {
return descriptor_;
}
AHardwareBuffer* handle() const { return ahwb_handle_; }
protected:
TensorHardwareBufferView(int access_capability, Tensor::View::Access access,
Tensor::View::State state,
const TensorHardwareBufferViewDescriptor& desc,
AHardwareBuffer* ahwb_handle)
: Tensor::View(kId, access_capability, access, state),
descriptor_(desc),
ahwb_handle_(ahwb_handle) {}
private:
bool MatchDescriptor(
uint64_t view_type_id,
const Tensor::ViewDescriptor& base_descriptor) const override {
if (!Tensor::View::MatchDescriptor(view_type_id, base_descriptor))
return false;
auto descriptor =
static_cast<const TensorHardwareBufferViewDescriptor&>(base_descriptor);
return descriptor.buffer.format == descriptor_.buffer.format &&
descriptor.buffer.size_alignment <=
descriptor_.buffer.size_alignment &&
descriptor_.buffer.size_alignment %
descriptor.buffer.size_alignment ==
0;
}
const TensorHardwareBufferViewDescriptor& descriptor_;
AHardwareBuffer* ahwb_handle_ = nullptr;
};
} // namespace mediapipe
#endif // !defined(MEDIAPIPE_NO_JNI) && \
(__ANDROID_API__ >= 26 || \
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
#endif // MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_

View File

@ -1,216 +0,0 @@
#if !defined(MEDIAPIPE_NO_JNI) && \
(__ANDROID_API__ >= 26 || \
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
#include <cstdint>
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "mediapipe/framework/formats/tensor_backend.h"
#include "mediapipe/framework/formats/tensor_cpu_buffer.h"
#include "mediapipe/framework/formats/tensor_hardware_buffer.h"
#include "mediapipe/framework/formats/tensor_v2.h"
#include "util/task/status_macros.h"
namespace mediapipe {
namespace {
class TensorCpuViewImpl : public TensorCpuView {
public:
TensorCpuViewImpl(int access_capabilities, Tensor::View::Access access,
Tensor::View::State state,
const TensorCpuViewDescriptor& descriptor, void* pointer,
AHardwareBuffer* ahwb_handle)
: TensorCpuView(access_capabilities, access, state, descriptor, pointer),
ahwb_handle_(ahwb_handle) {}
~TensorCpuViewImpl() {
// If handle_ is null then this view is constructed in GetViews with no
// access.
if (ahwb_handle_) {
if (__builtin_available(android 26, *)) {
AHardwareBuffer_unlock(ahwb_handle_, nullptr);
}
}
}
private:
AHardwareBuffer* ahwb_handle_;
};
class TensorHardwareBufferViewImpl : public TensorHardwareBufferView {
public:
TensorHardwareBufferViewImpl(
int access_capability, Tensor::View::Access access,
Tensor::View::State state,
const TensorHardwareBufferViewDescriptor& descriptor,
AHardwareBuffer* handle)
: TensorHardwareBufferView(access_capability, access, state, descriptor,
handle) {}
~TensorHardwareBufferViewImpl() = default;
};
class HardwareBufferCpuStorage : public TensorStorage {
public:
~HardwareBufferCpuStorage() {
if (!ahwb_handle_) return;
if (__builtin_available(android 26, *)) {
AHardwareBuffer_release(ahwb_handle_);
}
}
static absl::Status CanProvide(
int access_capability, const Tensor::Shape& shape, uint64_t view_type_id,
const Tensor::ViewDescriptor& base_descriptor) {
// TODO: use AHardwareBuffer_isSupported for API >= 29.
static const bool is_ahwb_supported = [] {
if (__builtin_available(android 26, *)) {
AHardwareBuffer_Desc desc = {};
// Aligned to the largest possible virtual memory page size.
constexpr uint32_t kPageSize = 16384;
desc.width = kPageSize;
desc.height = 1;
desc.layers = 1;
desc.format = AHARDWAREBUFFER_FORMAT_BLOB;
desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN |
AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN;
AHardwareBuffer* handle;
if (AHardwareBuffer_allocate(&desc, &handle) != 0) return false;
AHardwareBuffer_release(handle);
return true;
}
return false;
}();
if (!is_ahwb_supported) {
return absl::UnavailableError(
"AHardwareBuffer is not supported on the platform.");
}
if (view_type_id != TensorCpuView::kId &&
view_type_id != TensorHardwareBufferView::kId) {
return absl::InvalidArgumentError(
"A view type is not supported by this storage.");
}
return absl::OkStatus();
}
std::vector<std::unique_ptr<Tensor::View>> GetViews(uint64_t latest_version) {
std::vector<std::unique_ptr<Tensor::View>> result;
auto update_state = latest_version == version_
? Tensor::View::State::kUpToDate
: Tensor::View::State::kOutdated;
if (ahwb_handle_) {
result.push_back(
std::unique_ptr<Tensor::View>(new TensorHardwareBufferViewImpl(
kAccessCapability, Tensor::View::Access::kNoAccess, update_state,
hw_descriptor_, ahwb_handle_)));
result.push_back(std::unique_ptr<Tensor::View>(new TensorCpuViewImpl(
kAccessCapability, Tensor::View::Access::kNoAccess, update_state,
cpu_descriptor_, nullptr, nullptr)));
}
return result;
}
absl::StatusOr<std::unique_ptr<Tensor::View>> GetView(
Tensor::View::Access access, const Tensor::Shape& shape,
uint64_t latest_version, uint64_t view_type_id,
const Tensor::ViewDescriptor& base_descriptor, int access_capability) {
MP_RETURN_IF_ERROR(
CanProvide(access_capability, shape, view_type_id, base_descriptor));
const auto& buffer_descriptor =
view_type_id == TensorHardwareBufferView::kId
? static_cast<const TensorHardwareBufferViewDescriptor&>(
base_descriptor)
.buffer
: static_cast<const TensorCpuViewDescriptor&>(base_descriptor)
.buffer;
if (!ahwb_handle_) {
if (__builtin_available(android 26, *)) {
AHardwareBuffer_Desc desc = {};
desc.width = TensorBufferSize(buffer_descriptor, shape);
desc.height = 1;
desc.layers = 1;
desc.format = AHARDWAREBUFFER_FORMAT_BLOB;
// TODO: Use access capabilities to set hints.
desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN |
AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN;
auto error = AHardwareBuffer_allocate(&desc, &ahwb_handle_);
if (error != 0) {
return absl::UnknownError(
absl::StrCat("Error allocating hardware buffer: ", error));
}
// Fill all possible views to provide it as proto views.
hw_descriptor_.buffer = buffer_descriptor;
cpu_descriptor_.buffer = buffer_descriptor;
}
}
if (buffer_descriptor.format != hw_descriptor_.buffer.format ||
buffer_descriptor.size_alignment >
hw_descriptor_.buffer.size_alignment ||
hw_descriptor_.buffer.size_alignment %
buffer_descriptor.size_alignment >
0) {
return absl::AlreadyExistsError(
"A view with different params is already allocated with this "
"storage");
}
absl::StatusOr<std::unique_ptr<Tensor::View>> result;
if (view_type_id == TensorHardwareBufferView::kId) {
result = GetAhwbView(access, shape, base_descriptor);
} else {
result = GetCpuView(access, shape, base_descriptor);
}
if (result.ok()) version_ = latest_version;
return result;
}
private:
absl::StatusOr<std::unique_ptr<Tensor::View>> GetAhwbView(
Tensor::View::Access access, const Tensor::Shape& shape,
const Tensor::ViewDescriptor& base_descriptor) {
return std::unique_ptr<Tensor::View>(new TensorHardwareBufferViewImpl(
kAccessCapability, access, Tensor::View::State::kUpToDate,
hw_descriptor_, ahwb_handle_));
}
absl::StatusOr<std::unique_ptr<Tensor::View>> GetCpuView(
Tensor::View::Access access, const Tensor::Shape& shape,
const Tensor::ViewDescriptor& base_descriptor) {
void* pointer = nullptr;
if (__builtin_available(android 26, *)) {
int error =
AHardwareBuffer_lock(ahwb_handle_,
access == Tensor::View::Access::kWriteOnly
? AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN
: AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN,
-1, nullptr, &pointer);
if (error != 0) {
return absl::UnknownError(
absl::StrCat("Error locking hardware buffer: ", error));
}
}
return std::unique_ptr<Tensor::View>(
new TensorCpuViewImpl(access == Tensor::View::Access::kWriteOnly
? Tensor::View::AccessCapability::kWrite
: Tensor::View::AccessCapability::kRead,
access, Tensor::View::State::kUpToDate,
cpu_descriptor_, pointer, ahwb_handle_));
}
static constexpr int kAccessCapability =
Tensor::View::AccessCapability::kRead |
Tensor::View::AccessCapability::kWrite;
TensorHardwareBufferViewDescriptor hw_descriptor_;
AHardwareBuffer* ahwb_handle_ = nullptr;
TensorCpuViewDescriptor cpu_descriptor_;
uint64_t version_ = 0;
};
TENSOR_REGISTER_STORAGE(HardwareBufferCpuStorage);
} // namespace
} // namespace mediapipe
#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 ||
// defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))

View File

@ -1,76 +0,0 @@
#if !defined(MEDIAPIPE_NO_JNI) && \
(__ANDROID_API__ >= 26 || \
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
#include <android/hardware_buffer.h>
#include <cstdint>
#include "mediapipe/framework/formats/tensor_cpu_buffer.h"
#include "mediapipe/framework/formats/tensor_hardware_buffer.h"
#include "mediapipe/framework/formats/tensor_v2.h"
#include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h"
namespace mediapipe {
namespace {
class TensorHardwareBufferTest : public ::testing::Test {
public:
TensorHardwareBufferTest() {}
~TensorHardwareBufferTest() override {}
};
TEST_F(TensorHardwareBufferTest, TestFloat32) {
Tensor tensor{Tensor::Shape({1})};
{
MP_ASSERT_OK_AND_ASSIGN(
auto view,
tensor.GetView<Tensor::View::Access::kWriteOnly>(
TensorHardwareBufferViewDescriptor{
.buffer = {.format =
TensorBufferDescriptor::Format::kFloat32}}));
EXPECT_NE(view->handle(), nullptr);
}
{
const auto& const_tensor = tensor;
MP_ASSERT_OK_AND_ASSIGN(
auto view,
const_tensor.GetView<Tensor::View::Access::kReadOnly>(
TensorCpuViewDescriptor{
.buffer = {.format =
TensorBufferDescriptor::Format::kFloat32}}));
EXPECT_NE(view->data<void>(), nullptr);
}
}
TEST_F(TensorHardwareBufferTest, TestInt8Padding) {
Tensor tensor{Tensor::Shape({1})};
{
MP_ASSERT_OK_AND_ASSIGN(
auto view,
tensor.GetView<Tensor::View::Access::kWriteOnly>(
TensorHardwareBufferViewDescriptor{
.buffer = {.format = TensorBufferDescriptor::Format::kInt8,
.size_alignment = 4}}));
EXPECT_NE(view->handle(), nullptr);
}
{
const auto& const_tensor = tensor;
MP_ASSERT_OK_AND_ASSIGN(
auto view,
const_tensor.GetView<Tensor::View::Access::kReadOnly>(
TensorCpuViewDescriptor{
.buffer = {.format = TensorBufferDescriptor::Format::kInt8}}));
EXPECT_NE(view->data<void>(), nullptr);
}
}
} // namespace
} // namespace mediapipe
#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 ||
// defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))

View File

@ -18,8 +18,6 @@
#include <cstdint>
#include <type_traits>
#include "mediapipe/framework/tool/type_util.h"
namespace mediapipe {
// Generates unique view id at compile-time using FILE and LINE.
@ -41,10 +39,12 @@ namespace tensor_internal {
// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
constexpr uint64_t kFnvPrime = 0x00000100000001B3;
constexpr uint64_t kFnvOffsetBias = 0xcbf29ce484222325;
constexpr uint64_t FnvHash64(const char* str, uint64_t hash = kFnvOffsetBias) {
return (str[0] == 0) ? hash : FnvHash64(str + 1, (hash ^ str[0]) * kFnvPrime);
constexpr uint64_t FnvHash64(uint64_t value1, uint64_t value2) {
return (value2 ^ value1) * kFnvPrime;
}
constexpr uint64_t FnvHash64(const char* str, uint64_t hash = kFnvOffsetBias) {
return (str[0] == 0) ? hash : FnvHash64(str + 1, FnvHash64(hash, str[0]));
}
template <typename... Ts>
struct TypeList {
static constexpr std::size_t size{sizeof...(Ts)};

View File

@ -88,8 +88,8 @@ cc_library(
srcs = ["default_input_stream_handler.cc"],
hdrs = ["default_input_stream_handler.h"],
deps = [
":default_input_stream_handler_cc_proto",
"//mediapipe/framework:input_stream_handler",
"//mediapipe/framework/stream_handler:default_input_stream_handler_cc_proto",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
@ -110,8 +110,8 @@ cc_library(
srcs = ["fixed_size_input_stream_handler.cc"],
deps = [
":default_input_stream_handler",
":fixed_size_input_stream_handler_cc_proto",
"//mediapipe/framework:input_stream_handler",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto",
],
alwayslink = 1,
)
@ -159,13 +159,13 @@ cc_library(
name = "sync_set_input_stream_handler",
srcs = ["sync_set_input_stream_handler.cc"],
deps = [
":sync_set_input_stream_handler_cc_proto",
"//mediapipe/framework:collection",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework:input_stream_handler",
"//mediapipe/framework:mediapipe_options_cc_proto",
"//mediapipe/framework:packet_set",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto",
"//mediapipe/framework/tool:tag_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
@ -177,10 +177,10 @@ cc_library(
name = "timestamp_align_input_stream_handler",
srcs = ["timestamp_align_input_stream_handler.cc"],
deps = [
":timestamp_align_input_stream_handler_cc_proto",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework:input_stream_handler",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/stream_handler:timestamp_align_input_stream_handler_cc_proto",
"//mediapipe/framework/tool:validate_name",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
@ -243,6 +243,7 @@ cc_test(
srcs = ["set_input_stream_handler_test.cc"],
deps = [
":fixed_size_input_stream_handler",
":fixed_size_input_stream_handler_cc_proto",
":mux_input_stream_handler",
"//mediapipe/calculators/core:mux_calculator",
"//mediapipe/calculators/core:pass_through_calculator",
@ -251,7 +252,6 @@ cc_test(
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto",
],
)
@ -272,13 +272,13 @@ cc_test(
srcs = ["fixed_size_input_stream_handler_test.cc"],
deps = [
":fixed_size_input_stream_handler",
":fixed_size_input_stream_handler_cc_proto",
"//mediapipe/calculators/core:counting_source_calculator",
"//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/synchronization",
],
@ -289,11 +289,11 @@ cc_test(
srcs = ["sync_set_input_stream_handler_test.cc"],
deps = [
":sync_set_input_stream_handler",
":sync_set_input_stream_handler_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:test_calculators",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",

View File

@ -299,6 +299,7 @@ mediapipe_cc_test(
requires_full_emulation = False,
deps = [
":node_chain_subgraph_cc_proto",
":node_chain_subgraph_options_lib",
":options_field_util",
":options_registry",
":options_syntax_util",
@ -313,7 +314,6 @@ mediapipe_cc_test(
"//mediapipe/framework/port:status",
"//mediapipe/framework/testdata:night_light_calculator_cc_proto",
"//mediapipe/framework/testdata:night_light_calculator_options_lib",
"//mediapipe/framework/tool:node_chain_subgraph_options_lib",
"//mediapipe/util:header_util",
"@com_google_absl//absl/strings",
],
@ -422,9 +422,9 @@ cc_library(
srcs = ["source.cc"],
visibility = ["//visibility:public"],
deps = [
":source_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:source_cc_proto",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
],
@ -485,13 +485,13 @@ cc_library(
hdrs = ["template_expander.h"],
visibility = ["//visibility:public"],
deps = [
":calculator_graph_template_cc_proto",
":proto_util_lite",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:numbers",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:calculator_graph_template_cc_proto",
"@com_google_absl//absl/strings",
],
)
@ -506,6 +506,7 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
":calculator_graph_template_cc_proto",
":proto_util_lite",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/deps:proto_descriptor_cc_proto",
@ -515,7 +516,6 @@ cc_library(
"//mediapipe/framework/port:map_util",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:calculator_graph_template_cc_proto",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@ -661,8 +661,8 @@ cc_library(
hdrs = ["simulation_clock_executor.h"],
visibility = ["//visibility:public"],
deps = [
":simulation_clock",
"//mediapipe/framework:thread_pool_executor",
"//mediapipe/framework/tool:simulation_clock",
],
)
@ -789,10 +789,10 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":name_util",
":switch_container_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:switch_container_cc_proto",
],
)
@ -805,6 +805,7 @@ cc_library(
deps = [
":container_util",
":options_util",
":switch_container_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework/deps:mathutil",
@ -814,7 +815,6 @@ cc_library(
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
"//mediapipe/framework/tool:switch_container_cc_proto",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
@ -841,6 +841,7 @@ cc_library(
],
deps = [
":container_util",
":switch_container_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework:input_stream_shard",
@ -850,7 +851,6 @@ cc_library(
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
"//mediapipe/framework/tool:switch_container_cc_proto",
],
alwayslink = 1,
)
@ -893,6 +893,7 @@ cc_library(
":container_util",
":name_util",
":subgraph_expansion",
":switch_container_cc_proto",
":switch_demux_calculator",
":switch_mux_calculator",
"//mediapipe/calculators/core:packet_sequencer_calculator",
@ -904,7 +905,6 @@ cc_library(
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:switch_container_cc_proto",
"@com_google_absl//absl/strings",
],
alwayslink = 1,

View File

@ -564,6 +564,7 @@ cc_library(
name = "gpu_shared_data_internal_stub",
visibility = ["//visibility:private"],
deps = [
":gl_context_options_cc_proto",
":graph_support",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_node",
@ -571,7 +572,6 @@ cc_library(
"//mediapipe/framework:port",
"//mediapipe/framework/deps:no_destructor",
"//mediapipe/framework/port:ret_check",
"//mediapipe/gpu:gl_context_options_cc_proto",
],
)
@ -592,7 +592,7 @@ cc_library(
}),
visibility = ["//visibility:private"],
deps = [
"//mediapipe/gpu:gl_context_options_cc_proto",
":gl_context_options_cc_proto",
":graph_support",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:executor",
@ -833,10 +833,10 @@ cc_library(
deps = [
":gl_base",
":gl_simple_shaders",
":scale_mode_cc_proto",
":shader_util",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/gpu:scale_mode_cc_proto",
],
)
@ -907,8 +907,8 @@ proto_library(
srcs = ["gl_scaler_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
":scale_mode_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/gpu:scale_mode_proto",
],
)
@ -930,6 +930,7 @@ cc_library(
deps = [
":gl_calculator_helper",
":gl_quad_renderer",
":gl_scaler_calculator_cc_proto",
":gl_simple_shaders",
":shader_util",
"//mediapipe/framework:calculator_framework",
@ -937,7 +938,6 @@ cc_library(
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:options_util",
"//mediapipe/gpu:gl_scaler_calculator_cc_proto",
],
alwayslink = 1,
)
@ -950,13 +950,13 @@ cc_library(
":egl_surface_holder",
":gl_calculator_helper",
":gl_quad_renderer",
":gl_surface_sink_calculator_cc_proto",
":gpu_buffer",
":shader_util",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/gpu:gl_surface_sink_calculator_cc_proto",
"@com_google_absl//absl/synchronization",
],
alwayslink = 1,
@ -966,8 +966,8 @@ proto_library(
name = "gl_surface_sink_calculator_proto",
srcs = ["gl_surface_sink_calculator.proto"],
deps = [
":scale_mode_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/gpu:scale_mode_proto",
],
)

View File

@ -15,10 +15,13 @@
package com.google.mediapipe.framework;
import android.graphics.Bitmap;
import android.graphics.PixelFormat;
import android.media.Image;
import com.google.mediapipe.framework.image.BitmapExtractor;
import com.google.mediapipe.framework.image.ByteBufferExtractor;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.framework.image.MPImageProperties;
import com.google.mediapipe.framework.image.MediaImageExtractor;
import java.nio.ByteBuffer;
// TODO: use Preconditions in this file.
@ -97,7 +100,17 @@ public class AndroidPacketCreator extends PacketCreator {
}
return Packet.create(nativeCreateRgbaImage(mediapipeGraph.getNativeHandle(), bitmap));
}
if (properties.getStorageType() == MPImage.STORAGE_TYPE_MEDIA_IMAGE) {
Image mediaImage = MediaImageExtractor.extract(image);
if (mediaImage.getFormat() != PixelFormat.RGBA_8888) {
throw new UnsupportedOperationException("Android media image must use RGBA_8888 config.");
}
return createImage(
mediaImage.getPlanes()[0].getBuffer(),
mediaImage.getWidth(),
mediaImage.getHeight(),
/* numChannels= */ 4);
}
// Unsupported type.
throw new UnsupportedOperationException(
"Unsupported Image container type: " + properties.getStorageType());

View File

@ -14,6 +14,10 @@
package com.google.mediapipe.framework;
import com.google.common.flogger.FluentLogger;
import java.util.HashSet;
import java.util.Set;
/**
* A {@link TextureFrame} that represents a texture produced by MediaPipe.
*
@ -21,6 +25,7 @@ package com.google.mediapipe.framework;
* method.
*/
public class GraphTextureFrame implements TextureFrame {
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
private long nativeBufferHandle;
// We cache these to be able to get them without a JNI call.
private int textureName;
@ -30,6 +35,8 @@ public class GraphTextureFrame implements TextureFrame {
// True when created with PacketGetter.getTextureFrameDeferredSync(). This will result in gpuWait
// when calling getTextureName().
private final boolean deferredSync;
private final Set<Long> activeConsumerContextHandleSet = new HashSet<>();
private int refCount = 1;
GraphTextureFrame(long nativeHandle, long timestamp) {
this(nativeHandle, timestamp, false);
@ -54,17 +61,19 @@ public class GraphTextureFrame implements TextureFrame {
* condition if release() is called after the if-check for nativeBufferHandle is already passed.
*/
@Override
public int getTextureName() {
public synchronized int getTextureName() {
// Return special texture id 0 if handle is 0 i.e. frame is already released.
if (nativeBufferHandle == 0) {
return 0;
}
// Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using
// PacketGetter.getTextureFrameDeferredSync().
if (deferredSync) {
// Note that, if a CPU wait has already been done, the sync point will have been
// cleared and this will turn into a no-op. See GlFenceSyncPoint::Wait.
nativeGpuWait(nativeBufferHandle);
if (activeConsumerContextHandleSet.add(nativeGetCurrentExternalContextHandle())) {
// Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using
// PacketGetter.getTextureFrameDeferredSync().
if (deferredSync) {
// Note that, if a CPU wait has already been done, the sync point will have been
// cleared and this will turn into a no-op. See GlFenceSyncPoint::Wait.
nativeGpuWait(nativeBufferHandle);
}
}
return textureName;
}
@ -86,15 +95,31 @@ public class GraphTextureFrame implements TextureFrame {
return timestamp;
}
@Override
public boolean supportsRetain() {
return true;
}
@Override
public synchronized void retain() {
// TODO: check that refCount is > 0 and handle is not 0.
refCount++;
}
/**
* Releases a reference to the underlying buffer.
*
* <p>The consumer calls this when it is done using the texture.
*/
@Override
public void release() {
GlSyncToken consumerToken =
new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle));
public synchronized void release() {
GlSyncToken consumerToken = null;
// Note that this remove should be moved to the other overload of release when b/68808951 is
// addressed.
if (activeConsumerContextHandleSet.remove(nativeGetCurrentExternalContextHandle())) {
consumerToken =
new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle));
}
release(consumerToken);
}
@ -108,18 +133,40 @@ public class GraphTextureFrame implements TextureFrame {
* currently cannot create a GlSyncToken, so they cannot call this method.
*/
@Override
public void release(GlSyncToken consumerSyncToken) {
if (nativeBufferHandle != 0) {
long token = consumerSyncToken == null ? 0 : consumerSyncToken.nativeToken();
nativeReleaseBuffer(nativeBufferHandle, token);
nativeBufferHandle = 0;
public synchronized void release(GlSyncToken consumerSyncToken) {
if (nativeBufferHandle == 0) {
if (consumerSyncToken != null) {
logger.atWarning().log("release with sync token, but handle is 0");
}
return;
}
if (consumerSyncToken != null) {
long token = consumerSyncToken.nativeToken();
nativeDidRead(nativeBufferHandle, token);
// We should remove the token's context from activeConsumerContextHandleSet here, but for now
// we do it in the release(void) overload.
consumerSyncToken.release();
}
refCount--;
if (refCount <= 0) {
nativeReleaseBuffer(nativeBufferHandle);
nativeBufferHandle = 0;
}
}
private native void nativeReleaseBuffer(long nativeHandle, long consumerSyncToken);
@Override
protected void finalize() throws Throwable {
if (refCount >= 0 || nativeBufferHandle != 0) {
logger.atWarning().log("release was not called before finalize");
}
if (!activeConsumerContextHandleSet.isEmpty()) {
logger.atWarning().log("active consumers did not release with sync before finalize");
}
}
private native void nativeReleaseBuffer(long nativeHandle);
private native int nativeGetTextureName(long nativeHandle);
private native int nativeGetWidth(long nativeHandle);
@ -128,4 +175,8 @@ public class GraphTextureFrame implements TextureFrame {
private native void nativeGpuWait(long nativeHandle);
private native long nativeCreateSyncTokenForCurrentExternalContext(long nativeHandle);
private native long nativeGetCurrentExternalContextHandle();
private native void nativeDidRead(long nativeHandle, long consumerSyncToken);
}

View File

@ -59,4 +59,18 @@ public interface TextureFrame extends TextureReleaseCallback {
*/
@Override
void release(GlSyncToken syncToken);
/**
* If this method returns true, this object supports the retain method, and can be used with
* multiple consumers. Call retain for each additional consumer beyond the first; each consumer
* should call release.
*/
default boolean supportsRetain() {
return false;
}
/** Increments the reference count. Only available with some implementations of TextureFrame. */
default void retain() {
throw new UnsupportedOperationException();
}
}

View File

@ -15,20 +15,16 @@
#include "mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h"
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_context.h"
#include "mediapipe/gpu/gl_texture_buffer.h"
#include "mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h"
using mediapipe::GlTextureBufferSharedPtr;
JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeReleaseBuffer)(
JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken) {
JNIEnv* env, jobject thiz, jlong nativeHandle) {
GlTextureBufferSharedPtr* buffer =
reinterpret_cast<GlTextureBufferSharedPtr*>(nativeHandle);
if (consumerSyncToken) {
mediapipe::GlSyncToken& token =
*reinterpret_cast<mediapipe::GlSyncToken*>(consumerSyncToken);
(*buffer)->DidRead(token);
}
delete buffer;
}
@ -84,3 +80,18 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD(
}
return reinterpret_cast<jlong>(token);
}
JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD(
nativeGetCurrentExternalContextHandle)(JNIEnv* env, jobject thiz) {
return reinterpret_cast<jlong>(
mediapipe::GlContext::GetCurrentNativeContext());
}
JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeDidRead)(
JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken) {
GlTextureBufferSharedPtr* buffer =
reinterpret_cast<GlTextureBufferSharedPtr*>(nativeHandle);
mediapipe::GlSyncToken& token =
*reinterpret_cast<mediapipe::GlSyncToken*>(consumerSyncToken);
(*buffer)->DidRead(token);
}

View File

@ -26,7 +26,7 @@ extern "C" {
// Releases a native mediapipe::GpuBuffer.
JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeReleaseBuffer)(
JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken);
JNIEnv* env, jobject thiz, jlong nativeHandle);
JNIEXPORT jint JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeGetTextureName)(
JNIEnv* env, jobject thiz, jlong nativeHandle);
@ -44,6 +44,12 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD(
nativeCreateSyncTokenForCurrentExternalContext)(JNIEnv* env, jobject thiz,
jlong nativeHandle);
JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeDidRead)(
JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken);
JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD(
nativeGetCurrentExternalContextHandle)(JNIEnv* env, jobject thiz);
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus

View File

@ -17,3 +17,6 @@ from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.vision import image_classifier
from mediapipe.model_maker.python.vision import gesture_recognizer
from mediapipe.model_maker.python.text import text_classifier
# Remove duplicated and non-public API
del python

View File

@ -29,3 +29,12 @@ BertModelOptions = model_options.BertModelOptions
SupportedModels = model_spec.SupportedModels
TextClassifier = text_classifier.TextClassifier
TextClassifierOptions = text_classifier_options.TextClassifierOptions
# Remove duplicated and non-public API
del hyperparameters
del dataset
del model_options
del model_spec
del preprocessor # pylint: disable=undefined-variable
del text_classifier
del text_classifier_options

View File

@ -33,7 +33,6 @@ from mediapipe.model_maker.python.text.text_classifier import preprocessor
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.metadata.metadata_writers import text_classifier as text_classifier_writer
from official.nlp import optimization
def _validate(options: text_classifier_options.TextClassifierOptions):
@ -417,8 +416,22 @@ class _BertClassifier(TextClassifier):
total_steps = self._hparams.steps_per_epoch * self._hparams.epochs
warmup_steps = int(total_steps * 0.1)
initial_lr = self._hparams.learning_rate
self._optimizer = optimization.create_optimizer(initial_lr, total_steps,
warmup_steps)
# Implements linear decay of the learning rate.
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=initial_lr,
decay_steps=total_steps,
end_learning_rate=0.0,
power=1.0)
if warmup_steps:
lr_schedule = model_util.WarmUp(
initial_learning_rate=initial_lr,
decay_schedule_fn=lr_schedule,
warmup_steps=warmup_steps)
self._optimizer = tf.keras.optimizers.experimental.AdamW(
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0)
self._optimizer.exclude_from_weight_decay(
var_names=["LayerNorm", "layer_norm", "bias"])
def _save_vocab(self, vocab_filepath: str):
tf.io.gfile.copy(

View File

@ -146,6 +146,8 @@ py_test(
tags = ["notsan"],
deps = [
":gesture_recognizer_import",
":hyperparameters",
":model_options",
"//mediapipe/model_maker/python/core/utils:test_util",
"//mediapipe/tasks/python/test:test_utils",
],

View File

@ -25,3 +25,12 @@ HParams = hyperparameters.HParams
Dataset = dataset.Dataset
HandDataPreprocessingParams = dataset.HandDataPreprocessingParams
GestureRecognizerOptions = gesture_recognizer_options.GestureRecognizerOptions
# Remove duplicated and non-public API
del constants # pylint: disable=undefined-variable
del dataset
del gesture_recognizer
del gesture_recognizer_options
del hyperparameters
del metadata_writer # pylint: disable=undefined-variable
del model_options

View File

@ -173,15 +173,20 @@ class GestureRecognizer(classifier.Classifier):
batch_size=None,
dtype=tf.float32,
name='hand_embedding')
x = tf.keras.layers.BatchNormalization()(inputs)
x = tf.keras.layers.ReLU()(x)
x = inputs
dropout_rate = self._model_options.dropout_rate
x = tf.keras.layers.Dropout(rate=dropout_rate, name='dropout')(x)
for i, width in enumerate(self._model_options.layer_widths):
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.Dropout(rate=dropout_rate)(x)
x = tf.keras.layers.Dense(width, name=f'custom_gesture_recognizer_{i}')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.Dropout(rate=dropout_rate)(x)
outputs = tf.keras.layers.Dense(
self._num_classes,
activation='softmax',
name='custom_gesture_recognizer')(
name='custom_gesture_recognizer_out')(
x)
self._model = tf.keras.Model(inputs=inputs, outputs=outputs)

View File

@ -23,6 +23,8 @@ import tensorflow as tf
from mediapipe.model_maker.python.core.utils import test_util
from mediapipe.model_maker.python.vision import gesture_recognizer
from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters
from mediapipe.model_maker.python.vision.gesture_recognizer import model_options
from mediapipe.tasks.python.test import test_utils
_TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdata'
@ -48,11 +50,11 @@ class GestureRecognizerTest(tf.test.TestCase):
self._train_data, self._validation_data = all_data.split(0.9)
def test_gesture_recognizer_model(self):
model_options = gesture_recognizer.ModelOptions()
mo = gesture_recognizer.ModelOptions()
hparams = gesture_recognizer.HParams(
export_dir=tempfile.mkdtemp(), epochs=2)
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
model_options=model_options, hparams=hparams)
model_options=mo, hparams=hparams)
model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data,
validation_data=self._validation_data,
@ -60,12 +62,38 @@ class GestureRecognizerTest(tf.test.TestCase):
self._test_accuracy(model)
def test_export_gesture_recognizer_model(self):
model_options = gesture_recognizer.ModelOptions()
@unittest_mock.patch.object(
tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense)
def test_gesture_recognizer_model_layer_widths(self, mock_dense):
layer_widths = [64, 32]
mo = gesture_recognizer.ModelOptions(layer_widths=layer_widths)
hparams = gesture_recognizer.HParams(
export_dir=tempfile.mkdtemp(), epochs=2)
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
model_options=model_options, hparams=hparams)
model_options=mo, hparams=hparams)
model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data,
validation_data=self._validation_data,
options=gesture_recognizer_options)
expected_calls = [
unittest_mock.call(w, name=f'custom_gesture_recognizer_{i}')
for i, w in enumerate(layer_widths)
]
expected_calls.append(
unittest_mock.call(
len(self._train_data.label_names),
activation='softmax',
name='custom_gesture_recognizer_out'))
self.assertLen(mock_dense.call_args_list, len(expected_calls))
mock_dense.assert_has_calls(expected_calls)
self._test_accuracy(model)
def test_export_gesture_recognizer_model(self):
mo = gesture_recognizer.ModelOptions()
hparams = gesture_recognizer.HParams(
export_dir=tempfile.mkdtemp(), epochs=2)
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
model_options=mo, hparams=hparams)
model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data,
validation_data=self._validation_data,
@ -102,12 +130,12 @@ class GestureRecognizerTest(tf.test.TestCase):
self.assertGreater(accuracy, threshold)
@unittest_mock.patch.object(
gesture_recognizer.hyperparameters,
hyperparameters,
'HParams',
autospec=True,
return_value=gesture_recognizer.HParams(epochs=1))
@unittest_mock.patch.object(
gesture_recognizer.model_options,
model_options,
'GestureRecognizerModelOptions',
autospec=True,
return_value=gesture_recognizer.ModelOptions())
@ -122,11 +150,11 @@ class GestureRecognizerTest(tf.test.TestCase):
mock_model_options.assert_called_once()
def test_continual_training_by_loading_checkpoint(self):
model_options = gesture_recognizer.ModelOptions()
mo = gesture_recognizer.ModelOptions()
hparams = gesture_recognizer.HParams(
export_dir=tempfile.mkdtemp(), epochs=2)
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
model_options=model_options, hparams=hparams)
model_options=mo, hparams=hparams)
mock_stdout = io.StringIO()
with mock.patch('sys.stdout', mock_stdout):
model = gesture_recognizer.GestureRecognizer.create(

View File

@ -14,6 +14,7 @@
"""Configurable model options for gesture recognizer models."""
import dataclasses
from typing import List
@dataclasses.dataclass
@ -23,5 +24,10 @@ class GestureRecognizerModelOptions:
Attributes:
dropout_rate: The fraction of the input units to drop, used in dropout
layer.
layer_widths: A list of hidden layer widths for the gesture model. Each
element in the list will create a new hidden layer with the specified
width. The hidden layers are separated with BatchNorm, Dropout, and ReLU.
Defaults to an empty list(no hidden layers).
"""
dropout_rate: float = 0.05
layer_widths: List[int] = dataclasses.field(default_factory=list)

View File

@ -121,7 +121,9 @@ py_library(
srcs = ["image_classifier_test.py"],
data = ["//mediapipe/model_maker/python/vision/image_classifier/testdata"],
deps = [
":hyperparameters",
":image_classifier_import",
":model_options",
"//mediapipe/tasks/python/test:test_utils",
],
)

View File

@ -27,3 +27,12 @@ ModelOptions = model_options.ImageClassifierModelOptions
ModelSpec = model_spec.ModelSpec
SupportedModels = model_spec.SupportedModels
ImageClassifierOptions = image_classifier_options.ImageClassifierOptions
# Remove duplicated and non-public API
del dataset
del hyperparameters
del image_classifier
del image_classifier_options
del model_options
del model_spec
del train_image_classifier_lib # pylint: disable=undefined-variable

View File

@ -24,6 +24,8 @@ import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.vision import image_classifier
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
from mediapipe.model_maker.python.vision.image_classifier import model_options
from mediapipe.tasks.python.test import test_utils
@ -159,15 +161,15 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
self.assertGreaterEqual(accuracy, threshold)
@unittest_mock.patch.object(
image_classifier.hyperparameters,
hyperparameters,
'HParams',
autospec=True,
return_value=image_classifier.HParams(epochs=1))
return_value=hyperparameters.HParams(epochs=1))
@unittest_mock.patch.object(
image_classifier.model_options,
model_options,
'ImageClassifierModelOptions',
autospec=True,
return_value=image_classifier.ModelOptions())
return_value=model_options.ImageClassifierModelOptions())
def test_create_hparams_and_model_options_if_none_in_image_classifier_options(
self, mock_hparams, mock_model_options):
options = image_classifier.ImageClassifierOptions(

View File

@ -1,5 +1,5 @@
absl-py
mediapipe==0.9.1
mediapipe==0.9.0.1
numpy
opencv-python
tensorflow>=2.10

View File

@ -28,6 +28,8 @@ import PIL.Image
from mediapipe.python._framework_bindings import image
from mediapipe.python._framework_bindings import image_frame
TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata'
Image = image.Image
ImageFormat = image_frame.ImageFormat
@ -187,5 +189,26 @@ class ImageTest(absltest.TestCase):
gc.collect()
self.assertEqual(sys.getrefcount(rgb_image), initial_ref_count)
def test_image_create_from_cvmat(self):
image_path = os.path.join(os.path.dirname(__file__),
'solutions/testdata/hands.jpg')
mat = cv2.imread(image_path).astype(np.uint8)
mat = cv2.cvtColor(mat, cv2.COLOR_BGR2RGB)
rgb_image = Image(image_format=ImageFormat.SRGB, data=mat)
self.assertEqual(rgb_image.width, 720)
self.assertEqual(rgb_image.height, 382)
self.assertEqual(rgb_image.channels, 3)
self.assertEqual(rgb_image.image_format, ImageFormat.SRGB)
self.assertTrue(np.array_equal(mat, rgb_image.numpy_view()))
def test_image_create_from_file(self):
image_path = os.path.join(os.path.dirname(__file__),
'solutions/testdata/hands.jpg')
loaded_image = Image.create_from_file(image_path)
self.assertEqual(loaded_image.width, 720)
self.assertEqual(loaded_image.height, 382)
self.assertEqual(loaded_image.channels, 3)
self.assertEqual(loaded_image.image_format, ImageFormat.SRGB)
if __name__ == '__main__':
absltest.main()

View File

@ -157,7 +157,7 @@ class PacketTest(absltest.TestCase):
p.timestamp = 0
self.assertAlmostEqual(packet_getter.get_float(p), 0.42)
self.assertEqual(p.timestamp, 0)
p2 = packet_creator.create_float(np.float(0.42))
p2 = packet_creator.create_float(float(0.42))
p2.timestamp = 0
self.assertAlmostEqual(packet_getter.get_float(p2), 0.42)
self.assertEqual(p2.timestamp, 0)

View File

@ -48,16 +48,20 @@ void ImageSubmodule(pybind11::module* module) {
become immutable after creation.
Creation examples:
import cv2
cv_mat = cv2.imread(input_file)[:, :, ::-1]
rgb_frame = mp.Image(format=ImageFormat.SRGB, data=cv_mat)
gray_frame = mp.Image(
format=ImageFormat.GRAY, data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY))
from PIL import Image
pil_img = Image.new('RGB', (60, 30), color = 'red')
image = mp.Image(
format=mp.ImageFormat.SRGB, data=np.asarray(pil_img))
```python
import cv2
cv_mat = cv2.imread(input_file)[:, :, ::-1]
rgb_frame = mp.Image(image_format=ImageFormat.SRGB, data=cv_mat)
gray_frame = mp.Image(
image_format=ImageFormat.GRAY,
data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY))
from PIL import Image
pil_img = Image.new('RGB', (60, 30), color = 'red')
image = mp.Image(
image_format=mp.ImageFormat.SRGB, data=np.asarray(pil_img))
```
The pixel data in an Image can be retrieved as a numpy ndarray by calling
`Image.numpy_view()`. The returned numpy ndarray is a reference to the
@ -65,15 +69,18 @@ void ImageSubmodule(pybind11::module* module) {
numpy ndarray, it's required to obtain a copy of it.
Pixel data retrieval examples:
for channel in range(num_channel):
for col in range(width):
for row in range(height):
print(image[row, col, channel])
output_ndarray = image.numpy_view()
print(output_ndarray[0, 0, 0])
copied_ndarray = np.copy(output_ndarray)
copied_ndarray[0,0,0] = 0
```python
for channel in range(num_channel):
for col in range(width):
for row in range(height):
print(image[row, col, channel])
output_ndarray = image.numpy_view()
print(output_ndarray[0, 0, 0])
copied_ndarray = np.copy(output_ndarray)
copied_ndarray[0,0,0] = 0
```
)doc",
py::dynamic_attr());
@ -156,9 +163,11 @@ void ImageSubmodule(pybind11::module* module) {
An unwritable numpy ndarray.
Examples:
```
output_ndarray = image.numpy_view()
copied_ndarray = np.copy(output_ndarray)
copied_ndarray[0,0,0] = 0
```
)doc");
image.def(
@ -191,10 +200,12 @@ void ImageSubmodule(pybind11::module* module) {
IndexError: If the index is invalid or out of bounds.
Examples:
```
for channel in range(num_channel):
for col in range(width):
for row in range(height):
print(image[row, col, channel])
```
)doc");
image
@ -224,7 +235,9 @@ void ImageSubmodule(pybind11::module* module) {
A boolean.
Examples:
```
image.is_aligned(16)
```
)doc");
image.def_static(

View File

@ -83,14 +83,15 @@ void ImageFrameSubmodule(pybind11::module* module) {
Creation examples:
import cv2
cv_mat = cv2.imread(input_file)[:, :, ::-1]
rgb_frame = mp.ImageFrame(format=ImageFormat.SRGB, data=cv_mat)
rgb_frame = mp.ImageFrame(image_format=ImageFormat.SRGB, data=cv_mat)
gray_frame = mp.ImageFrame(
format=ImageFormat.GRAY, data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY))
image_format=ImageFormat.GRAY,
data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY))
from PIL import Image
pil_img = Image.new('RGB', (60, 30), color = 'red')
image_frame = mp.ImageFrame(
format=mp.ImageFormat.SRGB, data=np.asarray(pil_img))
image_format=mp.ImageFormat.SRGB, data=np.asarray(pil_img))
The pixel data in an ImageFrame can be retrieved as a numpy ndarray by calling
`ImageFrame.numpy_view()`. The returned numpy ndarray is a reference to the

View File

@ -23,4 +23,3 @@ objc_library(
],
module_name = "MPPCommon",
)

View File

@ -1,25 +1,25 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import <Foundation/Foundation.h>
NS_ASSUME_NONNULL_BEGIN
/**
* @enum TFLSupportErrorCode
* This enum specifies error codes for TensorFlow Lite Task Library.
* It maintains a 1:1 mapping to TfLiteSupportErrorCode of C libray.
* @enum MPPTasksErrorCode
* This enum specifies error codes for Mediapipe Task Library.
* It maintains a 1:1 mapping to MediaPipeTasksStatus of the C ++libray.
*/
typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) {
@ -48,16 +48,16 @@ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) {
MPPTasksErrorCodeFileReadError,
// I/O error when mmap-ing file.
MPPTasksErrorCodeFileMmapError,
// ZIP I/O error when unpacMPPTasksErrorCodeing the zip file.
// ZIP I/O error when unpacking the zip file.
MPPTasksErrorCodeFileZipError,
// TensorFlow Lite metadata error codes.
// Unexpected schema version (aMPPTasksErrorCodea file_identifier) in the Metadata FlatBuffer.
// Unexpected schema version (aka file_identifier) in the Metadata FlatBuffer.
MPPTasksErrorCodeMetadataInvalidSchemaVersionError = 200,
// No such associated file within metadata, or file has not been pacMPPTasksErrorCodeed.
// No such associated file within metadata, or file has not been packed.
MPPTasksErrorCodeMetadataAssociatedFileNotFoundError,
// ZIP I/O error when unpacMPPTasksErrorCodeing an associated file.
// ZIP I/O error when unpacking an associated file.
MPPTasksErrorCodeMetadataAssociatedFileZipError,
// Inconsistency error between the metadata and actual TF Lite model.
// E.g.: number of labels and output tensor values differ.
@ -167,11 +167,10 @@ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) {
// Task graph config is invalid.
MPPTasksErrorCodeInvalidTaskGraphConfigError,
// The first error code in MPPTasksErrorCode (for internal use only).
MPPTasksErrorCodeFirst = MPPTasksErrorCodeError,
/**
* The last error code in TFLSupportErrorCode (for internal use only).
*/
// The last error code in MPPTasksErrorCode (for internal use only).
MPPTasksErrorCodeLast = MPPTasksErrorCodeInvalidTaskGraphConfigError,
} NS_SWIFT_NAME(TasksErrorCode);

View File

@ -24,7 +24,7 @@ extern NSString *const MPPTasksErrorDomain;
@interface MPPCommonUtils : NSObject
/**
* Creates and saves an NSError in the Mediapipe task library domain, with the given code and
* Creates and saves an NSError in the MediPipe task library domain, with the given code and
* description.
*
* @param code Error code.
@ -51,9 +51,9 @@ extern NSString *const MPPTasksErrorDomain;
description:(NSString *)description;
/**
* Converts an absl status to an NSError.
* Converts an absl::Status to an NSError.
*
* @param status absl status.
* @param status absl::Status.
* @param error Pointer to the memory location where the created error should be saved. If `nil`,
* no error will be saved.
*/
@ -61,15 +61,15 @@ extern NSString *const MPPTasksErrorDomain;
/**
* Allocates a block of memory with the specified size and returns a pointer to it. If memory
* cannot be allocated because of an invalid memSize, it saves an error. In other cases, it
* cannot be allocated because of an invalid `memSize`, it saves an error. In other cases, it
* terminates program execution.
*
* @param memSize size of memory to be allocated
* @param error Pointer to the memory location where errors if any should be saved. If `nil`, no
* error will be saved.
*
* @return Pointer to the allocated block of memory on successfull allocation. nil in case as
* error is encountered because of invalid memSize. If failure is due to any other reason, method
* @return Pointer to the allocated block of memory on successfull allocation. `nil` in case as
* error is encountered because of invalid `memSize`. If failure is due to any other reason, method
* terminates program execution.
*/
+ (void *)mallocWithSize:(size_t)memSize error:(NSError **)error;

View File

@ -24,7 +24,7 @@
#include "mediapipe/tasks/cc/common.h"
/** Error domain of MediaPipe task library errors. */
NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks";
NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks";
@implementation MPPCommonUtils
@ -68,7 +68,7 @@ NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks";
if (status.ok()) {
return YES;
}
// Payload of absl::Status created by the Media Pipe task library stores an appropriate value of
// Payload of absl::Status created by the MediaPipe task library stores an appropriate value of
// the enum MediaPipeTasksStatus. The integer value corresponding to the MediaPipeTasksStatus enum
// stored in the payload is extracted here to later map to the appropriate error code to be
// returned. In cases where the enum is not stored in (payload is NULL or the payload string

View File

@ -17,25 +17,38 @@
NS_ASSUME_NONNULL_BEGIN
/**
* Holds settings for any single iOS Mediapipe classification task.
* Holds settings for any single iOS MediaPipe classification task.
*/
NS_SWIFT_NAME(ClassifierOptions)
@interface MPPClassifierOptions : NSObject <NSCopying>
/** If set, all classes in this list will be filtered out from the results . */
@property(nonatomic, copy) NSArray<NSString *> *labelDenyList;
/** If set, all classes not in this list will be filtered out from the results . */
@property(nonatomic, copy) NSArray<NSString *> *labelAllowList;
/** Display names local for display names*/
/** The locale to use for display names specified through the TFLite Model
* Metadata, if any. Defaults to English.
*/
@property(nonatomic, copy) NSString *displayNamesLocale;
/** Results with score threshold greater than this value are returned . */
/** The maximum number of top-scored classification results to return. If < 0,
* all available results will be returned. If 0, an invalid argument error is
* returned.
*/
@property(nonatomic) NSInteger maxResults;
/** Score threshold to override the one provided in the model metadata (if any).
* Results below this value are rejected.
*/
@property(nonatomic) float scoreThreshold;
/** Limit to the number of classes that can be returned in results. */
@property(nonatomic) NSInteger maxResults;
/** The allowlist of category names. If non-empty, detection results whose
* category name is not in this set will be filtered out. Duplicate or unknown
* category names are ignored. Mutually exclusive with categoryDenylist.
*/
@property(nonatomic, copy) NSArray<NSString *> *categoryAllowlist;
/** The denylist of category names. If non-empty, detection results whose
* category name is in this set will be filtered out. Duplicate or unknown
* category names are ignored. Mutually exclusive with categoryAllowlist.
*/
@property(nonatomic, copy) NSArray<NSString *> *categoryDenylist;
@end

View File

@ -30,8 +30,8 @@
classifierOptions.scoreThreshold = self.scoreThreshold;
classifierOptions.maxResults = self.maxResults;
classifierOptions.labelDenyList = self.labelDenyList;
classifierOptions.labelAllowList = self.labelAllowList;
classifierOptions.categoryDenylist = self.categoryDenylist;
classifierOptions.categoryAllowlist = self.categoryAllowlist;
classifierOptions.displayNamesLocale = self.displayNamesLocale;
return classifierOptions;

View File

@ -20,17 +20,23 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto
}
@implementation MPPClassifierOptions (Helpers)
- (void)copyToProto:(ClassifierOptionsProto *)classifierOptionsProto {
classifierOptionsProto->Clear();
if (self.displayNamesLocale) {
classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString);
}
classifierOptionsProto->set_max_results((int)self.maxResults);
classifierOptionsProto->set_score_threshold(self.scoreThreshold);
for (NSString *category in self.labelAllowList) {
for (NSString *category in self.categoryAllowlist) {
classifierOptionsProto->add_category_allowlist(category.cppString);
}
for (NSString *category in self.labelDenyList) {
for (NSString *category in self.categoryDenylist) {
classifierOptionsProto->add_category_denylist(category.cppString);
}
}

View File

@ -16,19 +16,35 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
objc_library(
name = "MPPBaseOptions",
srcs = ["sources/MPPBaseOptions.m"],
hdrs = ["sources/MPPBaseOptions.h"],
)
objc_library(
name = "MPPTaskOptions",
srcs = ["sources/MPPTaskOptions.m"],
hdrs = ["sources/MPPTaskOptions.h"],
copts = [
"-ObjC++",
"-std=c++17",
],
deps = [
":MPPBaseOptions",
],
)
objc_library(
name = "MPPTaskResult",
srcs = ["sources/MPPTaskResult.m"],
hdrs = ["sources/MPPTaskResult.h"],
)
objc_library(
name = "MPPTaskOptionsProtocol",
hdrs = ["sources/MPPTaskOptionsProtocol.h"],
deps = [
"//mediapipe/framework:calculator_options_cc_proto",
],
)
objc_library(
name = "MPPTaskInfo",
srcs = ["sources/MPPTaskInfo.mm"],
@ -64,32 +80,12 @@ objc_library(
)
objc_library(
name = "MPPTaskResult",
srcs = ["sources/MPPTaskResult.m"],
hdrs = ["sources/MPPTaskResult.h"],
)
objc_library(
name = "MPPBaseOptions",
srcs = ["sources/MPPBaseOptions.m"],
hdrs = ["sources/MPPBaseOptions.h"],
)
objc_library(
name = "MPPTaskOptionsProtocol",
hdrs = ["sources/MPPTaskOptionsProtocol.h"],
deps = [
"//mediapipe/framework:calculator_options_cc_proto",
],
)
objc_library(
name = "MPPTaskManager",
srcs = ["sources/MPPTaskManager.mm"],
hdrs = ["sources/MPPTaskManager.h"],
name = "MPPTaskRunner",
srcs = ["sources/MPPTaskRunner.mm"],
hdrs = ["sources/MPPTaskRunner.h"],
deps = [
"//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
],
)

View File

@ -17,7 +17,6 @@
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h"
NS_ASSUME_NONNULL_BEGIN
/**
@ -55,7 +54,7 @@ NS_ASSUME_NONNULL_BEGIN
outputStreams:(NSArray<NSString *> *)outputStreams
taskOptions:(id<MPPTaskOptionsProtocol>)taskOptions
enableFlowLimiting:(BOOL)enableFlowLimiting
error:(NSError **)error;
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
/**
* Creates a MediaPipe Task protobuf message from the MPPTaskInfo instance.

View File

@ -24,9 +24,9 @@
namespace {
using CalculatorGraphConfig = ::mediapipe::CalculatorGraphConfig;
using Node = ::mediapipe::CalculatorGraphConfig::Node;
using ::mediapipe::InputStreamInfo;
using ::mediapipe::CalculatorOptions;
using ::mediapipe::FlowLimiterCalculatorOptions;
using ::mediapipe::InputStreamInfo;
} // namespace
@implementation MPPTaskInfo
@ -82,45 +82,46 @@ using ::mediapipe::FlowLimiterCalculatorOptions;
graph_config.add_output_stream(cpp_output_stream);
}
if (self.enableFlowLimiting) {
Node *flow_limit_calculator_node = graph_config.add_node();
flow_limit_calculator_node->set_calculator("FlowLimiterCalculator");
InputStreamInfo *input_stream_info = flow_limit_calculator_node->add_input_stream_info();
input_stream_info->set_tag_index("FINISHED");
input_stream_info->set_back_edge(true);
FlowLimiterCalculatorOptions *flow_limit_calculator_options =
flow_limit_calculator_node->mutable_options()->MutableExtension(
FlowLimiterCalculatorOptions::ext);
flow_limit_calculator_options->set_max_in_flight(1);
flow_limit_calculator_options->set_max_in_queue(1);
for (NSString *inputStream in self.inputStreams) {
graph_config.add_input_stream(inputStream.cppString);
NSString *strippedInputStream = [MPPTaskInfo stripTagIndex:inputStream];
flow_limit_calculator_node->add_input_stream(strippedInputStream.cppString);
NSString *taskInputStream = [MPPTaskInfo addStreamNamePrefix:inputStream];
task_subgraph_node->add_input_stream(taskInputStream.cppString);
NSString *strippedTaskInputStream = [MPPTaskInfo stripTagIndex:taskInputStream];
flow_limit_calculator_node->add_output_stream(strippedTaskInputStream.cppString);
}
NSString *firstOutputStream = self.outputStreams[0];
auto finished_output_stream = "FINISHED:" + firstOutputStream.cppString;
flow_limit_calculator_node->add_input_stream(finished_output_stream);
} else {
if (!self.enableFlowLimiting) {
for (NSString *inputStream in self.inputStreams) {
auto cpp_input_stream = inputStream.cppString;
task_subgraph_node->add_input_stream(cpp_input_stream);
graph_config.add_input_stream(cpp_input_stream);
}
return graph_config;
}
Node *flow_limit_calculator_node = graph_config.add_node();
flow_limit_calculator_node->set_calculator("FlowLimiterCalculator");
InputStreamInfo *input_stream_info = flow_limit_calculator_node->add_input_stream_info();
input_stream_info->set_tag_index("FINISHED");
input_stream_info->set_back_edge(true);
FlowLimiterCalculatorOptions *flow_limit_calculator_options =
flow_limit_calculator_node->mutable_options()->MutableExtension(
FlowLimiterCalculatorOptions::ext);
flow_limit_calculator_options->set_max_in_flight(1);
flow_limit_calculator_options->set_max_in_queue(1);
for (NSString *inputStream in self.inputStreams) {
graph_config.add_input_stream(inputStream.cppString);
NSString *strippedInputStream = [MPPTaskInfo stripTagIndex:inputStream];
flow_limit_calculator_node->add_input_stream(strippedInputStream.cppString);
NSString *taskInputStream = [MPPTaskInfo addStreamNamePrefix:inputStream];
task_subgraph_node->add_input_stream(taskInputStream.cppString);
NSString *strippedTaskInputStream = [MPPTaskInfo stripTagIndex:taskInputStream];
flow_limit_calculator_node->add_output_stream(strippedTaskInputStream.cppString);
}
NSString *firstOutputStream = self.outputStreams[0];
auto finished_output_stream = "FINISHED:" + firstOutputStream.cppString;
flow_limit_calculator_node->add_input_stream(finished_output_stream);
return graph_config;
}

View File

@ -1,14 +1,17 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h"
@ -19,27 +22,13 @@ NS_ASSUME_NONNULL_BEGIN
* this class.
*/
NS_SWIFT_NAME(TaskOptions)
@interface MPPTaskOptions : NSObject <NSCopying>
/**
* Base options for configuring the Mediapipe task.
*/
@property(nonatomic, copy) MPPBaseOptions *baseOptions;
/**
* Initializes a new `MPPTaskOptions` with the absolute path to the model file
* stored locally on the device, set to the given the model path.
*
* @discussion The external model file must be a single standalone TFLite file. It could be packed
* with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the
* necessary metadata and associated files might result in errors. Check the [documentation]
* (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement.
*
* @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
*
* @return An instance of `MPPTaskOptions` initialized to the given model path.
*/
- (instancetype)initWithModelPath:(NSString *)modelPath;
@end
NS_ASSUME_NONNULL_END

View File

@ -1,17 +1,17 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h"
@ -25,12 +25,12 @@
return self;
}
- (instancetype)initWithModelPath:(NSString *)modelPath {
self = [self init];
if (self) {
_baseOptions.modelAssetPath = modelPath;
}
return self;
- (id)copyWithZone:(NSZone *)zone {
MPPTaskOptions *taskOptions = [[MPPTaskOptions alloc] init];
taskOptions.baseOptions = self.baseOptions;
return taskOptions;
}
@end

View File

@ -1,26 +1,29 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#include "mediapipe/framework/calculator_options.pb.h"
NS_ASSUME_NONNULL_BEGIN
/**
* Any mediapipe task options should confirm to this protocol.
* Any MediaPipe task options should confirm to this protocol.
*/
@protocol MPPTaskOptionsProtocol
/**
* Copies the iOS Mediapipe task options to an object of mediapipe::CalculatorOptions proto.
* Copies the iOS MediaPipe task options to an object of mediapipe::CalculatorOptions proto.
*/
- (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto;

View File

@ -1,30 +1,36 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
NS_ASSUME_NONNULL_BEGIN
/**
* MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend
* MediaPipe Tasks result base class. Any MediaPipe task result class should extend
* this class.
*/
NS_SWIFT_NAME(TaskResult)
@interface MPPTaskResult : NSObject <NSCopying>
/**
* Base options for configuring the Mediapipe task.
* Timestamp that is associated with the task result object.
*/
@property(nonatomic, assign, readonly) long timeStamp;
@property(nonatomic, assign, readonly) long timestamp;
- (instancetype)initWithTimeStamp:(long)timeStamp;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithTimestamp:(long)timestamp NS_DESIGNATED_INITIALIZER;
@end

View File

@ -1,27 +1,31 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h"
@implementation MPPTaskResult
- (instancetype)initWithTimeStamp:(long)timeStamp {
self = [self init];
- (instancetype)initWithTimestamp:(long)timestamp {
self = [super init];
if (self) {
_timeStamp = timeStamp;
_timestamp = timestamp;
}
return self;
}
- (id)copyWithZone:(NSZone *)zone {
return [[MPPTaskResult alloc] initWithTimestamp:self.timestamp];
}
@end

View File

@ -0,0 +1,50 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
NS_ASSUME_NONNULL_BEGIN
/**
* This class is used to create and call appropriate methods on the C++ Task Runner.
*/
@interface MPPTaskRunner : NSObject
/**
* Initializes a new `MPPTaskRunner` with the mediapipe task graph config proto.
*
* @param graphConfig A mediapipe task graph config proto.
*
* @return An instance of `MPPTaskRunner` initialized to the given graph config proto.
*/
- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
- (absl::StatusOr<mediapipe::tasks::core::PacketMap>)
process:(const mediapipe::tasks::core::PacketMap &)packetMap
error:(NSError **)error;
- (void)close;
- (instancetype)init NS_UNAVAILABLE;
+ (instancetype)new NS_UNAVAILABLE;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,56 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h"
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
namespace {
using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::Packet;
using ::mediapipe::tasks::core::PacketMap;
using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
} // namespace
@interface MPPTaskRunner () {
// Cpp Task Runner
std::unique_ptr<TaskRunnerCpp> _cppTaskRunner;
}
@end
@implementation MPPTaskRunner
- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig
error:(NSError **)error {
self = [super init];
if (self) {
auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig));
if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) {
return nil;
}
_cppTaskRunner = std::move(taskRunnerResult.value());
}
return self;
}
- (absl::StatusOr<PacketMap>)process:(const PacketMap &)packetMap {
return _cppTaskRunner->Process(packetMap);
}
- (void)close {
_cppTaskRunner->Close();
}
@end

View File

@ -203,6 +203,8 @@ public final class AudioClassifier extends BaseAudioTaskApi {
TaskRunner.create(
context,
TaskInfo.<AudioClassifierOptions>builder()
.setTaskName(AudioClassifier.class.getSimpleName())
.setTaskRunningModeName(options.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)

View File

@ -200,6 +200,8 @@ public final class AudioEmbedder extends BaseAudioTaskApi {
TaskRunner.create(
context,
TaskInfo.<AudioEmbedderOptions>builder()
.setTaskName(AudioEmbedder.class.getSimpleName())
.setTaskRunningModeName(options.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)

View File

@ -22,6 +22,7 @@ android_library(
],
manifest = "AndroidManifest.xml",
deps = [
":logging",
"//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite",
"//mediapipe/calculators/tensor:inference_calculator_java_proto_lite",
"//mediapipe/framework:calculator_java_proto_lite",
@ -37,11 +38,22 @@ android_library(
],
)
android_library(
name = "logging",
srcs = glob(
["logging/*.java"],
),
deps = [
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_core_aar")
mediapipe_tasks_core_aar(
name = "tasks_core",
srcs = glob(["*.java"]) + [
srcs = glob(["**/*.java"]) + [
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:java_src",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:java_src",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:java_src",

View File

@ -32,6 +32,12 @@ public abstract class TaskInfo<T extends TaskOptions> {
/** Builder for {@link TaskInfo}. */
@AutoValue.Builder
public abstract static class Builder<T extends TaskOptions> {
/** Sets the MediaPipe task name. */
public abstract Builder<T> setTaskName(String value);
/** Sets the MediaPipe task running mode name. */
public abstract Builder<T> setTaskRunningModeName(String value);
/** Sets the MediaPipe task graph name. */
public abstract Builder<T> setTaskGraphName(String value);
@ -71,6 +77,10 @@ public abstract class TaskInfo<T extends TaskOptions> {
}
}
abstract String taskName();
abstract String taskRunningModeName();
abstract String taskGraphName();
abstract T taskOptions();
@ -82,7 +92,7 @@ public abstract class TaskInfo<T extends TaskOptions> {
abstract Boolean enableFlowLimiting();
public static <T extends TaskOptions> Builder<T> builder() {
return new AutoValue_TaskInfo.Builder<T>();
return new AutoValue_TaskInfo.Builder<T>().setTaskName("").setTaskRunningModeName("");
}
/* Returns a list of the output stream names without the stream tags. */

View File

@ -21,6 +21,8 @@ import com.google.mediapipe.framework.AndroidPacketCreator;
import com.google.mediapipe.framework.Graph;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.tasks.core.logging.TasksStatsLogger;
import com.google.mediapipe.tasks.core.logging.TasksStatsDummyLogger;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
@ -34,6 +36,7 @@ public class TaskRunner implements AutoCloseable {
private final Graph graph;
private final ModelResourcesCache modelResourcesCache;
private final AndroidPacketCreator packetCreator;
private final TasksStatsLogger statsLogger;
private long lastSeenTimestamp = Long.MIN_VALUE;
private ErrorListener errorListener;
@ -51,6 +54,8 @@ public class TaskRunner implements AutoCloseable {
Context context,
TaskInfo<? extends TaskOptions> taskInfo,
OutputHandler<? extends TaskResult, ?> outputHandler) {
TasksStatsLogger statsLogger =
TasksStatsDummyLogger.create(context, taskInfo.taskName(), taskInfo.taskRunningModeName());
AndroidAssetUtil.initializeNativeAssetManager(context);
Graph mediapipeGraph = new Graph();
mediapipeGraph.loadBinaryGraph(taskInfo.generateGraphConfig());
@ -58,12 +63,15 @@ public class TaskRunner implements AutoCloseable {
mediapipeGraph.setServiceObject(new ModelResourcesCacheService(), graphModelResourcesCache);
mediapipeGraph.addMultiStreamCallback(
taskInfo.outputStreamNames(),
outputHandler::run,
/*observeTimestampBounds=*/ outputHandler.handleTimestampBoundChanges());
packets -> {
outputHandler.run(packets);
statsLogger.recordInvocationEnd(packets.get(0).getTimestamp());
},
/* observeTimestampBounds= */ outputHandler.handleTimestampBoundChanges());
mediapipeGraph.startRunningGraph();
// Waits until all calculators are opened and the graph is fully started.
mediapipeGraph.waitUntilGraphIdle();
return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler);
return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler, statsLogger);
}
/**
@ -91,7 +99,10 @@ public class TaskRunner implements AutoCloseable {
* @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs.
*/
public synchronized TaskResult process(Map<String, Packet> inputs) {
addPackets(inputs, generateSyntheticTimestamp());
long syntheticInputTimestamp = generateSyntheticTimestamp();
// TODO: Support recording GPU input arrival.
statsLogger.recordCpuInputArrival(syntheticInputTimestamp);
addPackets(inputs, syntheticInputTimestamp);
graph.waitUntilGraphIdle();
lastSeenTimestamp = outputHandler.getLatestOutputTimestamp();
return outputHandler.retrieveCachedTaskResult();
@ -112,6 +123,7 @@ public class TaskRunner implements AutoCloseable {
*/
public synchronized TaskResult process(Map<String, Packet> inputs, long inputTimestamp) {
validateInputTimstamp(inputTimestamp);
statsLogger.recordCpuInputArrival(inputTimestamp);
addPackets(inputs, inputTimestamp);
graph.waitUntilGraphIdle();
return outputHandler.retrieveCachedTaskResult();
@ -132,6 +144,7 @@ public class TaskRunner implements AutoCloseable {
*/
public synchronized void send(Map<String, Packet> inputs, long inputTimestamp) {
validateInputTimstamp(inputTimestamp);
statsLogger.recordCpuInputArrival(inputTimestamp);
addPackets(inputs, inputTimestamp);
}
@ -145,6 +158,7 @@ public class TaskRunner implements AutoCloseable {
graphStarted.set(false);
graph.closeAllPacketSources();
graph.waitUntilGraphDone();
statsLogger.logSessionEnd();
} catch (MediaPipeException e) {
reportError(e);
}
@ -154,6 +168,7 @@ public class TaskRunner implements AutoCloseable {
// Waits until all calculators are opened and the graph is fully restarted.
graph.waitUntilGraphIdle();
graphStarted.set(true);
statsLogger.logSessionStart();
} catch (MediaPipeException e) {
reportError(e);
}
@ -169,6 +184,7 @@ public class TaskRunner implements AutoCloseable {
graphStarted.set(false);
graph.closeAllPacketSources();
graph.waitUntilGraphDone();
statsLogger.logSessionEnd();
if (modelResourcesCache != null) {
modelResourcesCache.release();
}
@ -247,12 +263,15 @@ public class TaskRunner implements AutoCloseable {
private TaskRunner(
Graph graph,
ModelResourcesCache modelResourcesCache,
OutputHandler<? extends TaskResult, ?> outputHandler) {
OutputHandler<? extends TaskResult, ?> outputHandler,
TasksStatsLogger statsLogger) {
this.outputHandler = outputHandler;
this.graph = graph;
this.modelResourcesCache = modelResourcesCache;
this.packetCreator = new AndroidPacketCreator(graph);
this.statsLogger = statsLogger;
graphStarted.set(true);
this.statsLogger.logSessionStart();
}
/** Reports error. */

View File

@ -0,0 +1,78 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core.logging;
import android.content.Context;
/** A dummy MediaPipe Tasks stats logger that has all methods as no-ops. */
public class TasksStatsDummyLogger implements TasksStatsLogger {
/**
* Creates the MediaPipe Tasks stats dummy logger.
*
* @param context a {@link Context}.
* @param taskNameStr the task api name.
* @param taskRunningModeStr the task running mode string representation.
*/
public static TasksStatsDummyLogger create(
Context context, String taskNameStr, String taskRunningModeStr) {
return new TasksStatsDummyLogger();
}
private TasksStatsDummyLogger() {}
/** Logs the start of a MediaPipe Tasks API session. */
@Override
public void logSessionStart() {}
/**
* Records MediaPipe Tasks API receiving CPU input data.
*
* @param packetTimestamp the input packet timestamp that acts as the identifier of the api
* invocation.
*/
@Override
public void recordCpuInputArrival(long packetTimestamp) {}
/**
* Records MediaPipe Tasks API receiving GPU input data.
*
* @param packetTimestamp the input packet timestamp that acts as the identifier of the api
* invocation.
*/
@Override
public void recordGpuInputArrival(long packetTimestamp) {}
/**
* Records the end of a Mediapipe Tasks API invocation.
*
* @param packetTimestamp the output packet timestamp that acts as the identifier of the api
* invocation.
*/
@Override
public void recordInvocationEnd(long packetTimestamp) {}
/** Logs the MediaPipe Tasks API periodic invocation report. */
@Override
public void logInvocationReport(StatsSnapshot stats) {}
/** Logs the Tasks API session end event. */
@Override
public void logSessionEnd() {}
/** Logs the MediaPipe Tasks API initialization error. */
@Override
public void logInitError() {}
}

View File

@ -0,0 +1,98 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core.logging;
import com.google.auto.value.AutoValue;
/** The stats logger interface that defines what MediaPipe Tasks events to log. */
public interface TasksStatsLogger {
/** Task stats snapshot. */
@AutoValue
abstract static class StatsSnapshot {
static StatsSnapshot create(
int cpuInputCount,
int gpuInputCount,
int finishedCount,
int droppedCount,
long totalLatencyMs,
long peakLatencyMs,
long elapsedTimeMs) {
return new AutoValue_TasksStatsLogger_StatsSnapshot(
cpuInputCount,
gpuInputCount,
finishedCount,
droppedCount,
totalLatencyMs,
peakLatencyMs,
elapsedTimeMs);
}
static StatsSnapshot createDefault() {
return new AutoValue_TasksStatsLogger_StatsSnapshot(0, 0, 0, 0, 0, 0, 0);
}
abstract int cpuInputCount();
abstract int gpuInputCount();
abstract int finishedCount();
abstract int droppedCount();
abstract long totalLatencyMs();
abstract long peakLatencyMs();
abstract long elapsedTimeMs();
}
/** Logs the start of a MediaPipe Tasks API session. */
public void logSessionStart();
/**
* Records MediaPipe Tasks API receiving CPU input data.
*
* @param packetTimestamp the input packet timestamp that acts as the identifier of the api
* invocation.
*/
public void recordCpuInputArrival(long packetTimestamp);
/**
* Records MediaPipe Tasks API receiving GPU input data.
*
* @param packetTimestamp the input packet timestamp that acts as the identifier of the api
* invocation.
*/
public void recordGpuInputArrival(long packetTimestamp);
/**
* Records the end of a Mediapipe Tasks API invocation.
*
* @param packetTimestamp the output packet timestamp that acts as the identifier of the api
* invocation.
*/
public void recordInvocationEnd(long packetTimestamp);
/** Logs the MediaPipe Tasks API periodic invocation report. */
public void logInvocationReport(StatsSnapshot stats);
/** Logs the Tasks API session end event. */
public void logSessionEnd();
/** Logs the MediaPipe Tasks API initialization error. */
public void logInitError();
// TODO: Logs more error types.
}

View File

@ -169,6 +169,7 @@ public final class TextClassifier implements AutoCloseable {
TaskRunner.create(
context,
TaskInfo.<TextClassifierOptions>builder()
.setTaskName(TextClassifier.class.getSimpleName())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)

View File

@ -159,6 +159,7 @@ public final class TextEmbedder implements AutoCloseable {
TaskRunner.create(
context,
TaskInfo.<TextEmbedderOptions>builder()
.setTaskName(TextEmbedder.class.getSimpleName())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)

View File

@ -194,6 +194,8 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
TaskRunner.create(
context,
TaskInfo.<GestureRecognizerOptions>builder()
.setTaskName(GestureRecognizer.class.getSimpleName())
.setTaskRunningModeName(recognizerOptions.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)

View File

@ -183,6 +183,8 @@ public final class HandLandmarker extends BaseVisionTaskApi {
TaskRunner.create(
context,
TaskInfo.<HandLandmarkerOptions>builder()
.setTaskName(HandLandmarker.class.getSimpleName())
.setTaskRunningModeName(landmarkerOptions.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)

View File

@ -197,6 +197,8 @@ public final class ImageClassifier extends BaseVisionTaskApi {
TaskRunner.create(
context,
TaskInfo.<ImageClassifierOptions>builder()
.setTaskName(ImageClassifier.class.getSimpleName())
.setTaskRunningModeName(options.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)

View File

@ -180,6 +180,8 @@ public final class ImageEmbedder extends BaseVisionTaskApi {
TaskRunner.create(
context,
TaskInfo.<ImageEmbedderOptions>builder()
.setTaskName(ImageEmbedder.class.getSimpleName())
.setTaskRunningModeName(options.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)

View File

@ -190,6 +190,8 @@ public final class ObjectDetector extends BaseVisionTaskApi {
TaskRunner.create(
context,
TaskInfo.<ObjectDetectorOptions>builder()
.setTaskName(ObjectDetector.class.getSimpleName())
.setTaskRunningModeName(detectorOptions.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)

View File

@ -86,7 +86,7 @@ describe('convertBaseOptionsToProto()', () => {
it('can enable CPU delegate', async () => {
const baseOptionsProto = await convertBaseOptionsToProto({
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'cpu',
delegate: 'CPU',
});
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
});
@ -94,7 +94,7 @@ describe('convertBaseOptionsToProto()', () => {
it('can enable GPU delegate', async () => {
const baseOptionsProto = await convertBaseOptionsToProto({
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'gpu',
delegate: 'GPU',
});
expect(baseOptionsProto.toObject()).toEqual({
...mockBytesResult,
@ -117,7 +117,7 @@ describe('convertBaseOptionsToProto()', () => {
it('can reset delegate', async () => {
let baseOptionsProto = await convertBaseOptionsToProto({
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'gpu',
delegate: 'GPU',
});
// Clear backend
baseOptionsProto =

View File

@ -71,7 +71,7 @@ async function configureExternalFile(
/** Configues the `acceleration` option. */
function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) {
const acceleration = proto.getAcceleration() ?? new Acceleration();
if (options.delegate === 'gpu') {
if (options.delegate === 'GPU') {
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
} else {
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());

View File

@ -44,22 +44,14 @@ async function isSimdSupported(): Promise<boolean> {
}
async function createFileset(
taskName: string, basePath: string = '.'): Promise<WasmFileset> {
if (await isSimdSupported()) {
return {
wasmLoaderPath:
`${basePath}/${taskName}_wasm_internal.js`,
wasmBinaryPath:
`${basePath}/${taskName}_wasm_internal.wasm`,
};
} else {
return {
wasmLoaderPath:
`${basePath}/${taskName}_wasm_nosimd_internal.js`,
wasmBinaryPath:
`${basePath}/${taskName}_wasm_nosimd_internal.wasm`,
};
}
taskName: string, basePath: string = ''): Promise<WasmFileset> {
const suffix =
await isSimdSupported() ? 'wasm_internal' : 'wasm_nosimd_internal';
return {
wasmLoaderPath: `${basePath}/${taskName}_${suffix}.js`,
wasmBinaryPath: `${basePath}/${taskName}_${suffix}.wasm`,
};
}
// tslint:disable:class-as-namespace

View File

@ -31,7 +31,7 @@ export declare interface BaseOptions {
modelAssetBuffer?: Uint8Array|undefined;
/** Overrides the default backend to use for the provided model. */
delegate?: 'cpu'|'gpu'|undefined;
delegate?: 'CPU'|'GPU'|undefined;
}
/** Options to configure MediaPipe Tasks in general. */

View File

@ -1028,7 +1028,9 @@ export class GraphRunner {
// Set up our TS listener to receive any packets for this stream, and
// additionally reformat our Uint8Array into a Float32Array for the user.
this.setListener(outputStreamName, (data: Uint8Array) => {
const floatArray = new Float32Array(data.buffer); // Should be very fast
// Should be very fast
const floatArray =
new Float32Array(data.buffer, data.byteOffset, data.length / 4);
callbackFcn(floatArray);
});

View File

@ -490,10 +490,10 @@ setuptools.setup(
'Operating System :: MacOS :: MacOS X',
'Operating System :: Microsoft :: Windows',
'Operating System :: POSIX :: Linux',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3 :: Only',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',