Merge branch 'ios-task-files' into ios-task

This commit is contained in:
Prianka Liz Kariat 2022-12-23 16:10:57 +05:30
commit 7f7776ef80
89 changed files with 1189 additions and 891 deletions

View File

@ -1,6 +1,6 @@
--- ---
name: "Solution (legacy) Issue" name: "Solution (legacy) Issue"
about: Use this template for assistance with a specific Mediapipe solution (google.github.io/mediapipe/solutions), such as "Pose" or "Iris", including inference model usage/training, solution-specific calculators, etc. about: Use this template for assistance with a specific Mediapipe solution (google.github.io/mediapipe/solutions) such as "Pose", including inference model usage/training, solution-specific calculators etc.
labels: type:support labels: type:support
--- ---

View File

@ -259,6 +259,7 @@ mp_holistic = mp.solutions.holistic
# For static images: # For static images:
IMAGE_FILES = [] IMAGE_FILES = []
BG_COLOR = (192, 192, 192) # gray
with mp_holistic.Holistic( with mp_holistic.Holistic(
static_image_mode=True, static_image_mode=True,
model_complexity=2, model_complexity=2,

View File

@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
licenses(["notice"]) licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
proto_library( proto_library(
name = "mfcc_mel_calculators_proto", name = "mfcc_mel_calculators_proto",
srcs = ["mfcc_mel_calculators.proto"], srcs = ["mfcc_mel_calculators.proto"],

View File

@ -567,7 +567,7 @@ cc_library(
name = "packet_thinner_calculator", name = "packet_thinner_calculator",
srcs = ["packet_thinner_calculator.cc"], srcs = ["packet_thinner_calculator.cc"],
deps = [ deps = [
"//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", ":packet_thinner_calculator_cc_proto",
"//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/formats:video_stream_header",
@ -584,7 +584,7 @@ cc_test(
srcs = ["packet_thinner_calculator_test.cc"], srcs = ["packet_thinner_calculator_test.cc"],
deps = [ deps = [
":packet_thinner_calculator", ":packet_thinner_calculator",
"//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", ":packet_thinner_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/formats:video_stream_header",
@ -762,7 +762,7 @@ cc_library(
srcs = ["packet_resampler_calculator.cc"], srcs = ["packet_resampler_calculator.cc"],
hdrs = ["packet_resampler_calculator.h"], hdrs = ["packet_resampler_calculator.h"],
deps = [ deps = [
"//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", ":packet_resampler_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id", "//mediapipe/framework:collection_item_id",
"//mediapipe/framework/deps:mathutil", "//mediapipe/framework/deps:mathutil",
@ -786,7 +786,7 @@ cc_test(
], ],
deps = [ deps = [
":packet_resampler_calculator", ":packet_resampler_calculator",
"//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", ":packet_resampler_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/formats:video_stream_header",
@ -852,10 +852,10 @@ cc_test(
name = "flow_limiter_calculator_test", name = "flow_limiter_calculator_test",
srcs = ["flow_limiter_calculator_test.cc"], srcs = ["flow_limiter_calculator_test.cc"],
deps = [ deps = [
":counting_source_calculator",
":flow_limiter_calculator", ":flow_limiter_calculator",
":flow_limiter_calculator_cc_proto", ":flow_limiter_calculator_cc_proto",
"//mediapipe/calculators/core:counting_source_calculator", ":pass_through_calculator",
"//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework:test_calculators", "//mediapipe/framework:test_calculators",
@ -1302,7 +1302,7 @@ cc_test(
srcs = ["packet_sequencer_calculator_test.cc"], srcs = ["packet_sequencer_calculator_test.cc"],
deps = [ deps = [
":packet_sequencer_calculator", ":packet_sequencer_calculator",
"//mediapipe/calculators/core:pass_through_calculator", ":pass_through_calculator",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:subgraph", "//mediapipe/framework:subgraph",

View File

@ -47,7 +47,7 @@ namespace api2 {
// calculator: "Get{SpecificType}VectorItemCalculator" // calculator: "Get{SpecificType}VectorItemCalculator"
// input_stream: "VECTOR:vector" // input_stream: "VECTOR:vector"
// input_stream: "INDEX:index" // input_stream: "INDEX:index"
// input_stream: "ITEM:item" // output_stream: "ITEM:item"
// options { // options {
// [mediapipe.GetVectorItemCalculatorOptions.ext] { // [mediapipe.GetVectorItemCalculatorOptions.ext] {
// item_index: 5 // item_index: 5

View File

@ -76,7 +76,11 @@ constexpr char kMaxInFlightTag[] = "MAX_IN_FLIGHT";
// } // }
// output_stream: "gated_frames" // output_stream: "gated_frames"
// } // }
class RealTimeFlowLimiterCalculator : public CalculatorBase { //
// Please use FlowLimiterCalculator, which replaces this calculator and
// defines a few additional configuration options.
class ABSL_DEPRECATED("Use FlowLimiterCalculator instead.")
RealTimeFlowLimiterCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
int num_data_streams = cc->Inputs().NumEntries(""); int num_data_streams = cc->Inputs().NumEntries("");

View File

@ -66,12 +66,16 @@ class SequenceShiftCalculator : public Node {
// The number of packets or timestamps we need to store to output packet[i] at // The number of packets or timestamps we need to store to output packet[i] at
// the timestamp of packet[i + packet_offset]; equal to abs(packet_offset). // the timestamp of packet[i + packet_offset]; equal to abs(packet_offset).
int cache_size_; int cache_size_;
bool emit_empty_packets_before_first_packet_ = false;
}; };
MEDIAPIPE_REGISTER_NODE(SequenceShiftCalculator); MEDIAPIPE_REGISTER_NODE(SequenceShiftCalculator);
absl::Status SequenceShiftCalculator::Open(CalculatorContext* cc) { absl::Status SequenceShiftCalculator::Open(CalculatorContext* cc) {
packet_offset_ = kOffset(cc).GetOr( packet_offset_ = kOffset(cc).GetOr(
cc->Options<mediapipe::SequenceShiftCalculatorOptions>().packet_offset()); cc->Options<mediapipe::SequenceShiftCalculatorOptions>().packet_offset());
emit_empty_packets_before_first_packet_ =
cc->Options<mediapipe::SequenceShiftCalculatorOptions>()
.emit_empty_packets_before_first_packet();
cache_size_ = abs(packet_offset_); cache_size_ = abs(packet_offset_);
// An offset of zero is a no-op, but someone might still request it. // An offset of zero is a no-op, but someone might still request it.
if (packet_offset_ == 0) { if (packet_offset_ == 0) {
@ -96,6 +100,8 @@ void SequenceShiftCalculator::ProcessPositiveOffset(CalculatorContext* cc) {
// Ready to output oldest packet with current timestamp. // Ready to output oldest packet with current timestamp.
kOut(cc).Send(packet_cache_.front().At(cc->InputTimestamp())); kOut(cc).Send(packet_cache_.front().At(cc->InputTimestamp()));
packet_cache_.pop_front(); packet_cache_.pop_front();
} else if (emit_empty_packets_before_first_packet_) {
LOG(FATAL) << "Not supported yet";
} }
// Store current packet for later output. // Store current packet for later output.
packet_cache_.push_back(kIn(cc).packet()); packet_cache_.push_back(kIn(cc).packet());

View File

@ -23,4 +23,8 @@ message SequenceShiftCalculatorOptions {
optional SequenceShiftCalculatorOptions ext = 107633927; optional SequenceShiftCalculatorOptions ext = 107633927;
} }
optional int32 packet_offset = 1 [default = -1]; optional int32 packet_offset = 1 [default = -1];
// Emits empty packets before the first delayed packet is emitted. Takes
// effect only when packet offset is set to positive.
optional bool emit_empty_packets_before_first_packet = 2 [default = false];
} }

View File

@ -378,8 +378,8 @@ cc_library(
name = "scale_image_calculator", name = "scale_image_calculator",
srcs = ["scale_image_calculator.cc"], srcs = ["scale_image_calculator.cc"],
deps = [ deps = [
":scale_image_calculator_cc_proto",
":scale_image_utils", ":scale_image_utils",
"//mediapipe/calculators/image:scale_image_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
@ -747,8 +747,8 @@ cc_test(
tags = ["desktop_only_test"], tags = ["desktop_only_test"],
deps = [ deps = [
":affine_transformation", ":affine_transformation",
":image_transformation_calculator",
":warp_affine_calculator", ":warp_affine_calculator",
"//mediapipe/calculators/image:image_transformation_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_converter", "//mediapipe/calculators/tensor:image_to_tensor_converter",
"//mediapipe/calculators/tensor:image_to_tensor_utils", "//mediapipe/calculators/tensor:image_to_tensor_utils",
"//mediapipe/calculators/util:from_image_calculator", "//mediapipe/calculators/util:from_image_calculator",

View File

@ -92,8 +92,8 @@ class GlTextureWarpAffineRunner
constexpr GLchar kVertShader[] = R"( constexpr GLchar kVertShader[] = R"(
in vec4 position; in vec4 position;
in mediump vec4 texture_coordinate; in highp vec4 texture_coordinate;
out mediump vec2 sample_coordinate; out highp vec2 sample_coordinate;
uniform mat4 transform_matrix; uniform mat4 transform_matrix;
void main() { void main() {
@ -104,7 +104,7 @@ class GlTextureWarpAffineRunner
)"; )";
constexpr GLchar kFragShader[] = R"( constexpr GLchar kFragShader[] = R"(
DEFAULT_PRECISION(mediump, float) DEFAULT_PRECISION(highp, float)
in vec2 sample_coordinate; in vec2 sample_coordinate;
uniform sampler2D input_texture; uniform sampler2D input_texture;

View File

@ -38,6 +38,7 @@ void SetColorChannel(int channel, uint8 value, cv::Mat* mat) {
constexpr char kRgbaInTag[] = "RGBA_IN"; constexpr char kRgbaInTag[] = "RGBA_IN";
constexpr char kRgbInTag[] = "RGB_IN"; constexpr char kRgbInTag[] = "RGB_IN";
constexpr char kBgrInTag[] = "BGR_IN";
constexpr char kBgraInTag[] = "BGRA_IN"; constexpr char kBgraInTag[] = "BGRA_IN";
constexpr char kGrayInTag[] = "GRAY_IN"; constexpr char kGrayInTag[] = "GRAY_IN";
constexpr char kRgbaOutTag[] = "RGBA_OUT"; constexpr char kRgbaOutTag[] = "RGBA_OUT";
@ -57,6 +58,7 @@ constexpr char kGrayOutTag[] = "GRAY_OUT";
// RGB -> RGBA // RGB -> RGBA
// RGBA -> BGRA // RGBA -> BGRA
// BGRA -> RGBA // BGRA -> RGBA
// BGR -> RGB
// //
// This calculator only supports a single input stream and output stream at a // This calculator only supports a single input stream and output stream at a
// time. If more than one input stream or output stream is present, the // time. If more than one input stream or output stream is present, the
@ -69,6 +71,7 @@ constexpr char kGrayOutTag[] = "GRAY_OUT";
// RGB_IN: The input video stream (ImageFrame, SRGB). // RGB_IN: The input video stream (ImageFrame, SRGB).
// BGRA_IN: The input video stream (ImageFrame, SBGRA). // BGRA_IN: The input video stream (ImageFrame, SBGRA).
// GRAY_IN: The input video stream (ImageFrame, GRAY8). // GRAY_IN: The input video stream (ImageFrame, GRAY8).
// BGR_IN: The input video stream (ImageFrame, SBGR).
// //
// Output streams: // Output streams:
// RGBA_OUT: The output video stream (ImageFrame, SRGBA). // RGBA_OUT: The output video stream (ImageFrame, SRGBA).
@ -122,6 +125,10 @@ absl::Status ColorConvertCalculator::GetContract(CalculatorContract* cc) {
cc->Inputs().Tag(kBgraInTag).Set<ImageFrame>(); cc->Inputs().Tag(kBgraInTag).Set<ImageFrame>();
} }
if (cc->Inputs().HasTag(kBgrInTag)) {
cc->Inputs().Tag(kBgrInTag).Set<ImageFrame>();
}
if (cc->Outputs().HasTag(kRgbOutTag)) { if (cc->Outputs().HasTag(kRgbOutTag)) {
cc->Outputs().Tag(kRgbOutTag).Set<ImageFrame>(); cc->Outputs().Tag(kRgbOutTag).Set<ImageFrame>();
} }
@ -194,6 +201,11 @@ absl::Status ColorConvertCalculator::Process(CalculatorContext* cc) {
return ConvertAndOutput(kRgbaInTag, kBgraOutTag, ImageFormat::SBGRA, return ConvertAndOutput(kRgbaInTag, kBgraOutTag, ImageFormat::SBGRA,
cv::COLOR_RGBA2BGRA, cc); cv::COLOR_RGBA2BGRA, cc);
} }
// BGR -> RGB
if (cc->Inputs().HasTag(kBgrInTag) && cc->Outputs().HasTag(kRgbOutTag)) {
return ConvertAndOutput(kBgrInTag, kRgbOutTag, ImageFormat::SRGB,
cv::COLOR_BGR2RGB, cc);
}
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Unsupported image format conversion."; << "Unsupported image format conversion.";

View File

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"])
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
licenses(["notice"])
package(default_visibility = ["//visibility:private"]) package(default_visibility = ["//visibility:private"])
proto_library( proto_library(

View File

@ -68,8 +68,8 @@ class GlProcessor : public ImageToTensorConverter {
constexpr GLchar kExtractSubRectVertexShader[] = R"( constexpr GLchar kExtractSubRectVertexShader[] = R"(
in vec4 position; in vec4 position;
in mediump vec4 texture_coordinate; in highp vec4 texture_coordinate;
out mediump vec2 sample_coordinate; out highp vec2 sample_coordinate;
uniform mat4 transform_matrix; uniform mat4 transform_matrix;
void main() { void main() {
@ -86,7 +86,7 @@ class GlProcessor : public ImageToTensorConverter {
)"; )";
constexpr GLchar kExtractSubRectFragBody[] = R"( constexpr GLchar kExtractSubRectFragBody[] = R"(
DEFAULT_PRECISION(mediump, float) DEFAULT_PRECISION(highp, float)
// Provided by kExtractSubRectVertexShader. // Provided by kExtractSubRectVertexShader.
in vec2 sample_coordinate; in vec2 sample_coordinate;

View File

@ -22,8 +22,8 @@ cc_library(
name = "alignment_points_to_rects_calculator", name = "alignment_points_to_rects_calculator",
srcs = ["alignment_points_to_rects_calculator.cc"], srcs = ["alignment_points_to_rects_calculator.cc"],
deps = [ deps = [
":detections_to_rects_calculator",
":detections_to_rects_calculator_cc_proto", ":detections_to_rects_calculator_cc_proto",
"//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_options_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:detection_cc_proto",

View File

@ -1,4 +1,3 @@
#
# Copyright 2019 The MediaPipe Authors. # Copyright 2019 The MediaPipe Authors.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -227,13 +226,13 @@ cc_library(
":mediapipe_internal", ":mediapipe_internal",
], ],
deps = [ deps = [
":calculator_cc_proto",
":graph_service", ":graph_service",
":mediapipe_options_cc_proto",
":packet_generator_cc_proto",
":packet_type", ":packet_type",
":port", ":port",
"//mediapipe/framework:calculator_cc_proto", ":status_handler_cc_proto",
"//mediapipe/framework:mediapipe_options_cc_proto",
"//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework:status_handler_cc_proto",
"//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:any_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:options_map", "//mediapipe/framework/tool:options_map",
@ -329,10 +328,10 @@ cc_library(
":thread_pool_executor", ":thread_pool_executor",
":timestamp", ":timestamp",
":validated_graph_config", ":validated_graph_config",
"//mediapipe/framework:calculator_cc_proto", ":calculator_cc_proto",
"//mediapipe/framework:packet_generator_cc_proto", ":packet_generator_cc_proto",
"//mediapipe/framework:status_handler_cc_proto", ":status_handler_cc_proto",
"//mediapipe/framework:thread_pool_executor_cc_proto", ":thread_pool_executor_cc_proto",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
@ -370,7 +369,7 @@ cc_library(
visibility = [":mediapipe_internal"], visibility = [":mediapipe_internal"],
deps = [ deps = [
":graph_service", ":graph_service",
"//mediapipe/framework:packet", ":packet",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
], ],
) )
@ -380,7 +379,7 @@ cc_test(
srcs = ["graph_service_manager_test.cc"], srcs = ["graph_service_manager_test.cc"],
deps = [ deps = [
":graph_service_manager", ":graph_service_manager",
"//mediapipe/framework:packet", ":packet",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
], ],
) )
@ -392,6 +391,7 @@ cc_library(
visibility = [":mediapipe_internal"], visibility = [":mediapipe_internal"],
deps = [ deps = [
":calculator_base", ":calculator_base",
":calculator_cc_proto",
":calculator_context", ":calculator_context",
":calculator_context_manager", ":calculator_context_manager",
":calculator_state", ":calculator_state",
@ -408,10 +408,9 @@ cc_library(
":packet_set", ":packet_set",
":packet_type", ":packet_type",
":port", ":port",
":stream_handler_cc_proto",
":timestamp", ":timestamp",
":validated_graph_config", ":validated_graph_config",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:stream_handler_cc_proto",
"//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
@ -467,6 +466,7 @@ cc_library(
hdrs = ["calculator_state.h"], hdrs = ["calculator_state.h"],
visibility = [":mediapipe_internal"], visibility = [":mediapipe_internal"],
deps = [ deps = [
":calculator_cc_proto",
":counter", ":counter",
":counter_factory", ":counter_factory",
":graph_service", ":graph_service",
@ -476,7 +476,6 @@ cc_library(
":packet", ":packet",
":packet_set", ":packet_set",
":port", ":port",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:any_proto",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/tool:options_map", "//mediapipe/framework/tool:options_map",
@ -584,7 +583,7 @@ cc_library(
hdrs = ["executor.h"], hdrs = ["executor.h"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/framework:mediapipe_options_cc_proto", ":mediapipe_options_cc_proto",
"//mediapipe/framework/deps:registration", "//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor", "//mediapipe/framework/port:statusor",
@ -671,11 +670,11 @@ cc_library(
":collection_item_id", ":collection_item_id",
":input_stream_manager", ":input_stream_manager",
":input_stream_shard", ":input_stream_shard",
":mediapipe_options_cc_proto",
":mediapipe_profiling", ":mediapipe_profiling",
":packet", ":packet",
":packet_set", ":packet_set",
":packet_type", ":packet_type",
"//mediapipe/framework:mediapipe_options_cc_proto",
"//mediapipe/framework/deps:registration", "//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
@ -785,12 +784,12 @@ cc_library(
":calculator_context_manager", ":calculator_context_manager",
":collection", ":collection",
":collection_item_id", ":collection_item_id",
":mediapipe_options_cc_proto",
":output_stream_manager", ":output_stream_manager",
":output_stream_shard", ":output_stream_shard",
":packet_set", ":packet_set",
":packet_type", ":packet_type",
":timestamp", ":timestamp",
"//mediapipe/framework:mediapipe_options_cc_proto",
"//mediapipe/framework/deps:registration", "//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
@ -876,10 +875,10 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":packet", ":packet",
":packet_generator_cc_proto",
":packet_set", ":packet_set",
":packet_type", ":packet_type",
":port", ":port",
"//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework/deps:registration", "//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
@ -897,13 +896,13 @@ cc_library(
":delegating_executor", ":delegating_executor",
":executor", ":executor",
":packet", ":packet",
":packet_factory_cc_proto",
":packet_generator", ":packet_generator",
":packet_generator_cc_proto",
":packet_type", ":packet_type",
":port", ":port",
":thread_pool_executor", ":thread_pool_executor",
":validated_graph_config", ":validated_graph_config",
"//mediapipe/framework:packet_factory_cc_proto",
"//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
@ -1020,10 +1019,10 @@ cc_library(
hdrs = ["status_handler.h"], hdrs = ["status_handler.h"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":mediapipe_options_cc_proto",
":packet_set", ":packet_set",
":packet_type", ":packet_type",
":port", ":port",
"//mediapipe/framework:mediapipe_options_cc_proto",
"//mediapipe/framework/deps:registration", "//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
@ -1036,10 +1035,10 @@ cc_library(
hdrs = ["subgraph.h"], hdrs = ["subgraph.h"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":calculator_cc_proto",
":graph_service", ":graph_service",
":graph_service_manager", ":graph_service_manager",
":port", ":port",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/deps:registration", "//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
@ -1061,7 +1060,7 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":calculator_framework", ":calculator_framework",
"//mediapipe/framework:test_calculators_cc_proto", ":test_calculators_cc_proto",
"//mediapipe/framework/deps:mathutil", "//mediapipe/framework/deps:mathutil",
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
@ -1098,7 +1097,7 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":executor", ":executor",
"//mediapipe/framework:thread_pool_executor_cc_proto", ":thread_pool_executor_cc_proto",
"//mediapipe/framework/deps:thread_options", "//mediapipe/framework/deps:thread_options",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
@ -1163,22 +1162,22 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":calculator_base", ":calculator_base",
":calculator_cc_proto",
":calculator_contract", ":calculator_contract",
":graph_service_manager", ":graph_service_manager",
":legacy_calculator_support", ":legacy_calculator_support",
":packet", ":packet",
":packet_generator", ":packet_generator",
":packet_generator_cc_proto",
":packet_set", ":packet_set",
":packet_type", ":packet_type",
":port", ":port",
":status_handler", ":status_handler",
":status_handler_cc_proto",
":stream_handler_cc_proto",
":subgraph", ":subgraph",
":thread_pool_executor_cc_proto",
":timestamp", ":timestamp",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework:status_handler_cc_proto",
"//mediapipe/framework:stream_handler_cc_proto",
"//mediapipe/framework:thread_pool_executor_cc_proto",
"//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
@ -1203,11 +1202,11 @@ cc_test(
name = "validated_graph_config_test", name = "validated_graph_config_test",
srcs = ["validated_graph_config_test.cc"], srcs = ["validated_graph_config_test.cc"],
deps = [ deps = [
":calculator_cc_proto",
":calculator_framework", ":calculator_framework",
":graph_service", ":graph_service",
":graph_service_manager", ":graph_service_manager",
":validated_graph_config", ":validated_graph_config",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/api2:node", "//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
@ -1234,6 +1233,7 @@ cc_test(
linkstatic = 1, linkstatic = 1,
deps = [ deps = [
":calculator_base", ":calculator_base",
":calculator_cc_proto",
":calculator_context", ":calculator_context",
":calculator_context_manager", ":calculator_context_manager",
":calculator_registry", ":calculator_registry",
@ -1243,7 +1243,6 @@ cc_test(
":output_stream_shard", ":output_stream_shard",
":packet_set", ":packet_set",
":packet_type", ":packet_type",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:status_util", "//mediapipe/framework/tool:status_util",
@ -1257,11 +1256,11 @@ cc_test(
srcs = ["calculator_contract_test.cc"], srcs = ["calculator_contract_test.cc"],
linkstatic = 1, linkstatic = 1,
deps = [ deps = [
":calculator_cc_proto",
":calculator_contract", ":calculator_contract",
":calculator_contract_test_cc_proto", ":calculator_contract_test_cc_proto",
"//mediapipe/framework:calculator_cc_proto", ":packet_generator_cc_proto",
"//mediapipe/framework:packet_generator_cc_proto", ":status_handler_cc_proto",
"//mediapipe/framework:status_handler_cc_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
], ],
@ -1369,6 +1368,7 @@ cc_test(
srcs = ["calculator_context_test.cc"], srcs = ["calculator_context_test.cc"],
linkstatic = 1, linkstatic = 1,
deps = [ deps = [
":calculator_cc_proto",
":calculator_context", ":calculator_context",
":calculator_context_manager", ":calculator_context_manager",
":calculator_state", ":calculator_state",
@ -1377,7 +1377,6 @@ cc_test(
":output_stream_shard", ":output_stream_shard",
":packet_set", ":packet_set",
":packet_type", ":packet_type",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
@ -1404,6 +1403,7 @@ cc_test(
":executor", ":executor",
":input_stream_handler", ":input_stream_handler",
":lifetime_tracker", ":lifetime_tracker",
":mediapipe_options_cc_proto",
":output_stream_poller", ":output_stream_poller",
":packet_set", ":packet_set",
":packet_type", ":packet_type",
@ -1411,13 +1411,12 @@ cc_test(
":subgraph", ":subgraph",
":test_calculators", ":test_calculators",
":thread_pool_executor", ":thread_pool_executor",
":thread_pool_executor_cc_proto",
":timestamp", ":timestamp",
":type_map", ":type_map",
"//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:counting_source_calculator",
"//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:mux_calculator",
"//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/framework:mediapipe_options_cc_proto",
"//mediapipe/framework:thread_pool_executor_cc_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
@ -1482,12 +1481,12 @@ cc_test(
], ],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":calculator_cc_proto",
":calculator_framework", ":calculator_framework",
":test_calculators", ":test_calculators",
"//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:counting_source_calculator",
"//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:mux_calculator",
"//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
@ -1631,8 +1630,8 @@ cc_test(
srcs = ["packet_generator_test.cc"], srcs = ["packet_generator_test.cc"],
deps = [ deps = [
":packet_generator", ":packet_generator",
":packet_generator_cc_proto",
":packet_type", ":packet_type",
"//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/tool:validate_type", "//mediapipe/framework/tool:validate_type",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",

View File

@ -15,12 +15,17 @@
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe { namespace mediapipe::api2::builder {
namespace api2 { namespace {
namespace test {
using ::mediapipe::api2::test::Bar;
using ::mediapipe::api2::test::FloatAdder;
using ::mediapipe::api2::test::Foo;
using ::mediapipe::api2::test::Foo2;
using ::mediapipe::api2::test::FooBar1;
TEST(BuilderTest, BuildGraph) { TEST(BuilderTest, BuildGraph) {
builder::Graph graph; Graph graph;
auto& foo = graph.AddNode("Foo"); auto& foo = graph.AddNode("Foo");
auto& bar = graph.AddNode("Bar"); auto& bar = graph.AddNode("Bar");
graph.In("IN").SetName("base") >> foo.In("BASE"); graph.In("IN").SetName("base") >> foo.In("BASE");
@ -49,22 +54,19 @@ TEST(BuilderTest, BuildGraph) {
} }
TEST(BuilderTest, CopyableSource) { TEST(BuilderTest, CopyableSource) {
builder::Graph graph; Graph graph;
builder::Source<int> a = graph[Input<int>("A")]; Source<int> a = graph.In("A").SetName("a").Cast<int>();
a.SetName("a"); Source<int> b = graph.In("B").SetName("b").Cast<int>();
builder::Source<int> b = graph[Input<int>("B")]; SideSource<float> side_a =
b.SetName("b"); graph.SideIn("SIDE_A").SetName("side_a").Cast<float>();
builder::SideSource<float> side_a = graph[SideInput<float>("SIDE_A")]; SideSource<float> side_b =
side_a.SetName("side_a"); graph.SideIn("SIDE_B").SetName("side_b").Cast<float>();
builder::SideSource<float> side_b = graph[SideInput<float>("SIDE_B")]; Destination<int> out = graph.Out("OUT").Cast<int>();
side_b.SetName("side_b"); SideDestination<float> side_out = graph.SideOut("SIDE_OUT").Cast<float>();
builder::Destination<int> out = graph[Output<int>("OUT")];
builder::SideDestination<float> side_out =
graph[SideOutput<float>("SIDE_OUT")];
builder::Source<int> input = a; Source<int> input = a;
input = b; input = b;
builder::SideSource<float> side_input = side_b; SideSource<float> side_input = side_b;
side_input = side_a; side_input = side_a;
input >> out; input >> out;
@ -83,31 +85,27 @@ TEST(BuilderTest, CopyableSource) {
} }
TEST(BuilderTest, BuildGraphWithFunctions) { TEST(BuilderTest, BuildGraphWithFunctions) {
builder::Graph graph; Graph graph;
builder::Source<int> base = graph[Input<int>("IN")]; Source<int> base = graph.In("IN").SetName("base").Cast<int>();
base.SetName("base"); SideSource<float> side = graph.SideIn("SIDE").SetName("side").Cast<float>();
builder::SideSource<float> side = graph[SideInput<float>("SIDE")];
side.SetName("side");
auto foo_fn = [](builder::Source<int> base, builder::SideSource<float> side, auto foo_fn = [](Source<int> base, SideSource<float> side, Graph& graph) {
builder::Graph& graph) {
auto& foo = graph.AddNode("Foo"); auto& foo = graph.AddNode("Foo");
base >> foo[Input<int>("BASE")]; base >> foo.In("BASE");
side >> foo[SideInput<float>("SIDE")]; side >> foo.SideIn("SIDE");
return foo[Output<double>("OUT")]; return foo.Out("OUT")[0].Cast<double>();
}; };
builder::Source<double> foo_out = foo_fn(base, side, graph); Source<double> foo_out = foo_fn(base, side, graph);
auto bar_fn = [](builder::Source<double> in, builder::Graph& graph) { auto bar_fn = [](Source<double> in, Graph& graph) {
auto& bar = graph.AddNode("Bar"); auto& bar = graph.AddNode("Bar");
in >> bar[Input<double>("IN")]; in >> bar.In("IN");
return bar[Output<double>("OUT")]; return bar.Out("OUT")[0].Cast<double>();
}; };
builder::Source<double> bar_out = bar_fn(foo_out, graph); Source<double> bar_out = bar_fn(foo_out, graph);
bar_out.SetName("out");
bar_out >> graph[Output<double>("OUT")]; bar_out.SetName("out") >> graph.Out("OUT");
CalculatorGraphConfig expected = CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb( mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
@ -131,7 +129,7 @@ TEST(BuilderTest, BuildGraphWithFunctions) {
template <class FooT> template <class FooT>
void BuildGraphTypedTest() { void BuildGraphTypedTest() {
builder::Graph graph; Graph graph;
auto& foo = graph.AddNode<FooT>(); auto& foo = graph.AddNode<FooT>();
auto& bar = graph.AddNode<Bar>(); auto& bar = graph.AddNode<Bar>();
graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE")); graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE"));
@ -161,12 +159,12 @@ void BuildGraphTypedTest() {
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
} }
TEST(BuilderTest, BuildGraphTyped) { BuildGraphTypedTest<Foo>(); } TEST(BuilderTest, BuildGraphTyped) { BuildGraphTypedTest<test::Foo>(); }
TEST(BuilderTest, BuildGraphTyped2) { BuildGraphTypedTest<Foo2>(); } TEST(BuilderTest, BuildGraphTyped2) { BuildGraphTypedTest<test::Foo2>(); }
TEST(BuilderTest, FanOut) { TEST(BuilderTest, FanOut) {
builder::Graph graph; Graph graph;
auto& foo = graph.AddNode("Foo"); auto& foo = graph.AddNode("Foo");
auto& adder = graph.AddNode("FloatAdder"); auto& adder = graph.AddNode("FloatAdder");
graph.In("IN").SetName("base") >> foo.In("BASE"); graph.In("IN").SetName("base") >> foo.In("BASE");
@ -194,9 +192,9 @@ TEST(BuilderTest, FanOut) {
} }
TEST(BuilderTest, TypedMultiple) { TEST(BuilderTest, TypedMultiple) {
builder::Graph graph; Graph graph;
auto& foo = graph.AddNode<Foo>(); auto& foo = graph.AddNode<test::Foo>();
auto& adder = graph.AddNode<FloatAdder>(); auto& adder = graph.AddNode<test::FloatAdder>();
graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE")); graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE"));
foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[0]; foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[0];
foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[1]; foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[1];
@ -222,14 +220,14 @@ TEST(BuilderTest, TypedMultiple) {
} }
TEST(BuilderTest, TypedByPorts) { TEST(BuilderTest, TypedByPorts) {
builder::Graph graph; Graph graph;
auto& foo = graph.AddNode<Foo>(); auto& foo = graph.AddNode<test::Foo>();
auto& adder = graph.AddNode<FloatAdder>(); auto& adder = graph.AddNode<FloatAdder>();
graph[FooBar1::kIn].SetName("base") >> foo[Foo::kBase]; graph.In(FooBar1::kIn).SetName("base") >> foo[Foo::kBase];
foo[Foo::kOut] >> adder[FloatAdder::kIn][0]; foo[Foo::kOut] >> adder[FloatAdder::kIn][0];
foo[Foo::kOut] >> adder[FloatAdder::kIn][1]; foo[Foo::kOut] >> adder[FloatAdder::kIn][1];
adder[FloatAdder::kOut].SetName("out") >> graph[FooBar1::kOut]; adder[FloatAdder::kOut].SetName("out") >> graph.Out(FooBar1::kOut);
CalculatorGraphConfig expected = CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb( mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
@ -251,7 +249,7 @@ TEST(BuilderTest, TypedByPorts) {
} }
TEST(BuilderTest, PacketGenerator) { TEST(BuilderTest, PacketGenerator) {
builder::Graph graph; Graph graph;
auto& generator = graph.AddPacketGenerator("FloatGenerator"); auto& generator = graph.AddPacketGenerator("FloatGenerator");
graph.SideIn("IN") >> generator.SideIn("IN"); graph.SideIn("IN") >> generator.SideIn("IN");
generator.SideOut("OUT") >> graph.SideOut("OUT"); generator.SideOut("OUT") >> graph.SideOut("OUT");
@ -270,7 +268,7 @@ TEST(BuilderTest, PacketGenerator) {
} }
TEST(BuilderTest, EmptyTag) { TEST(BuilderTest, EmptyTag) {
builder::Graph graph; Graph graph;
auto& foo = graph.AddNode("Foo"); auto& foo = graph.AddNode("Foo");
graph.In("A").SetName("a") >> foo.In("")[0]; graph.In("A").SetName("a") >> foo.In("")[0];
graph.In("C").SetName("c") >> foo.In("")[2]; graph.In("C").SetName("c") >> foo.In("")[2];
@ -302,7 +300,7 @@ TEST(BuilderTest, StringLikeTags) {
const std::string kB = "B"; const std::string kB = "B";
constexpr absl::string_view kC = "C"; constexpr absl::string_view kC = "C";
builder::Graph graph; Graph graph;
auto& foo = graph.AddNode("Foo"); auto& foo = graph.AddNode("Foo");
graph.In(kA).SetName("a") >> foo.In(kA); graph.In(kA).SetName("a") >> foo.In(kA);
graph.In(kB).SetName("b") >> foo.In(kB); graph.In(kB).SetName("b") >> foo.In(kB);
@ -324,7 +322,7 @@ TEST(BuilderTest, StringLikeTags) {
} }
TEST(BuilderTest, GraphIndexes) { TEST(BuilderTest, GraphIndexes) {
builder::Graph graph; Graph graph;
auto& foo = graph.AddNode("Foo"); auto& foo = graph.AddNode("Foo");
graph.In(0).SetName("a") >> foo.In("")[0]; graph.In(0).SetName("a") >> foo.In("")[0];
graph.In(1).SetName("c") >> foo.In("")[2]; graph.In(1).SetName("c") >> foo.In("")[2];
@ -376,28 +374,27 @@ class AnyAndSameTypeCalculator : public NodeIntf {
}; };
TEST(BuilderTest, AnyAndSameTypeHandledProperly) { TEST(BuilderTest, AnyAndSameTypeHandledProperly) {
builder::Graph graph; Graph graph;
builder::Source<AnyType> any_input = graph[Input<AnyType>{"GRAPH_ANY_INPUT"}]; Source<AnyType> any_input = graph.In("GRAPH_ANY_INPUT");
builder::Source<int> int_input = graph[Input<int>{"GRAPH_INT_INPUT"}]; Source<int> int_input = graph.In("GRAPH_INT_INPUT").Cast<int>();
auto& node = graph.AddNode("AnyAndSameTypeCalculator"); auto& node = graph.AddNode("AnyAndSameTypeCalculator");
any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput];
int_input >> node[AnyAndSameTypeCalculator::kIntInput]; int_input >> node[AnyAndSameTypeCalculator::kIntInput];
builder::Source<AnyType> any_type_output = Source<AnyType> any_type_output =
node[AnyAndSameTypeCalculator::kAnyTypeOutput]; node[AnyAndSameTypeCalculator::kAnyTypeOutput];
any_type_output.SetName("any_type_output"); any_type_output.SetName("any_type_output");
builder::Source<AnyType> same_type_output = Source<AnyType> same_type_output =
node[AnyAndSameTypeCalculator::kSameTypeOutput]; node[AnyAndSameTypeCalculator::kSameTypeOutput];
same_type_output.SetName("same_type_output"); same_type_output.SetName("same_type_output");
builder::Source<AnyType> recursive_same_type_output = Source<AnyType> recursive_same_type_output =
node[AnyAndSameTypeCalculator::kRecursiveSameTypeOutput]; node[AnyAndSameTypeCalculator::kRecursiveSameTypeOutput];
recursive_same_type_output.SetName("recursive_same_type_output"); recursive_same_type_output.SetName("recursive_same_type_output");
builder::Source<int> same_int_output = Source<int> same_int_output = node[AnyAndSameTypeCalculator::kSameIntOutput];
node[AnyAndSameTypeCalculator::kSameIntOutput];
same_int_output.SetName("same_int_output"); same_int_output.SetName("same_int_output");
builder::Source<int> recursive_same_int_type_output = Source<int> recursive_same_int_type_output =
node[AnyAndSameTypeCalculator::kRecursiveSameIntOutput]; node[AnyAndSameTypeCalculator::kRecursiveSameIntOutput];
recursive_same_int_type_output.SetName("recursive_same_int_type_output"); recursive_same_int_type_output.SetName("recursive_same_int_type_output");
@ -420,15 +417,16 @@ TEST(BuilderTest, AnyAndSameTypeHandledProperly) {
} }
TEST(BuilderTest, AnyTypeCanBeCast) { TEST(BuilderTest, AnyTypeCanBeCast) {
builder::Graph graph; Graph graph;
builder::Source<std::string> any_input = Source<std::string> any_input =
graph.In("GRAPH_ANY_INPUT").Cast<std::string>(); graph.In("GRAPH_ANY_INPUT").Cast<std::string>();
auto& node = graph.AddNode("AnyAndSameTypeCalculator"); auto& node = graph.AddNode("AnyAndSameTypeCalculator");
any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput];
builder::Source<double> any_type_output = Source<double> any_type_output =
node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>(); node[AnyAndSameTypeCalculator::kAnyTypeOutput]
any_type_output.SetName("any_type_output"); .SetName("any_type_output")
.Cast<double>();
any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast<double>(); any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast<double>();
@ -446,11 +444,11 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
} }
TEST(BuilderTest, MultiPortIsCastToMultiPort) { TEST(BuilderTest, MultiPortIsCastToMultiPort) {
builder::Graph graph; Graph graph;
builder::MultiSource<AnyType> any_input = graph.In("ANY_INPUT"); MultiSource<AnyType> any_input = graph.In("ANY_INPUT");
builder::MultiSource<int> int_input = any_input.Cast<int>(); MultiSource<int> int_input = any_input.Cast<int>();
builder::MultiDestination<AnyType> any_output = graph.Out("ANY_OUTPUT"); MultiDestination<AnyType> any_output = graph.Out("ANY_OUTPUT");
builder::MultiDestination<int> int_output = any_output.Cast<int>(); MultiDestination<int> int_output = any_output.Cast<int>();
int_input >> int_output; int_input >> int_output;
CalculatorGraphConfig expected = CalculatorGraphConfig expected =
@ -462,11 +460,11 @@ TEST(BuilderTest, MultiPortIsCastToMultiPort) {
} }
TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) { TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) {
builder::Graph graph; Graph graph;
builder::MultiSource<AnyType> any_multi_input = graph.In("ANY_INPUT"); MultiSource<AnyType> any_multi_input = graph.In("ANY_INPUT");
builder::Source<AnyType> any_input = any_multi_input; Source<AnyType> any_input = any_multi_input;
builder::MultiDestination<AnyType> any_multi_output = graph.Out("ANY_OUTPUT"); MultiDestination<AnyType> any_multi_output = graph.Out("ANY_OUTPUT");
builder::Destination<AnyType> any_output = any_multi_output; Destination<AnyType> any_output = any_multi_output;
any_input >> any_output; any_input >> any_output;
CalculatorGraphConfig expected = CalculatorGraphConfig expected =
@ -478,11 +476,11 @@ TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) {
} }
TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) { TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) {
builder::Graph graph; Graph graph;
builder::Source<int> int_input = graph.In("INT_INPUT").Cast<int>(); Source<int> int_input = graph.In("INT_INPUT").Cast<int>();
builder::Source<AnyType> any_input = graph.In("ANY_OUTPUT"); Source<AnyType> any_input = graph.In("ANY_OUTPUT");
builder::Destination<int> int_output = graph.Out("INT_OUTPUT").Cast<int>(); Destination<int> int_output = graph.Out("INT_OUTPUT").Cast<int>();
builder::Destination<AnyType> any_output = graph.Out("ANY_OUTPUT"); Destination<AnyType> any_output = graph.Out("ANY_OUTPUT");
int_input >> int_output; int_input >> int_output;
any_input >> any_output; any_input >> any_output;
@ -496,6 +494,5 @@ TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) {
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
} }
} // namespace test } // namespace
} // namespace api2 } // namespace mediapipe::api2::builder
} // namespace mediapipe

View File

@ -557,8 +557,8 @@ class OutputSidePacketAccess {
if (output_) output_->Set(ToOldPacket(std::move(packet))); if (output_) output_->Set(ToOldPacket(std::move(packet)));
} }
void Set(const T& payload) { Set(MakePacket<T>(payload)); } void Set(const T& payload) { Set(api2::MakePacket<T>(payload)); }
void Set(T&& payload) { Set(MakePacket<T>(std::move(payload))); } void Set(T&& payload) { Set(api2::MakePacket<T>(std::move(payload))); }
private: private:
OutputSidePacketAccess(OutputSidePacket* output) : output_(output) {} OutputSidePacketAccess(OutputSidePacket* output) : output_(output) {}

View File

@ -20,9 +20,14 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
licenses(["notice"]) licenses(["notice"])
package(default_visibility = [ package_group(
"//mediapipe:__subpackages__", name = "mediapipe_internal",
]) packages = [
"//mediapipe/...",
],
)
package(default_visibility = ["mediapipe_internal"])
bzl_library( bzl_library(
name = "expand_template_bzl", name = "expand_template_bzl",
@ -214,6 +219,9 @@ cc_library(
name = "registration", name = "registration",
srcs = ["registration.cc"], srcs = ["registration.cc"],
hdrs = ["registration.h"], hdrs = ["registration.h"],
visibility = [
"mediapipe_internal",
],
deps = [ deps = [
":registration_token", ":registration_token",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",

View File

@ -26,7 +26,7 @@ licenses(["notice"])
mediapipe_proto_library( mediapipe_proto_library(
name = "detection_proto", name = "detection_proto",
srcs = ["detection.proto"], srcs = ["detection.proto"],
deps = ["//mediapipe/framework/formats:location_data_proto"], deps = [":location_data_proto"],
) )
mediapipe_register_type( mediapipe_register_type(
@ -38,7 +38,7 @@ mediapipe_register_type(
"::std::vector<::mediapipe::Detection>", "::std::vector<::mediapipe::Detection>",
"::std::vector<::mediapipe::DetectionList>", "::std::vector<::mediapipe::DetectionList>",
], ],
deps = ["//mediapipe/framework/formats:detection_cc_proto"], deps = [":detection_cc_proto"],
) )
mediapipe_proto_library( mediapipe_proto_library(
@ -105,8 +105,8 @@ cc_library(
srcs = ["matrix.cc"], srcs = ["matrix.cc"],
hdrs = ["matrix.h"], hdrs = ["matrix.h"],
deps = [ deps = [
":matrix_data_cc_proto",
"//mediapipe/framework:port", "//mediapipe/framework:port",
"//mediapipe/framework/formats:matrix_data_cc_proto",
"//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
@ -142,7 +142,7 @@ cc_library(
srcs = ["image_frame.cc"], srcs = ["image_frame.cc"],
hdrs = ["image_frame.h"], hdrs = ["image_frame.h"],
deps = [ deps = [
"//mediapipe/framework/formats:image_format_cc_proto", ":image_format_cc_proto",
"@com_google_absl//absl/base", "@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
@ -166,8 +166,8 @@ cc_library(
srcs = ["image_frame_opencv.cc"], srcs = ["image_frame_opencv.cc"],
hdrs = ["image_frame_opencv.h"], hdrs = ["image_frame_opencv.h"],
deps = [ deps = [
":image_format_cc_proto",
":image_frame", ":image_frame",
"//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_core",
], ],
) )
@ -194,7 +194,7 @@ cc_library(
deps = [ deps = [
"@com_google_protobuf//:protobuf", "@com_google_protobuf//:protobuf",
"//mediapipe/framework/formats/annotation:locus_cc_proto", "//mediapipe/framework/formats/annotation:locus_cc_proto",
"//mediapipe/framework/formats:location_data_cc_proto", ":location_data_cc_proto",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
@ -245,7 +245,7 @@ cc_library(
name = "video_stream_header", name = "video_stream_header",
hdrs = ["video_stream_header.h"], hdrs = ["video_stream_header.h"],
deps = [ deps = [
"//mediapipe/framework/formats:image_format_cc_proto", ":image_format_cc_proto",
], ],
) )
@ -263,9 +263,9 @@ cc_test(
size = "small", size = "small",
srcs = ["image_frame_opencv_test.cc"], srcs = ["image_frame_opencv_test.cc"],
deps = [ deps = [
":image_format_cc_proto",
":image_frame", ":image_frame",
":image_frame_opencv", ":image_frame_opencv",
"//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
@ -324,8 +324,8 @@ cc_library(
"//conditions:default": [], "//conditions:default": [],
}), }),
deps = [ deps = [
"//mediapipe/framework/formats:image_format_cc_proto", ":image_format_cc_proto",
"//mediapipe/framework/formats:image_frame", ":image_frame",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
"//mediapipe/framework:port", "//mediapipe/framework:port",
"//mediapipe/framework:type_map", "//mediapipe/framework:type_map",
@ -354,7 +354,7 @@ cc_library(
hdrs = ["image_multi_pool.h"], hdrs = ["image_multi_pool.h"],
deps = [ deps = [
":image", ":image",
"//mediapipe/framework/formats:image_frame_pool", ":image_frame_pool",
"//mediapipe/framework:port", "//mediapipe/framework:port",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
@ -390,7 +390,7 @@ cc_library(
], ],
deps = [ deps = [
":image", ":image",
"//mediapipe/framework/formats:image_format_cc_proto", ":image_format_cc_proto",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:statusor", "//mediapipe/framework/port:statusor",
@ -428,7 +428,10 @@ cc_library(
"tensor.cc", "tensor.cc",
"tensor_ahwb.cc", "tensor_ahwb.cc",
], ],
hdrs = ["tensor.h"], hdrs = [
"tensor.h",
"tensor_internal.h",
],
copts = select({ copts = select({
"//mediapipe:apple": [ "//mediapipe:apple": [
"-x objective-c++", "-x objective-c++",
@ -452,6 +455,7 @@ cc_library(
], ],
}), }),
deps = [ deps = [
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
"//mediapipe/framework:port", "//mediapipe/framework:port",

View File

@ -38,11 +38,11 @@ cc_library(
srcs = ["optical_flow_field.cc"], srcs = ["optical_flow_field.cc"],
hdrs = ["optical_flow_field.h"], hdrs = ["optical_flow_field.h"],
deps = [ deps = [
":optical_flow_field_data_cc_proto",
"//mediapipe/framework:type_map", "//mediapipe/framework:type_map",
"//mediapipe/framework/deps:mathutil", "//mediapipe/framework/deps:mathutil",
"//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location",
"//mediapipe/framework/formats:location_opencv", "//mediapipe/framework/formats:location_opencv",
"//mediapipe/framework/formats/motion:optical_flow_field_data_cc_proto",
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",

View File

@ -246,10 +246,10 @@ Tensor::OpenGlTexture2dView::GetLayoutDimensions(const Tensor::Shape& shape,
return Tensor::OpenGlTexture2dView::Layout::kAligned; return Tensor::OpenGlTexture2dView::Layout::kAligned;
} }
} }
// The best performance of a compute shader can be achived with textures' // The best performance of a compute shader can be achieved with textures'
// width multiple of 256. Making minimum fixed width of 256 waste memory for // width multiple of 256. Making minimum fixed width of 256 waste memory for
// small tensors. The optimal balance memory-vs-performance is power of 2. // small tensors. The optimal balance memory-vs-performance is power of 2.
// The texture width and height are choosen to be closer to square. // The texture width and height are chosen to be closer to square.
float power = std::log2(std::sqrt(static_cast<float>(num_pixels))); float power = std::log2(std::sqrt(static_cast<float>(num_pixels)));
w = 1 << static_cast<int>(power); w = 1 << static_cast<int>(power);
int h = (num_pixels + w - 1) / w; int h = (num_pixels + w - 1) / w;
@ -326,7 +326,7 @@ Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const {
auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_)); auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_));
AllocateOpenGlBuffer(); AllocateOpenGlBuffer();
if (!(valid_ & kValidOpenGlBuffer)) { if (!(valid_ & kValidOpenGlBuffer)) {
// If the call succeds then AHWB -> SSBO are synchronized so any usage of // If the call succeeds then AHWB -> SSBO are synchronized so any usage of
// the SSBO is correct after this call. // the SSBO is correct after this call.
if (!InsertAhwbToSsboFence()) { if (!InsertAhwbToSsboFence()) {
glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_);
@ -348,8 +348,10 @@ Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const {
}; };
} }
Tensor::OpenGlBufferView Tensor::GetOpenGlBufferWriteView() const { Tensor::OpenGlBufferView Tensor::GetOpenGlBufferWriteView(
uint64_t source_location_hash) const {
auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_)); auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_));
TrackAhwbUsage(source_location_hash);
AllocateOpenGlBuffer(); AllocateOpenGlBuffer();
valid_ = kValidOpenGlBuffer; valid_ = kValidOpenGlBuffer;
return {opengl_buffer_, std::move(lock), nullptr}; return {opengl_buffer_, std::move(lock), nullptr};
@ -385,6 +387,7 @@ void Tensor::Move(Tensor* src) {
src->element_type_ = ElementType::kNone; // Mark as invalidated. src->element_type_ = ElementType::kNone; // Mark as invalidated.
cpu_buffer_ = src->cpu_buffer_; cpu_buffer_ = src->cpu_buffer_;
src->cpu_buffer_ = nullptr; src->cpu_buffer_ = nullptr;
ahwb_tracking_key_ = src->ahwb_tracking_key_;
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
device_ = src->device_; device_ = src->device_;
src->device_ = nil; src->device_ = nil;
@ -589,8 +592,10 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const {
return {cpu_buffer_, std::move(lock)}; return {cpu_buffer_, std::move(lock)};
} }
Tensor::CpuWriteView Tensor::GetCpuWriteView() const { Tensor::CpuWriteView Tensor::GetCpuWriteView(
uint64_t source_location_hash) const {
auto lock = absl::make_unique<absl::MutexLock>(&view_mutex_); auto lock = absl::make_unique<absl::MutexLock>(&view_mutex_);
TrackAhwbUsage(source_location_hash);
AllocateCpuBuffer(); AllocateCpuBuffer();
valid_ = kValidCpu; valid_ = kValidCpu;
#ifdef MEDIAPIPE_TENSOR_USE_AHWB #ifdef MEDIAPIPE_TENSOR_USE_AHWB
@ -620,24 +625,4 @@ void Tensor::AllocateCpuBuffer() const {
} }
} }
void Tensor::SetPreferredStorageType(StorageType type) {
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
if (__builtin_available(android 26, *)) {
use_ahwb_ = type == StorageType::kAhwb;
VLOG(4) << "Tensor: use of AHardwareBuffer is "
<< (use_ahwb_ ? "allowed" : "not allowed");
}
#else
VLOG(4) << "Tensor: use of AHardwareBuffer is not allowed";
#endif // MEDIAPIPE_TENSOR_USE_AHWB
}
Tensor::StorageType Tensor::GetPreferredStorageType() {
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
return use_ahwb_ ? StorageType::kAhwb : StorageType::kDefault;
#else
return StorageType::kDefault;
#endif // MEDIAPIPE_TENSOR_USE_AHWB
}
} // namespace mediapipe } // namespace mediapipe

View File

@ -24,8 +24,9 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "absl/memory/memory.h" #include "absl/container/flat_hash_set.h"
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "mediapipe/framework/formats/tensor_internal.h"
#include "mediapipe/framework/port.h" #include "mediapipe/framework/port.h"
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
@ -48,6 +49,22 @@
#include "mediapipe/gpu/gl_context.h" #include "mediapipe/gpu/gl_context.h"
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
#if defined __has_builtin
#if __has_builtin(__builtin_LINE)
#define builtin_LINE __builtin_LINE
#endif
#if __has_builtin(__builtin_FILE)
#define builtin_FILE __builtin_FILE
#endif
#endif
#ifndef builtin_LINE
#define builtin_LINE() 0
#endif
#ifndef builtin_FILE
#define builtin_FILE() ""
#endif
namespace mediapipe { namespace mediapipe {
// Tensor is a container of multi-dimensional data that supports sharing the // Tensor is a container of multi-dimensional data that supports sharing the
@ -65,7 +82,7 @@ namespace mediapipe {
// GLuint buffer = view.buffer(); // GLuint buffer = view.buffer();
// Then the buffer can be bound to the GPU command buffer. // Then the buffer can be bound to the GPU command buffer.
// ...binding the buffer to the command buffer... // ...binding the buffer to the command buffer...
// ...commiting command buffer and releasing the view... // ...committing command buffer and releasing the view...
// //
// The following request for the CPU view will be blocked until the GPU view is // The following request for the CPU view will be blocked until the GPU view is
// released and the GPU task is finished. // released and the GPU task is finished.
@ -161,7 +178,9 @@ class Tensor {
using CpuReadView = CpuView<const void>; using CpuReadView = CpuView<const void>;
CpuReadView GetCpuReadView() const; CpuReadView GetCpuReadView() const;
using CpuWriteView = CpuView<void>; using CpuWriteView = CpuView<void>;
CpuWriteView GetCpuWriteView() const; CpuWriteView GetCpuWriteView(
uint64_t source_location_hash =
tensor_internal::FnvHash64(builtin_FILE(), builtin_LINE())) const;
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
// TODO: id<MTLBuffer> vs. MtlBufferView. // TODO: id<MTLBuffer> vs. MtlBufferView.
@ -305,7 +324,9 @@ class Tensor {
// A valid OpenGL context must be bound to the calling thread due to possible // A valid OpenGL context must be bound to the calling thread due to possible
// GPU resource allocation. // GPU resource allocation.
OpenGlBufferView GetOpenGlBufferReadView() const; OpenGlBufferView GetOpenGlBufferReadView() const;
OpenGlBufferView GetOpenGlBufferWriteView() const; OpenGlBufferView GetOpenGlBufferWriteView(
uint64_t source_location_hash =
tensor_internal::FnvHash64(builtin_FILE(), builtin_LINE())) const;
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
const Shape& shape() const { return shape_; } const Shape& shape() const { return shape_; }
@ -408,9 +429,13 @@ class Tensor {
mutable std::function<void()> release_callback_; mutable std::function<void()> release_callback_;
bool AllocateAHardwareBuffer(int size_alignment = 0) const; bool AllocateAHardwareBuffer(int size_alignment = 0) const;
void CreateEglSyncAndFd() const; void CreateEglSyncAndFd() const;
// Use Ahwb for other views: OpenGL / CPU buffer.
#endif // MEDIAPIPE_TENSOR_USE_AHWB #endif // MEDIAPIPE_TENSOR_USE_AHWB
static inline bool use_ahwb_ = false; // Use Ahwb for other views: OpenGL / CPU buffer.
mutable bool use_ahwb_ = false;
mutable uint64_t ahwb_tracking_key_ = 0;
// TODO: Tracks all unique tensors. Can grow to a large number. LRU
// can be more predicted.
static inline absl::flat_hash_set<uint64_t> ahwb_usage_track_;
// Expects the target SSBO to be already bound. // Expects the target SSBO to be already bound.
bool AllocateAhwbMapToSsbo() const; bool AllocateAhwbMapToSsbo() const;
bool InsertAhwbToSsboFence() const; bool InsertAhwbToSsboFence() const;
@ -419,6 +444,8 @@ class Tensor {
void* MapAhwbToCpuRead() const; void* MapAhwbToCpuRead() const;
void* MapAhwbToCpuWrite() const; void* MapAhwbToCpuWrite() const;
void MoveCpuOrSsboToAhwb() const; void MoveCpuOrSsboToAhwb() const;
// Set current tracking key, set "use ahwb" if the key is already marked.
void TrackAhwbUsage(uint64_t key) const;
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
mutable std::shared_ptr<mediapipe::GlContext> gl_context_; mutable std::shared_ptr<mediapipe::GlContext> gl_context_;

View File

@ -212,9 +212,6 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const {
CHECK(!(valid_ & kValidOpenGlTexture2d)) CHECK(!(valid_ & kValidOpenGlTexture2d))
<< "Tensor conversion between OpenGL texture and AHardwareBuffer is not " << "Tensor conversion between OpenGL texture and AHardwareBuffer is not "
"supported."; "supported.";
CHECK(ahwb_ || !(valid_ & kValidOpenGlBuffer))
<< "Interoperability bettween OpenGL buffer and AHardwareBuffer is not "
"supported on target system.";
bool transfer = !ahwb_; bool transfer = !ahwb_;
CHECK(AllocateAHardwareBuffer()) CHECK(AllocateAHardwareBuffer())
<< "AHardwareBuffer is not supported on the target system."; << "AHardwareBuffer is not supported on the target system.";
@ -268,6 +265,10 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView(
} }
bool Tensor::AllocateAHardwareBuffer(int size_alignment) const { bool Tensor::AllocateAHardwareBuffer(int size_alignment) const {
// Mark current tracking key as Ahwb-use.
ahwb_usage_track_.insert(ahwb_tracking_key_);
use_ahwb_ = true;
if (__builtin_available(android 26, *)) { if (__builtin_available(android 26, *)) {
if (ahwb_ == nullptr) { if (ahwb_ == nullptr) {
AHardwareBuffer_Desc desc = {}; AHardwareBuffer_Desc desc = {};
@ -315,7 +316,13 @@ void Tensor::MoveCpuOrSsboToAhwb() const {
ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, -1, nullptr, &dest); ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, -1, nullptr, &dest);
CHECK(error == 0) << "AHardwareBuffer_lock " << error; CHECK(error == 0) << "AHardwareBuffer_lock " << error;
} }
if (valid_ & kValidOpenGlBuffer) { if (valid_ & kValidCpu) {
std::memcpy(dest, cpu_buffer_, bytes());
// Free CPU memory because next time AHWB is mapped instead.
free(cpu_buffer_);
cpu_buffer_ = nullptr;
valid_ &= ~kValidCpu;
} else if (valid_ & kValidOpenGlBuffer) {
gl_context_->Run([this, dest]() { gl_context_->Run([this, dest]() {
glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_);
const void* src = glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(), const void* src = glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(),
@ -326,11 +333,9 @@ void Tensor::MoveCpuOrSsboToAhwb() const {
}); });
opengl_buffer_ = GL_INVALID_INDEX; opengl_buffer_ = GL_INVALID_INDEX;
gl_context_ = nullptr; gl_context_ = nullptr;
} else if (valid_ & kValidCpu) { // Reset OpenGL Buffer validness. The OpenGL buffer will be allocated on top
std::memcpy(dest, cpu_buffer_, bytes()); // of the Ahwb at the next request to the OpenGlBufferView.
// Free CPU memory because next time AHWB is mapped instead. valid_ &= ~kValidOpenGlBuffer;
free(cpu_buffer_);
cpu_buffer_ = nullptr;
} else { } else {
LOG(FATAL) << "Can't convert tensor with mask " << valid_ << " into AHWB."; LOG(FATAL) << "Can't convert tensor with mask " << valid_ << " into AHWB.";
} }
@ -446,6 +451,16 @@ void* Tensor::MapAhwbToCpuWrite() const {
return nullptr; return nullptr;
} }
void Tensor::TrackAhwbUsage(uint64_t source_location_hash) const {
if (ahwb_tracking_key_ == 0) {
ahwb_tracking_key_ = source_location_hash;
for (int dim : shape_.dims) {
ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim);
}
}
use_ahwb_ = ahwb_usage_track_.contains(ahwb_tracking_key_);
}
#else // MEDIAPIPE_TENSOR_USE_AHWB #else // MEDIAPIPE_TENSOR_USE_AHWB
bool Tensor::AllocateAhwbMapToSsbo() const { return false; } bool Tensor::AllocateAhwbMapToSsbo() const { return false; }
@ -454,6 +469,7 @@ void Tensor::MoveAhwbStuff(Tensor* src) {}
void Tensor::ReleaseAhwbStuff() {} void Tensor::ReleaseAhwbStuff() {}
void* Tensor::MapAhwbToCpuRead() const { return nullptr; } void* Tensor::MapAhwbToCpuRead() const { return nullptr; }
void* Tensor::MapAhwbToCpuWrite() const { return nullptr; } void* Tensor::MapAhwbToCpuWrite() const { return nullptr; }
void Tensor::TrackAhwbUsage(uint64_t key) const {}
#endif // MEDIAPIPE_TENSOR_USE_AHWB #endif // MEDIAPIPE_TENSOR_USE_AHWB

View File

@ -152,6 +152,36 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) {
{ {
auto view = tensor.GetAHardwareBufferReadView(); auto view = tensor.GetAHardwareBufferReadView();
EXPECT_NE(view.handle(), nullptr); EXPECT_NE(view.handle(), nullptr);
view.SetReadingFinishedFunc([](bool) { return true; });
}
auto ptr = tensor.GetCpuReadView().buffer<float>();
EXPECT_NE(ptr, nullptr);
std::vector<float> reference;
reference.resize(num_elements);
for (int i = 0; i < num_elements; i++) {
reference[i] = static_cast<float>(i) / 10.0f;
}
EXPECT_THAT(absl::Span<const float>(ptr, num_elements),
testing::Pointwise(testing::FloatEq(), reference));
}
TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) {
// Request the GPU view to get the ssbo allocated internally.
// Request Ahwb view then to transform the storage into Ahwb.
Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault);
constexpr size_t num_elements = 20;
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
RunInGlContext([&tensor] {
auto ssbo_view = tensor.GetOpenGlBufferWriteView();
auto ssbo_name = ssbo_view.name();
EXPECT_GT(ssbo_name, 0);
FillGpuBuffer(ssbo_name, tensor.shape().num_elements(),
tensor.element_type());
});
{
auto view = tensor.GetAHardwareBufferReadView();
EXPECT_NE(view.handle(), nullptr);
view.SetReadingFinishedFunc([](bool) { return true; });
} }
auto ptr = tensor.GetCpuReadView().buffer<float>(); auto ptr = tensor.GetCpuReadView().buffer<float>();
EXPECT_NE(ptr, nullptr); EXPECT_NE(ptr, nullptr);

View File

@ -1,71 +0,0 @@
#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_
#define MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_
#if !defined(MEDIAPIPE_NO_JNI) && \
(__ANDROID_API__ >= 26 || \
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
#include <android/hardware_buffer.h>
#include <cstdint>
#include "mediapipe/framework/formats/tensor_buffer.h"
#include "mediapipe/framework/formats/tensor_internal.h"
#include "mediapipe/framework/formats/tensor_v2.h"
namespace mediapipe {
// Supports:
// - float 16 and 32 bits
// - signed / unsigned integers 8,16,32 bits
class TensorHardwareBufferView;
struct TensorHardwareBufferViewDescriptor : public Tensor::ViewDescriptor {
using ViewT = TensorHardwareBufferView;
TensorBufferDescriptor buffer;
};
class TensorHardwareBufferView : public Tensor::View {
public:
TENSOR_UNIQUE_VIEW_TYPE_ID();
~TensorHardwareBufferView() = default;
const TensorHardwareBufferViewDescriptor& descriptor() const override {
return descriptor_;
}
AHardwareBuffer* handle() const { return ahwb_handle_; }
protected:
TensorHardwareBufferView(int access_capability, Tensor::View::Access access,
Tensor::View::State state,
const TensorHardwareBufferViewDescriptor& desc,
AHardwareBuffer* ahwb_handle)
: Tensor::View(kId, access_capability, access, state),
descriptor_(desc),
ahwb_handle_(ahwb_handle) {}
private:
bool MatchDescriptor(
uint64_t view_type_id,
const Tensor::ViewDescriptor& base_descriptor) const override {
if (!Tensor::View::MatchDescriptor(view_type_id, base_descriptor))
return false;
auto descriptor =
static_cast<const TensorHardwareBufferViewDescriptor&>(base_descriptor);
return descriptor.buffer.format == descriptor_.buffer.format &&
descriptor.buffer.size_alignment <=
descriptor_.buffer.size_alignment &&
descriptor_.buffer.size_alignment %
descriptor.buffer.size_alignment ==
0;
}
const TensorHardwareBufferViewDescriptor& descriptor_;
AHardwareBuffer* ahwb_handle_ = nullptr;
};
} // namespace mediapipe
#endif // !defined(MEDIAPIPE_NO_JNI) && \
(__ANDROID_API__ >= 26 || \
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
#endif // MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_

View File

@ -1,216 +0,0 @@
#if !defined(MEDIAPIPE_NO_JNI) && \
(__ANDROID_API__ >= 26 || \
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
#include <cstdint>
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "mediapipe/framework/formats/tensor_backend.h"
#include "mediapipe/framework/formats/tensor_cpu_buffer.h"
#include "mediapipe/framework/formats/tensor_hardware_buffer.h"
#include "mediapipe/framework/formats/tensor_v2.h"
#include "util/task/status_macros.h"
namespace mediapipe {
namespace {
class TensorCpuViewImpl : public TensorCpuView {
public:
TensorCpuViewImpl(int access_capabilities, Tensor::View::Access access,
Tensor::View::State state,
const TensorCpuViewDescriptor& descriptor, void* pointer,
AHardwareBuffer* ahwb_handle)
: TensorCpuView(access_capabilities, access, state, descriptor, pointer),
ahwb_handle_(ahwb_handle) {}
~TensorCpuViewImpl() {
// If handle_ is null then this view is constructed in GetViews with no
// access.
if (ahwb_handle_) {
if (__builtin_available(android 26, *)) {
AHardwareBuffer_unlock(ahwb_handle_, nullptr);
}
}
}
private:
AHardwareBuffer* ahwb_handle_;
};
class TensorHardwareBufferViewImpl : public TensorHardwareBufferView {
public:
TensorHardwareBufferViewImpl(
int access_capability, Tensor::View::Access access,
Tensor::View::State state,
const TensorHardwareBufferViewDescriptor& descriptor,
AHardwareBuffer* handle)
: TensorHardwareBufferView(access_capability, access, state, descriptor,
handle) {}
~TensorHardwareBufferViewImpl() = default;
};
class HardwareBufferCpuStorage : public TensorStorage {
public:
~HardwareBufferCpuStorage() {
if (!ahwb_handle_) return;
if (__builtin_available(android 26, *)) {
AHardwareBuffer_release(ahwb_handle_);
}
}
static absl::Status CanProvide(
int access_capability, const Tensor::Shape& shape, uint64_t view_type_id,
const Tensor::ViewDescriptor& base_descriptor) {
// TODO: use AHardwareBuffer_isSupported for API >= 29.
static const bool is_ahwb_supported = [] {
if (__builtin_available(android 26, *)) {
AHardwareBuffer_Desc desc = {};
// Aligned to the largest possible virtual memory page size.
constexpr uint32_t kPageSize = 16384;
desc.width = kPageSize;
desc.height = 1;
desc.layers = 1;
desc.format = AHARDWAREBUFFER_FORMAT_BLOB;
desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN |
AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN;
AHardwareBuffer* handle;
if (AHardwareBuffer_allocate(&desc, &handle) != 0) return false;
AHardwareBuffer_release(handle);
return true;
}
return false;
}();
if (!is_ahwb_supported) {
return absl::UnavailableError(
"AHardwareBuffer is not supported on the platform.");
}
if (view_type_id != TensorCpuView::kId &&
view_type_id != TensorHardwareBufferView::kId) {
return absl::InvalidArgumentError(
"A view type is not supported by this storage.");
}
return absl::OkStatus();
}
std::vector<std::unique_ptr<Tensor::View>> GetViews(uint64_t latest_version) {
std::vector<std::unique_ptr<Tensor::View>> result;
auto update_state = latest_version == version_
? Tensor::View::State::kUpToDate
: Tensor::View::State::kOutdated;
if (ahwb_handle_) {
result.push_back(
std::unique_ptr<Tensor::View>(new TensorHardwareBufferViewImpl(
kAccessCapability, Tensor::View::Access::kNoAccess, update_state,
hw_descriptor_, ahwb_handle_)));
result.push_back(std::unique_ptr<Tensor::View>(new TensorCpuViewImpl(
kAccessCapability, Tensor::View::Access::kNoAccess, update_state,
cpu_descriptor_, nullptr, nullptr)));
}
return result;
}
absl::StatusOr<std::unique_ptr<Tensor::View>> GetView(
Tensor::View::Access access, const Tensor::Shape& shape,
uint64_t latest_version, uint64_t view_type_id,
const Tensor::ViewDescriptor& base_descriptor, int access_capability) {
MP_RETURN_IF_ERROR(
CanProvide(access_capability, shape, view_type_id, base_descriptor));
const auto& buffer_descriptor =
view_type_id == TensorHardwareBufferView::kId
? static_cast<const TensorHardwareBufferViewDescriptor&>(
base_descriptor)
.buffer
: static_cast<const TensorCpuViewDescriptor&>(base_descriptor)
.buffer;
if (!ahwb_handle_) {
if (__builtin_available(android 26, *)) {
AHardwareBuffer_Desc desc = {};
desc.width = TensorBufferSize(buffer_descriptor, shape);
desc.height = 1;
desc.layers = 1;
desc.format = AHARDWAREBUFFER_FORMAT_BLOB;
// TODO: Use access capabilities to set hints.
desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN |
AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN;
auto error = AHardwareBuffer_allocate(&desc, &ahwb_handle_);
if (error != 0) {
return absl::UnknownError(
absl::StrCat("Error allocating hardware buffer: ", error));
}
// Fill all possible views to provide it as proto views.
hw_descriptor_.buffer = buffer_descriptor;
cpu_descriptor_.buffer = buffer_descriptor;
}
}
if (buffer_descriptor.format != hw_descriptor_.buffer.format ||
buffer_descriptor.size_alignment >
hw_descriptor_.buffer.size_alignment ||
hw_descriptor_.buffer.size_alignment %
buffer_descriptor.size_alignment >
0) {
return absl::AlreadyExistsError(
"A view with different params is already allocated with this "
"storage");
}
absl::StatusOr<std::unique_ptr<Tensor::View>> result;
if (view_type_id == TensorHardwareBufferView::kId) {
result = GetAhwbView(access, shape, base_descriptor);
} else {
result = GetCpuView(access, shape, base_descriptor);
}
if (result.ok()) version_ = latest_version;
return result;
}
private:
absl::StatusOr<std::unique_ptr<Tensor::View>> GetAhwbView(
Tensor::View::Access access, const Tensor::Shape& shape,
const Tensor::ViewDescriptor& base_descriptor) {
return std::unique_ptr<Tensor::View>(new TensorHardwareBufferViewImpl(
kAccessCapability, access, Tensor::View::State::kUpToDate,
hw_descriptor_, ahwb_handle_));
}
absl::StatusOr<std::unique_ptr<Tensor::View>> GetCpuView(
Tensor::View::Access access, const Tensor::Shape& shape,
const Tensor::ViewDescriptor& base_descriptor) {
void* pointer = nullptr;
if (__builtin_available(android 26, *)) {
int error =
AHardwareBuffer_lock(ahwb_handle_,
access == Tensor::View::Access::kWriteOnly
? AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN
: AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN,
-1, nullptr, &pointer);
if (error != 0) {
return absl::UnknownError(
absl::StrCat("Error locking hardware buffer: ", error));
}
}
return std::unique_ptr<Tensor::View>(
new TensorCpuViewImpl(access == Tensor::View::Access::kWriteOnly
? Tensor::View::AccessCapability::kWrite
: Tensor::View::AccessCapability::kRead,
access, Tensor::View::State::kUpToDate,
cpu_descriptor_, pointer, ahwb_handle_));
}
static constexpr int kAccessCapability =
Tensor::View::AccessCapability::kRead |
Tensor::View::AccessCapability::kWrite;
TensorHardwareBufferViewDescriptor hw_descriptor_;
AHardwareBuffer* ahwb_handle_ = nullptr;
TensorCpuViewDescriptor cpu_descriptor_;
uint64_t version_ = 0;
};
TENSOR_REGISTER_STORAGE(HardwareBufferCpuStorage);
} // namespace
} // namespace mediapipe
#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 ||
// defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))

View File

@ -1,76 +0,0 @@
#if !defined(MEDIAPIPE_NO_JNI) && \
(__ANDROID_API__ >= 26 || \
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
#include <android/hardware_buffer.h>
#include <cstdint>
#include "mediapipe/framework/formats/tensor_cpu_buffer.h"
#include "mediapipe/framework/formats/tensor_hardware_buffer.h"
#include "mediapipe/framework/formats/tensor_v2.h"
#include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h"
namespace mediapipe {
namespace {
class TensorHardwareBufferTest : public ::testing::Test {
public:
TensorHardwareBufferTest() {}
~TensorHardwareBufferTest() override {}
};
TEST_F(TensorHardwareBufferTest, TestFloat32) {
Tensor tensor{Tensor::Shape({1})};
{
MP_ASSERT_OK_AND_ASSIGN(
auto view,
tensor.GetView<Tensor::View::Access::kWriteOnly>(
TensorHardwareBufferViewDescriptor{
.buffer = {.format =
TensorBufferDescriptor::Format::kFloat32}}));
EXPECT_NE(view->handle(), nullptr);
}
{
const auto& const_tensor = tensor;
MP_ASSERT_OK_AND_ASSIGN(
auto view,
const_tensor.GetView<Tensor::View::Access::kReadOnly>(
TensorCpuViewDescriptor{
.buffer = {.format =
TensorBufferDescriptor::Format::kFloat32}}));
EXPECT_NE(view->data<void>(), nullptr);
}
}
TEST_F(TensorHardwareBufferTest, TestInt8Padding) {
Tensor tensor{Tensor::Shape({1})};
{
MP_ASSERT_OK_AND_ASSIGN(
auto view,
tensor.GetView<Tensor::View::Access::kWriteOnly>(
TensorHardwareBufferViewDescriptor{
.buffer = {.format = TensorBufferDescriptor::Format::kInt8,
.size_alignment = 4}}));
EXPECT_NE(view->handle(), nullptr);
}
{
const auto& const_tensor = tensor;
MP_ASSERT_OK_AND_ASSIGN(
auto view,
const_tensor.GetView<Tensor::View::Access::kReadOnly>(
TensorCpuViewDescriptor{
.buffer = {.format = TensorBufferDescriptor::Format::kInt8}}));
EXPECT_NE(view->data<void>(), nullptr);
}
}
} // namespace
} // namespace mediapipe
#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 ||
// defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))

View File

@ -18,8 +18,6 @@
#include <cstdint> #include <cstdint>
#include <type_traits> #include <type_traits>
#include "mediapipe/framework/tool/type_util.h"
namespace mediapipe { namespace mediapipe {
// Generates unique view id at compile-time using FILE and LINE. // Generates unique view id at compile-time using FILE and LINE.
@ -41,10 +39,12 @@ namespace tensor_internal {
// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function // https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
constexpr uint64_t kFnvPrime = 0x00000100000001B3; constexpr uint64_t kFnvPrime = 0x00000100000001B3;
constexpr uint64_t kFnvOffsetBias = 0xcbf29ce484222325; constexpr uint64_t kFnvOffsetBias = 0xcbf29ce484222325;
constexpr uint64_t FnvHash64(const char* str, uint64_t hash = kFnvOffsetBias) { constexpr uint64_t FnvHash64(uint64_t value1, uint64_t value2) {
return (str[0] == 0) ? hash : FnvHash64(str + 1, (hash ^ str[0]) * kFnvPrime); return (value2 ^ value1) * kFnvPrime;
}
constexpr uint64_t FnvHash64(const char* str, uint64_t hash = kFnvOffsetBias) {
return (str[0] == 0) ? hash : FnvHash64(str + 1, FnvHash64(hash, str[0]));
} }
template <typename... Ts> template <typename... Ts>
struct TypeList { struct TypeList {
static constexpr std::size_t size{sizeof...(Ts)}; static constexpr std::size_t size{sizeof...(Ts)};

View File

@ -88,8 +88,8 @@ cc_library(
srcs = ["default_input_stream_handler.cc"], srcs = ["default_input_stream_handler.cc"],
hdrs = ["default_input_stream_handler.h"], hdrs = ["default_input_stream_handler.h"],
deps = [ deps = [
":default_input_stream_handler_cc_proto",
"//mediapipe/framework:input_stream_handler", "//mediapipe/framework:input_stream_handler",
"//mediapipe/framework/stream_handler:default_input_stream_handler_cc_proto",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
alwayslink = 1, alwayslink = 1,
@ -110,8 +110,8 @@ cc_library(
srcs = ["fixed_size_input_stream_handler.cc"], srcs = ["fixed_size_input_stream_handler.cc"],
deps = [ deps = [
":default_input_stream_handler", ":default_input_stream_handler",
":fixed_size_input_stream_handler_cc_proto",
"//mediapipe/framework:input_stream_handler", "//mediapipe/framework:input_stream_handler",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -159,13 +159,13 @@ cc_library(
name = "sync_set_input_stream_handler", name = "sync_set_input_stream_handler",
srcs = ["sync_set_input_stream_handler.cc"], srcs = ["sync_set_input_stream_handler.cc"],
deps = [ deps = [
":sync_set_input_stream_handler_cc_proto",
"//mediapipe/framework:collection", "//mediapipe/framework:collection",
"//mediapipe/framework:collection_item_id", "//mediapipe/framework:collection_item_id",
"//mediapipe/framework:input_stream_handler", "//mediapipe/framework:input_stream_handler",
"//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:mediapipe_options_cc_proto",
"//mediapipe/framework:packet_set", "//mediapipe/framework:packet_set",
"//mediapipe/framework:timestamp", "//mediapipe/framework:timestamp",
"//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto",
"//mediapipe/framework/tool:tag_map", "//mediapipe/framework/tool:tag_map",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
@ -177,10 +177,10 @@ cc_library(
name = "timestamp_align_input_stream_handler", name = "timestamp_align_input_stream_handler",
srcs = ["timestamp_align_input_stream_handler.cc"], srcs = ["timestamp_align_input_stream_handler.cc"],
deps = [ deps = [
":timestamp_align_input_stream_handler_cc_proto",
"//mediapipe/framework:collection_item_id", "//mediapipe/framework:collection_item_id",
"//mediapipe/framework:input_stream_handler", "//mediapipe/framework:input_stream_handler",
"//mediapipe/framework:timestamp", "//mediapipe/framework:timestamp",
"//mediapipe/framework/stream_handler:timestamp_align_input_stream_handler_cc_proto",
"//mediapipe/framework/tool:validate_name", "//mediapipe/framework/tool:validate_name",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
@ -243,6 +243,7 @@ cc_test(
srcs = ["set_input_stream_handler_test.cc"], srcs = ["set_input_stream_handler_test.cc"],
deps = [ deps = [
":fixed_size_input_stream_handler", ":fixed_size_input_stream_handler",
":fixed_size_input_stream_handler_cc_proto",
":mux_input_stream_handler", ":mux_input_stream_handler",
"//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:mux_calculator",
"//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/calculators/core:pass_through_calculator",
@ -251,7 +252,6 @@ cc_test(
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto",
], ],
) )
@ -272,13 +272,13 @@ cc_test(
srcs = ["fixed_size_input_stream_handler_test.cc"], srcs = ["fixed_size_input_stream_handler_test.cc"],
deps = [ deps = [
":fixed_size_input_stream_handler", ":fixed_size_input_stream_handler",
":fixed_size_input_stream_handler_cc_proto",
"//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:counting_source_calculator",
"//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
], ],
@ -289,11 +289,11 @@ cc_test(
srcs = ["sync_set_input_stream_handler_test.cc"], srcs = ["sync_set_input_stream_handler_test.cc"],
deps = [ deps = [
":sync_set_input_stream_handler", ":sync_set_input_stream_handler",
":sync_set_input_stream_handler_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:test_calculators", "//mediapipe/framework:test_calculators",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",

View File

@ -299,6 +299,7 @@ mediapipe_cc_test(
requires_full_emulation = False, requires_full_emulation = False,
deps = [ deps = [
":node_chain_subgraph_cc_proto", ":node_chain_subgraph_cc_proto",
":node_chain_subgraph_options_lib",
":options_field_util", ":options_field_util",
":options_registry", ":options_registry",
":options_syntax_util", ":options_syntax_util",
@ -313,7 +314,6 @@ mediapipe_cc_test(
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/testdata:night_light_calculator_cc_proto", "//mediapipe/framework/testdata:night_light_calculator_cc_proto",
"//mediapipe/framework/testdata:night_light_calculator_options_lib", "//mediapipe/framework/testdata:night_light_calculator_options_lib",
"//mediapipe/framework/tool:node_chain_subgraph_options_lib",
"//mediapipe/util:header_util", "//mediapipe/util:header_util",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
@ -422,9 +422,9 @@ cc_library(
srcs = ["source.cc"], srcs = ["source.cc"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":source_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:source_cc_proto",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
@ -485,13 +485,13 @@ cc_library(
hdrs = ["template_expander.h"], hdrs = ["template_expander.h"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":calculator_graph_template_cc_proto",
":proto_util_lite", ":proto_util_lite",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:numbers", "//mediapipe/framework/port:numbers",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:calculator_graph_template_cc_proto",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
@ -506,6 +506,7 @@ cc_library(
], ],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":calculator_graph_template_cc_proto",
":proto_util_lite", ":proto_util_lite",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/deps:proto_descriptor_cc_proto", "//mediapipe/framework/deps:proto_descriptor_cc_proto",
@ -515,7 +516,6 @@ cc_library(
"//mediapipe/framework/port:map_util", "//mediapipe/framework/port:map_util",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:calculator_graph_template_cc_proto",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
@ -661,8 +661,8 @@ cc_library(
hdrs = ["simulation_clock_executor.h"], hdrs = ["simulation_clock_executor.h"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":simulation_clock",
"//mediapipe/framework:thread_pool_executor", "//mediapipe/framework:thread_pool_executor",
"//mediapipe/framework/tool:simulation_clock",
], ],
) )
@ -789,10 +789,10 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":name_util", ":name_util",
":switch_container_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:switch_container_cc_proto",
], ],
) )
@ -805,6 +805,7 @@ cc_library(
deps = [ deps = [
":container_util", ":container_util",
":options_util", ":options_util",
":switch_container_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id", "//mediapipe/framework:collection_item_id",
"//mediapipe/framework/deps:mathutil", "//mediapipe/framework/deps:mathutil",
@ -814,7 +815,6 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/stream_handler:immediate_input_stream_handler", "//mediapipe/framework/stream_handler:immediate_input_stream_handler",
"//mediapipe/framework/tool:switch_container_cc_proto",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
alwayslink = 1, alwayslink = 1,
@ -841,6 +841,7 @@ cc_library(
], ],
deps = [ deps = [
":container_util", ":container_util",
":switch_container_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id", "//mediapipe/framework:collection_item_id",
"//mediapipe/framework:input_stream_shard", "//mediapipe/framework:input_stream_shard",
@ -850,7 +851,6 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/stream_handler:immediate_input_stream_handler", "//mediapipe/framework/stream_handler:immediate_input_stream_handler",
"//mediapipe/framework/tool:switch_container_cc_proto",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -893,6 +893,7 @@ cc_library(
":container_util", ":container_util",
":name_util", ":name_util",
":subgraph_expansion", ":subgraph_expansion",
":switch_container_cc_proto",
":switch_demux_calculator", ":switch_demux_calculator",
":switch_mux_calculator", ":switch_mux_calculator",
"//mediapipe/calculators/core:packet_sequencer_calculator", "//mediapipe/calculators/core:packet_sequencer_calculator",
@ -904,7 +905,6 @@ cc_library(
"//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:switch_container_cc_proto",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
alwayslink = 1, alwayslink = 1,

View File

@ -564,6 +564,7 @@ cc_library(
name = "gpu_shared_data_internal_stub", name = "gpu_shared_data_internal_stub",
visibility = ["//visibility:private"], visibility = ["//visibility:private"],
deps = [ deps = [
":gl_context_options_cc_proto",
":graph_support", ":graph_support",
"//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_node", "//mediapipe/framework:calculator_node",
@ -571,7 +572,6 @@ cc_library(
"//mediapipe/framework:port", "//mediapipe/framework:port",
"//mediapipe/framework/deps:no_destructor", "//mediapipe/framework/deps:no_destructor",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/gpu:gl_context_options_cc_proto",
], ],
) )
@ -592,7 +592,7 @@ cc_library(
}), }),
visibility = ["//visibility:private"], visibility = ["//visibility:private"],
deps = [ deps = [
"//mediapipe/gpu:gl_context_options_cc_proto", ":gl_context_options_cc_proto",
":graph_support", ":graph_support",
"//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_context",
"//mediapipe/framework:executor", "//mediapipe/framework:executor",
@ -833,10 +833,10 @@ cc_library(
deps = [ deps = [
":gl_base", ":gl_base",
":gl_simple_shaders", ":gl_simple_shaders",
":scale_mode_cc_proto",
":shader_util", ":shader_util",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/gpu:scale_mode_cc_proto",
], ],
) )
@ -907,8 +907,8 @@ proto_library(
srcs = ["gl_scaler_calculator.proto"], srcs = ["gl_scaler_calculator.proto"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":scale_mode_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/gpu:scale_mode_proto",
], ],
) )
@ -930,6 +930,7 @@ cc_library(
deps = [ deps = [
":gl_calculator_helper", ":gl_calculator_helper",
":gl_quad_renderer", ":gl_quad_renderer",
":gl_scaler_calculator_cc_proto",
":gl_simple_shaders", ":gl_simple_shaders",
":shader_util", ":shader_util",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
@ -937,7 +938,6 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:options_util", "//mediapipe/framework/tool:options_util",
"//mediapipe/gpu:gl_scaler_calculator_cc_proto",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -950,13 +950,13 @@ cc_library(
":egl_surface_holder", ":egl_surface_holder",
":gl_calculator_helper", ":gl_calculator_helper",
":gl_quad_renderer", ":gl_quad_renderer",
":gl_surface_sink_calculator_cc_proto",
":gpu_buffer", ":gpu_buffer",
":shader_util", ":shader_util",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node", "//mediapipe/framework/api2:node",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/gpu:gl_surface_sink_calculator_cc_proto",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
], ],
alwayslink = 1, alwayslink = 1,
@ -966,8 +966,8 @@ proto_library(
name = "gl_surface_sink_calculator_proto", name = "gl_surface_sink_calculator_proto",
srcs = ["gl_surface_sink_calculator.proto"], srcs = ["gl_surface_sink_calculator.proto"],
deps = [ deps = [
":scale_mode_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/gpu:scale_mode_proto",
], ],
) )

View File

@ -15,10 +15,13 @@
package com.google.mediapipe.framework; package com.google.mediapipe.framework;
import android.graphics.Bitmap; import android.graphics.Bitmap;
import android.graphics.PixelFormat;
import android.media.Image;
import com.google.mediapipe.framework.image.BitmapExtractor; import com.google.mediapipe.framework.image.BitmapExtractor;
import com.google.mediapipe.framework.image.ByteBufferExtractor; import com.google.mediapipe.framework.image.ByteBufferExtractor;
import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.framework.image.MPImageProperties; import com.google.mediapipe.framework.image.MPImageProperties;
import com.google.mediapipe.framework.image.MediaImageExtractor;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
// TODO: use Preconditions in this file. // TODO: use Preconditions in this file.
@ -97,7 +100,17 @@ public class AndroidPacketCreator extends PacketCreator {
} }
return Packet.create(nativeCreateRgbaImage(mediapipeGraph.getNativeHandle(), bitmap)); return Packet.create(nativeCreateRgbaImage(mediapipeGraph.getNativeHandle(), bitmap));
} }
if (properties.getStorageType() == MPImage.STORAGE_TYPE_MEDIA_IMAGE) {
Image mediaImage = MediaImageExtractor.extract(image);
if (mediaImage.getFormat() != PixelFormat.RGBA_8888) {
throw new UnsupportedOperationException("Android media image must use RGBA_8888 config.");
}
return createImage(
mediaImage.getPlanes()[0].getBuffer(),
mediaImage.getWidth(),
mediaImage.getHeight(),
/* numChannels= */ 4);
}
// Unsupported type. // Unsupported type.
throw new UnsupportedOperationException( throw new UnsupportedOperationException(
"Unsupported Image container type: " + properties.getStorageType()); "Unsupported Image container type: " + properties.getStorageType());

View File

@ -14,6 +14,10 @@
package com.google.mediapipe.framework; package com.google.mediapipe.framework;
import com.google.common.flogger.FluentLogger;
import java.util.HashSet;
import java.util.Set;
/** /**
* A {@link TextureFrame} that represents a texture produced by MediaPipe. * A {@link TextureFrame} that represents a texture produced by MediaPipe.
* *
@ -21,6 +25,7 @@ package com.google.mediapipe.framework;
* method. * method.
*/ */
public class GraphTextureFrame implements TextureFrame { public class GraphTextureFrame implements TextureFrame {
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
private long nativeBufferHandle; private long nativeBufferHandle;
// We cache these to be able to get them without a JNI call. // We cache these to be able to get them without a JNI call.
private int textureName; private int textureName;
@ -30,6 +35,8 @@ public class GraphTextureFrame implements TextureFrame {
// True when created with PacketGetter.getTextureFrameDeferredSync(). This will result in gpuWait // True when created with PacketGetter.getTextureFrameDeferredSync(). This will result in gpuWait
// when calling getTextureName(). // when calling getTextureName().
private final boolean deferredSync; private final boolean deferredSync;
private final Set<Long> activeConsumerContextHandleSet = new HashSet<>();
private int refCount = 1;
GraphTextureFrame(long nativeHandle, long timestamp) { GraphTextureFrame(long nativeHandle, long timestamp) {
this(nativeHandle, timestamp, false); this(nativeHandle, timestamp, false);
@ -54,11 +61,12 @@ public class GraphTextureFrame implements TextureFrame {
* condition if release() is called after the if-check for nativeBufferHandle is already passed. * condition if release() is called after the if-check for nativeBufferHandle is already passed.
*/ */
@Override @Override
public int getTextureName() { public synchronized int getTextureName() {
// Return special texture id 0 if handle is 0 i.e. frame is already released. // Return special texture id 0 if handle is 0 i.e. frame is already released.
if (nativeBufferHandle == 0) { if (nativeBufferHandle == 0) {
return 0; return 0;
} }
if (activeConsumerContextHandleSet.add(nativeGetCurrentExternalContextHandle())) {
// Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using // Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using
// PacketGetter.getTextureFrameDeferredSync(). // PacketGetter.getTextureFrameDeferredSync().
if (deferredSync) { if (deferredSync) {
@ -66,6 +74,7 @@ public class GraphTextureFrame implements TextureFrame {
// cleared and this will turn into a no-op. See GlFenceSyncPoint::Wait. // cleared and this will turn into a no-op. See GlFenceSyncPoint::Wait.
nativeGpuWait(nativeBufferHandle); nativeGpuWait(nativeBufferHandle);
} }
}
return textureName; return textureName;
} }
@ -86,15 +95,31 @@ public class GraphTextureFrame implements TextureFrame {
return timestamp; return timestamp;
} }
@Override
public boolean supportsRetain() {
return true;
}
@Override
public synchronized void retain() {
// TODO: check that refCount is > 0 and handle is not 0.
refCount++;
}
/** /**
* Releases a reference to the underlying buffer. * Releases a reference to the underlying buffer.
* *
* <p>The consumer calls this when it is done using the texture. * <p>The consumer calls this when it is done using the texture.
*/ */
@Override @Override
public void release() { public synchronized void release() {
GlSyncToken consumerToken = GlSyncToken consumerToken = null;
// Note that this remove should be moved to the other overload of release when b/68808951 is
// addressed.
if (activeConsumerContextHandleSet.remove(nativeGetCurrentExternalContextHandle())) {
consumerToken =
new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle)); new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle));
}
release(consumerToken); release(consumerToken);
} }
@ -108,18 +133,40 @@ public class GraphTextureFrame implements TextureFrame {
* currently cannot create a GlSyncToken, so they cannot call this method. * currently cannot create a GlSyncToken, so they cannot call this method.
*/ */
@Override @Override
public void release(GlSyncToken consumerSyncToken) { public synchronized void release(GlSyncToken consumerSyncToken) {
if (nativeBufferHandle != 0) { if (nativeBufferHandle == 0) {
long token = consumerSyncToken == null ? 0 : consumerSyncToken.nativeToken();
nativeReleaseBuffer(nativeBufferHandle, token);
nativeBufferHandle = 0;
}
if (consumerSyncToken != null) { if (consumerSyncToken != null) {
logger.atWarning().log("release with sync token, but handle is 0");
}
return;
}
if (consumerSyncToken != null) {
long token = consumerSyncToken.nativeToken();
nativeDidRead(nativeBufferHandle, token);
// We should remove the token's context from activeConsumerContextHandleSet here, but for now
// we do it in the release(void) overload.
consumerSyncToken.release(); consumerSyncToken.release();
} }
refCount--;
if (refCount <= 0) {
nativeReleaseBuffer(nativeBufferHandle);
nativeBufferHandle = 0;
}
} }
private native void nativeReleaseBuffer(long nativeHandle, long consumerSyncToken); @Override
protected void finalize() throws Throwable {
if (refCount >= 0 || nativeBufferHandle != 0) {
logger.atWarning().log("release was not called before finalize");
}
if (!activeConsumerContextHandleSet.isEmpty()) {
logger.atWarning().log("active consumers did not release with sync before finalize");
}
}
private native void nativeReleaseBuffer(long nativeHandle);
private native int nativeGetTextureName(long nativeHandle); private native int nativeGetTextureName(long nativeHandle);
private native int nativeGetWidth(long nativeHandle); private native int nativeGetWidth(long nativeHandle);
@ -128,4 +175,8 @@ public class GraphTextureFrame implements TextureFrame {
private native void nativeGpuWait(long nativeHandle); private native void nativeGpuWait(long nativeHandle);
private native long nativeCreateSyncTokenForCurrentExternalContext(long nativeHandle); private native long nativeCreateSyncTokenForCurrentExternalContext(long nativeHandle);
private native long nativeGetCurrentExternalContextHandle();
private native void nativeDidRead(long nativeHandle, long consumerSyncToken);
} }

View File

@ -59,4 +59,18 @@ public interface TextureFrame extends TextureReleaseCallback {
*/ */
@Override @Override
void release(GlSyncToken syncToken); void release(GlSyncToken syncToken);
/**
* If this method returns true, this object supports the retain method, and can be used with
* multiple consumers. Call retain for each additional consumer beyond the first; each consumer
* should call release.
*/
default boolean supportsRetain() {
return false;
}
/** Increments the reference count. Only available with some implementations of TextureFrame. */
default void retain() {
throw new UnsupportedOperationException();
}
} }

View File

@ -15,20 +15,16 @@
#include "mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h"
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_context.h"
#include "mediapipe/gpu/gl_texture_buffer.h" #include "mediapipe/gpu/gl_texture_buffer.h"
#include "mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h"
using mediapipe::GlTextureBufferSharedPtr; using mediapipe::GlTextureBufferSharedPtr;
JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeReleaseBuffer)( JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeReleaseBuffer)(
JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken) { JNIEnv* env, jobject thiz, jlong nativeHandle) {
GlTextureBufferSharedPtr* buffer = GlTextureBufferSharedPtr* buffer =
reinterpret_cast<GlTextureBufferSharedPtr*>(nativeHandle); reinterpret_cast<GlTextureBufferSharedPtr*>(nativeHandle);
if (consumerSyncToken) {
mediapipe::GlSyncToken& token =
*reinterpret_cast<mediapipe::GlSyncToken*>(consumerSyncToken);
(*buffer)->DidRead(token);
}
delete buffer; delete buffer;
} }
@ -84,3 +80,18 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD(
} }
return reinterpret_cast<jlong>(token); return reinterpret_cast<jlong>(token);
} }
JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD(
nativeGetCurrentExternalContextHandle)(JNIEnv* env, jobject thiz) {
return reinterpret_cast<jlong>(
mediapipe::GlContext::GetCurrentNativeContext());
}
JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeDidRead)(
JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken) {
GlTextureBufferSharedPtr* buffer =
reinterpret_cast<GlTextureBufferSharedPtr*>(nativeHandle);
mediapipe::GlSyncToken& token =
*reinterpret_cast<mediapipe::GlSyncToken*>(consumerSyncToken);
(*buffer)->DidRead(token);
}

View File

@ -26,7 +26,7 @@ extern "C" {
// Releases a native mediapipe::GpuBuffer. // Releases a native mediapipe::GpuBuffer.
JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeReleaseBuffer)( JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeReleaseBuffer)(
JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken); JNIEnv* env, jobject thiz, jlong nativeHandle);
JNIEXPORT jint JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeGetTextureName)( JNIEXPORT jint JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeGetTextureName)(
JNIEnv* env, jobject thiz, jlong nativeHandle); JNIEnv* env, jobject thiz, jlong nativeHandle);
@ -44,6 +44,12 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD(
nativeCreateSyncTokenForCurrentExternalContext)(JNIEnv* env, jobject thiz, nativeCreateSyncTokenForCurrentExternalContext)(JNIEnv* env, jobject thiz,
jlong nativeHandle); jlong nativeHandle);
JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeDidRead)(
JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken);
JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD(
nativeGetCurrentExternalContextHandle)(JNIEnv* env, jobject thiz);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif // __cplusplus #endif // __cplusplus

View File

@ -17,3 +17,6 @@ from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.vision import image_classifier from mediapipe.model_maker.python.vision import image_classifier
from mediapipe.model_maker.python.vision import gesture_recognizer from mediapipe.model_maker.python.vision import gesture_recognizer
from mediapipe.model_maker.python.text import text_classifier from mediapipe.model_maker.python.text import text_classifier
# Remove duplicated and non-public API
del python

View File

@ -29,3 +29,12 @@ BertModelOptions = model_options.BertModelOptions
SupportedModels = model_spec.SupportedModels SupportedModels = model_spec.SupportedModels
TextClassifier = text_classifier.TextClassifier TextClassifier = text_classifier.TextClassifier
TextClassifierOptions = text_classifier_options.TextClassifierOptions TextClassifierOptions = text_classifier_options.TextClassifierOptions
# Remove duplicated and non-public API
del hyperparameters
del dataset
del model_options
del model_spec
del preprocessor # pylint: disable=undefined-variable
del text_classifier
del text_classifier_options

View File

@ -33,7 +33,6 @@ from mediapipe.model_maker.python.text.text_classifier import preprocessor
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.metadata.metadata_writers import text_classifier as text_classifier_writer from mediapipe.tasks.python.metadata.metadata_writers import text_classifier as text_classifier_writer
from official.nlp import optimization
def _validate(options: text_classifier_options.TextClassifierOptions): def _validate(options: text_classifier_options.TextClassifierOptions):
@ -417,8 +416,22 @@ class _BertClassifier(TextClassifier):
total_steps = self._hparams.steps_per_epoch * self._hparams.epochs total_steps = self._hparams.steps_per_epoch * self._hparams.epochs
warmup_steps = int(total_steps * 0.1) warmup_steps = int(total_steps * 0.1)
initial_lr = self._hparams.learning_rate initial_lr = self._hparams.learning_rate
self._optimizer = optimization.create_optimizer(initial_lr, total_steps, # Implements linear decay of the learning rate.
warmup_steps) lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=initial_lr,
decay_steps=total_steps,
end_learning_rate=0.0,
power=1.0)
if warmup_steps:
lr_schedule = model_util.WarmUp(
initial_learning_rate=initial_lr,
decay_schedule_fn=lr_schedule,
warmup_steps=warmup_steps)
self._optimizer = tf.keras.optimizers.experimental.AdamW(
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0)
self._optimizer.exclude_from_weight_decay(
var_names=["LayerNorm", "layer_norm", "bias"])
def _save_vocab(self, vocab_filepath: str): def _save_vocab(self, vocab_filepath: str):
tf.io.gfile.copy( tf.io.gfile.copy(

View File

@ -146,6 +146,8 @@ py_test(
tags = ["notsan"], tags = ["notsan"],
deps = [ deps = [
":gesture_recognizer_import", ":gesture_recognizer_import",
":hyperparameters",
":model_options",
"//mediapipe/model_maker/python/core/utils:test_util", "//mediapipe/model_maker/python/core/utils:test_util",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
], ],

View File

@ -25,3 +25,12 @@ HParams = hyperparameters.HParams
Dataset = dataset.Dataset Dataset = dataset.Dataset
HandDataPreprocessingParams = dataset.HandDataPreprocessingParams HandDataPreprocessingParams = dataset.HandDataPreprocessingParams
GestureRecognizerOptions = gesture_recognizer_options.GestureRecognizerOptions GestureRecognizerOptions = gesture_recognizer_options.GestureRecognizerOptions
# Remove duplicated and non-public API
del constants # pylint: disable=undefined-variable
del dataset
del gesture_recognizer
del gesture_recognizer_options
del hyperparameters
del metadata_writer # pylint: disable=undefined-variable
del model_options

View File

@ -173,15 +173,20 @@ class GestureRecognizer(classifier.Classifier):
batch_size=None, batch_size=None,
dtype=tf.float32, dtype=tf.float32,
name='hand_embedding') name='hand_embedding')
x = inputs
x = tf.keras.layers.BatchNormalization()(inputs)
x = tf.keras.layers.ReLU()(x)
dropout_rate = self._model_options.dropout_rate dropout_rate = self._model_options.dropout_rate
x = tf.keras.layers.Dropout(rate=dropout_rate, name='dropout')(x) for i, width in enumerate(self._model_options.layer_widths):
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.Dropout(rate=dropout_rate)(x)
x = tf.keras.layers.Dense(width, name=f'custom_gesture_recognizer_{i}')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.Dropout(rate=dropout_rate)(x)
outputs = tf.keras.layers.Dense( outputs = tf.keras.layers.Dense(
self._num_classes, self._num_classes,
activation='softmax', activation='softmax',
name='custom_gesture_recognizer')( name='custom_gesture_recognizer_out')(
x) x)
self._model = tf.keras.Model(inputs=inputs, outputs=outputs) self._model = tf.keras.Model(inputs=inputs, outputs=outputs)

View File

@ -23,6 +23,8 @@ import tensorflow as tf
from mediapipe.model_maker.python.core.utils import test_util from mediapipe.model_maker.python.core.utils import test_util
from mediapipe.model_maker.python.vision import gesture_recognizer from mediapipe.model_maker.python.vision import gesture_recognizer
from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters
from mediapipe.model_maker.python.vision.gesture_recognizer import model_options
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
_TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdata' _TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdata'
@ -48,11 +50,11 @@ class GestureRecognizerTest(tf.test.TestCase):
self._train_data, self._validation_data = all_data.split(0.9) self._train_data, self._validation_data = all_data.split(0.9)
def test_gesture_recognizer_model(self): def test_gesture_recognizer_model(self):
model_options = gesture_recognizer.ModelOptions() mo = gesture_recognizer.ModelOptions()
hparams = gesture_recognizer.HParams( hparams = gesture_recognizer.HParams(
export_dir=tempfile.mkdtemp(), epochs=2) export_dir=tempfile.mkdtemp(), epochs=2)
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
model_options=model_options, hparams=hparams) model_options=mo, hparams=hparams)
model = gesture_recognizer.GestureRecognizer.create( model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data, train_data=self._train_data,
validation_data=self._validation_data, validation_data=self._validation_data,
@ -60,12 +62,38 @@ class GestureRecognizerTest(tf.test.TestCase):
self._test_accuracy(model) self._test_accuracy(model)
def test_export_gesture_recognizer_model(self): @unittest_mock.patch.object(
model_options = gesture_recognizer.ModelOptions() tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense)
def test_gesture_recognizer_model_layer_widths(self, mock_dense):
layer_widths = [64, 32]
mo = gesture_recognizer.ModelOptions(layer_widths=layer_widths)
hparams = gesture_recognizer.HParams( hparams = gesture_recognizer.HParams(
export_dir=tempfile.mkdtemp(), epochs=2) export_dir=tempfile.mkdtemp(), epochs=2)
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
model_options=model_options, hparams=hparams) model_options=mo, hparams=hparams)
model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data,
validation_data=self._validation_data,
options=gesture_recognizer_options)
expected_calls = [
unittest_mock.call(w, name=f'custom_gesture_recognizer_{i}')
for i, w in enumerate(layer_widths)
]
expected_calls.append(
unittest_mock.call(
len(self._train_data.label_names),
activation='softmax',
name='custom_gesture_recognizer_out'))
self.assertLen(mock_dense.call_args_list, len(expected_calls))
mock_dense.assert_has_calls(expected_calls)
self._test_accuracy(model)
def test_export_gesture_recognizer_model(self):
mo = gesture_recognizer.ModelOptions()
hparams = gesture_recognizer.HParams(
export_dir=tempfile.mkdtemp(), epochs=2)
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
model_options=mo, hparams=hparams)
model = gesture_recognizer.GestureRecognizer.create( model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data, train_data=self._train_data,
validation_data=self._validation_data, validation_data=self._validation_data,
@ -102,12 +130,12 @@ class GestureRecognizerTest(tf.test.TestCase):
self.assertGreater(accuracy, threshold) self.assertGreater(accuracy, threshold)
@unittest_mock.patch.object( @unittest_mock.patch.object(
gesture_recognizer.hyperparameters, hyperparameters,
'HParams', 'HParams',
autospec=True, autospec=True,
return_value=gesture_recognizer.HParams(epochs=1)) return_value=gesture_recognizer.HParams(epochs=1))
@unittest_mock.patch.object( @unittest_mock.patch.object(
gesture_recognizer.model_options, model_options,
'GestureRecognizerModelOptions', 'GestureRecognizerModelOptions',
autospec=True, autospec=True,
return_value=gesture_recognizer.ModelOptions()) return_value=gesture_recognizer.ModelOptions())
@ -122,11 +150,11 @@ class GestureRecognizerTest(tf.test.TestCase):
mock_model_options.assert_called_once() mock_model_options.assert_called_once()
def test_continual_training_by_loading_checkpoint(self): def test_continual_training_by_loading_checkpoint(self):
model_options = gesture_recognizer.ModelOptions() mo = gesture_recognizer.ModelOptions()
hparams = gesture_recognizer.HParams( hparams = gesture_recognizer.HParams(
export_dir=tempfile.mkdtemp(), epochs=2) export_dir=tempfile.mkdtemp(), epochs=2)
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
model_options=model_options, hparams=hparams) model_options=mo, hparams=hparams)
mock_stdout = io.StringIO() mock_stdout = io.StringIO()
with mock.patch('sys.stdout', mock_stdout): with mock.patch('sys.stdout', mock_stdout):
model = gesture_recognizer.GestureRecognizer.create( model = gesture_recognizer.GestureRecognizer.create(

View File

@ -14,6 +14,7 @@
"""Configurable model options for gesture recognizer models.""" """Configurable model options for gesture recognizer models."""
import dataclasses import dataclasses
from typing import List
@dataclasses.dataclass @dataclasses.dataclass
@ -23,5 +24,10 @@ class GestureRecognizerModelOptions:
Attributes: Attributes:
dropout_rate: The fraction of the input units to drop, used in dropout dropout_rate: The fraction of the input units to drop, used in dropout
layer. layer.
layer_widths: A list of hidden layer widths for the gesture model. Each
element in the list will create a new hidden layer with the specified
width. The hidden layers are separated with BatchNorm, Dropout, and ReLU.
Defaults to an empty list(no hidden layers).
""" """
dropout_rate: float = 0.05 dropout_rate: float = 0.05
layer_widths: List[int] = dataclasses.field(default_factory=list)

View File

@ -121,7 +121,9 @@ py_library(
srcs = ["image_classifier_test.py"], srcs = ["image_classifier_test.py"],
data = ["//mediapipe/model_maker/python/vision/image_classifier/testdata"], data = ["//mediapipe/model_maker/python/vision/image_classifier/testdata"],
deps = [ deps = [
":hyperparameters",
":image_classifier_import", ":image_classifier_import",
":model_options",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
], ],
) )

View File

@ -27,3 +27,12 @@ ModelOptions = model_options.ImageClassifierModelOptions
ModelSpec = model_spec.ModelSpec ModelSpec = model_spec.ModelSpec
SupportedModels = model_spec.SupportedModels SupportedModels = model_spec.SupportedModels
ImageClassifierOptions = image_classifier_options.ImageClassifierOptions ImageClassifierOptions = image_classifier_options.ImageClassifierOptions
# Remove duplicated and non-public API
del dataset
del hyperparameters
del image_classifier
del image_classifier_options
del model_options
del model_spec
del train_image_classifier_lib # pylint: disable=undefined-variable

View File

@ -24,6 +24,8 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.vision import image_classifier from mediapipe.model_maker.python.vision import image_classifier
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
from mediapipe.model_maker.python.vision.image_classifier import model_options
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
@ -159,15 +161,15 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
self.assertGreaterEqual(accuracy, threshold) self.assertGreaterEqual(accuracy, threshold)
@unittest_mock.patch.object( @unittest_mock.patch.object(
image_classifier.hyperparameters, hyperparameters,
'HParams', 'HParams',
autospec=True, autospec=True,
return_value=image_classifier.HParams(epochs=1)) return_value=hyperparameters.HParams(epochs=1))
@unittest_mock.patch.object( @unittest_mock.patch.object(
image_classifier.model_options, model_options,
'ImageClassifierModelOptions', 'ImageClassifierModelOptions',
autospec=True, autospec=True,
return_value=image_classifier.ModelOptions()) return_value=model_options.ImageClassifierModelOptions())
def test_create_hparams_and_model_options_if_none_in_image_classifier_options( def test_create_hparams_and_model_options_if_none_in_image_classifier_options(
self, mock_hparams, mock_model_options): self, mock_hparams, mock_model_options):
options = image_classifier.ImageClassifierOptions( options = image_classifier.ImageClassifierOptions(

View File

@ -1,5 +1,5 @@
absl-py absl-py
mediapipe==0.9.1 mediapipe==0.9.0.1
numpy numpy
opencv-python opencv-python
tensorflow>=2.10 tensorflow>=2.10

View File

@ -28,6 +28,8 @@ import PIL.Image
from mediapipe.python._framework_bindings import image from mediapipe.python._framework_bindings import image
from mediapipe.python._framework_bindings import image_frame from mediapipe.python._framework_bindings import image_frame
TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata'
Image = image.Image Image = image.Image
ImageFormat = image_frame.ImageFormat ImageFormat = image_frame.ImageFormat
@ -187,5 +189,26 @@ class ImageTest(absltest.TestCase):
gc.collect() gc.collect()
self.assertEqual(sys.getrefcount(rgb_image), initial_ref_count) self.assertEqual(sys.getrefcount(rgb_image), initial_ref_count)
def test_image_create_from_cvmat(self):
image_path = os.path.join(os.path.dirname(__file__),
'solutions/testdata/hands.jpg')
mat = cv2.imread(image_path).astype(np.uint8)
mat = cv2.cvtColor(mat, cv2.COLOR_BGR2RGB)
rgb_image = Image(image_format=ImageFormat.SRGB, data=mat)
self.assertEqual(rgb_image.width, 720)
self.assertEqual(rgb_image.height, 382)
self.assertEqual(rgb_image.channels, 3)
self.assertEqual(rgb_image.image_format, ImageFormat.SRGB)
self.assertTrue(np.array_equal(mat, rgb_image.numpy_view()))
def test_image_create_from_file(self):
image_path = os.path.join(os.path.dirname(__file__),
'solutions/testdata/hands.jpg')
loaded_image = Image.create_from_file(image_path)
self.assertEqual(loaded_image.width, 720)
self.assertEqual(loaded_image.height, 382)
self.assertEqual(loaded_image.channels, 3)
self.assertEqual(loaded_image.image_format, ImageFormat.SRGB)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View File

@ -157,7 +157,7 @@ class PacketTest(absltest.TestCase):
p.timestamp = 0 p.timestamp = 0
self.assertAlmostEqual(packet_getter.get_float(p), 0.42) self.assertAlmostEqual(packet_getter.get_float(p), 0.42)
self.assertEqual(p.timestamp, 0) self.assertEqual(p.timestamp, 0)
p2 = packet_creator.create_float(np.float(0.42)) p2 = packet_creator.create_float(float(0.42))
p2.timestamp = 0 p2.timestamp = 0
self.assertAlmostEqual(packet_getter.get_float(p2), 0.42) self.assertAlmostEqual(packet_getter.get_float(p2), 0.42)
self.assertEqual(p2.timestamp, 0) self.assertEqual(p2.timestamp, 0)

View File

@ -48,16 +48,20 @@ void ImageSubmodule(pybind11::module* module) {
become immutable after creation. become immutable after creation.
Creation examples: Creation examples:
```python
import cv2 import cv2
cv_mat = cv2.imread(input_file)[:, :, ::-1] cv_mat = cv2.imread(input_file)[:, :, ::-1]
rgb_frame = mp.Image(format=ImageFormat.SRGB, data=cv_mat) rgb_frame = mp.Image(image_format=ImageFormat.SRGB, data=cv_mat)
gray_frame = mp.Image( gray_frame = mp.Image(
format=ImageFormat.GRAY, data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) image_format=ImageFormat.GRAY,
data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY))
from PIL import Image from PIL import Image
pil_img = Image.new('RGB', (60, 30), color = 'red') pil_img = Image.new('RGB', (60, 30), color = 'red')
image = mp.Image( image = mp.Image(
format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) image_format=mp.ImageFormat.SRGB, data=np.asarray(pil_img))
```
The pixel data in an Image can be retrieved as a numpy ndarray by calling The pixel data in an Image can be retrieved as a numpy ndarray by calling
`Image.numpy_view()`. The returned numpy ndarray is a reference to the `Image.numpy_view()`. The returned numpy ndarray is a reference to the
@ -65,6 +69,8 @@ void ImageSubmodule(pybind11::module* module) {
numpy ndarray, it's required to obtain a copy of it. numpy ndarray, it's required to obtain a copy of it.
Pixel data retrieval examples: Pixel data retrieval examples:
```python
for channel in range(num_channel): for channel in range(num_channel):
for col in range(width): for col in range(width):
for row in range(height): for row in range(height):
@ -74,6 +80,7 @@ void ImageSubmodule(pybind11::module* module) {
print(output_ndarray[0, 0, 0]) print(output_ndarray[0, 0, 0])
copied_ndarray = np.copy(output_ndarray) copied_ndarray = np.copy(output_ndarray)
copied_ndarray[0,0,0] = 0 copied_ndarray[0,0,0] = 0
```
)doc", )doc",
py::dynamic_attr()); py::dynamic_attr());
@ -156,9 +163,11 @@ void ImageSubmodule(pybind11::module* module) {
An unwritable numpy ndarray. An unwritable numpy ndarray.
Examples: Examples:
```
output_ndarray = image.numpy_view() output_ndarray = image.numpy_view()
copied_ndarray = np.copy(output_ndarray) copied_ndarray = np.copy(output_ndarray)
copied_ndarray[0,0,0] = 0 copied_ndarray[0,0,0] = 0
```
)doc"); )doc");
image.def( image.def(
@ -191,10 +200,12 @@ void ImageSubmodule(pybind11::module* module) {
IndexError: If the index is invalid or out of bounds. IndexError: If the index is invalid or out of bounds.
Examples: Examples:
```
for channel in range(num_channel): for channel in range(num_channel):
for col in range(width): for col in range(width):
for row in range(height): for row in range(height):
print(image[row, col, channel]) print(image[row, col, channel])
```
)doc"); )doc");
image image
@ -224,7 +235,9 @@ void ImageSubmodule(pybind11::module* module) {
A boolean. A boolean.
Examples: Examples:
```
image.is_aligned(16) image.is_aligned(16)
```
)doc"); )doc");
image.def_static( image.def_static(

View File

@ -83,14 +83,15 @@ void ImageFrameSubmodule(pybind11::module* module) {
Creation examples: Creation examples:
import cv2 import cv2
cv_mat = cv2.imread(input_file)[:, :, ::-1] cv_mat = cv2.imread(input_file)[:, :, ::-1]
rgb_frame = mp.ImageFrame(format=ImageFormat.SRGB, data=cv_mat) rgb_frame = mp.ImageFrame(image_format=ImageFormat.SRGB, data=cv_mat)
gray_frame = mp.ImageFrame( gray_frame = mp.ImageFrame(
format=ImageFormat.GRAY, data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) image_format=ImageFormat.GRAY,
data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY))
from PIL import Image from PIL import Image
pil_img = Image.new('RGB', (60, 30), color = 'red') pil_img = Image.new('RGB', (60, 30), color = 'red')
image_frame = mp.ImageFrame( image_frame = mp.ImageFrame(
format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) image_format=mp.ImageFormat.SRGB, data=np.asarray(pil_img))
The pixel data in an ImageFrame can be retrieved as a numpy ndarray by calling The pixel data in an ImageFrame can be retrieved as a numpy ndarray by calling
`ImageFrame.numpy_view()`. The returned numpy ndarray is a reference to the `ImageFrame.numpy_view()`. The returned numpy ndarray is a reference to the

View File

@ -23,4 +23,3 @@ objc_library(
], ],
module_name = "MPPCommon", module_name = "MPPCommon",
) )

View File

@ -1,25 +1,25 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. // Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
NS_ASSUME_NONNULL_BEGIN NS_ASSUME_NONNULL_BEGIN
/** /**
* @enum TFLSupportErrorCode * @enum MPPTasksErrorCode
* This enum specifies error codes for TensorFlow Lite Task Library. * This enum specifies error codes for Mediapipe Task Library.
* It maintains a 1:1 mapping to TfLiteSupportErrorCode of C libray. * It maintains a 1:1 mapping to MediaPipeTasksStatus of the C ++libray.
*/ */
typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) {
@ -48,16 +48,16 @@ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) {
MPPTasksErrorCodeFileReadError, MPPTasksErrorCodeFileReadError,
// I/O error when mmap-ing file. // I/O error when mmap-ing file.
MPPTasksErrorCodeFileMmapError, MPPTasksErrorCodeFileMmapError,
// ZIP I/O error when unpacMPPTasksErrorCodeing the zip file. // ZIP I/O error when unpacking the zip file.
MPPTasksErrorCodeFileZipError, MPPTasksErrorCodeFileZipError,
// TensorFlow Lite metadata error codes. // TensorFlow Lite metadata error codes.
// Unexpected schema version (aMPPTasksErrorCodea file_identifier) in the Metadata FlatBuffer. // Unexpected schema version (aka file_identifier) in the Metadata FlatBuffer.
MPPTasksErrorCodeMetadataInvalidSchemaVersionError = 200, MPPTasksErrorCodeMetadataInvalidSchemaVersionError = 200,
// No such associated file within metadata, or file has not been pacMPPTasksErrorCodeed. // No such associated file within metadata, or file has not been packed.
MPPTasksErrorCodeMetadataAssociatedFileNotFoundError, MPPTasksErrorCodeMetadataAssociatedFileNotFoundError,
// ZIP I/O error when unpacMPPTasksErrorCodeing an associated file. // ZIP I/O error when unpacking an associated file.
MPPTasksErrorCodeMetadataAssociatedFileZipError, MPPTasksErrorCodeMetadataAssociatedFileZipError,
// Inconsistency error between the metadata and actual TF Lite model. // Inconsistency error between the metadata and actual TF Lite model.
// E.g.: number of labels and output tensor values differ. // E.g.: number of labels and output tensor values differ.
@ -167,11 +167,10 @@ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) {
// Task graph config is invalid. // Task graph config is invalid.
MPPTasksErrorCodeInvalidTaskGraphConfigError, MPPTasksErrorCodeInvalidTaskGraphConfigError,
// The first error code in MPPTasksErrorCode (for internal use only).
MPPTasksErrorCodeFirst = MPPTasksErrorCodeError, MPPTasksErrorCodeFirst = MPPTasksErrorCodeError,
/** // The last error code in MPPTasksErrorCode (for internal use only).
* The last error code in TFLSupportErrorCode (for internal use only).
*/
MPPTasksErrorCodeLast = MPPTasksErrorCodeInvalidTaskGraphConfigError, MPPTasksErrorCodeLast = MPPTasksErrorCodeInvalidTaskGraphConfigError,
} NS_SWIFT_NAME(TasksErrorCode); } NS_SWIFT_NAME(TasksErrorCode);

View File

@ -24,7 +24,7 @@ extern NSString *const MPPTasksErrorDomain;
@interface MPPCommonUtils : NSObject @interface MPPCommonUtils : NSObject
/** /**
* Creates and saves an NSError in the Mediapipe task library domain, with the given code and * Creates and saves an NSError in the MediPipe task library domain, with the given code and
* description. * description.
* *
* @param code Error code. * @param code Error code.
@ -51,9 +51,9 @@ extern NSString *const MPPTasksErrorDomain;
description:(NSString *)description; description:(NSString *)description;
/** /**
* Converts an absl status to an NSError. * Converts an absl::Status to an NSError.
* *
* @param status absl status. * @param status absl::Status.
* @param error Pointer to the memory location where the created error should be saved. If `nil`, * @param error Pointer to the memory location where the created error should be saved. If `nil`,
* no error will be saved. * no error will be saved.
*/ */
@ -61,15 +61,15 @@ extern NSString *const MPPTasksErrorDomain;
/** /**
* Allocates a block of memory with the specified size and returns a pointer to it. If memory * Allocates a block of memory with the specified size and returns a pointer to it. If memory
* cannot be allocated because of an invalid memSize, it saves an error. In other cases, it * cannot be allocated because of an invalid `memSize`, it saves an error. In other cases, it
* terminates program execution. * terminates program execution.
* *
* @param memSize size of memory to be allocated * @param memSize size of memory to be allocated
* @param error Pointer to the memory location where errors if any should be saved. If `nil`, no * @param error Pointer to the memory location where errors if any should be saved. If `nil`, no
* error will be saved. * error will be saved.
* *
* @return Pointer to the allocated block of memory on successfull allocation. nil in case as * @return Pointer to the allocated block of memory on successfull allocation. `nil` in case as
* error is encountered because of invalid memSize. If failure is due to any other reason, method * error is encountered because of invalid `memSize`. If failure is due to any other reason, method
* terminates program execution. * terminates program execution.
*/ */
+ (void *)mallocWithSize:(size_t)memSize error:(NSError **)error; + (void *)mallocWithSize:(size_t)memSize error:(NSError **)error;

View File

@ -24,7 +24,7 @@
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
/** Error domain of MediaPipe task library errors. */ /** Error domain of MediaPipe task library errors. */
NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks"; NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks";
@implementation MPPCommonUtils @implementation MPPCommonUtils

View File

@ -17,25 +17,38 @@
NS_ASSUME_NONNULL_BEGIN NS_ASSUME_NONNULL_BEGIN
/** /**
* Holds settings for any single iOS Mediapipe classification task. * Holds settings for any single iOS MediaPipe classification task.
*/ */
NS_SWIFT_NAME(ClassifierOptions) NS_SWIFT_NAME(ClassifierOptions)
@interface MPPClassifierOptions : NSObject <NSCopying> @interface MPPClassifierOptions : NSObject <NSCopying>
/** If set, all classes in this list will be filtered out from the results . */ /** The locale to use for display names specified through the TFLite Model
@property(nonatomic, copy) NSArray<NSString *> *labelDenyList; * Metadata, if any. Defaults to English.
*/
/** If set, all classes not in this list will be filtered out from the results . */
@property(nonatomic, copy) NSArray<NSString *> *labelAllowList;
/** Display names local for display names*/
@property(nonatomic, copy) NSString *displayNamesLocale; @property(nonatomic, copy) NSString *displayNamesLocale;
/** Results with score threshold greater than this value are returned . */ /** The maximum number of top-scored classification results to return. If < 0,
* all available results will be returned. If 0, an invalid argument error is
* returned.
*/
@property(nonatomic) NSInteger maxResults;
/** Score threshold to override the one provided in the model metadata (if any).
* Results below this value are rejected.
*/
@property(nonatomic) float scoreThreshold; @property(nonatomic) float scoreThreshold;
/** Limit to the number of classes that can be returned in results. */ /** The allowlist of category names. If non-empty, detection results whose
@property(nonatomic) NSInteger maxResults; * category name is not in this set will be filtered out. Duplicate or unknown
* category names are ignored. Mutually exclusive with categoryDenylist.
*/
@property(nonatomic, copy) NSArray<NSString *> *categoryAllowlist;
/** The denylist of category names. If non-empty, detection results whose
* category name is in this set will be filtered out. Duplicate or unknown
* category names are ignored. Mutually exclusive with categoryAllowlist.
*/
@property(nonatomic, copy) NSArray<NSString *> *categoryDenylist;
@end @end

View File

@ -30,8 +30,8 @@
classifierOptions.scoreThreshold = self.scoreThreshold; classifierOptions.scoreThreshold = self.scoreThreshold;
classifierOptions.maxResults = self.maxResults; classifierOptions.maxResults = self.maxResults;
classifierOptions.labelDenyList = self.labelDenyList; classifierOptions.categoryDenylist = self.categoryDenylist;
classifierOptions.labelAllowList = self.labelAllowList; classifierOptions.categoryAllowlist = self.categoryAllowlist;
classifierOptions.displayNamesLocale = self.displayNamesLocale; classifierOptions.displayNamesLocale = self.displayNamesLocale;
return classifierOptions; return classifierOptions;

View File

@ -20,17 +20,23 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto
} }
@implementation MPPClassifierOptions (Helpers) @implementation MPPClassifierOptions (Helpers)
- (void)copyToProto:(ClassifierOptionsProto *)classifierOptionsProto { - (void)copyToProto:(ClassifierOptionsProto *)classifierOptionsProto {
classifierOptionsProto->Clear();
if (self.displayNamesLocale) { if (self.displayNamesLocale) {
classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString);
} }
classifierOptionsProto->set_max_results((int)self.maxResults); classifierOptionsProto->set_max_results((int)self.maxResults);
classifierOptionsProto->set_score_threshold(self.scoreThreshold); classifierOptionsProto->set_score_threshold(self.scoreThreshold);
for (NSString *category in self.labelAllowList) {
for (NSString *category in self.categoryAllowlist) {
classifierOptionsProto->add_category_allowlist(category.cppString); classifierOptionsProto->add_category_allowlist(category.cppString);
} }
for (NSString *category in self.labelDenyList) { for (NSString *category in self.categoryDenylist) {
classifierOptionsProto->add_category_denylist(category.cppString); classifierOptionsProto->add_category_denylist(category.cppString);
} }
} }

View File

@ -16,19 +16,35 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
objc_library(
name = "MPPBaseOptions",
srcs = ["sources/MPPBaseOptions.m"],
hdrs = ["sources/MPPBaseOptions.h"],
)
objc_library( objc_library(
name = "MPPTaskOptions", name = "MPPTaskOptions",
srcs = ["sources/MPPTaskOptions.m"], srcs = ["sources/MPPTaskOptions.m"],
hdrs = ["sources/MPPTaskOptions.h"], hdrs = ["sources/MPPTaskOptions.h"],
copts = [
"-ObjC++",
"-std=c++17",
],
deps = [ deps = [
":MPPBaseOptions", ":MPPBaseOptions",
], ],
) )
objc_library(
name = "MPPTaskResult",
srcs = ["sources/MPPTaskResult.m"],
hdrs = ["sources/MPPTaskResult.h"],
)
objc_library(
name = "MPPTaskOptionsProtocol",
hdrs = ["sources/MPPTaskOptionsProtocol.h"],
deps = [
"//mediapipe/framework:calculator_options_cc_proto",
],
)
objc_library( objc_library(
name = "MPPTaskInfo", name = "MPPTaskInfo",
srcs = ["sources/MPPTaskInfo.mm"], srcs = ["sources/MPPTaskInfo.mm"],
@ -64,29 +80,9 @@ objc_library(
) )
objc_library( objc_library(
name = "MPPTaskResult", name = "MPPTaskRunner",
srcs = ["sources/MPPTaskResult.m"], srcs = ["sources/MPPTaskRunner.mm"],
hdrs = ["sources/MPPTaskResult.h"], hdrs = ["sources/MPPTaskRunner.h"],
)
objc_library(
name = "MPPBaseOptions",
srcs = ["sources/MPPBaseOptions.m"],
hdrs = ["sources/MPPBaseOptions.h"],
)
objc_library(
name = "MPPTaskOptionsProtocol",
hdrs = ["sources/MPPTaskOptionsProtocol.h"],
deps = [
"//mediapipe/framework:calculator_options_cc_proto",
],
)
objc_library(
name = "MPPTaskManager",
srcs = ["sources/MPPTaskManager.mm"],
hdrs = ["sources/MPPTaskManager.h"],
deps = [ deps = [
"//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",

View File

@ -17,7 +17,6 @@
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h"
NS_ASSUME_NONNULL_BEGIN NS_ASSUME_NONNULL_BEGIN
/** /**
@ -55,7 +54,7 @@ NS_ASSUME_NONNULL_BEGIN
outputStreams:(NSArray<NSString *> *)outputStreams outputStreams:(NSArray<NSString *> *)outputStreams
taskOptions:(id<MPPTaskOptionsProtocol>)taskOptions taskOptions:(id<MPPTaskOptionsProtocol>)taskOptions
enableFlowLimiting:(BOOL)enableFlowLimiting enableFlowLimiting:(BOOL)enableFlowLimiting
error:(NSError **)error; error:(NSError **)error NS_DESIGNATED_INITIALIZER;
/** /**
* Creates a MediaPipe Task protobuf message from the MPPTaskInfo instance. * Creates a MediaPipe Task protobuf message from the MPPTaskInfo instance.

View File

@ -24,9 +24,9 @@
namespace { namespace {
using CalculatorGraphConfig = ::mediapipe::CalculatorGraphConfig; using CalculatorGraphConfig = ::mediapipe::CalculatorGraphConfig;
using Node = ::mediapipe::CalculatorGraphConfig::Node; using Node = ::mediapipe::CalculatorGraphConfig::Node;
using ::mediapipe::InputStreamInfo;
using ::mediapipe::CalculatorOptions; using ::mediapipe::CalculatorOptions;
using ::mediapipe::FlowLimiterCalculatorOptions; using ::mediapipe::FlowLimiterCalculatorOptions;
using ::mediapipe::InputStreamInfo;
} // namespace } // namespace
@implementation MPPTaskInfo @implementation MPPTaskInfo
@ -82,7 +82,15 @@ using ::mediapipe::FlowLimiterCalculatorOptions;
graph_config.add_output_stream(cpp_output_stream); graph_config.add_output_stream(cpp_output_stream);
} }
if (self.enableFlowLimiting) { if (!self.enableFlowLimiting) {
for (NSString *inputStream in self.inputStreams) {
auto cpp_input_stream = inputStream.cppString;
task_subgraph_node->add_input_stream(cpp_input_stream);
graph_config.add_input_stream(cpp_input_stream);
}
return graph_config;
}
Node *flow_limit_calculator_node = graph_config.add_node(); Node *flow_limit_calculator_node = graph_config.add_node();
flow_limit_calculator_node->set_calculator("FlowLimiterCalculator"); flow_limit_calculator_node->set_calculator("FlowLimiterCalculator");
@ -113,13 +121,6 @@ using ::mediapipe::FlowLimiterCalculatorOptions;
NSString *firstOutputStream = self.outputStreams[0]; NSString *firstOutputStream = self.outputStreams[0];
auto finished_output_stream = "FINISHED:" + firstOutputStream.cppString; auto finished_output_stream = "FINISHED:" + firstOutputStream.cppString;
flow_limit_calculator_node->add_input_stream(finished_output_stream); flow_limit_calculator_node->add_input_stream(finished_output_stream);
} else {
for (NSString *inputStream in self.inputStreams) {
auto cpp_input_stream = inputStream.cppString;
task_subgraph_node->add_input_stream(cpp_input_stream);
graph_config.add_input_stream(cpp_input_stream);
}
}
return graph_config; return graph_config;
} }

View File

@ -1,14 +1,17 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. // Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); //
you may not use this file except in compliance with the License. // Licensed under the Apache License, Version 2.0 (the "License");
You may obtain a copy of the License at // you may not use this file except in compliance with the License.
http://www.apache.org/licenses/LICENSE-2.0 // You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software //
distributed under the License is distributed on an "AS IS" BASIS, // http://www.apache.org/licenses/LICENSE-2.0
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
See the License for the specific language governing permissions and // Unless required by applicable law or agreed to in writing, software
limitations under the License. // distributed under the License is distributed on an "AS IS" BASIS,
==============================================================================*/ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" #import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h"
@ -19,27 +22,13 @@ NS_ASSUME_NONNULL_BEGIN
* this class. * this class.
*/ */
NS_SWIFT_NAME(TaskOptions) NS_SWIFT_NAME(TaskOptions)
@interface MPPTaskOptions : NSObject <NSCopying> @interface MPPTaskOptions : NSObject <NSCopying>
/** /**
* Base options for configuring the Mediapipe task. * Base options for configuring the Mediapipe task.
*/ */
@property(nonatomic, copy) MPPBaseOptions *baseOptions; @property(nonatomic, copy) MPPBaseOptions *baseOptions;
/**
* Initializes a new `MPPTaskOptions` with the absolute path to the model file
* stored locally on the device, set to the given the model path.
*
* @discussion The external model file must be a single standalone TFLite file. It could be packed
* with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the
* necessary metadata and associated files might result in errors. Check the [documentation]
* (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement.
*
* @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
*
* @return An instance of `MPPTaskOptions` initialized to the given model path.
*/
- (instancetype)initWithModelPath:(NSString *)modelPath;
@end @end
NS_ASSUME_NONNULL_END NS_ASSUME_NONNULL_END

View File

@ -1,17 +1,17 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. // Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" #import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h"
@ -25,12 +25,12 @@
return self; return self;
} }
- (instancetype)initWithModelPath:(NSString *)modelPath { - (id)copyWithZone:(NSZone *)zone {
self = [self init]; MPPTaskOptions *taskOptions = [[MPPTaskOptions alloc] init];
if (self) {
_baseOptions.modelAssetPath = modelPath; taskOptions.baseOptions = self.baseOptions;
}
return self; return taskOptions;
} }
@end @end

View File

@ -1,26 +1,29 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. // Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); //
you may not use this file except in compliance with the License. // Licensed under the Apache License, Version 2.0 (the "License");
You may obtain a copy of the License at // you may not use this file except in compliance with the License.
http://www.apache.org/licenses/LICENSE-2.0 // You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software //
distributed under the License is distributed on an "AS IS" BASIS, // http://www.apache.org/licenses/LICENSE-2.0
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
See the License for the specific language governing permissions and // Unless required by applicable law or agreed to in writing, software
limitations under the License. // distributed under the License is distributed on an "AS IS" BASIS,
==============================================================================*/ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
#include "mediapipe/framework/calculator_options.pb.h" #include "mediapipe/framework/calculator_options.pb.h"
NS_ASSUME_NONNULL_BEGIN NS_ASSUME_NONNULL_BEGIN
/** /**
* Any mediapipe task options should confirm to this protocol. * Any MediaPipe task options should confirm to this protocol.
*/ */
@protocol MPPTaskOptionsProtocol @protocol MPPTaskOptionsProtocol
/** /**
* Copies the iOS Mediapipe task options to an object of mediapipe::CalculatorOptions proto. * Copies the iOS MediaPipe task options to an object of mediapipe::CalculatorOptions proto.
*/ */
- (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto; - (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto;

View File

@ -1,30 +1,36 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. // Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); //
you may not use this file except in compliance with the License. // Licensed under the Apache License, Version 2.0 (the "License");
You may obtain a copy of the License at // you may not use this file except in compliance with the License.
http://www.apache.org/licenses/LICENSE-2.0 // You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software //
distributed under the License is distributed on an "AS IS" BASIS, // http://www.apache.org/licenses/LICENSE-2.0
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
See the License for the specific language governing permissions and // Unless required by applicable law or agreed to in writing, software
limitations under the License. // distributed under the License is distributed on an "AS IS" BASIS,
==============================================================================*/ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
NS_ASSUME_NONNULL_BEGIN NS_ASSUME_NONNULL_BEGIN
/** /**
* MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend * MediaPipe Tasks result base class. Any MediaPipe task result class should extend
* this class. * this class.
*/ */
NS_SWIFT_NAME(TaskResult) NS_SWIFT_NAME(TaskResult)
@interface MPPTaskResult : NSObject <NSCopying> @interface MPPTaskResult : NSObject <NSCopying>
/** /**
* Base options for configuring the Mediapipe task. * Timestamp that is associated with the task result object.
*/ */
@property(nonatomic, assign, readonly) long timeStamp; @property(nonatomic, assign, readonly) long timestamp;
- (instancetype)initWithTimeStamp:(long)timeStamp; - (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithTimestamp:(long)timestamp NS_DESIGNATED_INITIALIZER;
@end @end

View File

@ -1,27 +1,31 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. // Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h"
@implementation MPPTaskResult @implementation MPPTaskResult
- (instancetype)initWithTimeStamp:(long)timeStamp { - (instancetype)initWithTimestamp:(long)timestamp {
self = [self init]; self = [super init];
if (self) { if (self) {
_timeStamp = timeStamp; _timestamp = timestamp;
} }
return self; return self;
} }
- (id)copyWithZone:(NSZone *)zone {
return [[MPPTaskResult alloc] initWithTimestamp:self.timestamp];
}
@end @end

View File

@ -0,0 +1,50 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
NS_ASSUME_NONNULL_BEGIN
/**
* This class is used to create and call appropriate methods on the C++ Task Runner.
*/
@interface MPPTaskRunner : NSObject
/**
* Initializes a new `MPPTaskRunner` with the mediapipe task graph config proto.
*
* @param graphConfig A mediapipe task graph config proto.
*
* @return An instance of `MPPTaskRunner` initialized to the given graph config proto.
*/
- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
- (absl::StatusOr<mediapipe::tasks::core::PacketMap>)
process:(const mediapipe::tasks::core::PacketMap &)packetMap
error:(NSError **)error;
- (void)close;
- (instancetype)init NS_UNAVAILABLE;
+ (instancetype)new NS_UNAVAILABLE;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,56 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h"
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
namespace {
using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::Packet;
using ::mediapipe::tasks::core::PacketMap;
using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
} // namespace
@interface MPPTaskRunner () {
// Cpp Task Runner
std::unique_ptr<TaskRunnerCpp> _cppTaskRunner;
}
@end
@implementation MPPTaskRunner
- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig
error:(NSError **)error {
self = [super init];
if (self) {
auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig));
if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) {
return nil;
}
_cppTaskRunner = std::move(taskRunnerResult.value());
}
return self;
}
- (absl::StatusOr<PacketMap>)process:(const PacketMap &)packetMap {
return _cppTaskRunner->Process(packetMap);
}
- (void)close {
_cppTaskRunner->Close();
}
@end

View File

@ -203,6 +203,8 @@ public final class AudioClassifier extends BaseAudioTaskApi {
TaskRunner.create( TaskRunner.create(
context, context,
TaskInfo.<AudioClassifierOptions>builder() TaskInfo.<AudioClassifierOptions>builder()
.setTaskName(AudioClassifier.class.getSimpleName())
.setTaskRunningModeName(options.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME) .setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS) .setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS)

View File

@ -200,6 +200,8 @@ public final class AudioEmbedder extends BaseAudioTaskApi {
TaskRunner.create( TaskRunner.create(
context, context,
TaskInfo.<AudioEmbedderOptions>builder() TaskInfo.<AudioEmbedderOptions>builder()
.setTaskName(AudioEmbedder.class.getSimpleName())
.setTaskRunningModeName(options.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME) .setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS) .setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS)

View File

@ -22,6 +22,7 @@ android_library(
], ],
manifest = "AndroidManifest.xml", manifest = "AndroidManifest.xml",
deps = [ deps = [
":logging",
"//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite", "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite",
"//mediapipe/calculators/tensor:inference_calculator_java_proto_lite", "//mediapipe/calculators/tensor:inference_calculator_java_proto_lite",
"//mediapipe/framework:calculator_java_proto_lite", "//mediapipe/framework:calculator_java_proto_lite",
@ -37,11 +38,22 @@ android_library(
], ],
) )
android_library(
name = "logging",
srcs = glob(
["logging/*.java"],
),
deps = [
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_core_aar") load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_core_aar")
mediapipe_tasks_core_aar( mediapipe_tasks_core_aar(
name = "tasks_core", name = "tasks_core",
srcs = glob(["*.java"]) + [ srcs = glob(["**/*.java"]) + [
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:java_src", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:java_src",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:java_src", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:java_src",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:java_src", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:java_src",

View File

@ -32,6 +32,12 @@ public abstract class TaskInfo<T extends TaskOptions> {
/** Builder for {@link TaskInfo}. */ /** Builder for {@link TaskInfo}. */
@AutoValue.Builder @AutoValue.Builder
public abstract static class Builder<T extends TaskOptions> { public abstract static class Builder<T extends TaskOptions> {
/** Sets the MediaPipe task name. */
public abstract Builder<T> setTaskName(String value);
/** Sets the MediaPipe task running mode name. */
public abstract Builder<T> setTaskRunningModeName(String value);
/** Sets the MediaPipe task graph name. */ /** Sets the MediaPipe task graph name. */
public abstract Builder<T> setTaskGraphName(String value); public abstract Builder<T> setTaskGraphName(String value);
@ -71,6 +77,10 @@ public abstract class TaskInfo<T extends TaskOptions> {
} }
} }
abstract String taskName();
abstract String taskRunningModeName();
abstract String taskGraphName(); abstract String taskGraphName();
abstract T taskOptions(); abstract T taskOptions();
@ -82,7 +92,7 @@ public abstract class TaskInfo<T extends TaskOptions> {
abstract Boolean enableFlowLimiting(); abstract Boolean enableFlowLimiting();
public static <T extends TaskOptions> Builder<T> builder() { public static <T extends TaskOptions> Builder<T> builder() {
return new AutoValue_TaskInfo.Builder<T>(); return new AutoValue_TaskInfo.Builder<T>().setTaskName("").setTaskRunningModeName("");
} }
/* Returns a list of the output stream names without the stream tags. */ /* Returns a list of the output stream names without the stream tags. */

View File

@ -21,6 +21,8 @@ import com.google.mediapipe.framework.AndroidPacketCreator;
import com.google.mediapipe.framework.Graph; import com.google.mediapipe.framework.Graph;
import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.tasks.core.logging.TasksStatsLogger;
import com.google.mediapipe.tasks.core.logging.TasksStatsDummyLogger;
import java.util.Map; import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
@ -34,6 +36,7 @@ public class TaskRunner implements AutoCloseable {
private final Graph graph; private final Graph graph;
private final ModelResourcesCache modelResourcesCache; private final ModelResourcesCache modelResourcesCache;
private final AndroidPacketCreator packetCreator; private final AndroidPacketCreator packetCreator;
private final TasksStatsLogger statsLogger;
private long lastSeenTimestamp = Long.MIN_VALUE; private long lastSeenTimestamp = Long.MIN_VALUE;
private ErrorListener errorListener; private ErrorListener errorListener;
@ -51,6 +54,8 @@ public class TaskRunner implements AutoCloseable {
Context context, Context context,
TaskInfo<? extends TaskOptions> taskInfo, TaskInfo<? extends TaskOptions> taskInfo,
OutputHandler<? extends TaskResult, ?> outputHandler) { OutputHandler<? extends TaskResult, ?> outputHandler) {
TasksStatsLogger statsLogger =
TasksStatsDummyLogger.create(context, taskInfo.taskName(), taskInfo.taskRunningModeName());
AndroidAssetUtil.initializeNativeAssetManager(context); AndroidAssetUtil.initializeNativeAssetManager(context);
Graph mediapipeGraph = new Graph(); Graph mediapipeGraph = new Graph();
mediapipeGraph.loadBinaryGraph(taskInfo.generateGraphConfig()); mediapipeGraph.loadBinaryGraph(taskInfo.generateGraphConfig());
@ -58,12 +63,15 @@ public class TaskRunner implements AutoCloseable {
mediapipeGraph.setServiceObject(new ModelResourcesCacheService(), graphModelResourcesCache); mediapipeGraph.setServiceObject(new ModelResourcesCacheService(), graphModelResourcesCache);
mediapipeGraph.addMultiStreamCallback( mediapipeGraph.addMultiStreamCallback(
taskInfo.outputStreamNames(), taskInfo.outputStreamNames(),
outputHandler::run, packets -> {
outputHandler.run(packets);
statsLogger.recordInvocationEnd(packets.get(0).getTimestamp());
},
/* observeTimestampBounds= */ outputHandler.handleTimestampBoundChanges()); /* observeTimestampBounds= */ outputHandler.handleTimestampBoundChanges());
mediapipeGraph.startRunningGraph(); mediapipeGraph.startRunningGraph();
// Waits until all calculators are opened and the graph is fully started. // Waits until all calculators are opened and the graph is fully started.
mediapipeGraph.waitUntilGraphIdle(); mediapipeGraph.waitUntilGraphIdle();
return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler); return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler, statsLogger);
} }
/** /**
@ -91,7 +99,10 @@ public class TaskRunner implements AutoCloseable {
* @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs. * @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs.
*/ */
public synchronized TaskResult process(Map<String, Packet> inputs) { public synchronized TaskResult process(Map<String, Packet> inputs) {
addPackets(inputs, generateSyntheticTimestamp()); long syntheticInputTimestamp = generateSyntheticTimestamp();
// TODO: Support recording GPU input arrival.
statsLogger.recordCpuInputArrival(syntheticInputTimestamp);
addPackets(inputs, syntheticInputTimestamp);
graph.waitUntilGraphIdle(); graph.waitUntilGraphIdle();
lastSeenTimestamp = outputHandler.getLatestOutputTimestamp(); lastSeenTimestamp = outputHandler.getLatestOutputTimestamp();
return outputHandler.retrieveCachedTaskResult(); return outputHandler.retrieveCachedTaskResult();
@ -112,6 +123,7 @@ public class TaskRunner implements AutoCloseable {
*/ */
public synchronized TaskResult process(Map<String, Packet> inputs, long inputTimestamp) { public synchronized TaskResult process(Map<String, Packet> inputs, long inputTimestamp) {
validateInputTimstamp(inputTimestamp); validateInputTimstamp(inputTimestamp);
statsLogger.recordCpuInputArrival(inputTimestamp);
addPackets(inputs, inputTimestamp); addPackets(inputs, inputTimestamp);
graph.waitUntilGraphIdle(); graph.waitUntilGraphIdle();
return outputHandler.retrieveCachedTaskResult(); return outputHandler.retrieveCachedTaskResult();
@ -132,6 +144,7 @@ public class TaskRunner implements AutoCloseable {
*/ */
public synchronized void send(Map<String, Packet> inputs, long inputTimestamp) { public synchronized void send(Map<String, Packet> inputs, long inputTimestamp) {
validateInputTimstamp(inputTimestamp); validateInputTimstamp(inputTimestamp);
statsLogger.recordCpuInputArrival(inputTimestamp);
addPackets(inputs, inputTimestamp); addPackets(inputs, inputTimestamp);
} }
@ -145,6 +158,7 @@ public class TaskRunner implements AutoCloseable {
graphStarted.set(false); graphStarted.set(false);
graph.closeAllPacketSources(); graph.closeAllPacketSources();
graph.waitUntilGraphDone(); graph.waitUntilGraphDone();
statsLogger.logSessionEnd();
} catch (MediaPipeException e) { } catch (MediaPipeException e) {
reportError(e); reportError(e);
} }
@ -154,6 +168,7 @@ public class TaskRunner implements AutoCloseable {
// Waits until all calculators are opened and the graph is fully restarted. // Waits until all calculators are opened and the graph is fully restarted.
graph.waitUntilGraphIdle(); graph.waitUntilGraphIdle();
graphStarted.set(true); graphStarted.set(true);
statsLogger.logSessionStart();
} catch (MediaPipeException e) { } catch (MediaPipeException e) {
reportError(e); reportError(e);
} }
@ -169,6 +184,7 @@ public class TaskRunner implements AutoCloseable {
graphStarted.set(false); graphStarted.set(false);
graph.closeAllPacketSources(); graph.closeAllPacketSources();
graph.waitUntilGraphDone(); graph.waitUntilGraphDone();
statsLogger.logSessionEnd();
if (modelResourcesCache != null) { if (modelResourcesCache != null) {
modelResourcesCache.release(); modelResourcesCache.release();
} }
@ -247,12 +263,15 @@ public class TaskRunner implements AutoCloseable {
private TaskRunner( private TaskRunner(
Graph graph, Graph graph,
ModelResourcesCache modelResourcesCache, ModelResourcesCache modelResourcesCache,
OutputHandler<? extends TaskResult, ?> outputHandler) { OutputHandler<? extends TaskResult, ?> outputHandler,
TasksStatsLogger statsLogger) {
this.outputHandler = outputHandler; this.outputHandler = outputHandler;
this.graph = graph; this.graph = graph;
this.modelResourcesCache = modelResourcesCache; this.modelResourcesCache = modelResourcesCache;
this.packetCreator = new AndroidPacketCreator(graph); this.packetCreator = new AndroidPacketCreator(graph);
this.statsLogger = statsLogger;
graphStarted.set(true); graphStarted.set(true);
this.statsLogger.logSessionStart();
} }
/** Reports error. */ /** Reports error. */

View File

@ -0,0 +1,78 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core.logging;
import android.content.Context;
/** A dummy MediaPipe Tasks stats logger that has all methods as no-ops. */
public class TasksStatsDummyLogger implements TasksStatsLogger {
/**
* Creates the MediaPipe Tasks stats dummy logger.
*
* @param context a {@link Context}.
* @param taskNameStr the task api name.
* @param taskRunningModeStr the task running mode string representation.
*/
public static TasksStatsDummyLogger create(
Context context, String taskNameStr, String taskRunningModeStr) {
return new TasksStatsDummyLogger();
}
private TasksStatsDummyLogger() {}
/** Logs the start of a MediaPipe Tasks API session. */
@Override
public void logSessionStart() {}
/**
* Records MediaPipe Tasks API receiving CPU input data.
*
* @param packetTimestamp the input packet timestamp that acts as the identifier of the api
* invocation.
*/
@Override
public void recordCpuInputArrival(long packetTimestamp) {}
/**
* Records MediaPipe Tasks API receiving GPU input data.
*
* @param packetTimestamp the input packet timestamp that acts as the identifier of the api
* invocation.
*/
@Override
public void recordGpuInputArrival(long packetTimestamp) {}
/**
* Records the end of a Mediapipe Tasks API invocation.
*
* @param packetTimestamp the output packet timestamp that acts as the identifier of the api
* invocation.
*/
@Override
public void recordInvocationEnd(long packetTimestamp) {}
/** Logs the MediaPipe Tasks API periodic invocation report. */
@Override
public void logInvocationReport(StatsSnapshot stats) {}
/** Logs the Tasks API session end event. */
@Override
public void logSessionEnd() {}
/** Logs the MediaPipe Tasks API initialization error. */
@Override
public void logInitError() {}
}

View File

@ -0,0 +1,98 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core.logging;
import com.google.auto.value.AutoValue;
/** The stats logger interface that defines what MediaPipe Tasks events to log. */
public interface TasksStatsLogger {
/** Task stats snapshot. */
@AutoValue
abstract static class StatsSnapshot {
static StatsSnapshot create(
int cpuInputCount,
int gpuInputCount,
int finishedCount,
int droppedCount,
long totalLatencyMs,
long peakLatencyMs,
long elapsedTimeMs) {
return new AutoValue_TasksStatsLogger_StatsSnapshot(
cpuInputCount,
gpuInputCount,
finishedCount,
droppedCount,
totalLatencyMs,
peakLatencyMs,
elapsedTimeMs);
}
static StatsSnapshot createDefault() {
return new AutoValue_TasksStatsLogger_StatsSnapshot(0, 0, 0, 0, 0, 0, 0);
}
abstract int cpuInputCount();
abstract int gpuInputCount();
abstract int finishedCount();
abstract int droppedCount();
abstract long totalLatencyMs();
abstract long peakLatencyMs();
abstract long elapsedTimeMs();
}
/** Logs the start of a MediaPipe Tasks API session. */
public void logSessionStart();
/**
* Records MediaPipe Tasks API receiving CPU input data.
*
* @param packetTimestamp the input packet timestamp that acts as the identifier of the api
* invocation.
*/
public void recordCpuInputArrival(long packetTimestamp);
/**
* Records MediaPipe Tasks API receiving GPU input data.
*
* @param packetTimestamp the input packet timestamp that acts as the identifier of the api
* invocation.
*/
public void recordGpuInputArrival(long packetTimestamp);
/**
* Records the end of a Mediapipe Tasks API invocation.
*
* @param packetTimestamp the output packet timestamp that acts as the identifier of the api
* invocation.
*/
public void recordInvocationEnd(long packetTimestamp);
/** Logs the MediaPipe Tasks API periodic invocation report. */
public void logInvocationReport(StatsSnapshot stats);
/** Logs the Tasks API session end event. */
public void logSessionEnd();
/** Logs the MediaPipe Tasks API initialization error. */
public void logInitError();
// TODO: Logs more error types.
}

View File

@ -169,6 +169,7 @@ public final class TextClassifier implements AutoCloseable {
TaskRunner.create( TaskRunner.create(
context, context,
TaskInfo.<TextClassifierOptions>builder() TaskInfo.<TextClassifierOptions>builder()
.setTaskName(TextClassifier.class.getSimpleName())
.setTaskGraphName(TASK_GRAPH_NAME) .setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS) .setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS)

View File

@ -159,6 +159,7 @@ public final class TextEmbedder implements AutoCloseable {
TaskRunner.create( TaskRunner.create(
context, context,
TaskInfo.<TextEmbedderOptions>builder() TaskInfo.<TextEmbedderOptions>builder()
.setTaskName(TextEmbedder.class.getSimpleName())
.setTaskGraphName(TASK_GRAPH_NAME) .setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS) .setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS)

View File

@ -194,6 +194,8 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
TaskRunner.create( TaskRunner.create(
context, context,
TaskInfo.<GestureRecognizerOptions>builder() TaskInfo.<GestureRecognizerOptions>builder()
.setTaskName(GestureRecognizer.class.getSimpleName())
.setTaskRunningModeName(recognizerOptions.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME) .setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS) .setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS)

View File

@ -183,6 +183,8 @@ public final class HandLandmarker extends BaseVisionTaskApi {
TaskRunner.create( TaskRunner.create(
context, context,
TaskInfo.<HandLandmarkerOptions>builder() TaskInfo.<HandLandmarkerOptions>builder()
.setTaskName(HandLandmarker.class.getSimpleName())
.setTaskRunningModeName(landmarkerOptions.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME) .setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS) .setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS)

View File

@ -197,6 +197,8 @@ public final class ImageClassifier extends BaseVisionTaskApi {
TaskRunner.create( TaskRunner.create(
context, context,
TaskInfo.<ImageClassifierOptions>builder() TaskInfo.<ImageClassifierOptions>builder()
.setTaskName(ImageClassifier.class.getSimpleName())
.setTaskRunningModeName(options.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME) .setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS) .setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS)

View File

@ -180,6 +180,8 @@ public final class ImageEmbedder extends BaseVisionTaskApi {
TaskRunner.create( TaskRunner.create(
context, context,
TaskInfo.<ImageEmbedderOptions>builder() TaskInfo.<ImageEmbedderOptions>builder()
.setTaskName(ImageEmbedder.class.getSimpleName())
.setTaskRunningModeName(options.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME) .setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS) .setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS)

View File

@ -190,6 +190,8 @@ public final class ObjectDetector extends BaseVisionTaskApi {
TaskRunner.create( TaskRunner.create(
context, context,
TaskInfo.<ObjectDetectorOptions>builder() TaskInfo.<ObjectDetectorOptions>builder()
.setTaskName(ObjectDetector.class.getSimpleName())
.setTaskRunningModeName(detectorOptions.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME) .setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS) .setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS)

View File

@ -86,7 +86,7 @@ describe('convertBaseOptionsToProto()', () => {
it('can enable CPU delegate', async () => { it('can enable CPU delegate', async () => {
const baseOptionsProto = await convertBaseOptionsToProto({ const baseOptionsProto = await convertBaseOptionsToProto({
modelAssetBuffer: new Uint8Array(mockBytes), modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'cpu', delegate: 'CPU',
}); });
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
}); });
@ -94,7 +94,7 @@ describe('convertBaseOptionsToProto()', () => {
it('can enable GPU delegate', async () => { it('can enable GPU delegate', async () => {
const baseOptionsProto = await convertBaseOptionsToProto({ const baseOptionsProto = await convertBaseOptionsToProto({
modelAssetBuffer: new Uint8Array(mockBytes), modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'gpu', delegate: 'GPU',
}); });
expect(baseOptionsProto.toObject()).toEqual({ expect(baseOptionsProto.toObject()).toEqual({
...mockBytesResult, ...mockBytesResult,
@ -117,7 +117,7 @@ describe('convertBaseOptionsToProto()', () => {
it('can reset delegate', async () => { it('can reset delegate', async () => {
let baseOptionsProto = await convertBaseOptionsToProto({ let baseOptionsProto = await convertBaseOptionsToProto({
modelAssetBuffer: new Uint8Array(mockBytes), modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'gpu', delegate: 'GPU',
}); });
// Clear backend // Clear backend
baseOptionsProto = baseOptionsProto =

View File

@ -71,7 +71,7 @@ async function configureExternalFile(
/** Configues the `acceleration` option. */ /** Configues the `acceleration` option. */
function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) { function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) {
const acceleration = proto.getAcceleration() ?? new Acceleration(); const acceleration = proto.getAcceleration() ?? new Acceleration();
if (options.delegate === 'gpu') { if (options.delegate === 'GPU') {
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
} else { } else {
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite()); acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());

View File

@ -44,22 +44,14 @@ async function isSimdSupported(): Promise<boolean> {
} }
async function createFileset( async function createFileset(
taskName: string, basePath: string = '.'): Promise<WasmFileset> { taskName: string, basePath: string = ''): Promise<WasmFileset> {
if (await isSimdSupported()) { const suffix =
await isSimdSupported() ? 'wasm_internal' : 'wasm_nosimd_internal';
return { return {
wasmLoaderPath: wasmLoaderPath: `${basePath}/${taskName}_${suffix}.js`,
`${basePath}/${taskName}_wasm_internal.js`, wasmBinaryPath: `${basePath}/${taskName}_${suffix}.wasm`,
wasmBinaryPath:
`${basePath}/${taskName}_wasm_internal.wasm`,
}; };
} else {
return {
wasmLoaderPath:
`${basePath}/${taskName}_wasm_nosimd_internal.js`,
wasmBinaryPath:
`${basePath}/${taskName}_wasm_nosimd_internal.wasm`,
};
}
} }
// tslint:disable:class-as-namespace // tslint:disable:class-as-namespace

View File

@ -31,7 +31,7 @@ export declare interface BaseOptions {
modelAssetBuffer?: Uint8Array|undefined; modelAssetBuffer?: Uint8Array|undefined;
/** Overrides the default backend to use for the provided model. */ /** Overrides the default backend to use for the provided model. */
delegate?: 'cpu'|'gpu'|undefined; delegate?: 'CPU'|'GPU'|undefined;
} }
/** Options to configure MediaPipe Tasks in general. */ /** Options to configure MediaPipe Tasks in general. */

View File

@ -1028,7 +1028,9 @@ export class GraphRunner {
// Set up our TS listener to receive any packets for this stream, and // Set up our TS listener to receive any packets for this stream, and
// additionally reformat our Uint8Array into a Float32Array for the user. // additionally reformat our Uint8Array into a Float32Array for the user.
this.setListener(outputStreamName, (data: Uint8Array) => { this.setListener(outputStreamName, (data: Uint8Array) => {
const floatArray = new Float32Array(data.buffer); // Should be very fast // Should be very fast
const floatArray =
new Float32Array(data.buffer, data.byteOffset, data.length / 4);
callbackFcn(floatArray); callbackFcn(floatArray);
}); });

View File

@ -490,10 +490,10 @@ setuptools.setup(
'Operating System :: MacOS :: MacOS X', 'Operating System :: MacOS :: MacOS X',
'Operating System :: Microsoft :: Windows', 'Operating System :: Microsoft :: Windows',
'Operating System :: POSIX :: Linux', 'Operating System :: POSIX :: Linux',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3 :: Only', 'Programming Language :: Python :: 3 :: Only',
'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Artificial Intelligence',