diff --git a/.github/ISSUE_TEMPLATE/13-solution-issue.md b/.github/ISSUE_TEMPLATE/13-solution-issue.md index 9297edf6b..bf0d613c9 100644 --- a/.github/ISSUE_TEMPLATE/13-solution-issue.md +++ b/.github/ISSUE_TEMPLATE/13-solution-issue.md @@ -1,6 +1,6 @@ --- name: "Solution (legacy) Issue" -about: Use this template for assistance with a specific Mediapipe solution (google.github.io/mediapipe/solutions), such as "Pose" or "Iris", including inference model usage/training, solution-specific calculators, etc. +about: Use this template for assistance with a specific Mediapipe solution (google.github.io/mediapipe/solutions) such as "Pose", including inference model usage/training, solution-specific calculators etc. labels: type:support --- diff --git a/docs/solutions/holistic.md b/docs/solutions/holistic.md index 8c552834e..11589425d 100644 --- a/docs/solutions/holistic.md +++ b/docs/solutions/holistic.md @@ -259,6 +259,7 @@ mp_holistic = mp.solutions.holistic # For static images: IMAGE_FILES = [] +BG_COLOR = (192, 192, 192) # gray with mp_holistic.Holistic( static_image_mode=True, model_complexity=2, diff --git a/mediapipe/calculators/audio/BUILD b/mediapipe/calculators/audio/BUILD index 555f7543f..4a8f0f598 100644 --- a/mediapipe/calculators/audio/BUILD +++ b/mediapipe/calculators/audio/BUILD @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") + licenses(["notice"]) package(default_visibility = ["//visibility:private"]) -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") - proto_library( name = "mfcc_mel_calculators_proto", srcs = ["mfcc_mel_calculators.proto"], diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 2c143a609..b3378a74e 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -567,7 +567,7 @@ cc_library( name = "packet_thinner_calculator", srcs = ["packet_thinner_calculator.cc"], deps = [ - "//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", + ":packet_thinner_calculator_cc_proto", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:video_stream_header", @@ -584,7 +584,7 @@ cc_test( srcs = ["packet_thinner_calculator_test.cc"], deps = [ ":packet_thinner_calculator", - "//mediapipe/calculators/core:packet_thinner_calculator_cc_proto", + ":packet_thinner_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:video_stream_header", @@ -762,7 +762,7 @@ cc_library( srcs = ["packet_resampler_calculator.cc"], hdrs = ["packet_resampler_calculator.h"], deps = [ - "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", + ":packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework/deps:mathutil", @@ -786,7 +786,7 @@ cc_test( ], deps = [ ":packet_resampler_calculator", - "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", + ":packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:video_stream_header", @@ -852,10 +852,10 @@ cc_test( name = "flow_limiter_calculator_test", srcs = ["flow_limiter_calculator_test.cc"], deps = [ + ":counting_source_calculator", ":flow_limiter_calculator", ":flow_limiter_calculator_cc_proto", - "//mediapipe/calculators/core:counting_source_calculator", - "//mediapipe/calculators/core:pass_through_calculator", + ":pass_through_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:test_calculators", @@ -1302,7 +1302,7 @@ cc_test( srcs = ["packet_sequencer_calculator_test.cc"], deps = [ ":packet_sequencer_calculator", - "//mediapipe/calculators/core:pass_through_calculator", + ":pass_through_calculator", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:subgraph", diff --git a/mediapipe/calculators/core/get_vector_item_calculator.h b/mediapipe/calculators/core/get_vector_item_calculator.h index 25d90bfe6..ee886b381 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.h +++ b/mediapipe/calculators/core/get_vector_item_calculator.h @@ -47,7 +47,7 @@ namespace api2 { // calculator: "Get{SpecificType}VectorItemCalculator" // input_stream: "VECTOR:vector" // input_stream: "INDEX:index" -// input_stream: "ITEM:item" +// output_stream: "ITEM:item" // options { // [mediapipe.GetVectorItemCalculatorOptions.ext] { // item_index: 5 diff --git a/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc b/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc index ef3cb9896..e3c92ba52 100644 --- a/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc +++ b/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc @@ -76,7 +76,11 @@ constexpr char kMaxInFlightTag[] = "MAX_IN_FLIGHT"; // } // output_stream: "gated_frames" // } -class RealTimeFlowLimiterCalculator : public CalculatorBase { +// +// Please use FlowLimiterCalculator, which replaces this calculator and +// defines a few additional configuration options. +class ABSL_DEPRECATED("Use FlowLimiterCalculator instead.") + RealTimeFlowLimiterCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { int num_data_streams = cc->Inputs().NumEntries(""); diff --git a/mediapipe/calculators/core/sequence_shift_calculator.cc b/mediapipe/calculators/core/sequence_shift_calculator.cc index 66dbdef2e..026048b79 100644 --- a/mediapipe/calculators/core/sequence_shift_calculator.cc +++ b/mediapipe/calculators/core/sequence_shift_calculator.cc @@ -66,12 +66,16 @@ class SequenceShiftCalculator : public Node { // The number of packets or timestamps we need to store to output packet[i] at // the timestamp of packet[i + packet_offset]; equal to abs(packet_offset). int cache_size_; + bool emit_empty_packets_before_first_packet_ = false; }; MEDIAPIPE_REGISTER_NODE(SequenceShiftCalculator); absl::Status SequenceShiftCalculator::Open(CalculatorContext* cc) { packet_offset_ = kOffset(cc).GetOr( cc->Options().packet_offset()); + emit_empty_packets_before_first_packet_ = + cc->Options() + .emit_empty_packets_before_first_packet(); cache_size_ = abs(packet_offset_); // An offset of zero is a no-op, but someone might still request it. if (packet_offset_ == 0) { @@ -96,6 +100,8 @@ void SequenceShiftCalculator::ProcessPositiveOffset(CalculatorContext* cc) { // Ready to output oldest packet with current timestamp. kOut(cc).Send(packet_cache_.front().At(cc->InputTimestamp())); packet_cache_.pop_front(); + } else if (emit_empty_packets_before_first_packet_) { + LOG(FATAL) << "Not supported yet"; } // Store current packet for later output. packet_cache_.push_back(kIn(cc).packet()); diff --git a/mediapipe/calculators/core/sequence_shift_calculator.proto b/mediapipe/calculators/core/sequence_shift_calculator.proto index 15b111d71..36b0bb959 100644 --- a/mediapipe/calculators/core/sequence_shift_calculator.proto +++ b/mediapipe/calculators/core/sequence_shift_calculator.proto @@ -23,4 +23,8 @@ message SequenceShiftCalculatorOptions { optional SequenceShiftCalculatorOptions ext = 107633927; } optional int32 packet_offset = 1 [default = -1]; + + // Emits empty packets before the first delayed packet is emitted. Takes + // effect only when packet offset is set to positive. + optional bool emit_empty_packets_before_first_packet = 2 [default = false]; } diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 530dd3d4a..9aae8cfbc 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -378,8 +378,8 @@ cc_library( name = "scale_image_calculator", srcs = ["scale_image_calculator.cc"], deps = [ + ":scale_image_calculator_cc_proto", ":scale_image_utils", - "//mediapipe/calculators/image:scale_image_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_frame", @@ -747,8 +747,8 @@ cc_test( tags = ["desktop_only_test"], deps = [ ":affine_transformation", + ":image_transformation_calculator", ":warp_affine_calculator", - "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/tensor:image_to_tensor_converter", "//mediapipe/calculators/tensor:image_to_tensor_utils", "//mediapipe/calculators/util:from_image_calculator", diff --git a/mediapipe/calculators/image/affine_transformation_runner_gl.cc b/mediapipe/calculators/image/affine_transformation_runner_gl.cc index c38fc8e07..361dfc902 100644 --- a/mediapipe/calculators/image/affine_transformation_runner_gl.cc +++ b/mediapipe/calculators/image/affine_transformation_runner_gl.cc @@ -92,8 +92,8 @@ class GlTextureWarpAffineRunner constexpr GLchar kVertShader[] = R"( in vec4 position; - in mediump vec4 texture_coordinate; - out mediump vec2 sample_coordinate; + in highp vec4 texture_coordinate; + out highp vec2 sample_coordinate; uniform mat4 transform_matrix; void main() { @@ -104,7 +104,7 @@ class GlTextureWarpAffineRunner )"; constexpr GLchar kFragShader[] = R"( - DEFAULT_PRECISION(mediump, float) + DEFAULT_PRECISION(highp, float) in vec2 sample_coordinate; uniform sampler2D input_texture; diff --git a/mediapipe/calculators/image/color_convert_calculator.cc b/mediapipe/calculators/image/color_convert_calculator.cc index bdac932bb..4781f1ea1 100644 --- a/mediapipe/calculators/image/color_convert_calculator.cc +++ b/mediapipe/calculators/image/color_convert_calculator.cc @@ -38,6 +38,7 @@ void SetColorChannel(int channel, uint8 value, cv::Mat* mat) { constexpr char kRgbaInTag[] = "RGBA_IN"; constexpr char kRgbInTag[] = "RGB_IN"; +constexpr char kBgrInTag[] = "BGR_IN"; constexpr char kBgraInTag[] = "BGRA_IN"; constexpr char kGrayInTag[] = "GRAY_IN"; constexpr char kRgbaOutTag[] = "RGBA_OUT"; @@ -57,6 +58,7 @@ constexpr char kGrayOutTag[] = "GRAY_OUT"; // RGB -> RGBA // RGBA -> BGRA // BGRA -> RGBA +// BGR -> RGB // // This calculator only supports a single input stream and output stream at a // time. If more than one input stream or output stream is present, the @@ -69,6 +71,7 @@ constexpr char kGrayOutTag[] = "GRAY_OUT"; // RGB_IN: The input video stream (ImageFrame, SRGB). // BGRA_IN: The input video stream (ImageFrame, SBGRA). // GRAY_IN: The input video stream (ImageFrame, GRAY8). +// BGR_IN: The input video stream (ImageFrame, SBGR). // // Output streams: // RGBA_OUT: The output video stream (ImageFrame, SRGBA). @@ -122,6 +125,10 @@ absl::Status ColorConvertCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kBgraInTag).Set(); } + if (cc->Inputs().HasTag(kBgrInTag)) { + cc->Inputs().Tag(kBgrInTag).Set(); + } + if (cc->Outputs().HasTag(kRgbOutTag)) { cc->Outputs().Tag(kRgbOutTag).Set(); } @@ -194,6 +201,11 @@ absl::Status ColorConvertCalculator::Process(CalculatorContext* cc) { return ConvertAndOutput(kRgbaInTag, kBgraOutTag, ImageFormat::SBGRA, cv::COLOR_RGBA2BGRA, cc); } + // BGR -> RGB + if (cc->Inputs().HasTag(kBgrInTag) && cc->Outputs().HasTag(kRgbOutTag)) { + return ConvertAndOutput(kBgrInTag, kRgbOutTag, ImageFormat::SRGB, + cv::COLOR_BGR2RGB, cc); + } return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Unsupported image format conversion."; diff --git a/mediapipe/calculators/internal/BUILD b/mediapipe/calculators/internal/BUILD index caade2dc3..8647e3f3f 100644 --- a/mediapipe/calculators/internal/BUILD +++ b/mediapipe/calculators/internal/BUILD @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -licenses(["notice"]) - load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +licenses(["notice"]) + package(default_visibility = ["//visibility:private"]) proto_library( diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc index 5efd34041..165df8970 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc @@ -68,8 +68,8 @@ class GlProcessor : public ImageToTensorConverter { constexpr GLchar kExtractSubRectVertexShader[] = R"( in vec4 position; - in mediump vec4 texture_coordinate; - out mediump vec2 sample_coordinate; + in highp vec4 texture_coordinate; + out highp vec2 sample_coordinate; uniform mat4 transform_matrix; void main() { @@ -86,7 +86,7 @@ class GlProcessor : public ImageToTensorConverter { )"; constexpr GLchar kExtractSubRectFragBody[] = R"( - DEFAULT_PRECISION(mediump, float) + DEFAULT_PRECISION(highp, float) // Provided by kExtractSubRectVertexShader. in vec2 sample_coordinate; diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 1529ead8a..a679a80fd 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -22,8 +22,8 @@ cc_library( name = "alignment_points_to_rects_calculator", srcs = ["alignment_points_to_rects_calculator.cc"], deps = [ + ":detections_to_rects_calculator", ":detections_to_rects_calculator_cc_proto", - "//mediapipe/calculators/util:detections_to_rects_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 872944acd..83346dad1 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1,4 +1,3 @@ -# # Copyright 2019 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -227,13 +226,13 @@ cc_library( ":mediapipe_internal", ], deps = [ + ":calculator_cc_proto", ":graph_service", + ":mediapipe_options_cc_proto", + ":packet_generator_cc_proto", ":packet_type", ":port", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:mediapipe_options_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", + ":status_handler_cc_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_map", @@ -329,10 +328,10 @@ cc_library( ":thread_pool_executor", ":timestamp", ":validated_graph_config", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", - "//mediapipe/framework:thread_pool_executor_cc_proto", + ":calculator_cc_proto", + ":packet_generator_cc_proto", + ":status_handler_cc_proto", + ":thread_pool_executor_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", @@ -370,7 +369,7 @@ cc_library( visibility = [":mediapipe_internal"], deps = [ ":graph_service", - "//mediapipe/framework:packet", + ":packet", "@com_google_absl//absl/status", ], ) @@ -380,7 +379,7 @@ cc_test( srcs = ["graph_service_manager_test.cc"], deps = [ ":graph_service_manager", - "//mediapipe/framework:packet", + ":packet", "//mediapipe/framework/port:gtest_main", ], ) @@ -392,6 +391,7 @@ cc_library( visibility = [":mediapipe_internal"], deps = [ ":calculator_base", + ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_state", @@ -408,10 +408,9 @@ cc_library( ":packet_set", ":packet_type", ":port", + ":stream_handler_cc_proto", ":timestamp", ":validated_graph_config", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:stream_handler_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -467,6 +466,7 @@ cc_library( hdrs = ["calculator_state.h"], visibility = [":mediapipe_internal"], deps = [ + ":calculator_cc_proto", ":counter", ":counter_factory", ":graph_service", @@ -476,7 +476,6 @@ cc_library( ":packet", ":packet_set", ":port", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/tool:options_map", @@ -584,7 +583,7 @@ cc_library( hdrs = ["executor.h"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:mediapipe_options_cc_proto", + ":mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", @@ -671,11 +670,11 @@ cc_library( ":collection_item_id", ":input_stream_manager", ":input_stream_shard", + ":mediapipe_options_cc_proto", ":mediapipe_profiling", ":packet", ":packet_set", ":packet_type", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -785,12 +784,12 @@ cc_library( ":calculator_context_manager", ":collection", ":collection_item_id", + ":mediapipe_options_cc_proto", ":output_stream_manager", ":output_stream_shard", ":packet_set", ":packet_type", ":timestamp", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", @@ -876,10 +875,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":packet", + ":packet_generator_cc_proto", ":packet_set", ":packet_type", ":port", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:status", @@ -897,13 +896,13 @@ cc_library( ":delegating_executor", ":executor", ":packet", + ":packet_factory_cc_proto", ":packet_generator", + ":packet_generator_cc_proto", ":packet_type", ":port", ":thread_pool_executor", ":validated_graph_config", - "//mediapipe/framework:packet_factory_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", @@ -1020,10 +1019,10 @@ cc_library( hdrs = ["status_handler.h"], visibility = ["//visibility:public"], deps = [ + ":mediapipe_options_cc_proto", ":packet_set", ":packet_type", ":port", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:status", "@com_google_absl//absl/memory", @@ -1036,10 +1035,10 @@ cc_library( hdrs = ["subgraph.h"], visibility = ["//visibility:public"], deps = [ + ":calculator_cc_proto", ":graph_service", ":graph_service_manager", ":port", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -1061,7 +1060,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":calculator_framework", - "//mediapipe/framework:test_calculators_cc_proto", + ":test_calculators_cc_proto", "//mediapipe/framework/deps:mathutil", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:integral_types", @@ -1098,7 +1097,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":executor", - "//mediapipe/framework:thread_pool_executor_cc_proto", + ":thread_pool_executor_cc_proto", "//mediapipe/framework/deps:thread_options", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", @@ -1163,22 +1162,22 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":calculator_base", + ":calculator_cc_proto", ":calculator_contract", ":graph_service_manager", ":legacy_calculator_support", ":packet", ":packet_generator", + ":packet_generator_cc_proto", ":packet_set", ":packet_type", ":port", ":status_handler", + ":status_handler_cc_proto", + ":stream_handler_cc_proto", ":subgraph", + ":thread_pool_executor_cc_proto", ":timestamp", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", - "//mediapipe/framework:stream_handler_cc_proto", - "//mediapipe/framework:thread_pool_executor_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -1203,11 +1202,11 @@ cc_test( name = "validated_graph_config_test", srcs = ["validated_graph_config_test.cc"], deps = [ + ":calculator_cc_proto", ":calculator_framework", ":graph_service", ":graph_service_manager", ":validated_graph_config", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/api2:node", "//mediapipe/framework/api2:port", "//mediapipe/framework/port:gtest_main", @@ -1234,6 +1233,7 @@ cc_test( linkstatic = 1, deps = [ ":calculator_base", + ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_registry", @@ -1243,7 +1243,6 @@ cc_test( ":output_stream_shard", ":packet_set", ":packet_type", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:status_util", @@ -1257,11 +1256,11 @@ cc_test( srcs = ["calculator_contract_test.cc"], linkstatic = 1, deps = [ + ":calculator_cc_proto", ":calculator_contract", ":calculator_contract_test_cc_proto", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", + ":packet_generator_cc_proto", + ":status_handler_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", ], @@ -1369,6 +1368,7 @@ cc_test( srcs = ["calculator_context_test.cc"], linkstatic = 1, deps = [ + ":calculator_cc_proto", ":calculator_context", ":calculator_context_manager", ":calculator_state", @@ -1377,7 +1377,6 @@ cc_test( ":output_stream_shard", ":packet_set", ":packet_type", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", @@ -1404,6 +1403,7 @@ cc_test( ":executor", ":input_stream_handler", ":lifetime_tracker", + ":mediapipe_options_cc_proto", ":output_stream_poller", ":packet_set", ":packet_type", @@ -1411,13 +1411,12 @@ cc_test( ":subgraph", ":test_calculators", ":thread_pool_executor", + ":thread_pool_executor_cc_proto", ":timestamp", ":type_map", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/framework:mediapipe_options_cc_proto", - "//mediapipe/framework:thread_pool_executor_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", @@ -1482,12 +1481,12 @@ cc_test( ], visibility = ["//visibility:public"], deps = [ + ":calculator_cc_proto", ":calculator_framework", ":test_calculators", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", @@ -1631,8 +1630,8 @@ cc_test( srcs = ["packet_generator_test.cc"], deps = [ ":packet_generator", + ":packet_generator_cc_proto", ":packet_type", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/tool:validate_type", "@com_google_absl//absl/strings", diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index 3bf3ec198..b01c2b759 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -15,12 +15,17 @@ #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" -namespace mediapipe { -namespace api2 { -namespace test { +namespace mediapipe::api2::builder { +namespace { + +using ::mediapipe::api2::test::Bar; +using ::mediapipe::api2::test::FloatAdder; +using ::mediapipe::api2::test::Foo; +using ::mediapipe::api2::test::Foo2; +using ::mediapipe::api2::test::FooBar1; TEST(BuilderTest, BuildGraph) { - builder::Graph graph; + Graph graph; auto& foo = graph.AddNode("Foo"); auto& bar = graph.AddNode("Bar"); graph.In("IN").SetName("base") >> foo.In("BASE"); @@ -49,22 +54,19 @@ TEST(BuilderTest, BuildGraph) { } TEST(BuilderTest, CopyableSource) { - builder::Graph graph; - builder::Source a = graph[Input("A")]; - a.SetName("a"); - builder::Source b = graph[Input("B")]; - b.SetName("b"); - builder::SideSource side_a = graph[SideInput("SIDE_A")]; - side_a.SetName("side_a"); - builder::SideSource side_b = graph[SideInput("SIDE_B")]; - side_b.SetName("side_b"); - builder::Destination out = graph[Output("OUT")]; - builder::SideDestination side_out = - graph[SideOutput("SIDE_OUT")]; + Graph graph; + Source a = graph.In("A").SetName("a").Cast(); + Source b = graph.In("B").SetName("b").Cast(); + SideSource side_a = + graph.SideIn("SIDE_A").SetName("side_a").Cast(); + SideSource side_b = + graph.SideIn("SIDE_B").SetName("side_b").Cast(); + Destination out = graph.Out("OUT").Cast(); + SideDestination side_out = graph.SideOut("SIDE_OUT").Cast(); - builder::Source input = a; + Source input = a; input = b; - builder::SideSource side_input = side_b; + SideSource side_input = side_b; side_input = side_a; input >> out; @@ -83,31 +85,27 @@ TEST(BuilderTest, CopyableSource) { } TEST(BuilderTest, BuildGraphWithFunctions) { - builder::Graph graph; + Graph graph; - builder::Source base = graph[Input("IN")]; - base.SetName("base"); - builder::SideSource side = graph[SideInput("SIDE")]; - side.SetName("side"); + Source base = graph.In("IN").SetName("base").Cast(); + SideSource side = graph.SideIn("SIDE").SetName("side").Cast(); - auto foo_fn = [](builder::Source base, builder::SideSource side, - builder::Graph& graph) { + auto foo_fn = [](Source base, SideSource side, Graph& graph) { auto& foo = graph.AddNode("Foo"); - base >> foo[Input("BASE")]; - side >> foo[SideInput("SIDE")]; - return foo[Output("OUT")]; + base >> foo.In("BASE"); + side >> foo.SideIn("SIDE"); + return foo.Out("OUT")[0].Cast(); }; - builder::Source foo_out = foo_fn(base, side, graph); + Source foo_out = foo_fn(base, side, graph); - auto bar_fn = [](builder::Source in, builder::Graph& graph) { + auto bar_fn = [](Source in, Graph& graph) { auto& bar = graph.AddNode("Bar"); - in >> bar[Input("IN")]; - return bar[Output("OUT")]; + in >> bar.In("IN"); + return bar.Out("OUT")[0].Cast(); }; - builder::Source bar_out = bar_fn(foo_out, graph); - bar_out.SetName("out"); + Source bar_out = bar_fn(foo_out, graph); - bar_out >> graph[Output("OUT")]; + bar_out.SetName("out") >> graph.Out("OUT"); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -131,7 +129,7 @@ TEST(BuilderTest, BuildGraphWithFunctions) { template void BuildGraphTypedTest() { - builder::Graph graph; + Graph graph; auto& foo = graph.AddNode(); auto& bar = graph.AddNode(); graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE")); @@ -161,12 +159,12 @@ void BuildGraphTypedTest() { EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } -TEST(BuilderTest, BuildGraphTyped) { BuildGraphTypedTest(); } +TEST(BuilderTest, BuildGraphTyped) { BuildGraphTypedTest(); } -TEST(BuilderTest, BuildGraphTyped2) { BuildGraphTypedTest(); } +TEST(BuilderTest, BuildGraphTyped2) { BuildGraphTypedTest(); } TEST(BuilderTest, FanOut) { - builder::Graph graph; + Graph graph; auto& foo = graph.AddNode("Foo"); auto& adder = graph.AddNode("FloatAdder"); graph.In("IN").SetName("base") >> foo.In("BASE"); @@ -194,9 +192,9 @@ TEST(BuilderTest, FanOut) { } TEST(BuilderTest, TypedMultiple) { - builder::Graph graph; - auto& foo = graph.AddNode(); - auto& adder = graph.AddNode(); + Graph graph; + auto& foo = graph.AddNode(); + auto& adder = graph.AddNode(); graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE")); foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[0]; foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[1]; @@ -222,14 +220,14 @@ TEST(BuilderTest, TypedMultiple) { } TEST(BuilderTest, TypedByPorts) { - builder::Graph graph; - auto& foo = graph.AddNode(); + Graph graph; + auto& foo = graph.AddNode(); auto& adder = graph.AddNode(); - graph[FooBar1::kIn].SetName("base") >> foo[Foo::kBase]; + graph.In(FooBar1::kIn).SetName("base") >> foo[Foo::kBase]; foo[Foo::kOut] >> adder[FloatAdder::kIn][0]; foo[Foo::kOut] >> adder[FloatAdder::kIn][1]; - adder[FloatAdder::kOut].SetName("out") >> graph[FooBar1::kOut]; + adder[FloatAdder::kOut].SetName("out") >> graph.Out(FooBar1::kOut); CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( @@ -251,7 +249,7 @@ TEST(BuilderTest, TypedByPorts) { } TEST(BuilderTest, PacketGenerator) { - builder::Graph graph; + Graph graph; auto& generator = graph.AddPacketGenerator("FloatGenerator"); graph.SideIn("IN") >> generator.SideIn("IN"); generator.SideOut("OUT") >> graph.SideOut("OUT"); @@ -270,7 +268,7 @@ TEST(BuilderTest, PacketGenerator) { } TEST(BuilderTest, EmptyTag) { - builder::Graph graph; + Graph graph; auto& foo = graph.AddNode("Foo"); graph.In("A").SetName("a") >> foo.In("")[0]; graph.In("C").SetName("c") >> foo.In("")[2]; @@ -302,7 +300,7 @@ TEST(BuilderTest, StringLikeTags) { const std::string kB = "B"; constexpr absl::string_view kC = "C"; - builder::Graph graph; + Graph graph; auto& foo = graph.AddNode("Foo"); graph.In(kA).SetName("a") >> foo.In(kA); graph.In(kB).SetName("b") >> foo.In(kB); @@ -324,7 +322,7 @@ TEST(BuilderTest, StringLikeTags) { } TEST(BuilderTest, GraphIndexes) { - builder::Graph graph; + Graph graph; auto& foo = graph.AddNode("Foo"); graph.In(0).SetName("a") >> foo.In("")[0]; graph.In(1).SetName("c") >> foo.In("")[2]; @@ -376,28 +374,27 @@ class AnyAndSameTypeCalculator : public NodeIntf { }; TEST(BuilderTest, AnyAndSameTypeHandledProperly) { - builder::Graph graph; - builder::Source any_input = graph[Input{"GRAPH_ANY_INPUT"}]; - builder::Source int_input = graph[Input{"GRAPH_INT_INPUT"}]; + Graph graph; + Source any_input = graph.In("GRAPH_ANY_INPUT"); + Source int_input = graph.In("GRAPH_INT_INPUT").Cast(); auto& node = graph.AddNode("AnyAndSameTypeCalculator"); any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; int_input >> node[AnyAndSameTypeCalculator::kIntInput]; - builder::Source any_type_output = + Source any_type_output = node[AnyAndSameTypeCalculator::kAnyTypeOutput]; any_type_output.SetName("any_type_output"); - builder::Source same_type_output = + Source same_type_output = node[AnyAndSameTypeCalculator::kSameTypeOutput]; same_type_output.SetName("same_type_output"); - builder::Source recursive_same_type_output = + Source recursive_same_type_output = node[AnyAndSameTypeCalculator::kRecursiveSameTypeOutput]; recursive_same_type_output.SetName("recursive_same_type_output"); - builder::Source same_int_output = - node[AnyAndSameTypeCalculator::kSameIntOutput]; + Source same_int_output = node[AnyAndSameTypeCalculator::kSameIntOutput]; same_int_output.SetName("same_int_output"); - builder::Source recursive_same_int_type_output = + Source recursive_same_int_type_output = node[AnyAndSameTypeCalculator::kRecursiveSameIntOutput]; recursive_same_int_type_output.SetName("recursive_same_int_type_output"); @@ -420,15 +417,16 @@ TEST(BuilderTest, AnyAndSameTypeHandledProperly) { } TEST(BuilderTest, AnyTypeCanBeCast) { - builder::Graph graph; - builder::Source any_input = + Graph graph; + Source any_input = graph.In("GRAPH_ANY_INPUT").Cast(); auto& node = graph.AddNode("AnyAndSameTypeCalculator"); any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; - builder::Source any_type_output = - node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast(); - any_type_output.SetName("any_type_output"); + Source any_type_output = + node[AnyAndSameTypeCalculator::kAnyTypeOutput] + .SetName("any_type_output") + .Cast(); any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast(); @@ -446,11 +444,11 @@ TEST(BuilderTest, AnyTypeCanBeCast) { } TEST(BuilderTest, MultiPortIsCastToMultiPort) { - builder::Graph graph; - builder::MultiSource any_input = graph.In("ANY_INPUT"); - builder::MultiSource int_input = any_input.Cast(); - builder::MultiDestination any_output = graph.Out("ANY_OUTPUT"); - builder::MultiDestination int_output = any_output.Cast(); + Graph graph; + MultiSource any_input = graph.In("ANY_INPUT"); + MultiSource int_input = any_input.Cast(); + MultiDestination any_output = graph.Out("ANY_OUTPUT"); + MultiDestination int_output = any_output.Cast(); int_input >> int_output; CalculatorGraphConfig expected = @@ -462,11 +460,11 @@ TEST(BuilderTest, MultiPortIsCastToMultiPort) { } TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) { - builder::Graph graph; - builder::MultiSource any_multi_input = graph.In("ANY_INPUT"); - builder::Source any_input = any_multi_input; - builder::MultiDestination any_multi_output = graph.Out("ANY_OUTPUT"); - builder::Destination any_output = any_multi_output; + Graph graph; + MultiSource any_multi_input = graph.In("ANY_INPUT"); + Source any_input = any_multi_input; + MultiDestination any_multi_output = graph.Out("ANY_OUTPUT"); + Destination any_output = any_multi_output; any_input >> any_output; CalculatorGraphConfig expected = @@ -478,11 +476,11 @@ TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) { } TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) { - builder::Graph graph; - builder::Source int_input = graph.In("INT_INPUT").Cast(); - builder::Source any_input = graph.In("ANY_OUTPUT"); - builder::Destination int_output = graph.Out("INT_OUTPUT").Cast(); - builder::Destination any_output = graph.Out("ANY_OUTPUT"); + Graph graph; + Source int_input = graph.In("INT_INPUT").Cast(); + Source any_input = graph.In("ANY_OUTPUT"); + Destination int_output = graph.Out("INT_OUTPUT").Cast(); + Destination any_output = graph.Out("ANY_OUTPUT"); int_input >> int_output; any_input >> any_output; @@ -496,6 +494,5 @@ TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) { EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } -} // namespace test -} // namespace api2 -} // namespace mediapipe +} // namespace +} // namespace mediapipe::api2::builder diff --git a/mediapipe/framework/api2/port.h b/mediapipe/framework/api2/port.h index e63d3651e..eee542640 100644 --- a/mediapipe/framework/api2/port.h +++ b/mediapipe/framework/api2/port.h @@ -557,8 +557,8 @@ class OutputSidePacketAccess { if (output_) output_->Set(ToOldPacket(std::move(packet))); } - void Set(const T& payload) { Set(MakePacket(payload)); } - void Set(T&& payload) { Set(MakePacket(std::move(payload))); } + void Set(const T& payload) { Set(api2::MakePacket(payload)); } + void Set(T&& payload) { Set(api2::MakePacket(std::move(payload))); } private: OutputSidePacketAccess(OutputSidePacket* output) : output_(output) {} diff --git a/mediapipe/framework/deps/BUILD b/mediapipe/framework/deps/BUILD index 27bc105c8..7ff004f1e 100644 --- a/mediapipe/framework/deps/BUILD +++ b/mediapipe/framework/deps/BUILD @@ -20,9 +20,14 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") licenses(["notice"]) -package(default_visibility = [ - "//mediapipe:__subpackages__", -]) +package_group( + name = "mediapipe_internal", + packages = [ + "//mediapipe/...", + ], +) + +package(default_visibility = ["mediapipe_internal"]) bzl_library( name = "expand_template_bzl", @@ -214,6 +219,9 @@ cc_library( name = "registration", srcs = ["registration.cc"], hdrs = ["registration.h"], + visibility = [ + "mediapipe_internal", + ], deps = [ ":registration_token", "//mediapipe/framework/port:logging", diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index fdb698c48..f5a043f10 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -26,7 +26,7 @@ licenses(["notice"]) mediapipe_proto_library( name = "detection_proto", srcs = ["detection.proto"], - deps = ["//mediapipe/framework/formats:location_data_proto"], + deps = [":location_data_proto"], ) mediapipe_register_type( @@ -38,7 +38,7 @@ mediapipe_register_type( "::std::vector<::mediapipe::Detection>", "::std::vector<::mediapipe::DetectionList>", ], - deps = ["//mediapipe/framework/formats:detection_cc_proto"], + deps = [":detection_cc_proto"], ) mediapipe_proto_library( @@ -105,8 +105,8 @@ cc_library( srcs = ["matrix.cc"], hdrs = ["matrix.h"], deps = [ + ":matrix_data_cc_proto", "//mediapipe/framework:port", - "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", @@ -142,7 +142,7 @@ cc_library( srcs = ["image_frame.cc"], hdrs = ["image_frame.h"], deps = [ - "//mediapipe/framework/formats:image_format_cc_proto", + ":image_format_cc_proto", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", @@ -166,8 +166,8 @@ cc_library( srcs = ["image_frame_opencv.cc"], hdrs = ["image_frame_opencv.h"], deps = [ + ":image_format_cc_proto", ":image_frame", - "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/port:opencv_core", ], ) @@ -194,7 +194,7 @@ cc_library( deps = [ "@com_google_protobuf//:protobuf", "//mediapipe/framework/formats/annotation:locus_cc_proto", - "//mediapipe/framework/formats:location_data_cc_proto", + ":location_data_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -245,7 +245,7 @@ cc_library( name = "video_stream_header", hdrs = ["video_stream_header.h"], deps = [ - "//mediapipe/framework/formats:image_format_cc_proto", + ":image_format_cc_proto", ], ) @@ -263,9 +263,9 @@ cc_test( size = "small", srcs = ["image_frame_opencv_test.cc"], deps = [ + ":image_format_cc_proto", ":image_frame", ":image_frame_opencv", - "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -324,8 +324,8 @@ cc_library( "//conditions:default": [], }), deps = [ - "//mediapipe/framework/formats:image_format_cc_proto", - "//mediapipe/framework/formats:image_frame", + ":image_format_cc_proto", + ":image_frame", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", "//mediapipe/framework:type_map", @@ -354,7 +354,7 @@ cc_library( hdrs = ["image_multi_pool.h"], deps = [ ":image", - "//mediapipe/framework/formats:image_frame_pool", + ":image_frame_pool", "//mediapipe/framework:port", "//mediapipe/framework/port:logging", "@com_google_absl//absl/memory", @@ -390,7 +390,7 @@ cc_library( ], deps = [ ":image", - "//mediapipe/framework/formats:image_format_cc_proto", + ":image_format_cc_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:statusor", @@ -428,7 +428,10 @@ cc_library( "tensor.cc", "tensor_ahwb.cc", ], - hdrs = ["tensor.h"], + hdrs = [ + "tensor.h", + "tensor_internal.h", + ], copts = select({ "//mediapipe:apple": [ "-x objective-c++", @@ -452,6 +455,7 @@ cc_library( ], }), deps = [ + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", diff --git a/mediapipe/framework/formats/motion/BUILD b/mediapipe/framework/formats/motion/BUILD index f1bbc0289..c9bb8b4ff 100644 --- a/mediapipe/framework/formats/motion/BUILD +++ b/mediapipe/framework/formats/motion/BUILD @@ -38,11 +38,11 @@ cc_library( srcs = ["optical_flow_field.cc"], hdrs = ["optical_flow_field.h"], deps = [ + ":optical_flow_field_data_cc_proto", "//mediapipe/framework:type_map", "//mediapipe/framework/deps:mathutil", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", - "//mediapipe/framework/formats/motion:optical_flow_field_data_cc_proto", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", diff --git a/mediapipe/framework/formats/tensor.cc b/mediapipe/framework/formats/tensor.cc index fdafbff5c..3f11d368a 100644 --- a/mediapipe/framework/formats/tensor.cc +++ b/mediapipe/framework/formats/tensor.cc @@ -246,10 +246,10 @@ Tensor::OpenGlTexture2dView::GetLayoutDimensions(const Tensor::Shape& shape, return Tensor::OpenGlTexture2dView::Layout::kAligned; } } - // The best performance of a compute shader can be achived with textures' + // The best performance of a compute shader can be achieved with textures' // width multiple of 256. Making minimum fixed width of 256 waste memory for // small tensors. The optimal balance memory-vs-performance is power of 2. - // The texture width and height are choosen to be closer to square. + // The texture width and height are chosen to be closer to square. float power = std::log2(std::sqrt(static_cast(num_pixels))); w = 1 << static_cast(power); int h = (num_pixels + w - 1) / w; @@ -326,7 +326,7 @@ Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const { auto lock(absl::make_unique(&view_mutex_)); AllocateOpenGlBuffer(); if (!(valid_ & kValidOpenGlBuffer)) { - // If the call succeds then AHWB -> SSBO are synchronized so any usage of + // If the call succeeds then AHWB -> SSBO are synchronized so any usage of // the SSBO is correct after this call. if (!InsertAhwbToSsboFence()) { glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); @@ -348,8 +348,10 @@ Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const { }; } -Tensor::OpenGlBufferView Tensor::GetOpenGlBufferWriteView() const { +Tensor::OpenGlBufferView Tensor::GetOpenGlBufferWriteView( + uint64_t source_location_hash) const { auto lock(absl::make_unique(&view_mutex_)); + TrackAhwbUsage(source_location_hash); AllocateOpenGlBuffer(); valid_ = kValidOpenGlBuffer; return {opengl_buffer_, std::move(lock), nullptr}; @@ -385,6 +387,7 @@ void Tensor::Move(Tensor* src) { src->element_type_ = ElementType::kNone; // Mark as invalidated. cpu_buffer_ = src->cpu_buffer_; src->cpu_buffer_ = nullptr; + ahwb_tracking_key_ = src->ahwb_tracking_key_; #if MEDIAPIPE_METAL_ENABLED device_ = src->device_; src->device_ = nil; @@ -589,8 +592,10 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { return {cpu_buffer_, std::move(lock)}; } -Tensor::CpuWriteView Tensor::GetCpuWriteView() const { +Tensor::CpuWriteView Tensor::GetCpuWriteView( + uint64_t source_location_hash) const { auto lock = absl::make_unique(&view_mutex_); + TrackAhwbUsage(source_location_hash); AllocateCpuBuffer(); valid_ = kValidCpu; #ifdef MEDIAPIPE_TENSOR_USE_AHWB @@ -620,24 +625,4 @@ void Tensor::AllocateCpuBuffer() const { } } -void Tensor::SetPreferredStorageType(StorageType type) { -#ifdef MEDIAPIPE_TENSOR_USE_AHWB - if (__builtin_available(android 26, *)) { - use_ahwb_ = type == StorageType::kAhwb; - VLOG(4) << "Tensor: use of AHardwareBuffer is " - << (use_ahwb_ ? "allowed" : "not allowed"); - } -#else - VLOG(4) << "Tensor: use of AHardwareBuffer is not allowed"; -#endif // MEDIAPIPE_TENSOR_USE_AHWB -} - -Tensor::StorageType Tensor::GetPreferredStorageType() { -#ifdef MEDIAPIPE_TENSOR_USE_AHWB - return use_ahwb_ ? StorageType::kAhwb : StorageType::kDefault; -#else - return StorageType::kDefault; -#endif // MEDIAPIPE_TENSOR_USE_AHWB -} - } // namespace mediapipe diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 9d3e90b6a..8a6f02e9d 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -24,8 +24,9 @@ #include #include -#include "absl/memory/memory.h" +#include "absl/container/flat_hash_set.h" #include "absl/synchronization/mutex.h" +#include "mediapipe/framework/formats/tensor_internal.h" #include "mediapipe/framework/port.h" #if MEDIAPIPE_METAL_ENABLED @@ -48,6 +49,22 @@ #include "mediapipe/gpu/gl_context.h" #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 +#if defined __has_builtin +#if __has_builtin(__builtin_LINE) +#define builtin_LINE __builtin_LINE +#endif +#if __has_builtin(__builtin_FILE) +#define builtin_FILE __builtin_FILE +#endif +#endif + +#ifndef builtin_LINE +#define builtin_LINE() 0 +#endif +#ifndef builtin_FILE +#define builtin_FILE() "" +#endif + namespace mediapipe { // Tensor is a container of multi-dimensional data that supports sharing the @@ -65,7 +82,7 @@ namespace mediapipe { // GLuint buffer = view.buffer(); // Then the buffer can be bound to the GPU command buffer. // ...binding the buffer to the command buffer... -// ...commiting command buffer and releasing the view... +// ...committing command buffer and releasing the view... // // The following request for the CPU view will be blocked until the GPU view is // released and the GPU task is finished. @@ -161,7 +178,9 @@ class Tensor { using CpuReadView = CpuView; CpuReadView GetCpuReadView() const; using CpuWriteView = CpuView; - CpuWriteView GetCpuWriteView() const; + CpuWriteView GetCpuWriteView( + uint64_t source_location_hash = + tensor_internal::FnvHash64(builtin_FILE(), builtin_LINE())) const; #if MEDIAPIPE_METAL_ENABLED // TODO: id vs. MtlBufferView. @@ -305,7 +324,9 @@ class Tensor { // A valid OpenGL context must be bound to the calling thread due to possible // GPU resource allocation. OpenGlBufferView GetOpenGlBufferReadView() const; - OpenGlBufferView GetOpenGlBufferWriteView() const; + OpenGlBufferView GetOpenGlBufferWriteView( + uint64_t source_location_hash = + tensor_internal::FnvHash64(builtin_FILE(), builtin_LINE())) const; #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 const Shape& shape() const { return shape_; } @@ -408,9 +429,13 @@ class Tensor { mutable std::function release_callback_; bool AllocateAHardwareBuffer(int size_alignment = 0) const; void CreateEglSyncAndFd() const; - // Use Ahwb for other views: OpenGL / CPU buffer. #endif // MEDIAPIPE_TENSOR_USE_AHWB - static inline bool use_ahwb_ = false; + // Use Ahwb for other views: OpenGL / CPU buffer. + mutable bool use_ahwb_ = false; + mutable uint64_t ahwb_tracking_key_ = 0; + // TODO: Tracks all unique tensors. Can grow to a large number. LRU + // can be more predicted. + static inline absl::flat_hash_set ahwb_usage_track_; // Expects the target SSBO to be already bound. bool AllocateAhwbMapToSsbo() const; bool InsertAhwbToSsboFence() const; @@ -419,6 +444,8 @@ class Tensor { void* MapAhwbToCpuRead() const; void* MapAhwbToCpuWrite() const; void MoveCpuOrSsboToAhwb() const; + // Set current tracking key, set "use ahwb" if the key is already marked. + void TrackAhwbUsage(uint64_t key) const; #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 mutable std::shared_ptr gl_context_; diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index 3c3ec8b17..466811be7 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -212,9 +212,6 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const { CHECK(!(valid_ & kValidOpenGlTexture2d)) << "Tensor conversion between OpenGL texture and AHardwareBuffer is not " "supported."; - CHECK(ahwb_ || !(valid_ & kValidOpenGlBuffer)) - << "Interoperability bettween OpenGL buffer and AHardwareBuffer is not " - "supported on target system."; bool transfer = !ahwb_; CHECK(AllocateAHardwareBuffer()) << "AHardwareBuffer is not supported on the target system."; @@ -268,6 +265,10 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView( } bool Tensor::AllocateAHardwareBuffer(int size_alignment) const { + // Mark current tracking key as Ahwb-use. + ahwb_usage_track_.insert(ahwb_tracking_key_); + use_ahwb_ = true; + if (__builtin_available(android 26, *)) { if (ahwb_ == nullptr) { AHardwareBuffer_Desc desc = {}; @@ -315,7 +316,13 @@ void Tensor::MoveCpuOrSsboToAhwb() const { ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, -1, nullptr, &dest); CHECK(error == 0) << "AHardwareBuffer_lock " << error; } - if (valid_ & kValidOpenGlBuffer) { + if (valid_ & kValidCpu) { + std::memcpy(dest, cpu_buffer_, bytes()); + // Free CPU memory because next time AHWB is mapped instead. + free(cpu_buffer_); + cpu_buffer_ = nullptr; + valid_ &= ~kValidCpu; + } else if (valid_ & kValidOpenGlBuffer) { gl_context_->Run([this, dest]() { glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); const void* src = glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(), @@ -326,11 +333,9 @@ void Tensor::MoveCpuOrSsboToAhwb() const { }); opengl_buffer_ = GL_INVALID_INDEX; gl_context_ = nullptr; - } else if (valid_ & kValidCpu) { - std::memcpy(dest, cpu_buffer_, bytes()); - // Free CPU memory because next time AHWB is mapped instead. - free(cpu_buffer_); - cpu_buffer_ = nullptr; + // Reset OpenGL Buffer validness. The OpenGL buffer will be allocated on top + // of the Ahwb at the next request to the OpenGlBufferView. + valid_ &= ~kValidOpenGlBuffer; } else { LOG(FATAL) << "Can't convert tensor with mask " << valid_ << " into AHWB."; } @@ -446,6 +451,16 @@ void* Tensor::MapAhwbToCpuWrite() const { return nullptr; } +void Tensor::TrackAhwbUsage(uint64_t source_location_hash) const { + if (ahwb_tracking_key_ == 0) { + ahwb_tracking_key_ = source_location_hash; + for (int dim : shape_.dims) { + ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim); + } + } + use_ahwb_ = ahwb_usage_track_.contains(ahwb_tracking_key_); +} + #else // MEDIAPIPE_TENSOR_USE_AHWB bool Tensor::AllocateAhwbMapToSsbo() const { return false; } @@ -454,6 +469,7 @@ void Tensor::MoveAhwbStuff(Tensor* src) {} void Tensor::ReleaseAhwbStuff() {} void* Tensor::MapAhwbToCpuRead() const { return nullptr; } void* Tensor::MapAhwbToCpuWrite() const { return nullptr; } +void Tensor::TrackAhwbUsage(uint64_t key) const {} #endif // MEDIAPIPE_TENSOR_USE_AHWB diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc index 7ccd9c7f5..a6ca00949 100644 --- a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -152,6 +152,36 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { { auto view = tensor.GetAHardwareBufferReadView(); EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); + } + auto ptr = tensor.GetCpuReadView().buffer(); + EXPECT_NE(ptr, nullptr); + std::vector reference; + reference.resize(num_elements); + for (int i = 0; i < num_elements; i++) { + reference[i] = static_cast(i) / 10.0f; + } + EXPECT_THAT(absl::Span(ptr, num_elements), + testing::Pointwise(testing::FloatEq(), reference)); +} + +TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) { + // Request the GPU view to get the ssbo allocated internally. + // Request Ahwb view then to transform the storage into Ahwb. + Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault); + constexpr size_t num_elements = 20; + Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; + RunInGlContext([&tensor] { + auto ssbo_view = tensor.GetOpenGlBufferWriteView(); + auto ssbo_name = ssbo_view.name(); + EXPECT_GT(ssbo_name, 0); + FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), + tensor.element_type()); + }); + { + auto view = tensor.GetAHardwareBufferReadView(); + EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); } auto ptr = tensor.GetCpuReadView().buffer(); EXPECT_NE(ptr, nullptr); diff --git a/mediapipe/framework/formats/tensor_hardware_buffer.h b/mediapipe/framework/formats/tensor_hardware_buffer.h deleted file mode 100644 index fa0241bde..000000000 --- a/mediapipe/framework/formats/tensor_hardware_buffer.h +++ /dev/null @@ -1,71 +0,0 @@ -#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ -#define MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ - -#if !defined(MEDIAPIPE_NO_JNI) && \ - (__ANDROID_API__ >= 26 || \ - defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) - -#include - -#include - -#include "mediapipe/framework/formats/tensor_buffer.h" -#include "mediapipe/framework/formats/tensor_internal.h" -#include "mediapipe/framework/formats/tensor_v2.h" - -namespace mediapipe { - -// Supports: -// - float 16 and 32 bits -// - signed / unsigned integers 8,16,32 bits -class TensorHardwareBufferView; -struct TensorHardwareBufferViewDescriptor : public Tensor::ViewDescriptor { - using ViewT = TensorHardwareBufferView; - TensorBufferDescriptor buffer; -}; - -class TensorHardwareBufferView : public Tensor::View { - public: - TENSOR_UNIQUE_VIEW_TYPE_ID(); - ~TensorHardwareBufferView() = default; - - const TensorHardwareBufferViewDescriptor& descriptor() const override { - return descriptor_; - } - AHardwareBuffer* handle() const { return ahwb_handle_; } - - protected: - TensorHardwareBufferView(int access_capability, Tensor::View::Access access, - Tensor::View::State state, - const TensorHardwareBufferViewDescriptor& desc, - AHardwareBuffer* ahwb_handle) - : Tensor::View(kId, access_capability, access, state), - descriptor_(desc), - ahwb_handle_(ahwb_handle) {} - - private: - bool MatchDescriptor( - uint64_t view_type_id, - const Tensor::ViewDescriptor& base_descriptor) const override { - if (!Tensor::View::MatchDescriptor(view_type_id, base_descriptor)) - return false; - auto descriptor = - static_cast(base_descriptor); - return descriptor.buffer.format == descriptor_.buffer.format && - descriptor.buffer.size_alignment <= - descriptor_.buffer.size_alignment && - descriptor_.buffer.size_alignment % - descriptor.buffer.size_alignment == - 0; - } - const TensorHardwareBufferViewDescriptor& descriptor_; - AHardwareBuffer* ahwb_handle_ = nullptr; -}; - -} // namespace mediapipe - -#endif // !defined(MEDIAPIPE_NO_JNI) && \ - (__ANDROID_API__ >= 26 || \ - defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) - -#endif // MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_ diff --git a/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc deleted file mode 100644 index 9c223ce2c..000000000 --- a/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage.cc +++ /dev/null @@ -1,216 +0,0 @@ -#if !defined(MEDIAPIPE_NO_JNI) && \ - (__ANDROID_API__ >= 26 || \ - defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) - -#include - -#include "absl/memory/memory.h" -#include "absl/status/status.h" -#include "mediapipe/framework/formats/tensor_backend.h" -#include "mediapipe/framework/formats/tensor_cpu_buffer.h" -#include "mediapipe/framework/formats/tensor_hardware_buffer.h" -#include "mediapipe/framework/formats/tensor_v2.h" -#include "util/task/status_macros.h" - -namespace mediapipe { -namespace { - -class TensorCpuViewImpl : public TensorCpuView { - public: - TensorCpuViewImpl(int access_capabilities, Tensor::View::Access access, - Tensor::View::State state, - const TensorCpuViewDescriptor& descriptor, void* pointer, - AHardwareBuffer* ahwb_handle) - : TensorCpuView(access_capabilities, access, state, descriptor, pointer), - ahwb_handle_(ahwb_handle) {} - ~TensorCpuViewImpl() { - // If handle_ is null then this view is constructed in GetViews with no - // access. - if (ahwb_handle_) { - if (__builtin_available(android 26, *)) { - AHardwareBuffer_unlock(ahwb_handle_, nullptr); - } - } - } - - private: - AHardwareBuffer* ahwb_handle_; -}; - -class TensorHardwareBufferViewImpl : public TensorHardwareBufferView { - public: - TensorHardwareBufferViewImpl( - int access_capability, Tensor::View::Access access, - Tensor::View::State state, - const TensorHardwareBufferViewDescriptor& descriptor, - AHardwareBuffer* handle) - : TensorHardwareBufferView(access_capability, access, state, descriptor, - handle) {} - ~TensorHardwareBufferViewImpl() = default; -}; - -class HardwareBufferCpuStorage : public TensorStorage { - public: - ~HardwareBufferCpuStorage() { - if (!ahwb_handle_) return; - if (__builtin_available(android 26, *)) { - AHardwareBuffer_release(ahwb_handle_); - } - } - - static absl::Status CanProvide( - int access_capability, const Tensor::Shape& shape, uint64_t view_type_id, - const Tensor::ViewDescriptor& base_descriptor) { - // TODO: use AHardwareBuffer_isSupported for API >= 29. - static const bool is_ahwb_supported = [] { - if (__builtin_available(android 26, *)) { - AHardwareBuffer_Desc desc = {}; - // Aligned to the largest possible virtual memory page size. - constexpr uint32_t kPageSize = 16384; - desc.width = kPageSize; - desc.height = 1; - desc.layers = 1; - desc.format = AHARDWAREBUFFER_FORMAT_BLOB; - desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | - AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN; - AHardwareBuffer* handle; - if (AHardwareBuffer_allocate(&desc, &handle) != 0) return false; - AHardwareBuffer_release(handle); - return true; - } - return false; - }(); - if (!is_ahwb_supported) { - return absl::UnavailableError( - "AHardwareBuffer is not supported on the platform."); - } - - if (view_type_id != TensorCpuView::kId && - view_type_id != TensorHardwareBufferView::kId) { - return absl::InvalidArgumentError( - "A view type is not supported by this storage."); - } - return absl::OkStatus(); - } - - std::vector> GetViews(uint64_t latest_version) { - std::vector> result; - auto update_state = latest_version == version_ - ? Tensor::View::State::kUpToDate - : Tensor::View::State::kOutdated; - if (ahwb_handle_) { - result.push_back( - std::unique_ptr(new TensorHardwareBufferViewImpl( - kAccessCapability, Tensor::View::Access::kNoAccess, update_state, - hw_descriptor_, ahwb_handle_))); - - result.push_back(std::unique_ptr(new TensorCpuViewImpl( - kAccessCapability, Tensor::View::Access::kNoAccess, update_state, - cpu_descriptor_, nullptr, nullptr))); - } - return result; - } - - absl::StatusOr> GetView( - Tensor::View::Access access, const Tensor::Shape& shape, - uint64_t latest_version, uint64_t view_type_id, - const Tensor::ViewDescriptor& base_descriptor, int access_capability) { - MP_RETURN_IF_ERROR( - CanProvide(access_capability, shape, view_type_id, base_descriptor)); - const auto& buffer_descriptor = - view_type_id == TensorHardwareBufferView::kId - ? static_cast( - base_descriptor) - .buffer - : static_cast(base_descriptor) - .buffer; - if (!ahwb_handle_) { - if (__builtin_available(android 26, *)) { - AHardwareBuffer_Desc desc = {}; - desc.width = TensorBufferSize(buffer_descriptor, shape); - desc.height = 1; - desc.layers = 1; - desc.format = AHARDWAREBUFFER_FORMAT_BLOB; - // TODO: Use access capabilities to set hints. - desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | - AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN; - auto error = AHardwareBuffer_allocate(&desc, &ahwb_handle_); - if (error != 0) { - return absl::UnknownError( - absl::StrCat("Error allocating hardware buffer: ", error)); - } - // Fill all possible views to provide it as proto views. - hw_descriptor_.buffer = buffer_descriptor; - cpu_descriptor_.buffer = buffer_descriptor; - } - } - if (buffer_descriptor.format != hw_descriptor_.buffer.format || - buffer_descriptor.size_alignment > - hw_descriptor_.buffer.size_alignment || - hw_descriptor_.buffer.size_alignment % - buffer_descriptor.size_alignment > - 0) { - return absl::AlreadyExistsError( - "A view with different params is already allocated with this " - "storage"); - } - - absl::StatusOr> result; - if (view_type_id == TensorHardwareBufferView::kId) { - result = GetAhwbView(access, shape, base_descriptor); - } else { - result = GetCpuView(access, shape, base_descriptor); - } - if (result.ok()) version_ = latest_version; - return result; - } - - private: - absl::StatusOr> GetAhwbView( - Tensor::View::Access access, const Tensor::Shape& shape, - const Tensor::ViewDescriptor& base_descriptor) { - return std::unique_ptr(new TensorHardwareBufferViewImpl( - kAccessCapability, access, Tensor::View::State::kUpToDate, - hw_descriptor_, ahwb_handle_)); - } - - absl::StatusOr> GetCpuView( - Tensor::View::Access access, const Tensor::Shape& shape, - const Tensor::ViewDescriptor& base_descriptor) { - void* pointer = nullptr; - if (__builtin_available(android 26, *)) { - int error = - AHardwareBuffer_lock(ahwb_handle_, - access == Tensor::View::Access::kWriteOnly - ? AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN - : AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN, - -1, nullptr, &pointer); - if (error != 0) { - return absl::UnknownError( - absl::StrCat("Error locking hardware buffer: ", error)); - } - } - return std::unique_ptr( - new TensorCpuViewImpl(access == Tensor::View::Access::kWriteOnly - ? Tensor::View::AccessCapability::kWrite - : Tensor::View::AccessCapability::kRead, - access, Tensor::View::State::kUpToDate, - cpu_descriptor_, pointer, ahwb_handle_)); - } - - static constexpr int kAccessCapability = - Tensor::View::AccessCapability::kRead | - Tensor::View::AccessCapability::kWrite; - TensorHardwareBufferViewDescriptor hw_descriptor_; - AHardwareBuffer* ahwb_handle_ = nullptr; - - TensorCpuViewDescriptor cpu_descriptor_; - uint64_t version_ = 0; -}; -TENSOR_REGISTER_STORAGE(HardwareBufferCpuStorage); - -} // namespace -} // namespace mediapipe - -#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 || - // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) diff --git a/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc b/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc deleted file mode 100644 index 0afa9899f..000000000 --- a/mediapipe/framework/formats/tensor_hardware_buffer_cpu_storage_test.cc +++ /dev/null @@ -1,76 +0,0 @@ - -#if !defined(MEDIAPIPE_NO_JNI) && \ - (__ANDROID_API__ >= 26 || \ - defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) -#include - -#include - -#include "mediapipe/framework/formats/tensor_cpu_buffer.h" -#include "mediapipe/framework/formats/tensor_hardware_buffer.h" -#include "mediapipe/framework/formats/tensor_v2.h" -#include "testing/base/public/gmock.h" -#include "testing/base/public/gunit.h" - -namespace mediapipe { - -namespace { - -class TensorHardwareBufferTest : public ::testing::Test { - public: - TensorHardwareBufferTest() {} - ~TensorHardwareBufferTest() override {} -}; - -TEST_F(TensorHardwareBufferTest, TestFloat32) { - Tensor tensor{Tensor::Shape({1})}; - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorHardwareBufferViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat32}})); - EXPECT_NE(view->handle(), nullptr); - } - { - const auto& const_tensor = tensor; - MP_ASSERT_OK_AND_ASSIGN( - auto view, - const_tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = - TensorBufferDescriptor::Format::kFloat32}})); - EXPECT_NE(view->data(), nullptr); - } -} - -TEST_F(TensorHardwareBufferTest, TestInt8Padding) { - Tensor tensor{Tensor::Shape({1})}; - - { - MP_ASSERT_OK_AND_ASSIGN( - auto view, - tensor.GetView( - TensorHardwareBufferViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kInt8, - .size_alignment = 4}})); - EXPECT_NE(view->handle(), nullptr); - } - { - const auto& const_tensor = tensor; - MP_ASSERT_OK_AND_ASSIGN( - auto view, - const_tensor.GetView( - TensorCpuViewDescriptor{ - .buffer = {.format = TensorBufferDescriptor::Format::kInt8}})); - EXPECT_NE(view->data(), nullptr); - } -} - -} // namespace - -} // namespace mediapipe - -#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 || - // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) diff --git a/mediapipe/framework/formats/tensor_internal.h b/mediapipe/framework/formats/tensor_internal.h index 1231a991c..c223c5b1d 100644 --- a/mediapipe/framework/formats/tensor_internal.h +++ b/mediapipe/framework/formats/tensor_internal.h @@ -18,8 +18,6 @@ #include #include -#include "mediapipe/framework/tool/type_util.h" - namespace mediapipe { // Generates unique view id at compile-time using FILE and LINE. @@ -41,10 +39,12 @@ namespace tensor_internal { // https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function constexpr uint64_t kFnvPrime = 0x00000100000001B3; constexpr uint64_t kFnvOffsetBias = 0xcbf29ce484222325; -constexpr uint64_t FnvHash64(const char* str, uint64_t hash = kFnvOffsetBias) { - return (str[0] == 0) ? hash : FnvHash64(str + 1, (hash ^ str[0]) * kFnvPrime); +constexpr uint64_t FnvHash64(uint64_t value1, uint64_t value2) { + return (value2 ^ value1) * kFnvPrime; +} +constexpr uint64_t FnvHash64(const char* str, uint64_t hash = kFnvOffsetBias) { + return (str[0] == 0) ? hash : FnvHash64(str + 1, FnvHash64(hash, str[0])); } - template struct TypeList { static constexpr std::size_t size{sizeof...(Ts)}; diff --git a/mediapipe/framework/stream_handler/BUILD b/mediapipe/framework/stream_handler/BUILD index 01ef6ee86..68a9af52d 100644 --- a/mediapipe/framework/stream_handler/BUILD +++ b/mediapipe/framework/stream_handler/BUILD @@ -88,8 +88,8 @@ cc_library( srcs = ["default_input_stream_handler.cc"], hdrs = ["default_input_stream_handler.h"], deps = [ + ":default_input_stream_handler_cc_proto", "//mediapipe/framework:input_stream_handler", - "//mediapipe/framework/stream_handler:default_input_stream_handler_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -110,8 +110,8 @@ cc_library( srcs = ["fixed_size_input_stream_handler.cc"], deps = [ ":default_input_stream_handler", + ":fixed_size_input_stream_handler_cc_proto", "//mediapipe/framework:input_stream_handler", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", ], alwayslink = 1, ) @@ -159,13 +159,13 @@ cc_library( name = "sync_set_input_stream_handler", srcs = ["sync_set_input_stream_handler.cc"], deps = [ + ":sync_set_input_stream_handler_cc_proto", "//mediapipe/framework:collection", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:packet_set", "//mediapipe/framework:timestamp", - "//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto", "//mediapipe/framework/tool:tag_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -177,10 +177,10 @@ cc_library( name = "timestamp_align_input_stream_handler", srcs = ["timestamp_align_input_stream_handler.cc"], deps = [ + ":timestamp_align_input_stream_handler_cc_proto", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", "//mediapipe/framework:timestamp", - "//mediapipe/framework/stream_handler:timestamp_align_input_stream_handler_cc_proto", "//mediapipe/framework/tool:validate_name", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -243,6 +243,7 @@ cc_test( srcs = ["set_input_stream_handler_test.cc"], deps = [ ":fixed_size_input_stream_handler", + ":fixed_size_input_stream_handler_cc_proto", ":mux_input_stream_handler", "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", @@ -251,7 +252,6 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:ret_check", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", ], ) @@ -272,13 +272,13 @@ cc_test( srcs = ["fixed_size_input_stream_handler_test.cc"], deps = [ ":fixed_size_input_stream_handler", + ":fixed_size_input_stream_handler_cc_proto", "//mediapipe/calculators/core:counting_source_calculator", "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/synchronization", ], @@ -289,11 +289,11 @@ cc_test( srcs = ["sync_set_input_stream_handler_test.cc"], deps = [ ":sync_set_input_stream_handler", + ":sync_set_input_stream_handler_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:test_calculators", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 89cb802da..193343a90 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -299,6 +299,7 @@ mediapipe_cc_test( requires_full_emulation = False, deps = [ ":node_chain_subgraph_cc_proto", + ":node_chain_subgraph_options_lib", ":options_field_util", ":options_registry", ":options_syntax_util", @@ -313,7 +314,6 @@ mediapipe_cc_test( "//mediapipe/framework/port:status", "//mediapipe/framework/testdata:night_light_calculator_cc_proto", "//mediapipe/framework/testdata:night_light_calculator_options_lib", - "//mediapipe/framework/tool:node_chain_subgraph_options_lib", "//mediapipe/util:header_util", "@com_google_absl//absl/strings", ], @@ -422,9 +422,9 @@ cc_library( srcs = ["source.cc"], visibility = ["//visibility:public"], deps = [ + ":source_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:source_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", ], @@ -485,13 +485,13 @@ cc_library( hdrs = ["template_expander.h"], visibility = ["//visibility:public"], deps = [ + ":calculator_graph_template_cc_proto", ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:numbers", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "@com_google_absl//absl/strings", ], ) @@ -506,6 +506,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":calculator_graph_template_cc_proto", ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:proto_descriptor_cc_proto", @@ -515,7 +516,6 @@ cc_library( "//mediapipe/framework/port:map_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:calculator_graph_template_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -661,8 +661,8 @@ cc_library( hdrs = ["simulation_clock_executor.h"], visibility = ["//visibility:public"], deps = [ + ":simulation_clock", "//mediapipe/framework:thread_pool_executor", - "//mediapipe/framework/tool:simulation_clock", ], ) @@ -789,10 +789,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":name_util", + ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:switch_container_cc_proto", ], ) @@ -805,6 +805,7 @@ cc_library( deps = [ ":container_util", ":options_util", + ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework/deps:mathutil", @@ -814,7 +815,6 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", - "//mediapipe/framework/tool:switch_container_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -841,6 +841,7 @@ cc_library( ], deps = [ ":container_util", + ":switch_container_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_shard", @@ -850,7 +851,6 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", - "//mediapipe/framework/tool:switch_container_cc_proto", ], alwayslink = 1, ) @@ -893,6 +893,7 @@ cc_library( ":container_util", ":name_util", ":subgraph_expansion", + ":switch_container_cc_proto", ":switch_demux_calculator", ":switch_mux_calculator", "//mediapipe/calculators/core:packet_sequencer_calculator", @@ -904,7 +905,6 @@ cc_library( "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:switch_container_cc_proto", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 009eb3f9e..cc5e50dfc 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -564,6 +564,7 @@ cc_library( name = "gpu_shared_data_internal_stub", visibility = ["//visibility:private"], deps = [ + ":gl_context_options_cc_proto", ":graph_support", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", @@ -571,7 +572,6 @@ cc_library( "//mediapipe/framework:port", "//mediapipe/framework/deps:no_destructor", "//mediapipe/framework/port:ret_check", - "//mediapipe/gpu:gl_context_options_cc_proto", ], ) @@ -592,7 +592,7 @@ cc_library( }), visibility = ["//visibility:private"], deps = [ - "//mediapipe/gpu:gl_context_options_cc_proto", + ":gl_context_options_cc_proto", ":graph_support", "//mediapipe/framework:calculator_context", "//mediapipe/framework:executor", @@ -833,10 +833,10 @@ cc_library( deps = [ ":gl_base", ":gl_simple_shaders", + ":scale_mode_cc_proto", ":shader_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/gpu:scale_mode_cc_proto", ], ) @@ -907,8 +907,8 @@ proto_library( srcs = ["gl_scaler_calculator.proto"], visibility = ["//visibility:public"], deps = [ + ":scale_mode_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/gpu:scale_mode_proto", ], ) @@ -930,6 +930,7 @@ cc_library( deps = [ ":gl_calculator_helper", ":gl_quad_renderer", + ":gl_scaler_calculator_cc_proto", ":gl_simple_shaders", ":shader_util", "//mediapipe/framework:calculator_framework", @@ -937,7 +938,6 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_util", - "//mediapipe/gpu:gl_scaler_calculator_cc_proto", ], alwayslink = 1, ) @@ -950,13 +950,13 @@ cc_library( ":egl_surface_holder", ":gl_calculator_helper", ":gl_quad_renderer", + ":gl_surface_sink_calculator_cc_proto", ":gpu_buffer", ":shader_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/gpu:gl_surface_sink_calculator_cc_proto", "@com_google_absl//absl/synchronization", ], alwayslink = 1, @@ -966,8 +966,8 @@ proto_library( name = "gl_surface_sink_calculator_proto", srcs = ["gl_surface_sink_calculator.proto"], deps = [ + ":scale_mode_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/gpu:scale_mode_proto", ], ) diff --git a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java index 05700ba17..fc1e5484e 100644 --- a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java +++ b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java @@ -15,10 +15,13 @@ package com.google.mediapipe.framework; import android.graphics.Bitmap; +import android.graphics.PixelFormat; +import android.media.Image; import com.google.mediapipe.framework.image.BitmapExtractor; import com.google.mediapipe.framework.image.ByteBufferExtractor; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.framework.image.MPImageProperties; +import com.google.mediapipe.framework.image.MediaImageExtractor; import java.nio.ByteBuffer; // TODO: use Preconditions in this file. @@ -97,7 +100,17 @@ public class AndroidPacketCreator extends PacketCreator { } return Packet.create(nativeCreateRgbaImage(mediapipeGraph.getNativeHandle(), bitmap)); } - + if (properties.getStorageType() == MPImage.STORAGE_TYPE_MEDIA_IMAGE) { + Image mediaImage = MediaImageExtractor.extract(image); + if (mediaImage.getFormat() != PixelFormat.RGBA_8888) { + throw new UnsupportedOperationException("Android media image must use RGBA_8888 config."); + } + return createImage( + mediaImage.getPlanes()[0].getBuffer(), + mediaImage.getWidth(), + mediaImage.getHeight(), + /* numChannels= */ 4); + } // Unsupported type. throw new UnsupportedOperationException( "Unsupported Image container type: " + properties.getStorageType()); diff --git a/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java b/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java index efaec34a7..63ea7854b 100644 --- a/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java +++ b/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java @@ -14,6 +14,10 @@ package com.google.mediapipe.framework; +import com.google.common.flogger.FluentLogger; +import java.util.HashSet; +import java.util.Set; + /** * A {@link TextureFrame} that represents a texture produced by MediaPipe. * @@ -21,6 +25,7 @@ package com.google.mediapipe.framework; * method. */ public class GraphTextureFrame implements TextureFrame { + private static final FluentLogger logger = FluentLogger.forEnclosingClass(); private long nativeBufferHandle; // We cache these to be able to get them without a JNI call. private int textureName; @@ -30,6 +35,8 @@ public class GraphTextureFrame implements TextureFrame { // True when created with PacketGetter.getTextureFrameDeferredSync(). This will result in gpuWait // when calling getTextureName(). private final boolean deferredSync; + private final Set activeConsumerContextHandleSet = new HashSet<>(); + private int refCount = 1; GraphTextureFrame(long nativeHandle, long timestamp) { this(nativeHandle, timestamp, false); @@ -54,17 +61,19 @@ public class GraphTextureFrame implements TextureFrame { * condition if release() is called after the if-check for nativeBufferHandle is already passed. */ @Override - public int getTextureName() { + public synchronized int getTextureName() { // Return special texture id 0 if handle is 0 i.e. frame is already released. if (nativeBufferHandle == 0) { return 0; } - // Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using - // PacketGetter.getTextureFrameDeferredSync(). - if (deferredSync) { - // Note that, if a CPU wait has already been done, the sync point will have been - // cleared and this will turn into a no-op. See GlFenceSyncPoint::Wait. - nativeGpuWait(nativeBufferHandle); + if (activeConsumerContextHandleSet.add(nativeGetCurrentExternalContextHandle())) { + // Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using + // PacketGetter.getTextureFrameDeferredSync(). + if (deferredSync) { + // Note that, if a CPU wait has already been done, the sync point will have been + // cleared and this will turn into a no-op. See GlFenceSyncPoint::Wait. + nativeGpuWait(nativeBufferHandle); + } } return textureName; } @@ -86,15 +95,31 @@ public class GraphTextureFrame implements TextureFrame { return timestamp; } + @Override + public boolean supportsRetain() { + return true; + } + + @Override + public synchronized void retain() { + // TODO: check that refCount is > 0 and handle is not 0. + refCount++; + } + /** * Releases a reference to the underlying buffer. * *

The consumer calls this when it is done using the texture. */ @Override - public void release() { - GlSyncToken consumerToken = - new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle)); + public synchronized void release() { + GlSyncToken consumerToken = null; + // Note that this remove should be moved to the other overload of release when b/68808951 is + // addressed. + if (activeConsumerContextHandleSet.remove(nativeGetCurrentExternalContextHandle())) { + consumerToken = + new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle)); + } release(consumerToken); } @@ -108,18 +133,40 @@ public class GraphTextureFrame implements TextureFrame { * currently cannot create a GlSyncToken, so they cannot call this method. */ @Override - public void release(GlSyncToken consumerSyncToken) { - if (nativeBufferHandle != 0) { - long token = consumerSyncToken == null ? 0 : consumerSyncToken.nativeToken(); - nativeReleaseBuffer(nativeBufferHandle, token); - nativeBufferHandle = 0; + public synchronized void release(GlSyncToken consumerSyncToken) { + if (nativeBufferHandle == 0) { + if (consumerSyncToken != null) { + logger.atWarning().log("release with sync token, but handle is 0"); + } + return; } + if (consumerSyncToken != null) { + long token = consumerSyncToken.nativeToken(); + nativeDidRead(nativeBufferHandle, token); + // We should remove the token's context from activeConsumerContextHandleSet here, but for now + // we do it in the release(void) overload. consumerSyncToken.release(); } + + refCount--; + if (refCount <= 0) { + nativeReleaseBuffer(nativeBufferHandle); + nativeBufferHandle = 0; + } } - private native void nativeReleaseBuffer(long nativeHandle, long consumerSyncToken); + @Override + protected void finalize() throws Throwable { + if (refCount >= 0 || nativeBufferHandle != 0) { + logger.atWarning().log("release was not called before finalize"); + } + if (!activeConsumerContextHandleSet.isEmpty()) { + logger.atWarning().log("active consumers did not release with sync before finalize"); + } + } + + private native void nativeReleaseBuffer(long nativeHandle); private native int nativeGetTextureName(long nativeHandle); private native int nativeGetWidth(long nativeHandle); @@ -128,4 +175,8 @@ public class GraphTextureFrame implements TextureFrame { private native void nativeGpuWait(long nativeHandle); private native long nativeCreateSyncTokenForCurrentExternalContext(long nativeHandle); + + private native long nativeGetCurrentExternalContextHandle(); + + private native void nativeDidRead(long nativeHandle, long consumerSyncToken); } diff --git a/mediapipe/java/com/google/mediapipe/framework/TextureFrame.java b/mediapipe/java/com/google/mediapipe/framework/TextureFrame.java index babfd2958..76eaf39df 100644 --- a/mediapipe/java/com/google/mediapipe/framework/TextureFrame.java +++ b/mediapipe/java/com/google/mediapipe/framework/TextureFrame.java @@ -59,4 +59,18 @@ public interface TextureFrame extends TextureReleaseCallback { */ @Override void release(GlSyncToken syncToken); + + /** + * If this method returns true, this object supports the retain method, and can be used with + * multiple consumers. Call retain for each additional consumer beyond the first; each consumer + * should call release. + */ + default boolean supportsRetain() { + return false; + } + + /** Increments the reference count. Only available with some implementations of TextureFrame. */ + default void retain() { + throw new UnsupportedOperationException(); + } } diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc index 84df89260..dd99cccd4 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc @@ -15,20 +15,16 @@ #include "mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h" #include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gl_context.h" #include "mediapipe/gpu/gl_texture_buffer.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h" using mediapipe::GlTextureBufferSharedPtr; JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeReleaseBuffer)( - JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken) { + JNIEnv* env, jobject thiz, jlong nativeHandle) { GlTextureBufferSharedPtr* buffer = reinterpret_cast(nativeHandle); - if (consumerSyncToken) { - mediapipe::GlSyncToken& token = - *reinterpret_cast(consumerSyncToken); - (*buffer)->DidRead(token); - } delete buffer; } @@ -84,3 +80,18 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( } return reinterpret_cast(token); } + +JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( + nativeGetCurrentExternalContextHandle)(JNIEnv* env, jobject thiz) { + return reinterpret_cast( + mediapipe::GlContext::GetCurrentNativeContext()); +} + +JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeDidRead)( + JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken) { + GlTextureBufferSharedPtr* buffer = + reinterpret_cast(nativeHandle); + mediapipe::GlSyncToken& token = + *reinterpret_cast(consumerSyncToken); + (*buffer)->DidRead(token); +} diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h index 45637bb31..41c531fff 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h @@ -26,7 +26,7 @@ extern "C" { // Releases a native mediapipe::GpuBuffer. JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeReleaseBuffer)( - JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken); + JNIEnv* env, jobject thiz, jlong nativeHandle); JNIEXPORT jint JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeGetTextureName)( JNIEnv* env, jobject thiz, jlong nativeHandle); @@ -44,6 +44,12 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( nativeCreateSyncTokenForCurrentExternalContext)(JNIEnv* env, jobject thiz, jlong nativeHandle); +JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeDidRead)( + JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken); + +JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( + nativeGetCurrentExternalContextHandle)(JNIEnv* env, jobject thiz); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/mediapipe/model_maker/__init__.py b/mediapipe/model_maker/__init__.py index 9899a145b..b37088764 100644 --- a/mediapipe/model_maker/__init__.py +++ b/mediapipe/model_maker/__init__.py @@ -17,3 +17,6 @@ from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.vision import image_classifier from mediapipe.model_maker.python.vision import gesture_recognizer from mediapipe.model_maker.python.text import text_classifier + +# Remove duplicated and non-public API +del python diff --git a/mediapipe/model_maker/python/text/text_classifier/__init__.py b/mediapipe/model_maker/python/text/text_classifier/__init__.py index 618e51645..697461969 100644 --- a/mediapipe/model_maker/python/text/text_classifier/__init__.py +++ b/mediapipe/model_maker/python/text/text_classifier/__init__.py @@ -29,3 +29,12 @@ BertModelOptions = model_options.BertModelOptions SupportedModels = model_spec.SupportedModels TextClassifier = text_classifier.TextClassifier TextClassifierOptions = text_classifier_options.TextClassifierOptions + +# Remove duplicated and non-public API +del hyperparameters +del dataset +del model_options +del model_spec +del preprocessor # pylint: disable=undefined-variable +del text_classifier +del text_classifier_options diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index c285702d2..1a338e345 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -33,7 +33,6 @@ from mediapipe.model_maker.python.text.text_classifier import preprocessor from mediapipe.model_maker.python.text.text_classifier import text_classifier_options from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer from mediapipe.tasks.python.metadata.metadata_writers import text_classifier as text_classifier_writer -from official.nlp import optimization def _validate(options: text_classifier_options.TextClassifierOptions): @@ -417,8 +416,22 @@ class _BertClassifier(TextClassifier): total_steps = self._hparams.steps_per_epoch * self._hparams.epochs warmup_steps = int(total_steps * 0.1) initial_lr = self._hparams.learning_rate - self._optimizer = optimization.create_optimizer(initial_lr, total_steps, - warmup_steps) + # Implements linear decay of the learning rate. + lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=initial_lr, + decay_steps=total_steps, + end_learning_rate=0.0, + power=1.0) + if warmup_steps: + lr_schedule = model_util.WarmUp( + initial_learning_rate=initial_lr, + decay_schedule_fn=lr_schedule, + warmup_steps=warmup_steps) + + self._optimizer = tf.keras.optimizers.experimental.AdamW( + lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0) + self._optimizer.exclude_from_weight_decay( + var_names=["LayerNorm", "layer_norm", "bias"]) def _save_vocab(self, vocab_filepath: str): tf.io.gfile.copy( diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index 9123e36b0..cbdff7cf3 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -146,6 +146,8 @@ py_test( tags = ["notsan"], deps = [ ":gesture_recognizer_import", + ":hyperparameters", + ":model_options", "//mediapipe/model_maker/python/core/utils:test_util", "//mediapipe/tasks/python/test:test_utils", ], diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py b/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py index dc6923fac..a302e8d79 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py @@ -25,3 +25,12 @@ HParams = hyperparameters.HParams Dataset = dataset.Dataset HandDataPreprocessingParams = dataset.HandDataPreprocessingParams GestureRecognizerOptions = gesture_recognizer_options.GestureRecognizerOptions + +# Remove duplicated and non-public API +del constants # pylint: disable=undefined-variable +del dataset +del gesture_recognizer +del gesture_recognizer_options +del hyperparameters +del metadata_writer # pylint: disable=undefined-variable +del model_options diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py index f297d8640..556d2fcd7 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py @@ -173,15 +173,20 @@ class GestureRecognizer(classifier.Classifier): batch_size=None, dtype=tf.float32, name='hand_embedding') - - x = tf.keras.layers.BatchNormalization()(inputs) - x = tf.keras.layers.ReLU()(x) + x = inputs dropout_rate = self._model_options.dropout_rate - x = tf.keras.layers.Dropout(rate=dropout_rate, name='dropout')(x) + for i, width in enumerate(self._model_options.layer_widths): + x = tf.keras.layers.BatchNormalization()(x) + x = tf.keras.layers.ReLU()(x) + x = tf.keras.layers.Dropout(rate=dropout_rate)(x) + x = tf.keras.layers.Dense(width, name=f'custom_gesture_recognizer_{i}')(x) + x = tf.keras.layers.BatchNormalization()(x) + x = tf.keras.layers.ReLU()(x) + x = tf.keras.layers.Dropout(rate=dropout_rate)(x) outputs = tf.keras.layers.Dense( self._num_classes, activation='softmax', - name='custom_gesture_recognizer')( + name='custom_gesture_recognizer_out')( x) self._model = tf.keras.Model(inputs=inputs, outputs=outputs) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 280fc6a82..4fdb74225 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -23,6 +23,8 @@ import tensorflow as tf from mediapipe.model_maker.python.core.utils import test_util from mediapipe.model_maker.python.vision import gesture_recognizer +from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters +from mediapipe.model_maker.python.vision.gesture_recognizer import model_options from mediapipe.tasks.python.test import test_utils _TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdata' @@ -48,11 +50,11 @@ class GestureRecognizerTest(tf.test.TestCase): self._train_data, self._validation_data = all_data.split(0.9) def test_gesture_recognizer_model(self): - model_options = gesture_recognizer.ModelOptions() + mo = gesture_recognizer.ModelOptions() hparams = gesture_recognizer.HParams( export_dir=tempfile.mkdtemp(), epochs=2) gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( - model_options=model_options, hparams=hparams) + model_options=mo, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, validation_data=self._validation_data, @@ -60,12 +62,38 @@ class GestureRecognizerTest(tf.test.TestCase): self._test_accuracy(model) - def test_export_gesture_recognizer_model(self): - model_options = gesture_recognizer.ModelOptions() + @unittest_mock.patch.object( + tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense) + def test_gesture_recognizer_model_layer_widths(self, mock_dense): + layer_widths = [64, 32] + mo = gesture_recognizer.ModelOptions(layer_widths=layer_widths) hparams = gesture_recognizer.HParams( export_dir=tempfile.mkdtemp(), epochs=2) gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( - model_options=model_options, hparams=hparams) + model_options=mo, hparams=hparams) + model = gesture_recognizer.GestureRecognizer.create( + train_data=self._train_data, + validation_data=self._validation_data, + options=gesture_recognizer_options) + expected_calls = [ + unittest_mock.call(w, name=f'custom_gesture_recognizer_{i}') + for i, w in enumerate(layer_widths) + ] + expected_calls.append( + unittest_mock.call( + len(self._train_data.label_names), + activation='softmax', + name='custom_gesture_recognizer_out')) + self.assertLen(mock_dense.call_args_list, len(expected_calls)) + mock_dense.assert_has_calls(expected_calls) + self._test_accuracy(model) + + def test_export_gesture_recognizer_model(self): + mo = gesture_recognizer.ModelOptions() + hparams = gesture_recognizer.HParams( + export_dir=tempfile.mkdtemp(), epochs=2) + gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( + model_options=mo, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, validation_data=self._validation_data, @@ -102,12 +130,12 @@ class GestureRecognizerTest(tf.test.TestCase): self.assertGreater(accuracy, threshold) @unittest_mock.patch.object( - gesture_recognizer.hyperparameters, + hyperparameters, 'HParams', autospec=True, return_value=gesture_recognizer.HParams(epochs=1)) @unittest_mock.patch.object( - gesture_recognizer.model_options, + model_options, 'GestureRecognizerModelOptions', autospec=True, return_value=gesture_recognizer.ModelOptions()) @@ -122,11 +150,11 @@ class GestureRecognizerTest(tf.test.TestCase): mock_model_options.assert_called_once() def test_continual_training_by_loading_checkpoint(self): - model_options = gesture_recognizer.ModelOptions() + mo = gesture_recognizer.ModelOptions() hparams = gesture_recognizer.HParams( export_dir=tempfile.mkdtemp(), epochs=2) gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( - model_options=model_options, hparams=hparams) + model_options=mo, hparams=hparams) mock_stdout = io.StringIO() with mock.patch('sys.stdout', mock_stdout): model = gesture_recognizer.GestureRecognizer.create( diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py b/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py index 79a84c792..1870437d4 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py @@ -14,6 +14,7 @@ """Configurable model options for gesture recognizer models.""" import dataclasses +from typing import List @dataclasses.dataclass @@ -23,5 +24,10 @@ class GestureRecognizerModelOptions: Attributes: dropout_rate: The fraction of the input units to drop, used in dropout layer. + layer_widths: A list of hidden layer widths for the gesture model. Each + element in the list will create a new hidden layer with the specified + width. The hidden layers are separated with BatchNorm, Dropout, and ReLU. + Defaults to an empty list(no hidden layers). """ dropout_rate: float = 0.05 + layer_widths: List[int] = dataclasses.field(default_factory=list) diff --git a/mediapipe/model_maker/python/vision/image_classifier/BUILD b/mediapipe/model_maker/python/vision/image_classifier/BUILD index 29ae189e9..d7c47a359 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/BUILD +++ b/mediapipe/model_maker/python/vision/image_classifier/BUILD @@ -121,7 +121,9 @@ py_library( srcs = ["image_classifier_test.py"], data = ["//mediapipe/model_maker/python/vision/image_classifier/testdata"], deps = [ + ":hyperparameters", ":image_classifier_import", + ":model_options", "//mediapipe/tasks/python/test:test_utils", ], ) diff --git a/mediapipe/model_maker/python/vision/image_classifier/__init__.py b/mediapipe/model_maker/python/vision/image_classifier/__init__.py index 3d0543cd2..0f964ef66 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/__init__.py +++ b/mediapipe/model_maker/python/vision/image_classifier/__init__.py @@ -27,3 +27,12 @@ ModelOptions = model_options.ImageClassifierModelOptions ModelSpec = model_spec.ModelSpec SupportedModels = model_spec.SupportedModels ImageClassifierOptions = image_classifier_options.ImageClassifierOptions + +# Remove duplicated and non-public API +del dataset +del hyperparameters +del image_classifier +del image_classifier_options +del model_options +del model_spec +del train_image_classifier_lib # pylint: disable=undefined-variable diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py index 252659edc..6ca21d334 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -24,6 +24,8 @@ import numpy as np import tensorflow as tf from mediapipe.model_maker.python.vision import image_classifier +from mediapipe.model_maker.python.vision.image_classifier import hyperparameters +from mediapipe.model_maker.python.vision.image_classifier import model_options from mediapipe.tasks.python.test import test_utils @@ -159,15 +161,15 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): self.assertGreaterEqual(accuracy, threshold) @unittest_mock.patch.object( - image_classifier.hyperparameters, + hyperparameters, 'HParams', autospec=True, - return_value=image_classifier.HParams(epochs=1)) + return_value=hyperparameters.HParams(epochs=1)) @unittest_mock.patch.object( - image_classifier.model_options, + model_options, 'ImageClassifierModelOptions', autospec=True, - return_value=image_classifier.ModelOptions()) + return_value=model_options.ImageClassifierModelOptions()) def test_create_hparams_and_model_options_if_none_in_image_classifier_options( self, mock_hparams, mock_model_options): options = image_classifier.ImageClassifierOptions( diff --git a/mediapipe/model_maker/requirements.txt b/mediapipe/model_maker/requirements.txt index 9b3c9f906..d7e4a950f 100644 --- a/mediapipe/model_maker/requirements.txt +++ b/mediapipe/model_maker/requirements.txt @@ -1,5 +1,5 @@ absl-py -mediapipe==0.9.1 +mediapipe==0.9.0.1 numpy opencv-python tensorflow>=2.10 diff --git a/mediapipe/python/image_test.py b/mediapipe/python/image_test.py index 117d20974..cd9124948 100644 --- a/mediapipe/python/image_test.py +++ b/mediapipe/python/image_test.py @@ -28,6 +28,8 @@ import PIL.Image from mediapipe.python._framework_bindings import image from mediapipe.python._framework_bindings import image_frame +TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata' + Image = image.Image ImageFormat = image_frame.ImageFormat @@ -187,5 +189,26 @@ class ImageTest(absltest.TestCase): gc.collect() self.assertEqual(sys.getrefcount(rgb_image), initial_ref_count) + def test_image_create_from_cvmat(self): + image_path = os.path.join(os.path.dirname(__file__), + 'solutions/testdata/hands.jpg') + mat = cv2.imread(image_path).astype(np.uint8) + mat = cv2.cvtColor(mat, cv2.COLOR_BGR2RGB) + rgb_image = Image(image_format=ImageFormat.SRGB, data=mat) + self.assertEqual(rgb_image.width, 720) + self.assertEqual(rgb_image.height, 382) + self.assertEqual(rgb_image.channels, 3) + self.assertEqual(rgb_image.image_format, ImageFormat.SRGB) + self.assertTrue(np.array_equal(mat, rgb_image.numpy_view())) + + def test_image_create_from_file(self): + image_path = os.path.join(os.path.dirname(__file__), + 'solutions/testdata/hands.jpg') + loaded_image = Image.create_from_file(image_path) + self.assertEqual(loaded_image.width, 720) + self.assertEqual(loaded_image.height, 382) + self.assertEqual(loaded_image.channels, 3) + self.assertEqual(loaded_image.image_format, ImageFormat.SRGB) + if __name__ == '__main__': absltest.main() diff --git a/mediapipe/python/packet_test.py b/mediapipe/python/packet_test.py index e1a4c12af..16fc37c87 100644 --- a/mediapipe/python/packet_test.py +++ b/mediapipe/python/packet_test.py @@ -157,7 +157,7 @@ class PacketTest(absltest.TestCase): p.timestamp = 0 self.assertAlmostEqual(packet_getter.get_float(p), 0.42) self.assertEqual(p.timestamp, 0) - p2 = packet_creator.create_float(np.float(0.42)) + p2 = packet_creator.create_float(float(0.42)) p2.timestamp = 0 self.assertAlmostEqual(packet_getter.get_float(p2), 0.42) self.assertEqual(p2.timestamp, 0) diff --git a/mediapipe/python/pybind/image.cc b/mediapipe/python/pybind/image.cc index 5d8663143..1bcca12ff 100644 --- a/mediapipe/python/pybind/image.cc +++ b/mediapipe/python/pybind/image.cc @@ -48,16 +48,20 @@ void ImageSubmodule(pybind11::module* module) { become immutable after creation. Creation examples: - import cv2 - cv_mat = cv2.imread(input_file)[:, :, ::-1] - rgb_frame = mp.Image(format=ImageFormat.SRGB, data=cv_mat) - gray_frame = mp.Image( - format=ImageFormat.GRAY, data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) - from PIL import Image - pil_img = Image.new('RGB', (60, 30), color = 'red') - image = mp.Image( - format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) + ```python + import cv2 + cv_mat = cv2.imread(input_file)[:, :, ::-1] + rgb_frame = mp.Image(image_format=ImageFormat.SRGB, data=cv_mat) + gray_frame = mp.Image( + image_format=ImageFormat.GRAY, + data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) + + from PIL import Image + pil_img = Image.new('RGB', (60, 30), color = 'red') + image = mp.Image( + image_format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) + ``` The pixel data in an Image can be retrieved as a numpy ndarray by calling `Image.numpy_view()`. The returned numpy ndarray is a reference to the @@ -65,15 +69,18 @@ void ImageSubmodule(pybind11::module* module) { numpy ndarray, it's required to obtain a copy of it. Pixel data retrieval examples: - for channel in range(num_channel): - for col in range(width): - for row in range(height): - print(image[row, col, channel]) - output_ndarray = image.numpy_view() - print(output_ndarray[0, 0, 0]) - copied_ndarray = np.copy(output_ndarray) - copied_ndarray[0,0,0] = 0 + ```python + for channel in range(num_channel): + for col in range(width): + for row in range(height): + print(image[row, col, channel]) + + output_ndarray = image.numpy_view() + print(output_ndarray[0, 0, 0]) + copied_ndarray = np.copy(output_ndarray) + copied_ndarray[0,0,0] = 0 + ``` )doc", py::dynamic_attr()); @@ -156,9 +163,11 @@ void ImageSubmodule(pybind11::module* module) { An unwritable numpy ndarray. Examples: + ``` output_ndarray = image.numpy_view() copied_ndarray = np.copy(output_ndarray) copied_ndarray[0,0,0] = 0 + ``` )doc"); image.def( @@ -191,10 +200,12 @@ void ImageSubmodule(pybind11::module* module) { IndexError: If the index is invalid or out of bounds. Examples: + ``` for channel in range(num_channel): for col in range(width): for row in range(height): print(image[row, col, channel]) + ``` )doc"); image @@ -224,7 +235,9 @@ void ImageSubmodule(pybind11::module* module) { A boolean. Examples: + ``` image.is_aligned(16) + ``` )doc"); image.def_static( diff --git a/mediapipe/python/pybind/image_frame.cc b/mediapipe/python/pybind/image_frame.cc index a7fc6bfe4..bc7a9753d 100644 --- a/mediapipe/python/pybind/image_frame.cc +++ b/mediapipe/python/pybind/image_frame.cc @@ -83,14 +83,15 @@ void ImageFrameSubmodule(pybind11::module* module) { Creation examples: import cv2 cv_mat = cv2.imread(input_file)[:, :, ::-1] - rgb_frame = mp.ImageFrame(format=ImageFormat.SRGB, data=cv_mat) + rgb_frame = mp.ImageFrame(image_format=ImageFormat.SRGB, data=cv_mat) gray_frame = mp.ImageFrame( - format=ImageFormat.GRAY, data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) + image_format=ImageFormat.GRAY, + data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY)) from PIL import Image pil_img = Image.new('RGB', (60, 30), color = 'red') image_frame = mp.ImageFrame( - format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) + image_format=mp.ImageFormat.SRGB, data=np.asarray(pil_img)) The pixel data in an ImageFrame can be retrieved as a numpy ndarray by calling `ImageFrame.numpy_view()`. The returned numpy ndarray is a reference to the diff --git a/mediapipe/tasks/ios/common/BUILD b/mediapipe/tasks/ios/common/BUILD index 0d00c423f..5f13f8d5c 100644 --- a/mediapipe/tasks/ios/common/BUILD +++ b/mediapipe/tasks/ios/common/BUILD @@ -23,4 +23,3 @@ objc_library( ], module_name = "MPPCommon", ) - diff --git a/mediapipe/tasks/ios/common/sources/MPPCommon.h b/mediapipe/tasks/ios/common/sources/MPPCommon.h index 1f450370e..7ce791d12 100644 --- a/mediapipe/tasks/ios/common/sources/MPPCommon.h +++ b/mediapipe/tasks/ios/common/sources/MPPCommon.h @@ -1,25 +1,25 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import NS_ASSUME_NONNULL_BEGIN /** - * @enum TFLSupportErrorCode - * This enum specifies error codes for TensorFlow Lite Task Library. - * It maintains a 1:1 mapping to TfLiteSupportErrorCode of C libray. + * @enum MPPTasksErrorCode + * This enum specifies error codes for Mediapipe Task Library. + * It maintains a 1:1 mapping to MediaPipeTasksStatus of the C ++libray. */ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { @@ -48,16 +48,16 @@ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { MPPTasksErrorCodeFileReadError, // I/O error when mmap-ing file. MPPTasksErrorCodeFileMmapError, - // ZIP I/O error when unpacMPPTasksErrorCodeing the zip file. + // ZIP I/O error when unpacking the zip file. MPPTasksErrorCodeFileZipError, // TensorFlow Lite metadata error codes. - // Unexpected schema version (aMPPTasksErrorCodea file_identifier) in the Metadata FlatBuffer. + // Unexpected schema version (aka file_identifier) in the Metadata FlatBuffer. MPPTasksErrorCodeMetadataInvalidSchemaVersionError = 200, - // No such associated file within metadata, or file has not been pacMPPTasksErrorCodeed. + // No such associated file within metadata, or file has not been packed. MPPTasksErrorCodeMetadataAssociatedFileNotFoundError, - // ZIP I/O error when unpacMPPTasksErrorCodeing an associated file. + // ZIP I/O error when unpacking an associated file. MPPTasksErrorCodeMetadataAssociatedFileZipError, // Inconsistency error between the metadata and actual TF Lite model. // E.g.: number of labels and output tensor values differ. @@ -167,11 +167,10 @@ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) { // Task graph config is invalid. MPPTasksErrorCodeInvalidTaskGraphConfigError, + // The first error code in MPPTasksErrorCode (for internal use only). MPPTasksErrorCodeFirst = MPPTasksErrorCodeError, - /** - * The last error code in TFLSupportErrorCode (for internal use only). - */ + // The last error code in MPPTasksErrorCode (for internal use only). MPPTasksErrorCodeLast = MPPTasksErrorCodeInvalidTaskGraphConfigError, } NS_SWIFT_NAME(TasksErrorCode); diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h index 8a90856c7..407d87aba 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h @@ -24,7 +24,7 @@ extern NSString *const MPPTasksErrorDomain; @interface MPPCommonUtils : NSObject /** - * Creates and saves an NSError in the Mediapipe task library domain, with the given code and + * Creates and saves an NSError in the MediPipe task library domain, with the given code and * description. * * @param code Error code. @@ -51,9 +51,9 @@ extern NSString *const MPPTasksErrorDomain; description:(NSString *)description; /** - * Converts an absl status to an NSError. + * Converts an absl::Status to an NSError. * - * @param status absl status. + * @param status absl::Status. * @param error Pointer to the memory location where the created error should be saved. If `nil`, * no error will be saved. */ @@ -61,15 +61,15 @@ extern NSString *const MPPTasksErrorDomain; /** * Allocates a block of memory with the specified size and returns a pointer to it. If memory - * cannot be allocated because of an invalid memSize, it saves an error. In other cases, it + * cannot be allocated because of an invalid `memSize`, it saves an error. In other cases, it * terminates program execution. * * @param memSize size of memory to be allocated * @param error Pointer to the memory location where errors if any should be saved. If `nil`, no * error will be saved. * - * @return Pointer to the allocated block of memory on successfull allocation. nil in case as - * error is encountered because of invalid memSize. If failure is due to any other reason, method + * @return Pointer to the allocated block of memory on successfull allocation. `nil` in case as + * error is encountered because of invalid `memSize`. If failure is due to any other reason, method * terminates program execution. */ + (void *)mallocWithSize:(size_t)memSize error:(NSError **)error; diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm index 574f2ef9a..4d4880a87 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.mm @@ -24,7 +24,7 @@ #include "mediapipe/tasks/cc/common.h" /** Error domain of MediaPipe task library errors. */ -NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks"; +NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks"; @implementation MPPCommonUtils @@ -68,7 +68,7 @@ NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks"; if (status.ok()) { return YES; } - // Payload of absl::Status created by the Media Pipe task library stores an appropriate value of + // Payload of absl::Status created by the MediaPipe task library stores an appropriate value of // the enum MediaPipeTasksStatus. The integer value corresponding to the MediaPipeTasksStatus enum // stored in the payload is extracted here to later map to the appropriate error code to be // returned. In cases where the enum is not stored in (payload is NULL or the payload string diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h index 8c4981642..7bf5744f7 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.h @@ -17,25 +17,38 @@ NS_ASSUME_NONNULL_BEGIN /** - * Holds settings for any single iOS Mediapipe classification task. + * Holds settings for any single iOS MediaPipe classification task. */ NS_SWIFT_NAME(ClassifierOptions) @interface MPPClassifierOptions : NSObject -/** If set, all classes in this list will be filtered out from the results . */ -@property(nonatomic, copy) NSArray *labelDenyList; - -/** If set, all classes not in this list will be filtered out from the results . */ -@property(nonatomic, copy) NSArray *labelAllowList; - -/** Display names local for display names*/ +/** The locale to use for display names specified through the TFLite Model + * Metadata, if any. Defaults to English. + */ @property(nonatomic, copy) NSString *displayNamesLocale; -/** Results with score threshold greater than this value are returned . */ +/** The maximum number of top-scored classification results to return. If < 0, + * all available results will be returned. If 0, an invalid argument error is + * returned. + */ +@property(nonatomic) NSInteger maxResults; + +/** Score threshold to override the one provided in the model metadata (if any). + * Results below this value are rejected. + */ @property(nonatomic) float scoreThreshold; -/** Limit to the number of classes that can be returned in results. */ -@property(nonatomic) NSInteger maxResults; +/** The allowlist of category names. If non-empty, detection results whose + * category name is not in this set will be filtered out. Duplicate or unknown + * category names are ignored. Mutually exclusive with categoryDenylist. + */ +@property(nonatomic, copy) NSArray *categoryAllowlist; + +/** The denylist of category names. If non-empty, detection results whose + * category name is in this set will be filtered out. Duplicate or unknown + * category names are ignored. Mutually exclusive with categoryAllowlist. + */ +@property(nonatomic, copy) NSArray *categoryDenylist; @end diff --git a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m index 52dce23e4..accb6c7dd 100644 --- a/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m +++ b/mediapipe/tasks/ios/components/processors/sources/MPPClassifierOptions.m @@ -30,8 +30,8 @@ classifierOptions.scoreThreshold = self.scoreThreshold; classifierOptions.maxResults = self.maxResults; - classifierOptions.labelDenyList = self.labelDenyList; - classifierOptions.labelAllowList = self.labelAllowList; + classifierOptions.categoryDenylist = self.categoryDenylist; + classifierOptions.categoryAllowlist = self.categoryAllowlist; classifierOptions.displayNamesLocale = self.displayNamesLocale; return classifierOptions; diff --git a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm index 25e657599..efe9572e1 100644 --- a/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm +++ b/mediapipe/tasks/ios/components/processors/utils/sources/MPPClassifierOptions+Helpers.mm @@ -20,17 +20,23 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto } @implementation MPPClassifierOptions (Helpers) + - (void)copyToProto:(ClassifierOptionsProto *)classifierOptionsProto { + classifierOptionsProto->Clear(); + if (self.displayNamesLocale) { classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); } + classifierOptionsProto->set_max_results((int)self.maxResults); + classifierOptionsProto->set_score_threshold(self.scoreThreshold); - for (NSString *category in self.labelAllowList) { + + for (NSString *category in self.categoryAllowlist) { classifierOptionsProto->add_category_allowlist(category.cppString); } - for (NSString *category in self.labelDenyList) { + for (NSString *category in self.categoryDenylist) { classifierOptionsProto->add_category_denylist(category.cppString); } } diff --git a/mediapipe/tasks/ios/core/BUILD b/mediapipe/tasks/ios/core/BUILD index 58f7389ac..adc37d901 100644 --- a/mediapipe/tasks/ios/core/BUILD +++ b/mediapipe/tasks/ios/core/BUILD @@ -16,19 +16,35 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +objc_library( + name = "MPPBaseOptions", + srcs = ["sources/MPPBaseOptions.m"], + hdrs = ["sources/MPPBaseOptions.h"], +) + objc_library( name = "MPPTaskOptions", srcs = ["sources/MPPTaskOptions.m"], hdrs = ["sources/MPPTaskOptions.h"], - copts = [ - "-ObjC++", - "-std=c++17", - ], deps = [ ":MPPBaseOptions", ], ) +objc_library( + name = "MPPTaskResult", + srcs = ["sources/MPPTaskResult.m"], + hdrs = ["sources/MPPTaskResult.h"], +) + +objc_library( + name = "MPPTaskOptionsProtocol", + hdrs = ["sources/MPPTaskOptionsProtocol.h"], + deps = [ + "//mediapipe/framework:calculator_options_cc_proto", + ], +) + objc_library( name = "MPPTaskInfo", srcs = ["sources/MPPTaskInfo.mm"], @@ -64,32 +80,12 @@ objc_library( ) objc_library( - name = "MPPTaskResult", - srcs = ["sources/MPPTaskResult.m"], - hdrs = ["sources/MPPTaskResult.h"], -) - -objc_library( - name = "MPPBaseOptions", - srcs = ["sources/MPPBaseOptions.m"], - hdrs = ["sources/MPPBaseOptions.h"], -) - -objc_library( - name = "MPPTaskOptionsProtocol", - hdrs = ["sources/MPPTaskOptionsProtocol.h"], - deps = [ - "//mediapipe/framework:calculator_options_cc_proto", - ], -) - -objc_library( - name = "MPPTaskManager", - srcs = ["sources/MPPTaskManager.mm"], - hdrs = ["sources/MPPTaskManager.h"], + name = "MPPTaskRunner", + srcs = ["sources/MPPTaskRunner.mm"], + hdrs = ["sources/MPPTaskRunner.h"], deps = [ "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", ], ) diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h index fca660fae..ae4c9eba1 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h @@ -17,7 +17,6 @@ #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" - NS_ASSUME_NONNULL_BEGIN /** @@ -55,7 +54,7 @@ NS_ASSUME_NONNULL_BEGIN outputStreams:(NSArray *)outputStreams taskOptions:(id)taskOptions enableFlowLimiting:(BOOL)enableFlowLimiting - error:(NSError **)error; + error:(NSError **)error NS_DESIGNATED_INITIALIZER; /** * Creates a MediaPipe Task protobuf message from the MPPTaskInfo instance. diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm index 7d2fd6f28..be3c8cbf7 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.mm @@ -24,9 +24,9 @@ namespace { using CalculatorGraphConfig = ::mediapipe::CalculatorGraphConfig; using Node = ::mediapipe::CalculatorGraphConfig::Node; -using ::mediapipe::InputStreamInfo; using ::mediapipe::CalculatorOptions; using ::mediapipe::FlowLimiterCalculatorOptions; +using ::mediapipe::InputStreamInfo; } // namespace @implementation MPPTaskInfo @@ -82,45 +82,46 @@ using ::mediapipe::FlowLimiterCalculatorOptions; graph_config.add_output_stream(cpp_output_stream); } - if (self.enableFlowLimiting) { - Node *flow_limit_calculator_node = graph_config.add_node(); - - flow_limit_calculator_node->set_calculator("FlowLimiterCalculator"); - - InputStreamInfo *input_stream_info = flow_limit_calculator_node->add_input_stream_info(); - input_stream_info->set_tag_index("FINISHED"); - input_stream_info->set_back_edge(true); - - FlowLimiterCalculatorOptions *flow_limit_calculator_options = - flow_limit_calculator_node->mutable_options()->MutableExtension( - FlowLimiterCalculatorOptions::ext); - flow_limit_calculator_options->set_max_in_flight(1); - flow_limit_calculator_options->set_max_in_queue(1); - - for (NSString *inputStream in self.inputStreams) { - graph_config.add_input_stream(inputStream.cppString); - - NSString *strippedInputStream = [MPPTaskInfo stripTagIndex:inputStream]; - flow_limit_calculator_node->add_input_stream(strippedInputStream.cppString); - - NSString *taskInputStream = [MPPTaskInfo addStreamNamePrefix:inputStream]; - task_subgraph_node->add_input_stream(taskInputStream.cppString); - - NSString *strippedTaskInputStream = [MPPTaskInfo stripTagIndex:taskInputStream]; - flow_limit_calculator_node->add_output_stream(strippedTaskInputStream.cppString); - } - - NSString *firstOutputStream = self.outputStreams[0]; - auto finished_output_stream = "FINISHED:" + firstOutputStream.cppString; - flow_limit_calculator_node->add_input_stream(finished_output_stream); - } else { + if (!self.enableFlowLimiting) { for (NSString *inputStream in self.inputStreams) { auto cpp_input_stream = inputStream.cppString; task_subgraph_node->add_input_stream(cpp_input_stream); graph_config.add_input_stream(cpp_input_stream); } + return graph_config; } + Node *flow_limit_calculator_node = graph_config.add_node(); + + flow_limit_calculator_node->set_calculator("FlowLimiterCalculator"); + + InputStreamInfo *input_stream_info = flow_limit_calculator_node->add_input_stream_info(); + input_stream_info->set_tag_index("FINISHED"); + input_stream_info->set_back_edge(true); + + FlowLimiterCalculatorOptions *flow_limit_calculator_options = + flow_limit_calculator_node->mutable_options()->MutableExtension( + FlowLimiterCalculatorOptions::ext); + flow_limit_calculator_options->set_max_in_flight(1); + flow_limit_calculator_options->set_max_in_queue(1); + + for (NSString *inputStream in self.inputStreams) { + graph_config.add_input_stream(inputStream.cppString); + + NSString *strippedInputStream = [MPPTaskInfo stripTagIndex:inputStream]; + flow_limit_calculator_node->add_input_stream(strippedInputStream.cppString); + + NSString *taskInputStream = [MPPTaskInfo addStreamNamePrefix:inputStream]; + task_subgraph_node->add_input_stream(taskInputStream.cppString); + + NSString *strippedTaskInputStream = [MPPTaskInfo stripTagIndex:taskInputStream]; + flow_limit_calculator_node->add_output_stream(strippedTaskInputStream.cppString); + } + + NSString *firstOutputStream = self.outputStreams[0]; + auto finished_output_stream = "FINISHED:" + firstOutputStream.cppString; + flow_limit_calculator_node->add_input_stream(finished_output_stream); + return graph_config; } diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h index fa11cd38e..ee2f7d032 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.h @@ -1,14 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #import #import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" @@ -19,27 +22,13 @@ NS_ASSUME_NONNULL_BEGIN * this class. */ NS_SWIFT_NAME(TaskOptions) + @interface MPPTaskOptions : NSObject /** * Base options for configuring the Mediapipe task. */ @property(nonatomic, copy) MPPBaseOptions *baseOptions; -/** - * Initializes a new `MPPTaskOptions` with the absolute path to the model file - * stored locally on the device, set to the given the model path. - * - * @discussion The external model file must be a single standalone TFLite file. It could be packed - * with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the - * necessary metadata and associated files might result in errors. Check the [documentation] - * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement. - * - * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. - * - * @return An instance of `MPPTaskOptions` initialized to the given model path. - */ -- (instancetype)initWithModelPath:(NSString *)modelPath; - @end NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m index f71d275be..fe74517c3 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptions.m @@ -1,17 +1,17 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" #import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h" @@ -25,12 +25,12 @@ return self; } -- (instancetype)initWithModelPath:(NSString *)modelPath { - self = [self init]; - if (self) { - _baseOptions.modelAssetPath = modelPath; - } - return self; +- (id)copyWithZone:(NSZone *)zone { + MPPTaskOptions *taskOptions = [[MPPTaskOptions alloc] init]; + + taskOptions.baseOptions = self.baseOptions; + + return taskOptions; } @end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h b/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h index 18543e9ef..44fba4c0b 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h @@ -1,26 +1,29 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #import #include "mediapipe/framework/calculator_options.pb.h" NS_ASSUME_NONNULL_BEGIN /** - * Any mediapipe task options should confirm to this protocol. + * Any MediaPipe task options should confirm to this protocol. */ @protocol MPPTaskOptionsProtocol /** - * Copies the iOS Mediapipe task options to an object of mediapipe::CalculatorOptions proto. + * Copies the iOS MediaPipe task options to an object of mediapipe::CalculatorOptions proto. */ - (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto; diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h index e4845c26d..d15d4f258 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h @@ -1,30 +1,36 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #import NS_ASSUME_NONNULL_BEGIN /** - * MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend + * MediaPipe Tasks result base class. Any MediaPipe task result class should extend * this class. */ NS_SWIFT_NAME(TaskResult) + @interface MPPTaskResult : NSObject /** - * Base options for configuring the Mediapipe task. + * Timestamp that is associated with the task result object. */ -@property(nonatomic, assign, readonly) long timeStamp; +@property(nonatomic, assign, readonly) long timestamp; -- (instancetype)initWithTimeStamp:(long)timeStamp; +- (instancetype)init NS_UNAVAILABLE; + +- (instancetype)initWithTimestamp:(long)timestamp NS_DESIGNATED_INITIALIZER; @end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.m b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m index 6a79ea7a9..7088eb246 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskResult.m +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m @@ -1,27 +1,31 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ #import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h" @implementation MPPTaskResult -- (instancetype)initWithTimeStamp:(long)timeStamp { - self = [self init]; +- (instancetype)initWithTimestamp:(long)timestamp { + self = [super init]; if (self) { - _timeStamp = timeStamp; + _timestamp = timestamp; } return self; } +- (id)copyWithZone:(NSZone *)zone { + return [[MPPTaskResult alloc] initWithTimestamp:self.timestamp]; +} + @end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h new file mode 100644 index 000000000..6561e136d --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h @@ -0,0 +1,50 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * This class is used to create and call appropriate methods on the C++ Task Runner. + */ + +@interface MPPTaskRunner : NSObject + +/** + * Initializes a new `MPPTaskRunner` with the mediapipe task graph config proto. + * + * @param graphConfig A mediapipe task graph config proto. + * + * @return An instance of `MPPTaskRunner` initialized to the given graph config proto. + */ +- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +- (absl::StatusOr) + process:(const mediapipe::tasks::core::PacketMap &)packetMap + error:(NSError **)error; + +- (void)close; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm new file mode 100644 index 000000000..e08d0bc1b --- /dev/null +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm @@ -0,0 +1,56 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h" +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" + +namespace { +using ::mediapipe::CalculatorGraphConfig; +using ::mediapipe::Packet; +using ::mediapipe::tasks::core::PacketMap; +using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; +} // namespace + +@interface MPPTaskRunner () { + // Cpp Task Runner + std::unique_ptr _cppTaskRunner; +} +@end + +@implementation MPPTaskRunner + +- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig + error:(NSError **)error { + self = [super init]; + if (self) { + auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig)); + + if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) { + return nil; + } + + _cppTaskRunner = std::move(taskRunnerResult.value()); + } + return self; +} + +- (absl::StatusOr)process:(const PacketMap &)packetMap { + return _cppTaskRunner->Process(packetMap); +} + +- (void)close { + _cppTaskRunner->Close(); +} + +@end diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java index d78685fe3..4e5cd7655 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java @@ -203,6 +203,8 @@ public final class AudioClassifier extends BaseAudioTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(AudioClassifier.class.getSimpleName()) + .setTaskRunningModeName(options.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java index 4bc505d84..077f28ca2 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java @@ -200,6 +200,8 @@ public final class AudioEmbedder extends BaseAudioTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(AudioEmbedder.class.getSimpleName()) + .setTaskRunningModeName(options.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD index 31f885267..5f7101776 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD @@ -22,6 +22,7 @@ android_library( ], manifest = "AndroidManifest.xml", deps = [ + ":logging", "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite", "//mediapipe/calculators/tensor:inference_calculator_java_proto_lite", "//mediapipe/framework:calculator_java_proto_lite", @@ -37,11 +38,22 @@ android_library( ], ) +android_library( + name = "logging", + srcs = glob( + ["logging/*.java"], + ), + deps = [ + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_core_aar") mediapipe_tasks_core_aar( name = "tasks_core", - srcs = glob(["*.java"]) + [ + srcs = glob(["**/*.java"]) + [ "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:java_src", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:java_src", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:java_src", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java index 12f8be8ba..310f5739c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java @@ -32,6 +32,12 @@ public abstract class TaskInfo { /** Builder for {@link TaskInfo}. */ @AutoValue.Builder public abstract static class Builder { + /** Sets the MediaPipe task name. */ + public abstract Builder setTaskName(String value); + + /** Sets the MediaPipe task running mode name. */ + public abstract Builder setTaskRunningModeName(String value); + /** Sets the MediaPipe task graph name. */ public abstract Builder setTaskGraphName(String value); @@ -71,6 +77,10 @@ public abstract class TaskInfo { } } + abstract String taskName(); + + abstract String taskRunningModeName(); + abstract String taskGraphName(); abstract T taskOptions(); @@ -82,7 +92,7 @@ public abstract class TaskInfo { abstract Boolean enableFlowLimiting(); public static Builder builder() { - return new AutoValue_TaskInfo.Builder(); + return new AutoValue_TaskInfo.Builder().setTaskName("").setTaskRunningModeName(""); } /* Returns a list of the output stream names without the stream tags. */ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java index e6fc91cf6..1a128c538 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java @@ -21,6 +21,8 @@ import com.google.mediapipe.framework.AndroidPacketCreator; import com.google.mediapipe.framework.Graph; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.tasks.core.logging.TasksStatsLogger; +import com.google.mediapipe.tasks.core.logging.TasksStatsDummyLogger; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; @@ -34,6 +36,7 @@ public class TaskRunner implements AutoCloseable { private final Graph graph; private final ModelResourcesCache modelResourcesCache; private final AndroidPacketCreator packetCreator; + private final TasksStatsLogger statsLogger; private long lastSeenTimestamp = Long.MIN_VALUE; private ErrorListener errorListener; @@ -51,6 +54,8 @@ public class TaskRunner implements AutoCloseable { Context context, TaskInfo taskInfo, OutputHandler outputHandler) { + TasksStatsLogger statsLogger = + TasksStatsDummyLogger.create(context, taskInfo.taskName(), taskInfo.taskRunningModeName()); AndroidAssetUtil.initializeNativeAssetManager(context); Graph mediapipeGraph = new Graph(); mediapipeGraph.loadBinaryGraph(taskInfo.generateGraphConfig()); @@ -58,12 +63,15 @@ public class TaskRunner implements AutoCloseable { mediapipeGraph.setServiceObject(new ModelResourcesCacheService(), graphModelResourcesCache); mediapipeGraph.addMultiStreamCallback( taskInfo.outputStreamNames(), - outputHandler::run, - /*observeTimestampBounds=*/ outputHandler.handleTimestampBoundChanges()); + packets -> { + outputHandler.run(packets); + statsLogger.recordInvocationEnd(packets.get(0).getTimestamp()); + }, + /* observeTimestampBounds= */ outputHandler.handleTimestampBoundChanges()); mediapipeGraph.startRunningGraph(); // Waits until all calculators are opened and the graph is fully started. mediapipeGraph.waitUntilGraphIdle(); - return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler); + return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler, statsLogger); } /** @@ -91,7 +99,10 @@ public class TaskRunner implements AutoCloseable { * @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs. */ public synchronized TaskResult process(Map inputs) { - addPackets(inputs, generateSyntheticTimestamp()); + long syntheticInputTimestamp = generateSyntheticTimestamp(); + // TODO: Support recording GPU input arrival. + statsLogger.recordCpuInputArrival(syntheticInputTimestamp); + addPackets(inputs, syntheticInputTimestamp); graph.waitUntilGraphIdle(); lastSeenTimestamp = outputHandler.getLatestOutputTimestamp(); return outputHandler.retrieveCachedTaskResult(); @@ -112,6 +123,7 @@ public class TaskRunner implements AutoCloseable { */ public synchronized TaskResult process(Map inputs, long inputTimestamp) { validateInputTimstamp(inputTimestamp); + statsLogger.recordCpuInputArrival(inputTimestamp); addPackets(inputs, inputTimestamp); graph.waitUntilGraphIdle(); return outputHandler.retrieveCachedTaskResult(); @@ -132,6 +144,7 @@ public class TaskRunner implements AutoCloseable { */ public synchronized void send(Map inputs, long inputTimestamp) { validateInputTimstamp(inputTimestamp); + statsLogger.recordCpuInputArrival(inputTimestamp); addPackets(inputs, inputTimestamp); } @@ -145,6 +158,7 @@ public class TaskRunner implements AutoCloseable { graphStarted.set(false); graph.closeAllPacketSources(); graph.waitUntilGraphDone(); + statsLogger.logSessionEnd(); } catch (MediaPipeException e) { reportError(e); } @@ -154,6 +168,7 @@ public class TaskRunner implements AutoCloseable { // Waits until all calculators are opened and the graph is fully restarted. graph.waitUntilGraphIdle(); graphStarted.set(true); + statsLogger.logSessionStart(); } catch (MediaPipeException e) { reportError(e); } @@ -169,6 +184,7 @@ public class TaskRunner implements AutoCloseable { graphStarted.set(false); graph.closeAllPacketSources(); graph.waitUntilGraphDone(); + statsLogger.logSessionEnd(); if (modelResourcesCache != null) { modelResourcesCache.release(); } @@ -247,12 +263,15 @@ public class TaskRunner implements AutoCloseable { private TaskRunner( Graph graph, ModelResourcesCache modelResourcesCache, - OutputHandler outputHandler) { + OutputHandler outputHandler, + TasksStatsLogger statsLogger) { this.outputHandler = outputHandler; this.graph = graph; this.modelResourcesCache = modelResourcesCache; this.packetCreator = new AndroidPacketCreator(graph); + this.statsLogger = statsLogger; graphStarted.set(true); + this.statsLogger.logSessionStart(); } /** Reports error. */ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsDummyLogger.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsDummyLogger.java new file mode 100644 index 000000000..c10b5d224 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsDummyLogger.java @@ -0,0 +1,78 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.core.logging; + +import android.content.Context; + +/** A dummy MediaPipe Tasks stats logger that has all methods as no-ops. */ +public class TasksStatsDummyLogger implements TasksStatsLogger { + + /** + * Creates the MediaPipe Tasks stats dummy logger. + * + * @param context a {@link Context}. + * @param taskNameStr the task api name. + * @param taskRunningModeStr the task running mode string representation. + */ + public static TasksStatsDummyLogger create( + Context context, String taskNameStr, String taskRunningModeStr) { + return new TasksStatsDummyLogger(); + } + + private TasksStatsDummyLogger() {} + + /** Logs the start of a MediaPipe Tasks API session. */ + @Override + public void logSessionStart() {} + + /** + * Records MediaPipe Tasks API receiving CPU input data. + * + * @param packetTimestamp the input packet timestamp that acts as the identifier of the api + * invocation. + */ + @Override + public void recordCpuInputArrival(long packetTimestamp) {} + + /** + * Records MediaPipe Tasks API receiving GPU input data. + * + * @param packetTimestamp the input packet timestamp that acts as the identifier of the api + * invocation. + */ + @Override + public void recordGpuInputArrival(long packetTimestamp) {} + + /** + * Records the end of a Mediapipe Tasks API invocation. + * + * @param packetTimestamp the output packet timestamp that acts as the identifier of the api + * invocation. + */ + @Override + public void recordInvocationEnd(long packetTimestamp) {} + + /** Logs the MediaPipe Tasks API periodic invocation report. */ + @Override + public void logInvocationReport(StatsSnapshot stats) {} + + /** Logs the Tasks API session end event. */ + @Override + public void logSessionEnd() {} + + /** Logs the MediaPipe Tasks API initialization error. */ + @Override + public void logInitError() {} +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsLogger.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsLogger.java new file mode 100644 index 000000000..c726e7d0d --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/logging/TasksStatsLogger.java @@ -0,0 +1,98 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.core.logging; + +import com.google.auto.value.AutoValue; + +/** The stats logger interface that defines what MediaPipe Tasks events to log. */ +public interface TasksStatsLogger { + /** Task stats snapshot. */ + @AutoValue + abstract static class StatsSnapshot { + static StatsSnapshot create( + int cpuInputCount, + int gpuInputCount, + int finishedCount, + int droppedCount, + long totalLatencyMs, + long peakLatencyMs, + long elapsedTimeMs) { + return new AutoValue_TasksStatsLogger_StatsSnapshot( + cpuInputCount, + gpuInputCount, + finishedCount, + droppedCount, + totalLatencyMs, + peakLatencyMs, + elapsedTimeMs); + } + + static StatsSnapshot createDefault() { + return new AutoValue_TasksStatsLogger_StatsSnapshot(0, 0, 0, 0, 0, 0, 0); + } + + abstract int cpuInputCount(); + + abstract int gpuInputCount(); + + abstract int finishedCount(); + + abstract int droppedCount(); + + abstract long totalLatencyMs(); + + abstract long peakLatencyMs(); + + abstract long elapsedTimeMs(); + } + + /** Logs the start of a MediaPipe Tasks API session. */ + public void logSessionStart(); + + /** + * Records MediaPipe Tasks API receiving CPU input data. + * + * @param packetTimestamp the input packet timestamp that acts as the identifier of the api + * invocation. + */ + public void recordCpuInputArrival(long packetTimestamp); + + /** + * Records MediaPipe Tasks API receiving GPU input data. + * + * @param packetTimestamp the input packet timestamp that acts as the identifier of the api + * invocation. + */ + public void recordGpuInputArrival(long packetTimestamp); + + /** + * Records the end of a Mediapipe Tasks API invocation. + * + * @param packetTimestamp the output packet timestamp that acts as the identifier of the api + * invocation. + */ + public void recordInvocationEnd(long packetTimestamp); + + /** Logs the MediaPipe Tasks API periodic invocation report. */ + public void logInvocationReport(StatsSnapshot stats); + + /** Logs the Tasks API session end event. */ + public void logSessionEnd(); + + /** Logs the MediaPipe Tasks API initialization error. */ + public void logInitError(); + + // TODO: Logs more error types. +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java index 0ea91a9f8..edb78a191 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java @@ -169,6 +169,7 @@ public final class TextClassifier implements AutoCloseable { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(TextClassifier.class.getSimpleName()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java index 9b464d0e8..28f351d4b 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java @@ -159,6 +159,7 @@ public final class TextEmbedder implements AutoCloseable { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(TextEmbedder.class.getSimpleName()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java index e9e74a067..a933d2f65 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java @@ -194,6 +194,8 @@ public final class GestureRecognizer extends BaseVisionTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(GestureRecognizer.class.getSimpleName()) + .setTaskRunningModeName(recognizerOptions.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java index a9270d347..1d08ab928 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java @@ -183,6 +183,8 @@ public final class HandLandmarker extends BaseVisionTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(HandLandmarker.class.getSimpleName()) + .setTaskRunningModeName(landmarkerOptions.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java index 8990f46fd..38482797c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java @@ -197,6 +197,8 @@ public final class ImageClassifier extends BaseVisionTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(ImageClassifier.class.getSimpleName()) + .setTaskRunningModeName(options.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java index af053d860..488927257 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java @@ -180,6 +180,8 @@ public final class ImageEmbedder extends BaseVisionTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(ImageEmbedder.class.getSimpleName()) + .setTaskRunningModeName(options.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java index 769b9137f..d706189ee 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java @@ -190,6 +190,8 @@ public final class ObjectDetector extends BaseVisionTaskApi { TaskRunner.create( context, TaskInfo.builder() + .setTaskName(ObjectDetector.class.getSimpleName()) + .setTaskRunningModeName(detectorOptions.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) .setOutputStreams(OUTPUT_STREAMS) diff --git a/mediapipe/tasks/web/components/processors/base_options.test.ts b/mediapipe/tasks/web/components/processors/base_options.test.ts index 46c2277e9..6d58be68f 100644 --- a/mediapipe/tasks/web/components/processors/base_options.test.ts +++ b/mediapipe/tasks/web/components/processors/base_options.test.ts @@ -86,7 +86,7 @@ describe('convertBaseOptionsToProto()', () => { it('can enable CPU delegate', async () => { const baseOptionsProto = await convertBaseOptionsToProto({ modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'cpu', + delegate: 'CPU', }); expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); }); @@ -94,7 +94,7 @@ describe('convertBaseOptionsToProto()', () => { it('can enable GPU delegate', async () => { const baseOptionsProto = await convertBaseOptionsToProto({ modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'gpu', + delegate: 'GPU', }); expect(baseOptionsProto.toObject()).toEqual({ ...mockBytesResult, @@ -117,7 +117,7 @@ describe('convertBaseOptionsToProto()', () => { it('can reset delegate', async () => { let baseOptionsProto = await convertBaseOptionsToProto({ modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'gpu', + delegate: 'GPU', }); // Clear backend baseOptionsProto = diff --git a/mediapipe/tasks/web/components/processors/base_options.ts b/mediapipe/tasks/web/components/processors/base_options.ts index 16d562262..97b62b784 100644 --- a/mediapipe/tasks/web/components/processors/base_options.ts +++ b/mediapipe/tasks/web/components/processors/base_options.ts @@ -71,7 +71,7 @@ async function configureExternalFile( /** Configues the `acceleration` option. */ function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) { const acceleration = proto.getAcceleration() ?? new Acceleration(); - if (options.delegate === 'gpu') { + if (options.delegate === 'GPU') { acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); } else { acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite()); diff --git a/mediapipe/tasks/web/core/fileset_resolver.ts b/mediapipe/tasks/web/core/fileset_resolver.ts index d4691243b..9917035a4 100644 --- a/mediapipe/tasks/web/core/fileset_resolver.ts +++ b/mediapipe/tasks/web/core/fileset_resolver.ts @@ -44,22 +44,14 @@ async function isSimdSupported(): Promise { } async function createFileset( - taskName: string, basePath: string = '.'): Promise { - if (await isSimdSupported()) { - return { - wasmLoaderPath: - `${basePath}/${taskName}_wasm_internal.js`, - wasmBinaryPath: - `${basePath}/${taskName}_wasm_internal.wasm`, - }; - } else { - return { - wasmLoaderPath: - `${basePath}/${taskName}_wasm_nosimd_internal.js`, - wasmBinaryPath: - `${basePath}/${taskName}_wasm_nosimd_internal.wasm`, - }; - } + taskName: string, basePath: string = ''): Promise { + const suffix = + await isSimdSupported() ? 'wasm_internal' : 'wasm_nosimd_internal'; + + return { + wasmLoaderPath: `${basePath}/${taskName}_${suffix}.js`, + wasmBinaryPath: `${basePath}/${taskName}_${suffix}.wasm`, + }; } // tslint:disable:class-as-namespace diff --git a/mediapipe/tasks/web/core/task_runner_options.d.ts b/mediapipe/tasks/web/core/task_runner_options.d.ts index aa0b4a028..5f23cd4bf 100644 --- a/mediapipe/tasks/web/core/task_runner_options.d.ts +++ b/mediapipe/tasks/web/core/task_runner_options.d.ts @@ -31,7 +31,7 @@ export declare interface BaseOptions { modelAssetBuffer?: Uint8Array|undefined; /** Overrides the default backend to use for the provided model. */ - delegate?: 'cpu'|'gpu'|undefined; + delegate?: 'CPU'|'GPU'|undefined; } /** Options to configure MediaPipe Tasks in general. */ diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index a9bb979af..ef866bc91 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -1028,7 +1028,9 @@ export class GraphRunner { // Set up our TS listener to receive any packets for this stream, and // additionally reformat our Uint8Array into a Float32Array for the user. this.setListener(outputStreamName, (data: Uint8Array) => { - const floatArray = new Float32Array(data.buffer); // Should be very fast + // Should be very fast + const floatArray = + new Float32Array(data.buffer, data.byteOffset, data.length / 4); callbackFcn(floatArray); }); diff --git a/setup.py b/setup.py index b072a850e..992430cf1 100644 --- a/setup.py +++ b/setup.py @@ -490,10 +490,10 @@ setuptools.setup( 'Operating System :: MacOS :: MacOS X', 'Operating System :: Microsoft :: Windows', 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3 :: Only', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Artificial Intelligence',