From 7fb37c80e88495d6b50f39a9e9d348e2c0796a2c Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 3 May 2022 15:29:57 -0700 Subject: [PATCH] Project import generated by Copybara. GitOrigin-RevId: 19a829ffd755edb43e54d20c0e7b9348512d5108 --- .bazelrc | 7 + WORKSPACE | 5 +- docs/getting_started/android.md | 2 +- docs/getting_started/gpu_support.md | 15 + docs/getting_started/hello_world_ios.md | 2 +- docs/getting_started/ios.md | 12 +- docs/tools/tracing_and_profiling.md | 19 +- mediapipe/BUILD | 10 + mediapipe/calculators/core/BUILD | 34 ++ .../calculators/core/begin_loop_calculator.cc | 4 + .../calculators/core/end_loop_calculator.cc | 5 + .../core/get_vector_item_calculator.cc | 32 ++ .../core/get_vector_item_calculator.h | 77 +++++ .../core/split_vector_calculator.cc | 3 + .../core/vector_size_calculator.cc | 32 ++ .../calculators/core/vector_size_calculator.h | 64 ++++ .../image/scale_image_calculator.cc | 21 +- mediapipe/calculators/tensor/BUILD | 10 +- .../tensor/image_to_tensor_calculator.cc | 68 ++-- .../tensor/image_to_tensor_calculator.proto | 9 + .../tensor/image_to_tensor_calculator_test.cc | 118 ++++--- .../image_to_tensor_converter_gl_buffer.cc | 10 +- .../image_to_tensor_converter_gl_texture.cc | 10 +- .../tensor/image_to_tensor_converter_metal.cc | 11 +- .../image_to_tensor_converter_opencv.cc | 43 ++- .../tensor/inference_calculator.cc | 14 + .../calculators/tensor/inference_calculator.h | 17 +- .../tensor/inference_calculator.proto | 3 + .../tensor/inference_calculator_cpu.cc | 102 +++--- .../tensor/inference_calculator_gl.cc | 67 ++-- .../tensor/inference_calculator_metal.cc | 47 ++- .../tensors_to_detections_calculator.cc | 212 +++++++++++-- .../tensors_to_detections_calculator.proto | 41 +++ .../tensor_to_image_frame_calculator.cc | 8 +- mediapipe/calculators/tflite/BUILD | 3 +- .../tflite_custom_op_resolver_calculator.cc | 42 ++- mediapipe/calculators/util/BUILD | 36 +++ .../detection_label_id_to_text_calculator.cc | 54 +++- ...etection_label_id_to_text_calculator.proto | 6 +- .../detection_transformation_calculator.cc | 298 ++++++++++++++++++ ...etection_transformation_calculator_test.cc | 287 +++++++++++++++++ .../calculators/video/tracking_graph_test.cc | 2 +- .../autoflip/quality/kinematic_path_solver.h | 2 +- .../media_sequence/read_demo_dataset.py | 1 - mediapipe/examples/ios/facedetectioncpu/BUILD | 2 +- mediapipe/examples/ios/facedetectiongpu/BUILD | 2 +- mediapipe/examples/ios/faceeffect/BUILD | 2 +- mediapipe/examples/ios/facemeshgpu/BUILD | 2 +- mediapipe/examples/ios/handdetectiongpu/BUILD | 2 +- mediapipe/examples/ios/handtrackinggpu/BUILD | 2 +- mediapipe/examples/ios/helloworld/BUILD | 2 +- .../examples/ios/holistictrackinggpu/BUILD | 2 +- mediapipe/examples/ios/iristrackinggpu/BUILD | 2 +- mediapipe/examples/ios/link_local_profiles.py | 1 - .../examples/ios/objectdetectioncpu/BUILD | 2 +- .../examples/ios/objectdetectiongpu/BUILD | 2 +- .../ios/objectdetectiontrackinggpu/BUILD | 2 +- mediapipe/examples/ios/posetrackinggpu/BUILD | 2 +- .../examples/ios/selfiesegmentationgpu/BUILD | 2 +- mediapipe/framework/BUILD | 8 +- mediapipe/framework/api2/packet.h | 5 +- mediapipe/framework/api2/packet_nc.cc | 8 +- mediapipe/framework/api2/packet_test.cc | 17 + mediapipe/framework/api2/port_test.cc | 11 + mediapipe/framework/calculator_contract.h | 15 +- mediapipe/framework/calculator_graph.cc | 155 +++++---- mediapipe/framework/calculator_graph.h | 30 +- .../calculator_graph_side_packet_test.cc | 27 +- mediapipe/framework/calculator_node.cc | 8 +- mediapipe/framework/calculator_node.h | 12 +- mediapipe/framework/formats/BUILD | 17 + mediapipe/framework/formats/tensor.h | 4 +- mediapipe/framework/graph_service.h | 68 +++- mediapipe/framework/graph_service_manager.h | 2 + .../framework/graph_service_manager_test.cc | 8 +- mediapipe/framework/graph_service_test.cc | 7 + mediapipe/framework/input_stream_handler.cc | 9 +- mediapipe/framework/input_stream_handler.h | 2 +- mediapipe/framework/input_stream_manager.cc | 5 + mediapipe/framework/input_stream_manager.h | 9 +- .../framework/input_stream_manager_test.cc | 2 + mediapipe/framework/mediapipe_cc_test.bzl | 2 + mediapipe/framework/port/BUILD | 3 +- mediapipe/framework/test_service.cc | 10 +- mediapipe/framework/test_service.h | 18 ++ mediapipe/framework/tool/BUILD | 2 +- mediapipe/framework/tool/test_util.cc | 4 +- mediapipe/framework/tool/test_util.h | 4 +- mediapipe/gpu/BUILD | 21 +- mediapipe/gpu/gl_context.cc | 34 +- mediapipe/gpu/gl_context.h | 49 +++ mediapipe/gpu/gl_context_eagl.cc | 3 +- mediapipe/gpu/gl_context_egl.cc | 3 +- mediapipe/gpu/gl_context_nsgl.cc | 3 +- mediapipe/gpu/gl_context_webgl.cc | 3 +- mediapipe/gpu/gl_quad_renderer.cc | 2 +- mediapipe/gpu/gl_quad_renderer.h | 2 +- mediapipe/gpu/gpu_buffer_format.cc | 2 +- mediapipe/gpu/gpu_service.cc | 3 +- mediapipe/gpu/gpu_service.h | 11 +- mediapipe/gpu/gpu_shared_data_internal.cc | 2 +- mediapipe/gpu/graph_support.h | 6 +- .../image_frame_to_gpu_buffer_calculator.cc | 36 ++- .../com/google/mediapipe/components/BUILD | 14 +- .../components/CameraXPreviewHelper.java | 19 ++ .../components/ExternalTextureConverter.java | 57 +++- .../components/MicrophoneHelper.java | 36 ++- .../google/mediapipe/framework/jni/graph.cc | 12 +- .../framework/jni/register_natives.cc | 3 + mediapipe/modules/face_geometry/README.md | 16 +- .../modules/face_geometry/face_geometry.pbtxt | 6 +- .../face_geometry_from_detection.pbtxt | 8 +- .../face_geometry_from_landmarks.pbtxt | 6 +- .../face_landmark/face_landmark_cpu.pbtxt | 4 +- .../face_landmark/face_landmark_gpu.pbtxt | 4 +- .../palm_detection/palm_detection_cpu.pbtxt | 4 +- .../palm_detection/palm_detection_gpu.pbtxt | 4 +- .../selfie_segmentation_cpu.pbtxt | 4 +- .../selfie_segmentation_gpu.pbtxt | 4 +- mediapipe/objc/BUILD | 5 - mediapipe/python/pybind/image.cc | 14 +- mediapipe/python/pybind/packet_creator.cc | 17 +- mediapipe/python/solution_base.py | 7 + mediapipe/python/solutions/holistic.py | 8 +- mediapipe/python/solutions/objectron.py | 6 +- mediapipe/python/solutions/pose.py | 8 +- mediapipe/util/BUILD | 19 ++ mediapipe/util/label_map.proto | 40 +++ mediapipe/util/label_map_util.cc | 78 +++++ mediapipe/util/label_map_util.h | 34 ++ mediapipe/util/tflite/BUILD | 12 + mediapipe/util/tflite/error_reporter.cc | 49 +++ mediapipe/util/tflite/error_reporter.h | 52 +++ setup.py | 3 +- setup_android_sdk_and_ndk.sh | 2 +- third_party/opencv_macos.BUILD | 26 +- 136 files changed, 2572 insertions(+), 555 deletions(-) create mode 100644 mediapipe/calculators/core/get_vector_item_calculator.cc create mode 100644 mediapipe/calculators/core/get_vector_item_calculator.h create mode 100644 mediapipe/calculators/core/vector_size_calculator.cc create mode 100644 mediapipe/calculators/core/vector_size_calculator.h create mode 100644 mediapipe/calculators/util/detection_transformation_calculator.cc create mode 100644 mediapipe/calculators/util/detection_transformation_calculator_test.cc create mode 100644 mediapipe/util/label_map.proto create mode 100644 mediapipe/util/label_map_util.cc create mode 100644 mediapipe/util/label_map_util.h create mode 100644 mediapipe/util/tflite/error_reporter.cc create mode 100644 mediapipe/util/tflite/error_reporter.h diff --git a/.bazelrc b/.bazelrc index 73e15b32b..724dd23fd 100644 --- a/.bazelrc +++ b/.bazelrc @@ -32,6 +32,9 @@ build:macos --copt=-w # Sets the default Apple platform to macOS. build --apple_platform_type=macos +# Compile ObjC++ files with C++17 +build --per_file_copt=.*\.mm\$@-std=c++17 + # Allow debugging with XCODE build --apple_generate_dsym @@ -88,6 +91,10 @@ build:darwin_x86_64 --apple_platform_type=macos build:darwin_x86_64 --macos_minimum_os=10.12 build:darwin_x86_64 --cpu=darwin_x86_64 +build:darwin_arm64 --apple_platform_type=macos +build:darwin_arm64 --macos_minimum_os=10.16 +build:darwin_arm64 --cpu=darwin_arm64 + # This bazelrc file is meant to be written by a setup script. try-import %workspace%/.configure.bazelrc diff --git a/WORKSPACE b/WORKSPACE index 7bdc114ff..e85e34d84 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -202,7 +202,10 @@ new_local_repository( new_local_repository( name = "macos_opencv", build_file = "@//third_party:opencv_macos.BUILD", - path = "/usr/local/opt/opencv@3", + # For local MacOS builds, the path should point to an opencv@3 installation. + # If you edit the path here, you will also need to update the corresponding + # prefix in "opencv_macos.BUILD". + path = "/usr/local", ) new_local_repository( diff --git a/docs/getting_started/android.md b/docs/getting_started/android.md index b3f6c5df4..ad2b42216 100644 --- a/docs/getting_started/android.md +++ b/docs/getting_started/android.md @@ -53,7 +53,7 @@ the following: ```bash $ echo "android_sdk_repository(name = \"androidsdk\")" >> WORKSPACE -$ echo "android_ndk_repository(name = \"androidndk\")" >> WORKSPACE +$ echo "android_ndk_repository(name = \"androidndk\", api_level=21)" >> WORKSPACE ``` In order to use MediaPipe on earlier Android versions, MediaPipe needs to switch diff --git a/docs/getting_started/gpu_support.md b/docs/getting_started/gpu_support.md index 38bab9be3..b4f4aa180 100644 --- a/docs/getting_started/gpu_support.md +++ b/docs/getting_started/gpu_support.md @@ -59,6 +59,21 @@ OpenGL ES profile shading language version string: OpenGL ES GLSL ES 3.20 OpenGL ES profile extensions: ``` +If you have connected to your computer through SSH and find when you probe for +GPU information you see the output: + +```bash +glxinfo | grep -i opengl +Error: unable to open display +``` + +Try re-establishing your SSH connection with the `-X` option and try again. For +example: + +```bash +ssh -X @ +``` + *Notice the ES 3.20 text above.* You need to see ES 3.1 or greater printed in order to perform TFLite inference diff --git a/docs/getting_started/hello_world_ios.md b/docs/getting_started/hello_world_ios.md index 4591b5f33..dd75d416a 100644 --- a/docs/getting_started/hello_world_ios.md +++ b/docs/getting_started/hello_world_ios.md @@ -131,7 +131,7 @@ Create a `BUILD` file in the `$APPLICATION_PATH` and add the following build rules: ``` -MIN_IOS_VERSION = "10.0" +MIN_IOS_VERSION = "11.0" load( "@build_bazel_rules_apple//apple:ios.bzl", diff --git a/docs/getting_started/ios.md b/docs/getting_started/ios.md index cd11828af..6ed3192fa 100644 --- a/docs/getting_started/ios.md +++ b/docs/getting_started/ios.md @@ -32,9 +32,14 @@ example apps, start from, start from xcode-select --install ``` -3. Install [Bazel](https://bazel.build/). +3. Install [Bazelisk](https://github.com/bazelbuild/bazelisk) +. - We recommend using [Homebrew](https://brew.sh/) to get the latest version. + We recommend using [Homebrew](https://brew.sh/) to get the latest versions. + + ```bash + brew install bazelisk + ``` 4. Set Python 3.7 as the default Python version and install the Python "six" library. This is needed for TensorFlow. @@ -187,6 +192,9 @@ Note: When you ask Xcode to run an app, by default it will use the Debug configuration. Some of our demos are computationally heavy; you may want to use the Release configuration for better performance. +Note: Due to an imcoptibility caused by one of our dependencies, MediaPipe +cannot be used for apps running on the iPhone Simulator on Apple Silicon (M1). + Tip: To switch build configuration in Xcode, click on the target menu, choose "Edit Scheme...", select the Run action, and switch the Build Configuration from Debug to Release. Note that this is set independently for each target. diff --git a/docs/tools/tracing_and_profiling.md b/docs/tools/tracing_and_profiling.md index 30b4bd993..2d712461f 100644 --- a/docs/tools/tracing_and_profiling.md +++ b/docs/tools/tracing_and_profiling.md @@ -258,13 +258,14 @@ Many of the following settings are advanced and not recommended for general usage. Consult [Enabling tracing and profiling](#enabling-tracing-and-profiling) for a friendlier introduction. -histogram_interval_size_usec :Specifies the size of the runtimes histogram -intervals (in microseconds) to generate the histogram of the Process() time. The -last interval extends to +inf. If not specified, the interval is 1000000 usec = -1 sec. +histogram_interval_size_usec +: Specifies the size of the runtimes histogram intervals (in microseconds) to + generate the histogram of the `Process()` time. The last interval extends to + +inf. If not specified, the interval is 1000000 usec = 1 sec. -num_histogram_intervals :Specifies the number of intervals to generate the -histogram of the `Process()` runtime. If not specified, one interval is used. +num_histogram_intervals +: Specifies the number of intervals to generate the histogram of the + `Process()` runtime. If not specified, one interval is used. enable_profiler : If true, the profiler starts profiling when graph is initialized. @@ -288,7 +289,7 @@ trace_event_types_disabled trace_log_path : The output directory and base-name prefix for trace log files. Log files are - written to: StrCat(trace_log_path, index, "`.binarypb`") + written to: `StrCat(trace_log_path, index, ".binarypb")` trace_log_count : The number of trace log files retained. The trace log files are named @@ -310,8 +311,8 @@ trace_log_instant_events trace_log_interval_count : The number of trace log intervals per file. The total log duration is: - `trace_log_interval_usec * trace_log_file_count * trace_log_interval_count`. - The default value specifies 10 intervals per file. + `trace_log_interval_usec * trace_log_count * trace_log_interval_count`. The + default value specifies 10 intervals per file. trace_log_disabled : An option to turn ON/OFF writing trace files to disk. Saving trace files to diff --git a/mediapipe/BUILD b/mediapipe/BUILD index 1171ea6f0..3187c0cf7 100644 --- a/mediapipe/BUILD +++ b/mediapipe/BUILD @@ -75,6 +75,7 @@ alias( actual = select({ ":macos_i386": ":macos_i386", ":macos_x86_64": ":macos_x86_64", + ":macos_arm64": ":macos_arm64", "//conditions:default": ":macos_i386", # Arbitrarily chosen from above. }), visibility = ["//visibility:public"], @@ -119,6 +120,15 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "macos_arm64", + values = { + "apple_platform_type": "macos", + "cpu": "darwin_arm64", + }, + visibility = ["//visibility:public"], +) + [ config_setting( name = arch, diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 3cb0dd018..ff0a5d663 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -214,6 +214,7 @@ cc_library( "//mediapipe/framework:collection_item_id", "//mediapipe/framework:packet", "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:integral_types", @@ -1257,3 +1258,36 @@ cc_test( "@com_google_absl//absl/time", ], ) + +cc_library( + name = "get_vector_item_calculator", + srcs = ["get_vector_item_calculator.cc"], + hdrs = ["get_vector_item_calculator.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:packet", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_library( + name = "vector_size_calculator", + srcs = ["vector_size_calculator.cc"], + hdrs = ["vector_size_calculator.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) diff --git a/mediapipe/calculators/core/begin_loop_calculator.cc b/mediapipe/calculators/core/begin_loop_calculator.cc index e698e194c..5bf8e65fc 100644 --- a/mediapipe/calculators/core/begin_loop_calculator.cc +++ b/mediapipe/calculators/core/begin_loop_calculator.cc @@ -28,6 +28,10 @@ typedef BeginLoopCalculator> BeginLoopNormalizedLandmarkListVectorCalculator; REGISTER_CALCULATOR(BeginLoopNormalizedLandmarkListVectorCalculator); +// A calculator to process std::vector. +typedef BeginLoopCalculator> BeginLoopIntCalculator; +REGISTER_CALCULATOR(BeginLoopIntCalculator); + // A calculator to process std::vector. typedef BeginLoopCalculator> BeginLoopNormalizedRectCalculator; diff --git a/mediapipe/calculators/core/end_loop_calculator.cc b/mediapipe/calculators/core/end_loop_calculator.cc index 2a366f992..d21bc03a4 100644 --- a/mediapipe/calculators/core/end_loop_calculator.cc +++ b/mediapipe/calculators/core/end_loop_calculator.cc @@ -17,6 +17,7 @@ #include #include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/util/render_data.pb.h" @@ -50,4 +51,8 @@ REGISTER_CALCULATOR(EndLoopClassificationListCalculator); typedef EndLoopCalculator> EndLoopTensorCalculator; REGISTER_CALCULATOR(EndLoopTensorCalculator); +typedef EndLoopCalculator> + EndLoopDetectionCalculator; +REGISTER_CALCULATOR(EndLoopDetectionCalculator); + } // namespace mediapipe diff --git a/mediapipe/calculators/core/get_vector_item_calculator.cc b/mediapipe/calculators/core/get_vector_item_calculator.cc new file mode 100644 index 000000000..56a2f3304 --- /dev/null +++ b/mediapipe/calculators/core/get_vector_item_calculator.cc @@ -0,0 +1,32 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/core/get_vector_item_calculator.h" + +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe { +namespace api2 { + +using GetLandmarkListVectorItemCalculator = + GetVectorItemCalculator; +REGISTER_CALCULATOR(GetLandmarkListVectorItemCalculator); + +using GetClassificationListVectorItemCalculator = + GetVectorItemCalculator; +REGISTER_CALCULATOR(GetClassificationListVectorItemCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/core/get_vector_item_calculator.h b/mediapipe/calculators/core/get_vector_item_calculator.h new file mode 100644 index 000000000..21009a30b --- /dev/null +++ b/mediapipe/calculators/core/get_vector_item_calculator.h @@ -0,0 +1,77 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_CORE_GET_VECTOR_ITEM_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_CORE_GET_VECTOR_ITEM_CALCULATOR_H_ + +#include + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { +namespace api2 { + +// A calcutlator to return an item from the vector by its index. +// +// Inputs: +// VECTOR - std::vector +// Vector to take an item from. +// INDEX - int +// Index of the item to return. +// +// Outputs: +// ITEM - T +// Item from the vector at given index. +// +// Example config: +// node { +// calculator: "Get{SpecificType}VectorItemCalculator" +// input_stream: "VECTOR:vector" +// input_stream: "INDEX:index" +// input_stream: "ITEM:item" +// } +// +template +class GetVectorItemCalculator : public Node { + public: + static constexpr Input> kIn{"VECTOR"}; + static constexpr Input kIdx{"INDEX"}; + static constexpr Output kOut{"ITEM"}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut); + + absl::Status Process(CalculatorContext* cc) final { + if (kIn(cc).IsEmpty() || kIdx(cc).IsEmpty()) { + return absl::OkStatus(); + } + + const std::vector& items = kIn(cc).Get(); + const int idx = kIdx(cc).Get(); + + RET_CHECK_LT(idx, items.size()); + kOut(cc).Send(items[idx]); + + return absl::OkStatus(); + } +}; + +} // namespace api2 +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_CORE_GET_VECTOR_ITEM_CALCULATOR_H_ diff --git a/mediapipe/calculators/core/split_vector_calculator.cc b/mediapipe/calculators/core/split_vector_calculator.cc index a80136be7..b76722de9 100644 --- a/mediapipe/calculators/core/split_vector_calculator.cc +++ b/mediapipe/calculators/core/split_vector_calculator.cc @@ -83,4 +83,7 @@ REGISTER_CALCULATOR(SplitClassificationListVectorCalculator); typedef SplitVectorCalculator SplitUint64tVectorCalculator; REGISTER_CALCULATOR(SplitUint64tVectorCalculator); +typedef SplitVectorCalculator SplitFloatVectorCalculator; +REGISTER_CALCULATOR(SplitFloatVectorCalculator); + } // namespace mediapipe diff --git a/mediapipe/calculators/core/vector_size_calculator.cc b/mediapipe/calculators/core/vector_size_calculator.cc new file mode 100644 index 000000000..bcbe22741 --- /dev/null +++ b/mediapipe/calculators/core/vector_size_calculator.cc @@ -0,0 +1,32 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/core/vector_size_calculator.h" + +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe { +namespace api2 { + +using LandmarkListVectorSizeCalculator = + VectorSizeCalculator; +REGISTER_CALCULATOR(LandmarkListVectorSizeCalculator); + +using ClassificationListVectorSizeCalculator = + VectorSizeCalculator; +REGISTER_CALCULATOR(ClassificationListVectorSizeCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/core/vector_size_calculator.h b/mediapipe/calculators/core/vector_size_calculator.h new file mode 100644 index 000000000..06aa422ff --- /dev/null +++ b/mediapipe/calculators/core/vector_size_calculator.h @@ -0,0 +1,64 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_CORE_VECTOR_SIZE_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_CORE_VECTOR_SIZE_CALCULATOR_H_ + +#include + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { +namespace api2 { + +// A calcutlator to return vector size. +// +// Inputs: +// VECTOR - std::vector +// Vector which size to return. +// +// Outputs: +// SIZE - int +// Size of the input vector. +// +// Example config: +// node { +// calculator: "{SpecificType}VectorSizeCalculator" +// input_stream: "VECTOR:vector" +// output_stream: "SIZE:vector_size" +// } +// +template +class VectorSizeCalculator : public Node { + public: + static constexpr Input> kIn{"VECTOR"}; + static constexpr Output kOut{"SIZE"}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + + absl::Status Process(CalculatorContext* cc) final { + if (kIn(cc).IsEmpty()) { + return absl::OkStatus(); + } + kOut(cc).Send(kIn(cc).Get().size()); + return absl::OkStatus(); + } +}; + +} // namespace api2 +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_CORE_VECTOR_SIZE_CALCULATOR_H_ diff --git a/mediapipe/calculators/image/scale_image_calculator.cc b/mediapipe/calculators/image/scale_image_calculator.cc index f6596b3fd..2870e0022 100644 --- a/mediapipe/calculators/image/scale_image_calculator.cc +++ b/mediapipe/calculators/image/scale_image_calculator.cc @@ -421,6 +421,10 @@ absl::Status ScaleImageCalculator::InitializeFromOptions() { alignment_boundary_ = options_.alignment_boundary(); } + if (options_.has_output_format()) { + output_format_ = options_.output_format(); + } + downscaler_.reset(new ImageResizer(options_.post_sharpening_coefficient())); return absl::OkStatus(); @@ -433,13 +437,17 @@ absl::Status ScaleImageCalculator::ValidateImageFormats() const { << "The output image format was set to UNKNOWN."; // TODO Remove these conditions. RET_CHECK(output_format_ == ImageFormat::SRGB || + output_format_ == ImageFormat::SRGBA || (input_format_ == output_format_ && output_format_ == ImageFormat::YCBCR420P)) << "Outputting YCbCr420P images from SRGB input is not yet supported"; RET_CHECK(input_format_ == output_format_ || - input_format_ == ImageFormat::YCBCR420P) + (input_format_ == ImageFormat::YCBCR420P && + output_format_ == ImageFormat::SRGB) || + (input_format_ == ImageFormat::SRGB && + output_format_ == ImageFormat::SRGBA)) << "Conversion of the color space (except from " - "YCbCr420P to SRGB) is not yet supported."; + "YCbCr420P to SRGB or SRGB to SRBGA) is not yet supported."; return absl::OkStatus(); } @@ -604,6 +612,15 @@ absl::Status ScaleImageCalculator::Process(CalculatorContext* cc) { .Add(output_image.release(), cc->InputTimestamp()); return absl::OkStatus(); } + } else if (input_format_ == ImageFormat::SRGB && + output_format_ == ImageFormat::SRGBA) { + image_frame = &cc->Inputs().Get(input_data_id_).Get(); + cv::Mat input_mat = ::mediapipe::formats::MatView(image_frame); + converted_image_frame.Reset(ImageFormat::SRGBA, image_frame->Width(), + image_frame->Height(), alignment_boundary_); + cv::Mat output_mat = ::mediapipe::formats::MatView(&converted_image_frame); + cv::cvtColor(input_mat, output_mat, cv::COLOR_RGB2RGBA, 4); + image_frame = &converted_image_frame; } else { image_frame = &cc->Inputs().Get(input_data_id_).Get(); MP_RETURN_IF_ERROR(ValidateImageFrame(cc, *image_frame)); diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index d41fa2c63..586fb0dd3 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -28,7 +28,9 @@ package(default_visibility = ["//visibility:private"]) exports_files( glob(["testdata/image_to_tensor/*"]), - visibility = ["//mediapipe/calculators/image:__subpackages__"], + visibility = [ + "//mediapipe/calculators/image:__subpackages__", + ], ) selects.config_setting_group( @@ -64,15 +66,16 @@ cc_library( ":inference_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:packet", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", "//mediapipe/framework/tool:subgraph_expansion", - "//mediapipe/util/tflite:config", "//mediapipe/util/tflite:tflite_model_loader", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ], alwayslink = 1, @@ -91,6 +94,7 @@ cc_library( "//mediapipe/util/tflite:tflite_gpu_runner", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@org_tensorflow//tensorflow/lite:framework_stable", "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", ], @@ -142,6 +146,8 @@ cc_library( ":inference_calculator_interface", "@com_google_absl//absl/memory", "@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", + "@org_tensorflow//tensorflow/lite:framework_stable", + "@org_tensorflow//tensorflow/lite/c:c_api_types", ] + select({ "//conditions:default": [ "//mediapipe/util:cpu_util", diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc index 9900610e5..cd854beee 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc @@ -142,22 +142,35 @@ class ImageToTensorCalculator : public Node { cc->Options(); RET_CHECK(options.has_output_tensor_float_range() || - options.has_output_tensor_int_range()) + options.has_output_tensor_int_range() || + options.has_output_tensor_uint_range()) << "Output tensor range is required."; if (options.has_output_tensor_float_range()) { RET_CHECK_LT(options.output_tensor_float_range().min(), options.output_tensor_float_range().max()) << "Valid output float tensor range is required."; } + if (options.has_output_tensor_uint_range()) { + RET_CHECK_LT(options.output_tensor_uint_range().min(), + options.output_tensor_uint_range().max()) + << "Valid output uint tensor range is required."; + RET_CHECK_GE(options.output_tensor_uint_range().min(), 0) + << "The minimum of the output uint tensor range must be " + "non-negative."; + RET_CHECK_LE(options.output_tensor_uint_range().max(), 255) + << "The maximum of the output uint tensor range must be less than or " + "equal to 255."; + } if (options.has_output_tensor_int_range()) { RET_CHECK_LT(options.output_tensor_int_range().min(), options.output_tensor_int_range().max()) << "Valid output int tensor range is required."; - RET_CHECK_GE(options.output_tensor_int_range().min(), 0) - << "The minimum of the output int tensor range must be non-negative."; - RET_CHECK_LE(options.output_tensor_int_range().max(), 255) + RET_CHECK_GE(options.output_tensor_int_range().min(), -128) + << "The minimum of the output int tensor range must be greater than " + "or equal to -128."; + RET_CHECK_LE(options.output_tensor_int_range().max(), 127) << "The maximum of the output int tensor range must be less than or " - "equal to 255."; + "equal to 127."; } RET_CHECK_GT(options.output_tensor_width(), 0) << "Valid output tensor width is required."; @@ -187,15 +200,19 @@ class ImageToTensorCalculator : public Node { options_ = cc->Options(); output_width_ = options_.output_tensor_width(); output_height_ = options_.output_tensor_height(); - is_int_output_ = options_.has_output_tensor_int_range(); - range_min_ = - is_int_output_ - ? static_cast(options_.output_tensor_int_range().min()) - : options_.output_tensor_float_range().min(); - range_max_ = - is_int_output_ - ? static_cast(options_.output_tensor_int_range().max()) - : options_.output_tensor_float_range().max(); + is_float_output_ = options_.has_output_tensor_float_range(); + if (options_.has_output_tensor_uint_range()) { + range_min_ = + static_cast(options_.output_tensor_uint_range().min()); + range_max_ = + static_cast(options_.output_tensor_uint_range().max()); + } else if (options_.has_output_tensor_int_range()) { + range_min_ = static_cast(options_.output_tensor_int_range().min()); + range_max_ = static_cast(options_.output_tensor_int_range().max()); + } else { + range_min_ = options_.output_tensor_float_range().min(); + range_max_ = options_.output_tensor_float_range().max(); + } return absl::OkStatus(); } @@ -275,6 +292,17 @@ class ImageToTensorCalculator : public Node { } } + Tensor::ElementType GetOutputTensorType() { + if (is_float_output_) { + return Tensor::ElementType::kFloat32; + } + if (range_min_ < 0) { + return Tensor::ElementType::kInt8; + } else { + return Tensor::ElementType::kUInt8; + } + } + absl::StatusOr> GetInputImage( CalculatorContext* cc) { if (kIn(cc).IsConnected()) { @@ -305,7 +333,7 @@ class ImageToTensorCalculator : public Node { const Image& image) { // Lazy initialization of the GPU or CPU converter. if (image.UsesGpu()) { - if (is_int_output_) { + if (!is_float_output_) { return absl::UnimplementedError( "ImageToTensorConverter for the input GPU image currently doesn't " "support quantization."); @@ -337,11 +365,9 @@ class ImageToTensorCalculator : public Node { } else { if (!cpu_converter_) { #if !MEDIAPIPE_DISABLE_OPENCV - ASSIGN_OR_RETURN(cpu_converter_, - CreateOpenCvConverter( - cc, GetBorderMode(), - is_int_output_ ? Tensor::ElementType::kUInt8 - : Tensor::ElementType::kFloat32)); + ASSIGN_OR_RETURN( + cpu_converter_, + CreateOpenCvConverter(cc, GetBorderMode(), GetOutputTensorType())); #else LOG(FATAL) << "Cannot create image to tensor opencv converter since " "MEDIAPIPE_DISABLE_OPENCV is defined."; @@ -356,7 +382,7 @@ class ImageToTensorCalculator : public Node { mediapipe::ImageToTensorCalculatorOptions options_; int output_width_ = 0; int output_height_ = 0; - bool is_int_output_ = false; + bool is_float_output_ = false; float range_min_ = 0.0f; float range_max_ = 1.0f; }; diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator.proto b/mediapipe/calculators/tensor/image_to_tensor_calculator.proto index bf8ba160d..780ee8021 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator.proto +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator.proto @@ -39,6 +39,14 @@ message ImageToTensorCalculatorOptions { optional int64 max = 2; } + // Range of uint values [min, max]. + // min, must be strictly less than max. + // Please note that UIntRange is supported for CPU tensors only. + message UIntRange { + optional uint64 min = 1; + optional uint64 max = 2; + } + // Pixel extrapolation methods. See @border_mode. enum BorderMode { BORDER_UNSPECIFIED = 0; @@ -58,6 +66,7 @@ message ImageToTensorCalculatorOptions { oneof range { FloatRange output_tensor_float_range = 4; IntRange output_tensor_int_range = 7; + UIntRange output_tensor_uint_range = 8; } // For CONVENTIONAL mode for OpenGL, input image starts at bottom and needs diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc index 4e35e3be6..07a5f9fe1 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc @@ -76,12 +76,21 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet, } std::string output_tensor_range; if (output_int_tensor) { - output_tensor_range = absl::Substitute(R"(output_tensor_int_range { + if (range_min < 0) { + output_tensor_range = absl::Substitute(R"(output_tensor_int_range { min: $0 max: $1 })", - static_cast(range_min), - static_cast(range_max)); + static_cast(range_min), + static_cast(range_max)); + } else { + output_tensor_range = absl::Substitute(R"(output_tensor_uint_range { + min: $0 + max: $1 + })", + static_cast(range_min), + static_cast(range_max)); + } } else { output_tensor_range = absl::Substitute(R"(output_tensor_float_range { min: $0 @@ -141,9 +150,15 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet, auto view = tensor.GetCpuReadView(); cv::Mat tensor_mat; if (output_int_tensor) { - EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kUInt8); - tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8UC3, - const_cast(view.buffer())); + if (range_min < 0) { + EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kInt8); + tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8SC3, + const_cast(view.buffer())); + } else { + EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kUInt8); + tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8UC3, + const_cast(view.buffer())); + } } else { EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kFloat32); tensor_mat = cv::Mat(tensor_height, tensor_width, CV_32FC3, @@ -190,25 +205,28 @@ const std::vector kInputTypesToTest = {InputType::kImageFrame, InputType::kImage}; void RunTest(cv::Mat input, cv::Mat expected_result, - std::vector float_range, std::vector int_range, - int tensor_width, int tensor_height, bool keep_aspect, + std::vector> float_ranges, + std::vector> int_ranges, int tensor_width, + int tensor_height, bool keep_aspect, absl::optional border_mode, const mediapipe::NormalizedRect& roi) { - ASSERT_EQ(2, float_range.size()); - ASSERT_EQ(2, int_range.size()); for (auto input_type : kInputTypesToTest) { - RunTestWithInputImagePacket( - input_type == InputType::kImageFrame ? MakeImageFramePacket(input) - : MakeImagePacket(input), - expected_result, float_range[0], float_range[1], tensor_width, - tensor_height, keep_aspect, border_mode, roi, - /*output_int_tensor=*/false); - RunTestWithInputImagePacket( - input_type == InputType::kImageFrame ? MakeImageFramePacket(input) - : MakeImagePacket(input), - expected_result, int_range[0], int_range[1], tensor_width, - tensor_height, keep_aspect, border_mode, roi, - /*output_int_tensor=*/true); + for (auto float_range : float_ranges) { + RunTestWithInputImagePacket( + input_type == InputType::kImageFrame ? MakeImageFramePacket(input) + : MakeImagePacket(input), + expected_result, float_range.first, float_range.second, tensor_width, + tensor_height, keep_aspect, border_mode, roi, + /*output_int_tensor=*/false); + } + for (auto int_range : int_ranges) { + RunTestWithInputImagePacket( + input_type == InputType::kImageFrame ? MakeImageFramePacket(input) + : MakeImagePacket(input), + expected_result, int_range.first, int_range.second, tensor_width, + tensor_height, keep_aspect, border_mode, roi, + /*output_int_tensor=*/true); + } } } @@ -224,8 +242,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspect) { "tensor/testdata/image_to_tensor/input.jpg"), GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect.png"), - /*float_range=*/{0.0f, 1.0f}, - /*int_range=*/{0, 255}, + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, /*border mode*/ {}, roi); } @@ -242,8 +260,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectBorderZero) { GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/" "medium_sub_rect_keep_aspect_border_zero.png"), - /*float_range=*/{0.0f, 1.0f}, - /*int_range=*/{0, 255}, + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, BorderMode::kZero, roi); } @@ -260,8 +278,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectWithRotation) { GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/" "medium_sub_rect_keep_aspect_with_rotation.png"), - /*float_range=*/{0.0f, 1.0f}, - /*int_range=*/{0, 255}, + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, BorderMode::kReplicate, roi); } @@ -279,8 +297,8 @@ TEST(ImageToTensorCalculatorTest, GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/" "medium_sub_rect_keep_aspect_with_rotation_border_zero.png"), - /*float_range=*/{0.0f, 1.0f}, - /*int_range=*/{0, 255}, + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, BorderMode::kZero, roi); } @@ -298,8 +316,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotation) { GetRgb( "/mediapipe/calculators/" "tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation.png"), - /*float_range=*/{-1.0f, 1.0f}, - /*int_range=*/{0, 255}, + /*float_ranges=*/{{-1.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, BorderMode::kReplicate, roi); } @@ -316,8 +334,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotationBorderZero) { GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/" "medium_sub_rect_with_rotation_border_zero.png"), - /*float_range=*/{-1.0f, 1.0f}, - /*int_range=*/{0, 255}, + /*float_ranges=*/{{-1.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, BorderMode::kZero, roi); } @@ -333,8 +351,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRect) { "tensor/testdata/image_to_tensor/input.jpg"), GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/large_sub_rect.png"), - /*float_range=*/{0.0f, 1.0f}, - /*int_range=*/{0, 255}, + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, BorderMode::kReplicate, roi); } @@ -351,8 +369,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectBorderZero) { "tensor/testdata/image_to_tensor/input.jpg"), GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/large_sub_rect_border_zero.png"), - /*float_range=*/{0.0f, 1.0f}, - /*int_range=*/{0, 255}, + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, BorderMode::kZero, roi); } @@ -369,8 +387,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspect) { "tensor/testdata/image_to_tensor/input.jpg"), GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect.png"), - /*float_range=*/{0.0f, 1.0f}, - /*int_range=*/{0, 255}, + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, BorderMode::kReplicate, roi); } @@ -387,8 +405,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectBorderZero) { GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/" "large_sub_rect_keep_aspect_border_zero.png"), - /*float_range=*/{0.0f, 1.0f}, - /*int_range=*/{0, 255}, + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, BorderMode::kZero, roi); } @@ -405,8 +423,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotation) { GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/" "large_sub_rect_keep_aspect_with_rotation.png"), - /*float_range=*/{0.0f, 1.0f}, - /*int_range=*/{0, 255}, + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, /*border_mode=*/{}, roi); } @@ -424,8 +442,8 @@ TEST(ImageToTensorCalculatorTest, GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/" "large_sub_rect_keep_aspect_with_rotation_border_zero.png"), - /*float_range=*/{0.0f, 1.0f}, - /*int_range=*/{0, 255}, + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, /*border_mode=*/BorderMode::kZero, roi); } @@ -441,8 +459,8 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRange) { "tensor/testdata/image_to_tensor/input.jpg"), GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/noop_except_range.png"), - /*float_range=*/{0.0f, 1.0f}, - /*int_range=*/{0, 255}, + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true, BorderMode::kReplicate, roi); } @@ -458,8 +476,8 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRangeBorderZero) { "tensor/testdata/image_to_tensor/input.jpg"), GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/noop_except_range.png"), - /*float_range=*/{0.0f, 1.0f}, - /*int_range=*/{0, 255}, + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true, BorderMode::kZero, roi); } diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc index d01916f3e..ddc7ff85e 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc @@ -268,10 +268,12 @@ class GlProcessor : public ImageToTensorConverter { const RotatedRect& roi, const Size& output_dims, float range_min, float range_max) override { - if (input.format() != mediapipe::GpuBufferFormat::kBGRA32) { - return InvalidArgumentError( - absl::StrCat("Only BGRA/RGBA textures are supported, passed format: ", - static_cast(input.format()))); + if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 && + input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 && + input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) { + return InvalidArgumentError(absl::StrCat( + "Only 4-channel texture input formats are supported, passed format: ", + static_cast(input.format()))); } constexpr int kNumChannels = 3; diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc index e81621b76..6f035e67b 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc @@ -172,10 +172,12 @@ class GlProcessor : public ImageToTensorConverter { const RotatedRect& roi, const Size& output_dims, float range_min, float range_max) override { - if (input.format() != mediapipe::GpuBufferFormat::kBGRA32) { - return InvalidArgumentError( - absl::StrCat("Only BGRA/RGBA textures are supported, passed format: ", - static_cast(input.format()))); + if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 && + input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 && + input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) { + return InvalidArgumentError(absl::StrCat( + "Only 4-channel texture input formats are supported, passed format: ", + static_cast(input.format()))); } constexpr int kNumChannels = 3; diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc index 9714faa51..cfabae333 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc @@ -352,11 +352,12 @@ class MetalProcessor : public ImageToTensorConverter { const RotatedRect& roi, const Size& output_dims, float range_min, float range_max) override { - if (input.format() != mediapipe::GpuBufferFormat::kBGRA32) { - return InvalidArgumentError( - absl::StrCat("Only BGRA/RGBA textures are supported, passed " - "format: ", - static_cast(input.format()))); + if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 && + input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 && + input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) { + return InvalidArgumentError(absl::StrCat( + "Only 4-channel texture input formats are supported, passed format: ", + static_cast(input.format()))); } @autoreleasepool { diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc index 45e027439..6d36e5878 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc @@ -45,7 +45,19 @@ class OpenCvProcessor : public ImageToTensorConverter { border_mode_ = cv::BORDER_CONSTANT; break; } - mat_type_ = tensor_type == Tensor::ElementType::kUInt8 ? CV_8UC3 : CV_32FC3; + switch (tensor_type_) { + case Tensor::ElementType::kInt8: + mat_type_ = CV_8SC3; + break; + case Tensor::ElementType::kFloat32: + mat_type_ = CV_32FC3; + break; + case Tensor::ElementType::kUInt8: + mat_type_ = CV_8UC3; + break; + default: + mat_type_ = -1; + } } absl::StatusOr Convert(const mediapipe::Image& input, @@ -65,12 +77,22 @@ class OpenCvProcessor : public ImageToTensorConverter { output_dims.width, kNumChannels}); auto buffer_view = tensor.GetCpuWriteView(); cv::Mat dst; - if (tensor_type_ == Tensor::ElementType::kUInt8) { - dst = cv::Mat(output_dims.height, output_dims.width, mat_type_, - buffer_view.buffer()); - } else { - dst = cv::Mat(output_dims.height, output_dims.width, mat_type_, - buffer_view.buffer()); + switch (tensor_type_) { + case Tensor::ElementType::kInt8: + dst = cv::Mat(output_dims.height, output_dims.width, mat_type_, + buffer_view.buffer()); + break; + case Tensor::ElementType::kFloat32: + dst = cv::Mat(output_dims.height, output_dims.width, mat_type_, + buffer_view.buffer()); + break; + case Tensor::ElementType::kUInt8: + dst = cv::Mat(output_dims.height, output_dims.width, mat_type_, + buffer_view.buffer()); + break; + default: + return InvalidArgumentError( + absl::StrCat("Unsupported tensor type: ", tensor_type_)); } const cv::RotatedRect rotated_rect(cv::Point2f(roi.center_x, roi.center_y), @@ -124,6 +146,13 @@ class OpenCvProcessor : public ImageToTensorConverter { absl::StatusOr> CreateOpenCvConverter( CalculatorContext* cc, BorderMode border_mode, Tensor::ElementType tensor_type) { + if (tensor_type != Tensor::ElementType::kInt8 && + tensor_type != Tensor::ElementType::kFloat32 && + tensor_type != Tensor::ElementType::kUInt8) { + return absl::InvalidArgumentError(absl::StrCat( + "Tensor type is currently not supported by OpenCvProcessor, type: ", + tensor_type)); + } return absl::make_unique(border_mode, tensor_type); } diff --git a/mediapipe/calculators/tensor/inference_calculator.cc b/mediapipe/calculators/tensor/inference_calculator.cc index 0311612ff..c143c9901 100644 --- a/mediapipe/calculators/tensor/inference_calculator.cc +++ b/mediapipe/calculators/tensor/inference_calculator.cc @@ -21,7 +21,9 @@ #include "absl/memory/memory.h" #include "absl/strings/string_view.h" +#include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/tool/subgraph_expansion.h" +#include "tensorflow/lite/core/api/op_resolver.h" namespace mediapipe { namespace api2 { @@ -67,5 +69,17 @@ absl::StatusOr> InferenceCalculator::GetModelAsPacket( "Must specify TFLite model as path or loaded model."); } +absl::StatusOr> +InferenceCalculator::GetOpResolverAsPacket(CalculatorContext* cc) { + if (kSideInOpResolver(cc).IsConnected()) { + return kSideInOpResolver(cc).As(); + } else if (kSideInCustomOpResolver(cc).IsConnected()) { + return kSideInCustomOpResolver(cc).As(); + } + return PacketAdopting( + std::make_unique< + tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>()); +} + } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_calculator.h b/mediapipe/calculators/tensor/inference_calculator.h index 1c54bc46e..b5f3a0a15 100644 --- a/mediapipe/calculators/tensor/inference_calculator.h +++ b/mediapipe/calculators/tensor/inference_calculator.h @@ -27,6 +27,7 @@ #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/util/tflite/tflite_model_loader.h" +#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/error_reporter.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" @@ -55,8 +56,11 @@ namespace api2 { // TENSORS - Vector of Tensors // // Input side packet: +// DEPRECATED: Prefer to use the "OP_RESOLVER" input side packet instead. // CUSTOM_OP_RESOLVER (optional) - Use a custom op resolver, // instead of the builtin one. +// OP_RESOLVER (optional) - Use to provide tflite op resolver +// (tflite::OpResolver) // MODEL (optional) - Use to specify TfLite model // (std::unique_ptr>) @@ -95,15 +99,21 @@ namespace api2 { class InferenceCalculator : public NodeIntf { public: static constexpr Input> kInTensors{"TENSORS"}; + // Deprecated. Prefers to use "OP_RESOLVER" input side packet instead. + // TODO: Removes the "CUSTOM_OP_RESOLVER" side input after the + // migration. static constexpr SideInput::Optional kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"}; + static constexpr SideInput::Optional kSideInOpResolver{ + "OP_RESOLVER"}; static constexpr SideInput::Optional kSideInModel{"MODEL"}; static constexpr Output> kOutTensors{"TENSORS"}; static constexpr SideInput< mediapipe::InferenceCalculatorOptions::Delegate>::Optional kDelegate{ "DELEGATE"}; - MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver, kSideInModel, - kOutTensors, kDelegate); + MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver, + kSideInOpResolver, kSideInModel, kOutTensors, + kDelegate); protected: using TfLiteDelegatePtr = @@ -111,6 +121,9 @@ class InferenceCalculator : public NodeIntf { absl::StatusOr> GetModelAsPacket( CalculatorContext* cc); + + absl::StatusOr> GetOpResolverAsPacket( + CalculatorContext* cc); }; struct InferenceCalculatorSelector : public InferenceCalculator { diff --git a/mediapipe/calculators/tensor/inference_calculator.proto b/mediapipe/calculators/tensor/inference_calculator.proto index 04f8d141d..46552803b 100644 --- a/mediapipe/calculators/tensor/inference_calculator.proto +++ b/mediapipe/calculators/tensor/inference_calculator.proto @@ -116,6 +116,9 @@ message InferenceCalculatorOptions { // to ensure there is no clash of the tokens. If unspecified, NNAPI will // not try caching the compilation. optional string model_token = 2; + // The name of an accelerator to be used for NNAPI delegate, e.g. + // "google-edgetpu". When not specified, it will be selected by NNAPI. + optional string accelerator_name = 3; } message Xnnpack { // Number of threads for XNNPACK delegate. (By default, calculator tries diff --git a/mediapipe/calculators/tensor/inference_calculator_cpu.cc b/mediapipe/calculators/tensor/inference_calculator_cpu.cc index d4f2224c5..8ac8ce31f 100644 --- a/mediapipe/calculators/tensor/inference_calculator_cpu.cc +++ b/mediapipe/calculators/tensor/inference_calculator_cpu.cc @@ -19,7 +19,7 @@ #include "absl/memory/memory.h" #include "mediapipe/calculators/tensor/inference_calculator.h" - +#include "tensorflow/lite/interpreter_builder.h" #if defined(MEDIAPIPE_ANDROID) #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" #endif // ANDROID @@ -28,6 +28,7 @@ #include "mediapipe/util/cpu_util.h" #endif // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__ +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" namespace mediapipe { @@ -61,6 +62,17 @@ int GetXnnpackNumThreads( return GetXnnpackDefaultNumThreads(); } +template +void CopyTensorBuffer(const Tensor& input_tensor, + tflite::Interpreter* interpreter, + int input_tensor_index) { + auto input_tensor_view = input_tensor.GetCpuReadView(); + auto input_tensor_buffer = input_tensor_view.buffer(); + T* local_tensor_buffer = + interpreter->typed_input_tensor(input_tensor_index); + std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes()); +} + } // namespace class InferenceCalculatorCpuImpl @@ -73,15 +85,16 @@ class InferenceCalculatorCpuImpl absl::Status Close(CalculatorContext* cc) override; private: - absl::Status LoadModel(CalculatorContext* cc); - absl::Status LoadDelegate(CalculatorContext* cc); - absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc); + absl::Status InitInterpreter(CalculatorContext* cc); + absl::Status LoadDelegate(CalculatorContext* cc, + tflite::InterpreterBuilder* interpreter_builder); + absl::Status AllocateTensors(); // TfLite requires us to keep the model alive as long as the interpreter is. Packet model_packet_; std::unique_ptr interpreter_; TfLiteDelegatePtr delegate_; - bool has_quantized_input_; + TfLiteType input_tensor_type_ = TfLiteType::kTfLiteNoType; }; absl::Status InferenceCalculatorCpuImpl::UpdateContract( @@ -94,8 +107,7 @@ absl::Status InferenceCalculatorCpuImpl::UpdateContract( } absl::Status InferenceCalculatorCpuImpl::Open(CalculatorContext* cc) { - MP_RETURN_IF_ERROR(LoadModel(cc)); - return LoadDelegateAndAllocateTensors(cc); + return InitInterpreter(cc); } absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) { @@ -108,19 +120,23 @@ absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) { // Read CPU input into tensors. for (int i = 0; i < input_tensors.size(); ++i) { - const Tensor* input_tensor = &input_tensors[i]; - auto input_tensor_view = input_tensor->GetCpuReadView(); - if (has_quantized_input_) { - // TODO: Support more quantized tensor types. - auto input_tensor_buffer = input_tensor_view.buffer(); - uint8* local_tensor_buffer = interpreter_->typed_input_tensor(i); - std::memcpy(local_tensor_buffer, input_tensor_buffer, - input_tensor->bytes()); - } else { - auto input_tensor_buffer = input_tensor_view.buffer(); - float* local_tensor_buffer = interpreter_->typed_input_tensor(i); - std::memcpy(local_tensor_buffer, input_tensor_buffer, - input_tensor->bytes()); + switch (input_tensor_type_) { + case TfLiteType::kTfLiteFloat16: + case TfLiteType::kTfLiteFloat32: { + CopyTensorBuffer(input_tensors[i], interpreter_.get(), i); + break; + } + case TfLiteType::kTfLiteUInt8: { + CopyTensorBuffer(input_tensors[i], interpreter_.get(), i); + break; + } + case TfLiteType::kTfLiteInt8: { + CopyTensorBuffer(input_tensors[i], interpreter_.get(), i); + break; + } + default: + return absl::InvalidArgumentError( + absl::StrCat("Unsupported input tensor type:", input_tensor_type_)); } } @@ -150,39 +166,34 @@ absl::Status InferenceCalculatorCpuImpl::Close(CalculatorContext* cc) { return absl::OkStatus(); } -absl::Status InferenceCalculatorCpuImpl::LoadModel(CalculatorContext* cc) { +absl::Status InferenceCalculatorCpuImpl::InitInterpreter( + CalculatorContext* cc) { ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); const auto& model = *model_packet_.Get(); - tflite::ops::builtin::BuiltinOpResolver op_resolver = - kSideInCustomOpResolver(cc).GetOr( - tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates()); - - tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); - RET_CHECK(interpreter_); - + ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc)); + const auto& op_resolver = op_resolver_packet.Get(); + tflite::InterpreterBuilder interpreter_builder(model, op_resolver); + MP_RETURN_IF_ERROR(LoadDelegate(cc, &interpreter_builder)); #if defined(__EMSCRIPTEN__) - interpreter_->SetNumThreads(1); + interpreter_builder.SetNumThreads(1); #else - interpreter_->SetNumThreads( + interpreter_builder.SetNumThreads( cc->Options().cpu_num_thread()); #endif // __EMSCRIPTEN__ - return absl::OkStatus(); + RET_CHECK_EQ(interpreter_builder(&interpreter_), kTfLiteOk); + RET_CHECK(interpreter_); + return AllocateTensors(); } -absl::Status InferenceCalculatorCpuImpl::LoadDelegateAndAllocateTensors( - CalculatorContext* cc) { - MP_RETURN_IF_ERROR(LoadDelegate(cc)); - - // AllocateTensors() can be called only after ModifyGraphWithDelegate. +absl::Status InferenceCalculatorCpuImpl::AllocateTensors() { RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); - has_quantized_input_ = - interpreter_->tensor(interpreter_->inputs()[0])->quantization.type == - kTfLiteAffineQuantization; + input_tensor_type_ = interpreter_->tensor(interpreter_->inputs()[0])->type; return absl::OkStatus(); } -absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) { +absl::Status InferenceCalculatorCpuImpl::LoadDelegate( + CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) { const auto& calculator_opts = cc->Options(); auto opts_delegate = calculator_opts.delegate(); @@ -211,18 +222,20 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) { if (nnapi_requested) { // Attempt to use NNAPI. // If not supported, the default CPU delegate will be created and used. - interpreter_->SetAllowFp16PrecisionForFp32(1); tflite::StatefulNnApiDelegate::Options options; const auto& nnapi = opts_delegate.nnapi(); + options.allow_fp16 = true; // Set up cache_dir and model_token for NNAPI compilation cache. options.cache_dir = nnapi.has_cache_dir() ? nnapi.cache_dir().c_str() : nullptr; options.model_token = nnapi.has_model_token() ? nnapi.model_token().c_str() : nullptr; + options.accelerator_name = nnapi.has_accelerator_name() + ? nnapi.accelerator_name().c_str() + : nullptr; delegate_ = TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options), [](TfLiteDelegate*) {}); - RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), - kTfLiteOk); + interpreter_builder->AddDelegate(delegate_.get()); return absl::OkStatus(); } #endif // MEDIAPIPE_ANDROID @@ -239,8 +252,7 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) { GetXnnpackNumThreads(opts_has_delegate, opts_delegate); delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts), &TfLiteXNNPackDelegateDelete); - RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), - kTfLiteOk); + interpreter_builder->AddDelegate(delegate_.get()); } return absl::OkStatus(); diff --git a/mediapipe/calculators/tensor/inference_calculator_gl.cc b/mediapipe/calculators/tensor/inference_calculator_gl.cc index 8b998d665..dfdf7382c 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl.cc @@ -22,6 +22,7 @@ #include "mediapipe/calculators/tensor/inference_calculator.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/util/tflite/config.h" +#include "tensorflow/lite/interpreter_builder.h" #if MEDIAPIPE_TFLITE_GL_INFERENCE #include "mediapipe/gpu/gl_calculator_helper.h" @@ -52,9 +53,11 @@ class InferenceCalculatorGlImpl private: absl::Status ReadGpuCaches(); absl::Status SaveGpuCaches(); - absl::Status LoadModel(CalculatorContext* cc); - absl::Status LoadDelegate(CalculatorContext* cc); - absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc); + absl::Status InitInterpreter(CalculatorContext* cc); + absl::Status LoadDelegate(CalculatorContext* cc, + tflite::InterpreterBuilder* interpreter_builder); + absl::Status BindBuffersToTensors(); + absl::Status AllocateTensors(); absl::Status InitTFLiteGPURunner(CalculatorContext* cc); // TfLite requires us to keep the model alive as long as the interpreter is. @@ -137,17 +140,11 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) { #endif // MEDIAPIPE_ANDROID } - // When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner - // for everything. - if (!use_advanced_gpu_api_) { - MP_RETURN_IF_ERROR(LoadModel(cc)); - } - MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR( gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) - : LoadDelegateAndAllocateTensors(cc); + : InitInterpreter(cc); })); return absl::OkStatus(); } @@ -292,12 +289,6 @@ absl::Status InferenceCalculatorGlImpl::ReadGpuCaches() { absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner( CalculatorContext* cc) { - ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); - const auto& model = *model_packet_.Get(); - tflite::ops::builtin::BuiltinOpResolver op_resolver = - kSideInCustomOpResolver(cc).GetOr( - tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates()); - // Create runner tflite::gpu::InferenceOptions options; options.priority1 = allow_precision_loss_ @@ -335,6 +326,10 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner( break; } } + ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); + const auto& model = *model_packet_.Get(); + ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc)); + const auto& op_resolver = op_resolver_packet.Get(); MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel( model, op_resolver, /*allow_quant_ops=*/true)); @@ -355,31 +350,27 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner( return absl::OkStatus(); } -absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) { +absl::Status InferenceCalculatorGlImpl::InitInterpreter(CalculatorContext* cc) { ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); const auto& model = *model_packet_.Get(); - tflite::ops::builtin::BuiltinOpResolver op_resolver = - kSideInCustomOpResolver(cc).GetOr( - tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates()); - - tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); - RET_CHECK(interpreter_); - + ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc)); + const auto& op_resolver = op_resolver_packet.Get(); + tflite::InterpreterBuilder interpreter_builder(model, op_resolver); + MP_RETURN_IF_ERROR(LoadDelegate(cc, &interpreter_builder)); #if defined(__EMSCRIPTEN__) - interpreter_->SetNumThreads(1); + interpreter_builder.SetNumThreads(1); #else - interpreter_->SetNumThreads( + interpreter_builder.SetNumThreads( cc->Options().cpu_num_thread()); #endif // __EMSCRIPTEN__ - + RET_CHECK_EQ(interpreter_builder(&interpreter_), kTfLiteOk); + RET_CHECK(interpreter_); + MP_RETURN_IF_ERROR(BindBuffersToTensors()); + MP_RETURN_IF_ERROR(AllocateTensors()); return absl::OkStatus(); } -absl::Status InferenceCalculatorGlImpl::LoadDelegateAndAllocateTensors( - CalculatorContext* cc) { - MP_RETURN_IF_ERROR(LoadDelegate(cc)); - - // AllocateTensors() can be called only after ModifyGraphWithDelegate. +absl::Status InferenceCalculatorGlImpl::AllocateTensors() { RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); // TODO: Support quantized tensors. RET_CHECK_NE( @@ -388,7 +379,8 @@ absl::Status InferenceCalculatorGlImpl::LoadDelegateAndAllocateTensors( return absl::OkStatus(); } -absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) { +absl::Status InferenceCalculatorGlImpl::LoadDelegate( + CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) { // Configure and create the delegate. TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault(); options.compile_options.precision_loss_allowed = @@ -399,7 +391,11 @@ absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) { options.compile_options.inline_parameters = 1; delegate_ = TfLiteDelegatePtr(TfLiteGpuDelegateCreate(&options), &TfLiteGpuDelegateDelete); + interpreter_builder->AddDelegate(delegate_.get()); + return absl::OkStatus(); +} +absl::Status InferenceCalculatorGlImpl::BindBuffersToTensors() { // Get input image sizes. const auto& input_indices = interpreter_->inputs(); for (int i = 0; i < input_indices.size(); ++i) { @@ -431,11 +427,6 @@ absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) { output_indices[i]), kTfLiteOk); } - - // Must call this last. - RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), - kTfLiteOk); - return absl::OkStatus(); } diff --git a/mediapipe/calculators/tensor/inference_calculator_metal.cc b/mediapipe/calculators/tensor/inference_calculator_metal.cc index 49e042290..ae0a5e38d 100644 --- a/mediapipe/calculators/tensor/inference_calculator_metal.cc +++ b/mediapipe/calculators/tensor/inference_calculator_metal.cc @@ -90,9 +90,10 @@ class InferenceCalculatorMetalImpl absl::Status Close(CalculatorContext* cc) override; private: - absl::Status LoadModel(CalculatorContext* cc); - absl::Status LoadDelegate(CalculatorContext* cc); - absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc); + absl::Status InitInterpreter(CalculatorContext* cc); + void AddDelegate(CalculatorContext* cc, + tflite::InterpreterBuilder* interpreter_builder); + absl::Status CreateConverters(CalculatorContext* cc); // TfLite requires us to keep the model alive as long as the interpreter is. Packet model_packet_; @@ -127,11 +128,9 @@ absl::Status InferenceCalculatorMetalImpl::Open(CalculatorContext* cc) { const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>(); allow_precision_loss_ = options.delegate().gpu().allow_precision_loss(); - MP_RETURN_IF_ERROR(LoadModel(cc)); - gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; RET_CHECK(gpu_helper_); - return LoadDelegateAndAllocateTensors(cc); + return InitInterpreter(cc); } absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) { @@ -199,27 +198,20 @@ absl::Status InferenceCalculatorMetalImpl::Close(CalculatorContext* cc) { return absl::OkStatus(); } -absl::Status InferenceCalculatorMetalImpl::LoadModel(CalculatorContext* cc) { +absl::Status InferenceCalculatorMetalImpl::InitInterpreter( + CalculatorContext* cc) { ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); const auto& model = *model_packet_.Get(); - tflite::ops::builtin::BuiltinOpResolver op_resolver = - kSideInCustomOpResolver(cc).GetOr( - tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates()); - - tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); + ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc)); + const auto& op_resolver = op_resolver_packet.Get(); + tflite::InterpreterBuilder interpreter_builder(model, op_resolver); + AddDelegate(cc, &interpreter_builder); + interpreter_builder.SetNumThreads( + cc->Options().cpu_num_thread()); + RET_CHECK_EQ(interpreter_builder(&interpreter_), kTfLiteOk); RET_CHECK(interpreter_); - interpreter_->SetNumThreads( - cc->Options().cpu_num_thread()); - - return absl::OkStatus(); -} - -absl::Status InferenceCalculatorMetalImpl::LoadDelegateAndAllocateTensors( - CalculatorContext* cc) { - MP_RETURN_IF_ERROR(LoadDelegate(cc)); - - // AllocateTensors() can be called only after ModifyGraphWithDelegate. + MP_RETURN_IF_ERROR(CreateConverters(cc)); RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); // TODO: Support quantized tensors. RET_CHECK_NE( @@ -228,7 +220,8 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegateAndAllocateTensors( return absl::OkStatus(); } -absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) { +void InferenceCalculatorMetalImpl::AddDelegate( + CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) { const auto& calculator_opts = cc->Options(); @@ -242,9 +235,11 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) { options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeDoNotWait; delegate_ = TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete); - RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), - kTfLiteOk); + interpreter_builder->AddDelegate(delegate_.get()); +} +absl::Status InferenceCalculatorMetalImpl::CreateConverters( + CalculatorContext* cc) { id device = gpu_helper_.mtlDevice; // Get input image sizes. diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc index c4e941f12..b1babaffb 100644 --- a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc @@ -91,6 +91,40 @@ void ConvertAnchorsToRawValues(const std::vector& anchors, } } +absl::Status CheckCustomTensorMapping( + const TensorsToDetectionsCalculatorOptions::TensorMapping& tensor_mapping) { + RET_CHECK(tensor_mapping.has_detections_tensor_index() && + tensor_mapping.has_scores_tensor_index()); + int bitmap = 0; + bitmap |= 1 << tensor_mapping.detections_tensor_index(); + bitmap |= 1 << tensor_mapping.scores_tensor_index(); + if (!tensor_mapping.has_num_detections_tensor_index() && + !tensor_mapping.has_classes_tensor_index() && + !tensor_mapping.has_anchors_tensor_index()) { + // Only allows the output tensor index 0 and 1 to be occupied. + RET_CHECK_EQ(3, bitmap) << "The custom output tensor indices should only " + "cover index 0 and 1."; + } else if (tensor_mapping.has_anchors_tensor_index()) { + RET_CHECK(!tensor_mapping.has_classes_tensor_index() && + !tensor_mapping.has_num_detections_tensor_index()); + bitmap |= 1 << tensor_mapping.anchors_tensor_index(); + // If the"anchors" tensor will be available, only allows the output tensor + // index 0, 1, 2 to be occupied. + RET_CHECK_EQ(7, bitmap) << "The custom output tensor indices should only " + "cover index 0, 1 and 2."; + } else { + RET_CHECK(tensor_mapping.has_classes_tensor_index() && + tensor_mapping.has_num_detections_tensor_index()); + // If the "classes" and the "number of detections" tensors will be + // available, only allows the output tensor index 0, 1, 2, 3 to be occupied. + bitmap |= 1 << tensor_mapping.classes_tensor_index(); + bitmap |= 1 << tensor_mapping.num_detections_tensor_index(); + RET_CHECK_EQ(15, bitmap) << "The custom output tensor indices should only " + "cover index 0, 1, 2 and 3."; + } + return absl::OkStatus(); +} + } // namespace // Convert result Tensors from object detection models into MediaPipe @@ -170,13 +204,27 @@ class TensorsToDetectionsCalculator : public Node { Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score, int class_id, bool flip_vertically); + bool IsClassIndexAllowed(int class_index); int num_classes_ = 0; int num_boxes_ = 0; int num_coords_ = 0; - std::set ignore_classes_; + int max_results_ = -1; - ::mediapipe::TensorsToDetectionsCalculatorOptions options_; + // Set of allowed or ignored class indices. + struct ClassIndexSet { + absl::flat_hash_set values; + bool is_allowlist; + }; + // Allowed or ignored class indices based on provided options or side packet. + // These are used to filter out the output detection results. + ClassIndexSet class_index_set_; + + TensorsToDetectionsCalculatorOptions options_; + bool scores_tensor_index_is_set_ = false; + TensorsToDetectionsCalculatorOptions::TensorMapping tensor_mapping_; + std::vector box_indices_ = {0, 1, 2, 3}; + bool has_custom_box_indices_ = false; std::vector anchors_; #ifndef MEDIAPIPE_DISABLE_GL_COMPUTE @@ -239,6 +287,21 @@ absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) { } } } + const int num_input_tensors = kInTensors(cc)->size(); + if (!scores_tensor_index_is_set_) { + if (num_input_tensors == 2 || + num_input_tensors == kNumInputTensorsWithAnchors) { + tensor_mapping_.set_scores_tensor_index(1); + } else { + tensor_mapping_.set_scores_tensor_index(2); + } + scores_tensor_index_is_set_ = true; + } + if (gpu_processing || num_input_tensors != 4) { + // Allows custom bounding box indices when receiving 4 cpu tensors. + // Uses the default bbox indices in other cases. + RET_CHECK(!has_custom_box_indices_); + } if (gpu_processing) { if (!gpu_inited_) { @@ -263,13 +326,15 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU( // Postprocessing on CPU for model without postprocessing op. E.g. output // raw score tensor and box tensor. Anchor decoding will be handled below. // TODO: Add flexible input tensor size handling. - auto raw_box_tensor = &input_tensors[0]; + auto raw_box_tensor = + &input_tensors[tensor_mapping_.detections_tensor_index()]; RET_CHECK_EQ(raw_box_tensor->shape().dims.size(), 3); RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1); RET_CHECK_GT(num_boxes_, 0) << "Please set num_boxes in calculator options"; RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_); RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_); - auto raw_score_tensor = &input_tensors[1]; + auto raw_score_tensor = + &input_tensors[tensor_mapping_.scores_tensor_index()]; RET_CHECK_EQ(raw_score_tensor->shape().dims.size(), 3); RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1); RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_); @@ -282,7 +347,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU( // TODO: Support other options to load anchors. if (!anchors_init_) { if (input_tensors.size() == kNumInputTensorsWithAnchors) { - auto anchor_tensor = &input_tensors[2]; + auto anchor_tensor = + &input_tensors[tensor_mapping_.anchors_tensor_index()]; RET_CHECK_EQ(anchor_tensor->shape().dims.size(), 2); RET_CHECK_EQ(anchor_tensor->shape().dims[0], num_boxes_); RET_CHECK_EQ(anchor_tensor->shape().dims[1], kNumCoordsPerBox); @@ -308,7 +374,7 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU( float max_score = -std::numeric_limits::max(); // Find the top score for box i. for (int score_idx = 0; score_idx < num_classes_; ++score_idx) { - if (ignore_classes_.find(score_idx) == ignore_classes_.end()) { + if (IsClassIndexAllowed(score_idx)) { auto score = raw_scores[i * num_classes_ + score_idx]; if (options_.sigmoid_score()) { if (options_.has_score_clipping_thresh()) { @@ -338,23 +404,26 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU( // Postprocessing on CPU with postprocessing op (e.g. anchor decoding and // non-maximum suppression) within the model. RET_CHECK_EQ(input_tensors.size(), 4); - - auto num_boxes_tensor = &input_tensors[3]; + auto num_boxes_tensor = + &input_tensors[tensor_mapping_.num_detections_tensor_index()]; RET_CHECK_EQ(num_boxes_tensor->shape().dims.size(), 1); RET_CHECK_EQ(num_boxes_tensor->shape().dims[0], 1); - auto detection_boxes_tensor = &input_tensors[0]; + auto detection_boxes_tensor = + &input_tensors[tensor_mapping_.detections_tensor_index()]; RET_CHECK_EQ(detection_boxes_tensor->shape().dims.size(), 3); RET_CHECK_EQ(detection_boxes_tensor->shape().dims[0], 1); const int max_detections = detection_boxes_tensor->shape().dims[1]; RET_CHECK_EQ(detection_boxes_tensor->shape().dims[2], num_coords_); - auto detection_classes_tensor = &input_tensors[1]; + auto detection_classes_tensor = + &input_tensors[tensor_mapping_.classes_tensor_index()]; RET_CHECK_EQ(detection_classes_tensor->shape().dims.size(), 2); RET_CHECK_EQ(detection_classes_tensor->shape().dims[0], 1); RET_CHECK_EQ(detection_classes_tensor->shape().dims[1], max_detections); - auto detection_scores_tensor = &input_tensors[2]; + auto detection_scores_tensor = + &input_tensors[tensor_mapping_.scores_tensor_index()]; RET_CHECK_EQ(detection_scores_tensor->shape().dims.size(), 2); RET_CHECK_EQ(detection_scores_tensor->shape().dims[0], 1); RET_CHECK_EQ(detection_scores_tensor->shape().dims[1], max_detections); @@ -394,12 +463,14 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU( -> absl::Status { if (!anchors_init_) { if (input_tensors.size() == kNumInputTensorsWithAnchors) { - auto read_view = input_tensors[2].GetOpenGlBufferReadView(); + auto read_view = input_tensors[tensor_mapping_.anchors_tensor_index()] + .GetOpenGlBufferReadView(); glBindBuffer(GL_COPY_READ_BUFFER, read_view.name()); auto write_view = raw_anchors_buffer_->GetOpenGlBufferWriteView(); glBindBuffer(GL_COPY_WRITE_BUFFER, write_view.name()); - glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0, - input_tensors[2].bytes()); + glCopyBufferSubData( + GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0, + input_tensors[tensor_mapping_.anchors_tensor_index()].bytes()); } else if (!kInAnchors(cc).IsEmpty()) { const auto& anchors = *kInAnchors(cc); auto anchors_view = raw_anchors_buffer_->GetCpuWriteView(); @@ -418,7 +489,9 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU( auto decoded_boxes_view = decoded_boxes_buffer_->GetOpenGlBufferWriteView(); glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, decoded_boxes_view.name()); - auto input0_view = input_tensors[0].GetOpenGlBufferReadView(); + auto input0_view = + input_tensors[tensor_mapping_.detections_tensor_index()] + .GetOpenGlBufferReadView(); glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 1, input0_view.name()); auto raw_anchors_view = raw_anchors_buffer_->GetOpenGlBufferReadView(); glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 2, raw_anchors_view.name()); @@ -427,7 +500,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU( // Score boxes. glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, scored_boxes_view.name()); - auto input1_view = input_tensors[1].GetOpenGlBufferReadView(); + auto input1_view = input_tensors[tensor_mapping_.scores_tensor_index()] + .GetOpenGlBufferReadView(); glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 1, input1_view.name()); glUseProgram(score_program_); glDispatchCompute(num_boxes_, 1, 1); @@ -459,7 +533,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU( if (input_tensors.size() == kNumInputTensorsWithAnchors) { RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors); auto command_buffer = [gpu_helper_ commandBuffer]; - auto src_buffer = input_tensors[2].GetMtlBufferReadView(command_buffer); + auto src_buffer = input_tensors[tensor_mapping_.anchors_tensor_index()] + .GetMtlBufferReadView(command_buffer); auto dest_buffer = raw_anchors_buffer_->GetMtlBufferWriteView(command_buffer); id blit_command = @@ -468,7 +543,9 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU( sourceOffset:0 toBuffer:dest_buffer.buffer() destinationOffset:0 - size:input_tensors[2].bytes()]; + size:input_tensors[tensor_mapping_ + .anchors_tensor_index()] + .bytes()]; [blit_command endEncoding]; [command_buffer commit]; } else if (!kInAnchors(cc).IsEmpty()) { @@ -495,7 +572,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU( auto decoded_boxes_view = decoded_boxes_buffer_->GetMtlBufferWriteView(command_buffer); [command_encoder setBuffer:decoded_boxes_view.buffer() offset:0 atIndex:0]; - auto input0_view = input_tensors[0].GetMtlBufferReadView(command_buffer); + auto input0_view = input_tensors[tensor_mapping_.detections_tensor_index()] + .GetMtlBufferReadView(command_buffer); [command_encoder setBuffer:input0_view.buffer() offset:0 atIndex:1]; auto raw_anchors_view = raw_anchors_buffer_->GetMtlBufferReadView(command_buffer); @@ -507,7 +585,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU( [command_encoder setComputePipelineState:score_program_]; [command_encoder setBuffer:scored_boxes_view.buffer() offset:0 atIndex:0]; - auto input1_view = input_tensors[1].GetMtlBufferReadView(command_buffer); + auto input1_view = input_tensors[tensor_mapping_.scores_tensor_index()] + .GetMtlBufferReadView(command_buffer); [command_encoder setBuffer:input1_view.buffer() offset:0 atIndex:1]; MTLSize score_threads_per_group = MTLSizeMake(1, num_classes_, 1); MTLSize score_threadgroups = MTLSizeMake(num_boxes_, 1, 1); @@ -570,6 +649,10 @@ absl::Status TensorsToDetectionsCalculator::LoadOptions(CalculatorContext* cc) { num_classes_ = options_.num_classes(); num_boxes_ = options_.num_boxes(); num_coords_ = options_.num_coords(); + CHECK_NE(options_.max_results(), 0) + << "The maximum number of the top-scored detection results must be " + "non-zero."; + max_results_ = options_.max_results(); // Currently only support 2D when num_values_per_keypoint equals to 2. CHECK_EQ(options_.num_values_per_keypoint(), 2); @@ -581,15 +664,55 @@ absl::Status TensorsToDetectionsCalculator::LoadOptions(CalculatorContext* cc) { if (kSideInIgnoreClasses(cc).IsConnected()) { RET_CHECK(!kSideInIgnoreClasses(cc).IsEmpty()); + RET_CHECK(options_.allow_classes().empty()); + class_index_set_.is_allowlist = false; for (int ignore_class : *kSideInIgnoreClasses(cc)) { - ignore_classes_.insert(ignore_class); + class_index_set_.values.insert(ignore_class); + } + } else if (!options_.allow_classes().empty()) { + RET_CHECK(options_.ignore_classes().empty()); + class_index_set_.is_allowlist = true; + for (int i = 0; i < options_.allow_classes_size(); ++i) { + class_index_set_.values.insert(options_.allow_classes(i)); } } else { + class_index_set_.is_allowlist = false; for (int i = 0; i < options_.ignore_classes_size(); ++i) { - ignore_classes_.insert(options_.ignore_classes(i)); + class_index_set_.values.insert(options_.ignore_classes(i)); } } + if (options_.has_tensor_mapping()) { + RET_CHECK_OK(CheckCustomTensorMapping(options_.tensor_mapping())); + tensor_mapping_ = options_.tensor_mapping(); + scores_tensor_index_is_set_ = true; + } else { + // Assigns the default tensor indices. + tensor_mapping_.set_detections_tensor_index(0); + tensor_mapping_.set_classes_tensor_index(1); + tensor_mapping_.set_anchors_tensor_index(2); + tensor_mapping_.set_num_detections_tensor_index(3); + // The scores tensor index needs to be determined based on the number of + // model's output tensors, which will be available in the first invocation + // of the Process() method. + tensor_mapping_.set_scores_tensor_index(-1); + scores_tensor_index_is_set_ = false; + } + + if (options_.has_box_boundaries_indices()) { + box_indices_ = {options_.box_boundaries_indices().ymin(), + options_.box_boundaries_indices().xmin(), + options_.box_boundaries_indices().ymax(), + options_.box_boundaries_indices().xmax()}; + int bitmap = 0; + for (int i : box_indices_) { + bitmap |= 1 << i; + } + RET_CHECK_EQ(bitmap, 15) << "The custom box boundaries indices should only " + "cover index 0, 1, 2, and 3."; + has_custom_box_indices_ = true; + } + return absl::OkStatus(); } @@ -661,14 +784,22 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections( const float* detection_boxes, const float* detection_scores, const int* detection_classes, std::vector* output_detections) { for (int i = 0; i < num_boxes_; ++i) { + if (max_results_ > 0 && output_detections->size() == max_results_) { + break; + } if (options_.has_min_score_thresh() && detection_scores[i] < options_.min_score_thresh()) { continue; } + if (!IsClassIndexAllowed(detection_classes[i])) { + continue; + } const int box_offset = i * num_coords_; Detection detection = ConvertToDetection( - detection_boxes[box_offset + 0], detection_boxes[box_offset + 1], - detection_boxes[box_offset + 2], detection_boxes[box_offset + 3], + /*box_ymin=*/detection_boxes[box_offset + box_indices_[0]], + /*box_xmin=*/detection_boxes[box_offset + box_indices_[1]], + /*box_ymax=*/detection_boxes[box_offset + box_indices_[2]], + /*box_xmax=*/detection_boxes[box_offset + box_indices_[3]], detection_scores[i], detection_classes[i], options_.flip_vertically()); const auto& bbox = detection.location_data().relative_bounding_box(); if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) || @@ -910,7 +1041,7 @@ void main() { options_.has_score_clipping_thresh() ? 1 : 0, options_.has_score_clipping_thresh() ? options_.score_clipping_thresh() : 0, - !ignore_classes_.empty() ? 1 : 0); + !IsClassIndexAllowed(0)); // # filter classes supported is hardware dependent. int max_wg_size; // typically <= 1024 @@ -919,7 +1050,14 @@ void main() { CHECK_LT(num_classes_, max_wg_size) << "# classes must be < " << max_wg_size; // TODO support better filtering. - CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed"; + if (class_index_set_.is_allowlist) { + CHECK_EQ(class_index_set_.values.size(), + IsClassIndexAllowed(0) ? num_classes_ : num_classes_ - 1) + << "Only all classes >= class 0 or >= class 1"; + } else { + CHECK_EQ(class_index_set_.values.size(), IsClassIndexAllowed(0) ? 0 : 1) + << "Only ignore class 0 is allowed"; + } // Shader program { @@ -1126,10 +1264,17 @@ kernel void scoreKernel( options_.has_score_clipping_thresh() ? 1 : 0, options_.has_score_clipping_thresh() ? options_.score_clipping_thresh() : 0, - ignore_classes_.size() ? 1 : 0); + !IsClassIndexAllowed(0)); // TODO support better filtering. - CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed"; + if (class_index_set_.is_allowlist) { + CHECK_EQ(class_index_set_.values.size(), + IsClassIndexAllowed(0) ? num_classes_ : num_classes_ - 1) + << "Only all classes >= class 0 or >= class 1"; + } else { + CHECK_EQ(class_index_set_.values.size(), IsClassIndexAllowed(0) ? 0 : 1) + << "Only ignore class 0 is allowed"; + } { // Shader program @@ -1161,5 +1306,16 @@ kernel void scoreKernel( return absl::OkStatus(); } +bool TensorsToDetectionsCalculator::IsClassIndexAllowed(int class_index) { + if (class_index_set_.values.empty()) { + return true; + } + if (class_index_set_.is_allowlist) { + return class_index_set_.values.contains(class_index); + } else { + return !class_index_set_.values.contains(class_index); + } +} + } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.proto b/mediapipe/calculators/tensor/tensors_to_detections_calculator.proto index 364eb5cce..c9d6b69da 100644 --- a/mediapipe/calculators/tensor/tensors_to_detections_calculator.proto +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.proto @@ -57,7 +57,12 @@ message TensorsToDetectionsCalculatorOptions { optional bool reverse_output_order = 14 [default = false]; // The ids of classes that should be ignored during decoding the score for // each predicted box. Can be overridden with IGNORE_CLASSES side packet. + // `ignore_classes` and `allow_classes` are mutually exclusive. repeated int32 ignore_classes = 8; + // The ids of classes that should be allowed during decoding the score for + // each predicted box. `ignore_classes` and `allow_classes` are mutually + // exclusive. + repeated int32 allow_classes = 21 [packed = true]; optional bool sigmoid_score = 15 [default = false]; optional float score_clipping_thresh = 16; @@ -71,4 +76,40 @@ message TensorsToDetectionsCalculatorOptions { // Score threshold for perserving decoded detections. optional float min_score_thresh = 19; + + // The maximum number of the detection results to return. If < 0, all + // available results will be returned. + // For the detection models that have built-in non max suppression op, the + // output detections are the top-scored results. Otherwise, the output + // detections are the first N results that have higher scores than + // `min_score_thresh`. + optional int32 max_results = 20 [default = -1]; + + // The custom model output tensor mapping. + // The indices of the "detections" tensor and the "scores" tensor are always + // required. If the model outputs an "anchors" tensor, `anchors_tensor_index` + // must be specified. If the model outputs both "classes" tensor and "number + // of detections" tensors, `classes_tensor_index` and + // `num_detections_tensor_index` must be set. + message TensorMapping { + optional int32 detections_tensor_index = 1; + optional int32 classes_tensor_index = 2; + optional int32 scores_tensor_index = 3; + optional int32 num_detections_tensor_index = 4; + optional int32 anchors_tensor_index = 5; + } + optional TensorMapping tensor_mapping = 22; + + // Represents the bounding box by using the combination of boundaries, + // {ymin, xmin, ymax, xmax}. + // The default order is {ymin, xmin, ymax, xmax}. + message BoxBoundariesIndices { + optional int32 ymin = 1 [default = 0]; + optional int32 xmin = 2 [default = 1]; + optional int32 ymax = 3 [default = 2]; + optional int32 xmax = 4 [default = 3]; + } + oneof box_indices { + BoxBoundariesIndices box_boundaries_indices = 23; + } } diff --git a/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc index 622e76850..200a273f6 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc @@ -121,8 +121,12 @@ absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) { if (d > 255) d = 255; buffer[i] = d; } - output = ::absl::make_unique(format, width, height, - width * depth, buffer.release()); + output = ::absl::make_unique( + format, width, height, width * depth, buffer.release(), + [total_size](uint8* ptr) { + ::operator delete[](ptr, total_size, + std::align_val_t(EIGEN_MAX_ALIGN_BYTES)); + }); } else if (input_tensor.dtype() == tensorflow::DT_UINT8) { if (scale_factor_ != 1.0) { return absl::InvalidArgumentError("scale_factor_ given for uint8 tensor"); diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 55616bb83..b132db01d 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -121,10 +121,11 @@ cc_library( deps = [ ":tflite_custom_op_resolver_calculator_cc_proto", "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:packet", "//mediapipe/framework/port:ret_check", - "//mediapipe/framework/port:status", "//mediapipe/util/tflite:cpu_op_resolver", "//mediapipe/util/tflite:op_resolver", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", ], alwayslink = 1, ) diff --git a/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc b/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc index 11e27dff1..950d742a9 100644 --- a/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc @@ -12,14 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.pb.h" +#include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/util/tflite/cpu_op_resolver.h" #include "mediapipe/util/tflite/op_resolver.h" +#include "tensorflow/lite/core/api/op_resolver.h" namespace mediapipe { +namespace { +constexpr char kOpResolverTag[] = "OP_RESOLVER"; +} // namespace + // This calculator creates a custom op resolver as a side packet that can be // used in TfLiteInferenceCalculator. Current custom op resolver supports the // following custom op on CPU and GPU: @@ -27,7 +35,9 @@ namespace mediapipe { // MaxPoolArgmax // MaxUnpooling // -// Usage example: +// Usage examples: +// +// For using with TfliteInferenceCalculator: // node { // calculator: "TfLiteCustomOpResolverCalculator" // output_side_packet: "op_resolver" @@ -37,12 +47,27 @@ namespace mediapipe { // } // } // } +// +// For using with InferenceCalculator: +// node { +// calculator: "TfLiteCustomOpResolverCalculator" +// output_side_packet: "OP_RESOLVER:op_resolver" +// node_options: { +// [type.googleapis.com/mediapipe.TfLiteCustomOpResolverCalculatorOptions] { +// use_gpu: true +// } +// } +// } class TfLiteCustomOpResolverCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { - cc->OutputSidePackets() - .Index(0) - .Set(); + if (cc->OutputSidePackets().HasTag(kOpResolverTag)) { + cc->OutputSidePackets().Tag(kOpResolverTag).Set(); + } else { + cc->OutputSidePackets() + .Index(0) + .Set(); + } return absl::OkStatus(); } @@ -59,7 +84,14 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase { op_resolver = absl::make_unique(); } - cc->OutputSidePackets().Index(0).Set(Adopt(op_resolver.release())); + if (cc->OutputSidePackets().HasTag(kOpResolverTag)) { + cc->OutputSidePackets() + .Tag(kOpResolverTag) + .Set(mediapipe::api2::PacketAdopting( + std::move(op_resolver))); + } else { + cc->OutputSidePackets().Index(0).Set(Adopt(op_resolver.release())); + } return absl::OkStatus(); } diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 0c1a3ce3c..d00fd09ff 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -54,6 +54,7 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", + "//mediapipe/util:label_map_proto", ], ) @@ -304,6 +305,7 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", "//mediapipe/util:resource_util", + "//mediapipe/util:label_map_cc_proto", ] + select({ "//mediapipe:android": [ "//mediapipe/util/android/file/base", @@ -350,6 +352,40 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "detection_transformation_calculator", + srcs = ["detection_transformation_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + +cc_test( + name = "detection_transformation_calculator_test", + size = "small", + srcs = ["detection_transformation_calculator_test.cc"], + deps = [ + ":detection_transformation_calculator", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:packet", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + ], +) + cc_library( name = "non_max_suppression_calculator", srcs = ["non_max_suppression_calculator.cc"], diff --git a/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc b/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc index 07fe791f6..20d1c1cbd 100644 --- a/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc +++ b/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "absl/container/node_hash_map.h" #include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/port/status.h" +#include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/resource_util.h" #if defined(MEDIAPIPE_MOBILE) @@ -53,8 +53,11 @@ class DetectionLabelIdToTextCalculator : public CalculatorBase { absl::Status Process(CalculatorContext* cc) override; private: - absl::node_hash_map label_map_; - ::mediapipe::DetectionLabelIdToTextCalculatorOptions options_; + // Local label map built from the calculator options' `label_map_path` or + // `label` field. + LabelMap local_label_map_; + bool keep_label_id_; + const LabelMap& GetLabelMap(CalculatorContext* cc); }; REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator); @@ -69,13 +72,16 @@ absl::Status DetectionLabelIdToTextCalculator::GetContract( absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); - options_ = + const auto& options = cc->Options<::mediapipe::DetectionLabelIdToTextCalculatorOptions>(); - if (options_.has_label_map_path()) { + if (options.has_label_map_path()) { + RET_CHECK(!options.has_label_map() && options.label().empty()) + << "Only can set one of the following fields in the CalculatorOptions: " + "label_map_path, label, and label_map."; std::string string_path; ASSIGN_OR_RETURN(string_path, - PathToResourceAsFile(options_.label_map_path())); + PathToResourceAsFile(options.label_map_path())); std::string label_map_string; MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string)); @@ -83,13 +89,21 @@ absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) { std::string line; int i = 0; while (std::getline(stream, line)) { - label_map_[i++] = line; + LabelMapItem item; + item.set_name(line); + (*local_label_map_.mutable_index_to_item())[i++] = item; } - } else { - for (int i = 0; i < options_.label_size(); ++i) { - label_map_[i] = options_.label(i); + } else if (!options.label().empty()) { + RET_CHECK(!options.has_label_map()) + << "Only can set one of the following fields in the CalculatorOptions: " + "label_map_path, label, and label_map."; + for (int i = 0; i < options.label_size(); ++i) { + LabelMapItem item; + item.set_name(options.label(i)); + (*local_label_map_.mutable_index_to_item())[i] = item; } } + keep_label_id_ = options.keep_label_id(); return absl::OkStatus(); } @@ -101,13 +115,18 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) { Detection& output_detection = output_detections.back(); bool has_text_label = false; for (const int32 label_id : output_detection.label_id()) { - if (label_map_.find(label_id) != label_map_.end()) { - output_detection.add_label(label_map_[label_id]); + if (GetLabelMap(cc).index_to_item().find(label_id) != + GetLabelMap(cc).index_to_item().end()) { + auto item = GetLabelMap(cc).index_to_item().at(label_id); + output_detection.add_label(item.name()); + if (item.has_display_name()) { + output_detection.add_display_name(item.display_name()); + } has_text_label = true; } } // Remove label_id field if text labels exist. - if (has_text_label && !options_.keep_label_id()) { + if (has_text_label && !keep_label_id_) { output_detection.clear_label_id(); } } @@ -117,4 +136,13 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) { return absl::OkStatus(); } +const LabelMap& DetectionLabelIdToTextCalculator::GetLabelMap( + CalculatorContext* cc) { + return !local_label_map_.index_to_item().empty() + ? local_label_map_ + : cc->Options< + ::mediapipe::DetectionLabelIdToTextCalculatorOptions>() + .label_map(); +} + } // namespace mediapipe diff --git a/mediapipe/calculators/util/detection_label_id_to_text_calculator.proto b/mediapipe/calculators/util/detection_label_id_to_text_calculator.proto index 198ca4d65..bb1cf6098 100644 --- a/mediapipe/calculators/util/detection_label_id_to_text_calculator.proto +++ b/mediapipe/calculators/util/detection_label_id_to_text_calculator.proto @@ -17,6 +17,7 @@ syntax = "proto2"; package mediapipe; import "mediapipe/framework/calculator.proto"; +import "mediapipe/util/label_map.proto"; message DetectionLabelIdToTextCalculatorOptions { extend mediapipe.CalculatorOptions { @@ -26,7 +27,7 @@ message DetectionLabelIdToTextCalculatorOptions { // Path to a label map file for getting the actual name of detected classes. optional string label_map_path = 1; - // Alternative way to specify label map + // Alternative way to specify label map. // label: "label for id 0" // label: "label for id 1" // ... @@ -36,4 +37,7 @@ message DetectionLabelIdToTextCalculatorOptions { // could be found. By setting this field to true, it is always copied to the // output detections. optional bool keep_label_id = 3; + + // Label map. + optional LabelMap label_map = 4; } diff --git a/mediapipe/calculators/util/detection_transformation_calculator.cc b/mediapipe/calculators/util/detection_transformation_calculator.cc new file mode 100644 index 000000000..9a9db8487 --- /dev/null +++ b/mediapipe/calculators/util/detection_transformation_calculator.cc @@ -0,0 +1,298 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" + +namespace mediapipe { +namespace api2 { +namespace { + +template +T BoundedValue(T value, T upper_bound) { + T output = std::min(value, upper_bound); + if (output < 0) { + return 0; + } + return output; +} + +absl::Status ConvertRelativeBoundingBoxToBoundingBox( + const std::pair& image_size, Detection* detection) { + const int image_width = image_size.first; + const int image_height = image_size.second; + const auto& relative_bbox = + detection->location_data().relative_bounding_box(); + auto* bbox = detection->mutable_location_data()->mutable_bounding_box(); + bbox->set_xmin( + BoundedValue(relative_bbox.xmin() * image_width, image_width)); + bbox->set_ymin( + BoundedValue(relative_bbox.ymin() * image_height, image_height)); + bbox->set_width( + BoundedValue(relative_bbox.width() * image_width, image_width)); + bbox->set_height( + BoundedValue(relative_bbox.height() * image_height, image_height)); + detection->mutable_location_data()->set_format(LocationData::BOUNDING_BOX); + detection->mutable_location_data()->clear_relative_bounding_box(); + return absl::OkStatus(); +} + +absl::Status ConvertBoundingBoxToRelativeBoundingBox( + const std::pair& image_size, Detection* detection) { + int image_width = image_size.first; + int image_height = image_size.second; + const auto& bbox = detection->location_data().bounding_box(); + auto* relative_bbox = + detection->mutable_location_data()->mutable_relative_bounding_box(); + relative_bbox->set_xmin( + BoundedValue((float)bbox.xmin() / image_width, 1.0f)); + relative_bbox->set_ymin( + BoundedValue((float)bbox.ymin() / image_height, 1.0f)); + relative_bbox->set_width( + BoundedValue((float)bbox.width() / image_width, 1.0f)); + relative_bbox->set_height( + BoundedValue((float)bbox.height() / image_height, 1.0f)); + detection->mutable_location_data()->clear_bounding_box(); + detection->mutable_location_data()->set_format( + LocationData::RELATIVE_BOUNDING_BOX); + return absl::OkStatus(); +} + +absl::StatusOr GetLocationDataFormat( + const Detection& detection) { + if (!detection.has_location_data()) { + return absl::InvalidArgumentError("Detection must have location data."); + } + LocationData::Format format = detection.location_data().format(); + RET_CHECK(format == LocationData::RELATIVE_BOUNDING_BOX || + format == LocationData::BOUNDING_BOX) + << "Detection's location data format must be either " + "RELATIVE_BOUNDING_BOX or BOUNDING_BOX"; + return format; +} + +absl::StatusOr GetLocationDataFormat( + std::vector& detections) { + RET_CHECK(!detections.empty()); + LocationData::Format output_format; + ASSIGN_OR_RETURN(output_format, GetLocationDataFormat(detections[0])); + for (int i = 1; i < detections.size(); ++i) { + ASSIGN_OR_RETURN(LocationData::Format format, + GetLocationDataFormat(detections[i])); + if (output_format != format) { + return absl::InvalidArgumentError( + "Input detections have different location data formats."); + } + } + return output_format; +} + +absl::Status ConvertBoundingBox(const std::pair& image_size, + Detection* detection) { + if (!detection->has_location_data()) { + return absl::InvalidArgumentError("Detection must have location data."); + } + switch (detection->location_data().format()) { + case LocationData::RELATIVE_BOUNDING_BOX: + return ConvertRelativeBoundingBoxToBoundingBox(image_size, detection); + case LocationData::BOUNDING_BOX: + return ConvertBoundingBoxToRelativeBoundingBox(image_size, detection); + default: + return absl::InvalidArgumentError( + "Detection's location data format must be either " + "RELATIVE_BOUNDING_BOX or BOUNDING_BOX."); + } +} + +} // namespace + +// Transforms relative bounding box(es) to pixel bounding box(es) in a detection +// proto/detection list/detection vector, or vice versa. +// +// Inputs: +// One of the following: +// DETECTION: A Detection proto. +// DETECTIONS: An std::vector/ a DetectionList proto. +// IMAGE_SIZE: A std::pair represention image width and height. +// +// Outputs: +// At least one of the following: +// PIXEL_DETECTION: A Detection proto with pixel bounding box. +// PIXEL_DETECTIONS: An std::vector with pixel bounding boxes. +// PIXEL_DETECTION_LIST: A DetectionList proto with pixel bounding boxes. +// RELATIVE_DETECTION: A Detection proto with relative bounding box. +// RELATIVE_DETECTIONS: An std::vector with relative bounding boxes. +// RELATIVE_DETECTION_LIST: A DetectionList proto with relative bounding boxes. +// +// Example config: +// For input detection(s) with relative bounding box(es): +// node { +// calculator: "DetectionTransformationCalculator" +// input_stream: "DETECTION:input_detection" +// input_stream: "IMAGE_SIZE:image_size" +// output_stream: "PIXEL_DETECTION:output_detection" +// output_stream: "PIXEL_DETECTIONS:output_detections" +// output_stream: "PIXEL_DETECTION_LIST:output_detection_list" +// } +// +// For input detection(s) with pixel bounding box(es): +// node { +// calculator: "DetectionTransformationCalculator" +// input_stream: "DETECTION:input_detection" +// input_stream: "IMAGE_SIZE:image_size" +// output_stream: "RELATIVE_DETECTION:output_detection" +// output_stream: "RELATIVE_DETECTIONS:output_detections" +// output_stream: "RELATIVE_DETECTION_LIST:output_detection_list" +// } +class DetectionTransformationCalculator : public Node { + public: + static constexpr Input::Optional kInDetection{"DETECTION"}; + static constexpr Input>>::Optional + kInDetections{"DETECTIONS"}; + static constexpr Input> kInImageSize{"IMAGE_SIZE"}; + static constexpr Output::Optional kOutPixelDetection{ + "PIXEL_DETECTION"}; + static constexpr Output>::Optional kOutPixelDetections{ + "PIXEL_DETECTIONS"}; + static constexpr Output::Optional kOutPixelDetectionList{ + "PIXEL_DETECTION_LIST"}; + static constexpr Output::Optional kOutRelativeDetection{ + "RELATIVE_DETECTION"}; + static constexpr Output>::Optional + kOutRelativeDetections{"RELATIVE_DETECTIONS"}; + static constexpr Output::Optional kOutRelativeDetectionList{ + "RELATIVE_DETECTION_LIST"}; + + MEDIAPIPE_NODE_CONTRACT(kInDetection, kInDetections, kInImageSize, + kOutPixelDetection, kOutPixelDetections, + kOutPixelDetectionList, kOutRelativeDetection, + kOutRelativeDetections, kOutRelativeDetectionList); + + static absl::Status UpdateContract(CalculatorContract* cc) { + RET_CHECK(kInImageSize(cc).IsConnected()) << "Image size must be provided."; + RET_CHECK(kInDetections(cc).IsConnected() ^ kInDetection(cc).IsConnected()); + if (kInDetections(cc).IsConnected()) { + RET_CHECK(kOutPixelDetections(cc).IsConnected() || + kOutPixelDetectionList(cc).IsConnected() || + kOutRelativeDetections(cc).IsConnected() || + kOutRelativeDetectionList(cc).IsConnected()) + << "Output must be a container of detections."; + } + RET_CHECK(kOutPixelDetections(cc).IsConnected() || + kOutPixelDetectionList(cc).IsConnected() || + kOutPixelDetection(cc).IsConnected() || + kOutRelativeDetections(cc).IsConnected() || + kOutRelativeDetectionList(cc).IsConnected() || + kOutRelativeDetection(cc).IsConnected()) + << "Must connect at least one output stream."; + return absl::OkStatus(); + } + + absl::Status Open(CalculatorContext* cc) override { + output_pixel_bounding_boxes_ = kOutPixelDetections(cc).IsConnected() || + kOutPixelDetectionList(cc).IsConnected() || + kOutPixelDetection(cc).IsConnected(); + output_relative_bounding_boxes_ = + kOutRelativeDetections(cc).IsConnected() || + kOutRelativeDetectionList(cc).IsConnected() || + kOutRelativeDetection(cc).IsConnected(); + RET_CHECK(output_pixel_bounding_boxes_ ^ output_relative_bounding_boxes_) + << "All output streams must have the same stream tag prefix, either " + "\"PIXEL\" or \"RELATIVE_\"."; + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + std::pair image_size = kInImageSize(cc).Get(); + std::vector transformed_detections; + LocationData::Format input_location_data_format; + if (kInDetections(cc).IsConnected()) { + transformed_detections = kInDetections(cc).Visit( + [&](const DetectionList& detection_list) { + return std::vector(detection_list.detection().begin(), + detection_list.detection().end()); + }, + [&](const std::vector& detection_vector) { + return detection_vector; + }); + ASSIGN_OR_RETURN(input_location_data_format, + GetLocationDataFormat(transformed_detections)); + for (Detection& detection : transformed_detections) { + MP_RETURN_IF_ERROR(ConvertBoundingBox(image_size, &detection)); + } + } else { + ASSIGN_OR_RETURN(input_location_data_format, + GetLocationDataFormat(kInDetection(cc).Get())); + Detection transformed_detection(kInDetection(cc).Get()); + MP_RETURN_IF_ERROR( + ConvertBoundingBox(image_size, &transformed_detection)); + transformed_detections.push_back(transformed_detection); + } + if (input_location_data_format == LocationData::RELATIVE_BOUNDING_BOX) { + RET_CHECK(!output_relative_bounding_boxes_) + << "Input detections are with relative bounding box(es), and the " + "output detections must have pixel bounding box(es)."; + if (kOutPixelDetection(cc).IsConnected()) { + kOutPixelDetection(cc).Send(transformed_detections[0]); + } + if (kOutPixelDetections(cc).IsConnected()) { + kOutPixelDetections(cc).Send(transformed_detections); + } + if (kOutPixelDetectionList(cc).IsConnected()) { + DetectionList detection_list; + for (const auto& detection : transformed_detections) { + detection_list.add_detection()->CopyFrom(detection); + } + kOutPixelDetectionList(cc).Send(detection_list); + } + } else { + RET_CHECK(!output_pixel_bounding_boxes_) + << "Input detections are with pixel bounding box(es), and the " + "output detections must have relative bounding box(es)."; + if (kOutRelativeDetection(cc).IsConnected()) { + kOutRelativeDetection(cc).Send(transformed_detections[0]); + } + if (kOutRelativeDetections(cc).IsConnected()) { + kOutRelativeDetections(cc).Send(transformed_detections); + } + if (kOutRelativeDetectionList(cc).IsConnected()) { + DetectionList detection_list; + for (const auto& detection : transformed_detections) { + detection_list.add_detection()->CopyFrom(detection); + } + kOutRelativeDetectionList(cc).Send(detection_list); + } + } + return absl::OkStatus(); + } + + private: + bool output_relative_bounding_boxes_; + bool output_pixel_bounding_boxes_; +}; + +MEDIAPIPE_REGISTER_NODE(DetectionTransformationCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/util/detection_transformation_calculator_test.cc b/mediapipe/calculators/util/detection_transformation_calculator_test.cc new file mode 100644 index 000000000..48291b2dc --- /dev/null +++ b/mediapipe/calculators/util/detection_transformation_calculator_test.cc @@ -0,0 +1,287 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +constexpr char kDetectionTag[] = "DETECTION"; +constexpr char kDetectionsTag[] = "DETECTIONS"; +constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kPixelDetectionTag[] = "PIXEL_DETECTION"; +constexpr char kPixelDetectionListTag[] = "PIXEL_DETECTION_LIST"; +constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS"; +constexpr char kRelativeDetectionListTag[] = "RELATIVE_DETECTION_LIST"; +constexpr char kRelativeDetectionsTag[] = "RELATIVE_DETECTIONS"; + +Detection DetectionWithBoundingBox(int32 xmin, int32 ymin, int32 width, + int32 height) { + Detection detection; + LocationData* location_data = detection.mutable_location_data(); + location_data->set_format(LocationData::BOUNDING_BOX); + location_data->mutable_bounding_box()->set_xmin(xmin); + location_data->mutable_bounding_box()->set_ymin(ymin); + location_data->mutable_bounding_box()->set_width(width); + location_data->mutable_bounding_box()->set_height(height); + return detection; +} + +Detection DetectionWithRelativeBoundingBox(float xmin, float ymin, float width, + float height) { + Detection detection; + LocationData* location_data = detection.mutable_location_data(); + location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX); + location_data->mutable_relative_bounding_box()->set_xmin(xmin); + location_data->mutable_relative_bounding_box()->set_ymin(ymin); + location_data->mutable_relative_bounding_box()->set_width(width); + location_data->mutable_relative_bounding_box()->set_height(height); + return detection; +} + +std::vector ConvertToDetectionVector( + const DetectionList& detection_list) { + std::vector output; + for (const auto& detection : detection_list.detection()) { + output.push_back(detection); + } + return output; +} + +void CheckBoundingBox(const Detection& output, const Detection& expected) { + const auto& output_bbox = output.location_data().bounding_box(); + const auto& expected_bbox = output.location_data().bounding_box(); + EXPECT_THAT(output_bbox.xmin(), testing::Eq(expected_bbox.xmin())); + EXPECT_THAT(output_bbox.ymin(), testing::Eq(expected_bbox.ymin())); + EXPECT_THAT(output_bbox.width(), testing::Eq(expected_bbox.width())); + EXPECT_THAT(output_bbox.height(), testing::Eq(expected_bbox.height())); +} + +void CheckRelativeBoundingBox(const Detection& output, + const Detection& expected) { + const auto& output_bbox = output.location_data().relative_bounding_box(); + const auto& expected_bbox = output.location_data().relative_bounding_box(); + EXPECT_THAT(output_bbox.xmin(), testing::FloatEq(expected_bbox.xmin())); + EXPECT_THAT(output_bbox.ymin(), testing::FloatEq(expected_bbox.ymin())); + EXPECT_THAT(output_bbox.width(), testing::FloatEq(expected_bbox.width())); + EXPECT_THAT(output_bbox.height(), testing::FloatEq(expected_bbox.height())); +} + +void CheckOutputDetections(const std::vector& expected, + const std::vector& output) { + ASSERT_EQ(output.size(), expected.size()); + for (int i = 0; i < output.size(); ++i) { + auto output_format = output[i].location_data().format(); + ASSERT_TRUE(output_format == LocationData::RELATIVE_BOUNDING_BOX || + output_format == LocationData::BOUNDING_BOX); + ASSERT_EQ(output_format, expected[i].location_data().format()); + if (output_format == LocationData::RELATIVE_BOUNDING_BOX) { + CheckRelativeBoundingBox(output[i], expected[i]); + } + if (output_format == LocationData::BOUNDING_BOX) { + CheckBoundingBox(output[i], expected[i]); + } + } +} + +TEST(DetectionsTransformationCalculatorTest, MissingImageSize) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "DetectionTransformationCalculator" + input_stream: "DETECTIONS:detections" + output_stream: "PIXEL_DETECTION:detection" + )pb")); + + auto status = runner.Run(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.message(), + testing::HasSubstr("Image size must be provided")); +} + +TEST(DetectionsTransformationCalculatorTest, WrongOutputType) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "DetectionTransformationCalculator" + input_stream: "DETECTIONS:detections" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "PIXEL_DETECTION:detection" + )pb")); + + auto status = runner.Run(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.message(), + testing::HasSubstr("Output must be a container of detections")); +} + +TEST(DetectionsTransformationCalculatorTest, WrongLocationDataFormat) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "DetectionTransformationCalculator" + input_stream: "DETECTION:input_detection" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "PIXEL_DETECTION:output_detection" + )pb")); + + Detection detection; + detection.mutable_location_data()->set_format(LocationData::GLOBAL); + runner.MutableInputs() + ->Tag(kDetectionTag) + .packets.push_back(MakePacket(detection).At(Timestamp(0))); + std::pair image_size({2000, 1000}); + runner.MutableInputs() + ->Tag(kImageSizeTag) + .packets.push_back( + MakePacket>(image_size).At(Timestamp(0))); + + auto status = runner.Run(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.message(), + testing::HasSubstr("location data format must be either " + "RELATIVE_BOUNDING_BOX or BOUNDING_BOX")); +} + +TEST(DetectionsTransformationCalculatorTest, + ConvertBoundingBoxToRelativeBoundingBox) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "DetectionTransformationCalculator" + input_stream: "DETECTIONS:input_detections" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "RELATIVE_DETECTIONS:output_detections" + output_stream: "RELATIVE_DETECTION_LIST:output_detection_list" + )pb")); + + auto detections(absl::make_unique>()); + detections->push_back(DetectionWithBoundingBox(100, 200, 400, 300)); + detections->push_back(DetectionWithBoundingBox(0, 0, 2000, 1000)); + std::pair image_size({2000, 1000}); + runner.MutableInputs() + ->Tag(kDetectionsTag) + .packets.push_back(Adopt(detections.release()).At(Timestamp(0))); + runner.MutableInputs() + ->Tag(kImageSizeTag) + .packets.push_back( + MakePacket>(image_size).At(Timestamp(0))); + MP_ASSERT_OK(runner.Run()); + + std::vector expected( + {DetectionWithRelativeBoundingBox(0.05, 0.2, 0.2, 0.3), + DetectionWithRelativeBoundingBox(0, 0, 1, 1)}); + const std::vector& detections_output = + runner.Outputs().Tag(kRelativeDetectionsTag).packets; + ASSERT_EQ(1, detections_output.size()); + CheckOutputDetections(expected, + detections_output[0].Get>()); + + const std::vector& detection_list_output = + runner.Outputs().Tag(kRelativeDetectionListTag).packets; + ASSERT_EQ(1, detection_list_output.size()); + CheckOutputDetections( + expected, + ConvertToDetectionVector(detection_list_output[0].Get())); +} + +TEST(DetectionsTransformationCalculatorTest, + ConvertRelativeBoundingBoxToBoundingBox) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "DetectionTransformationCalculator" + input_stream: "DETECTIONS:input_detections" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "PIXEL_DETECTIONS:output_detections" + output_stream: "PIXEL_DETECTION_LIST:output_detection_list" + )pb")); + + auto detections(absl::make_unique>()); + detections->push_back(DetectionWithRelativeBoundingBox(0.1, 0.2, 0.3, 0.4)); + detections->push_back(DetectionWithRelativeBoundingBox(0, 0, 1, 1)); + std::pair image_size({2000, 1000}); + runner.MutableInputs() + ->Tag(kDetectionsTag) + .packets.push_back(Adopt(detections.release()).At(Timestamp(0))); + runner.MutableInputs() + ->Tag(kImageSizeTag) + .packets.push_back( + MakePacket>(image_size).At(Timestamp(0))); + MP_ASSERT_OK(runner.Run()); + + std::vector expected({DetectionWithBoundingBox(100, 200, 400, 300), + DetectionWithBoundingBox(0, 0, 2000, 1000)}); + const std::vector& detections_output = + runner.Outputs().Tag(kPixelDetectionsTag).packets; + ASSERT_EQ(1, detections_output.size()); + CheckOutputDetections(expected, + detections_output[0].Get>()); + + const std::vector& detection_list_output = + runner.Outputs().Tag(kPixelDetectionListTag).packets; + ASSERT_EQ(1, detection_list_output.size()); + CheckOutputDetections( + expected, + ConvertToDetectionVector(detection_list_output[0].Get())); +} + +TEST(DetectionsTransformationCalculatorTest, ConvertSingleDetection) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "DetectionTransformationCalculator" + input_stream: "DETECTION:input_detection" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "PIXEL_DETECTION:outpu_detection" + output_stream: "PIXEL_DETECTIONS:output_detections" + output_stream: "PIXEL_DETECTION_LIST:output_detection_list" + )pb")); + + runner.MutableInputs() + ->Tag(kDetectionTag) + .packets.push_back(MakePacket(DetectionWithRelativeBoundingBox( + 0.05, 0.2, 0.2, 0.3)) + .At(Timestamp(0))); + std::pair image_size({2000, 1000}); + runner.MutableInputs() + ->Tag(kImageSizeTag) + .packets.push_back( + MakePacket>(image_size).At(Timestamp(0))); + MP_ASSERT_OK(runner.Run()); + + std::vector expected( + {DetectionWithBoundingBox(100, 200, 400, 300)}); + const std::vector& detection_output = + runner.Outputs().Tag(kPixelDetectionTag).packets; + ASSERT_EQ(1, detection_output.size()); + CheckOutputDetections(expected, {detection_output[0].Get()}); + + const std::vector& detections_output = + runner.Outputs().Tag(kPixelDetectionsTag).packets; + ASSERT_EQ(1, detections_output.size()); + CheckOutputDetections(expected, + detections_output[0].Get>()); + + const std::vector& detection_list_output = + runner.Outputs().Tag(kPixelDetectionListTag).packets; + ASSERT_EQ(1, detection_list_output.size()); + CheckOutputDetections( + expected, + ConvertToDetectionVector(detection_list_output[0].Get())); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/video/tracking_graph_test.cc b/mediapipe/calculators/video/tracking_graph_test.cc index e446e155c..6516bd7da 100644 --- a/mediapipe/calculators/video/tracking_graph_test.cc +++ b/mediapipe/calculators/video/tracking_graph_test.cc @@ -181,7 +181,7 @@ class TrackingGraphTest : public Test { // Each image is shifted to the right and bottom by kTranslationStep // pixels compared with the previous image. static constexpr int kTranslationStep = 10; - static constexpr float kEqualityTolerance = 3e-4f; + static constexpr float kEqualityTolerance = 1e-3f; }; void TrackingGraphTest::ExpectBoxAtFrame(const TimedBoxProto& box, float frame, diff --git a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h index 94d19ff80..76c13d98e 100644 --- a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h +++ b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h @@ -85,7 +85,7 @@ class KinematicPathSolver { double current_position_px_; double prior_position_px_; double current_velocity_deg_per_s_; - uint64 current_time_; + uint64 current_time_ = 0; // History of observations (second) and their time (first). std::deque> raw_positions_at_time_; // Current target position. diff --git a/mediapipe/examples/desktop/media_sequence/read_demo_dataset.py b/mediapipe/examples/desktop/media_sequence/read_demo_dataset.py index b84d991b5..c057addad 100644 --- a/mediapipe/examples/desktop/media_sequence/read_demo_dataset.py +++ b/mediapipe/examples/desktop/media_sequence/read_demo_dataset.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Lint as: python3 """Example of reading a MediaSequence dataset. """ diff --git a/mediapipe/examples/ios/facedetectioncpu/BUILD b/mediapipe/examples/ios/facedetectioncpu/BUILD index 0acd41dfd..9424fddea 100644 --- a/mediapipe/examples/ios/facedetectioncpu/BUILD +++ b/mediapipe/examples/ios/facedetectioncpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "10.0" +MIN_IOS_VERSION = "11.0" alias( name = "facedetectioncpu", diff --git a/mediapipe/examples/ios/facedetectiongpu/BUILD b/mediapipe/examples/ios/facedetectiongpu/BUILD index 4ca3c267e..8ed689b4f 100644 --- a/mediapipe/examples/ios/facedetectiongpu/BUILD +++ b/mediapipe/examples/ios/facedetectiongpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "10.0" +MIN_IOS_VERSION = "11.0" alias( name = "facedetectiongpu", diff --git a/mediapipe/examples/ios/faceeffect/BUILD b/mediapipe/examples/ios/faceeffect/BUILD index 87b59901c..50a6f68bd 100644 --- a/mediapipe/examples/ios/faceeffect/BUILD +++ b/mediapipe/examples/ios/faceeffect/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "10.0" +MIN_IOS_VERSION = "11.0" alias( name = "faceeffect", diff --git a/mediapipe/examples/ios/facemeshgpu/BUILD b/mediapipe/examples/ios/facemeshgpu/BUILD index d10d531ca..02103ce2f 100644 --- a/mediapipe/examples/ios/facemeshgpu/BUILD +++ b/mediapipe/examples/ios/facemeshgpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "10.0" +MIN_IOS_VERSION = "11.0" alias( name = "facemeshgpu", diff --git a/mediapipe/examples/ios/handdetectiongpu/BUILD b/mediapipe/examples/ios/handdetectiongpu/BUILD index b43cdbacd..9b9255374 100644 --- a/mediapipe/examples/ios/handdetectiongpu/BUILD +++ b/mediapipe/examples/ios/handdetectiongpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "10.0" +MIN_IOS_VERSION = "11.0" alias( name = "handdetectiongpu", diff --git a/mediapipe/examples/ios/handtrackinggpu/BUILD b/mediapipe/examples/ios/handtrackinggpu/BUILD index 66a9c64db..647b7670a 100644 --- a/mediapipe/examples/ios/handtrackinggpu/BUILD +++ b/mediapipe/examples/ios/handtrackinggpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "10.0" +MIN_IOS_VERSION = "11.0" alias( name = "handtrackinggpu", diff --git a/mediapipe/examples/ios/helloworld/BUILD b/mediapipe/examples/ios/helloworld/BUILD index 192996bf3..aed0c35a5 100644 --- a/mediapipe/examples/ios/helloworld/BUILD +++ b/mediapipe/examples/ios/helloworld/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "10.0" +MIN_IOS_VERSION = "11.0" alias( name = "helloworld", diff --git a/mediapipe/examples/ios/holistictrackinggpu/BUILD b/mediapipe/examples/ios/holistictrackinggpu/BUILD index 6d72282ed..cd10877de 100644 --- a/mediapipe/examples/ios/holistictrackinggpu/BUILD +++ b/mediapipe/examples/ios/holistictrackinggpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "10.0" +MIN_IOS_VERSION = "11.0" alias( name = "holistictrackinggpu", diff --git a/mediapipe/examples/ios/iristrackinggpu/BUILD b/mediapipe/examples/ios/iristrackinggpu/BUILD index 84222760c..056447d63 100644 --- a/mediapipe/examples/ios/iristrackinggpu/BUILD +++ b/mediapipe/examples/ios/iristrackinggpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "10.0" +MIN_IOS_VERSION = "11.0" alias( name = "iristrackinggpu", diff --git a/mediapipe/examples/ios/link_local_profiles.py b/mediapipe/examples/ios/link_local_profiles.py index a2a91e72d..bc4a06c97 100755 --- a/mediapipe/examples/ios/link_local_profiles.py +++ b/mediapipe/examples/ios/link_local_profiles.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Lint as: python3 """This script is used to set up automatic provisioning for iOS examples. It scans the provisioning profiles used by Xcode, looking for one matching the diff --git a/mediapipe/examples/ios/objectdetectioncpu/BUILD b/mediapipe/examples/ios/objectdetectioncpu/BUILD index 5ddd12df6..7638c7413 100644 --- a/mediapipe/examples/ios/objectdetectioncpu/BUILD +++ b/mediapipe/examples/ios/objectdetectioncpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "10.0" +MIN_IOS_VERSION = "11.0" alias( name = "objectdetectioncpu", diff --git a/mediapipe/examples/ios/objectdetectiongpu/BUILD b/mediapipe/examples/ios/objectdetectiongpu/BUILD index b31c13f53..e07e6ada4 100644 --- a/mediapipe/examples/ios/objectdetectiongpu/BUILD +++ b/mediapipe/examples/ios/objectdetectiongpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) # Apache 2.0 -MIN_IOS_VERSION = "10.0" +MIN_IOS_VERSION = "11.0" alias( name = "objectdetectiongpu", diff --git a/mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD b/mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD index 37e0b85e9..2236c5257 100644 --- a/mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD +++ b/mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "10.0" +MIN_IOS_VERSION = "11.0" alias( name = "objectdetectiontrackinggpu", diff --git a/mediapipe/examples/ios/posetrackinggpu/BUILD b/mediapipe/examples/ios/posetrackinggpu/BUILD index 01a82cb4b..86b41ed36 100644 --- a/mediapipe/examples/ios/posetrackinggpu/BUILD +++ b/mediapipe/examples/ios/posetrackinggpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "10.0" +MIN_IOS_VERSION = "11.0" alias( name = "posetrackinggpu", diff --git a/mediapipe/examples/ios/selfiesegmentationgpu/BUILD b/mediapipe/examples/ios/selfiesegmentationgpu/BUILD index 884ac95a5..1ba7997ed 100644 --- a/mediapipe/examples/ios/selfiesegmentationgpu/BUILD +++ b/mediapipe/examples/ios/selfiesegmentationgpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "10.0" +MIN_IOS_VERSION = "11.0" alias( name = "selfiesegmentationgpu", diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index d592f9c9c..1166c2a33 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -234,7 +234,9 @@ cc_library( "//mediapipe/framework/tool:options_map", "//mediapipe/framework/tool:packet_generator_wrapper_calculator_cc_proto", "//mediapipe/framework/tool:tag_map", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -348,6 +350,7 @@ cc_library( "//mediapipe/framework/tool:validate", "//mediapipe/framework/tool:validate_name", "//mediapipe/gpu:graph_support", + "//mediapipe/gpu:gpu_service", "//mediapipe/util:cpu_util", ] + select({ "//conditions:default": ["//mediapipe/gpu:gpu_shared_data_internal"], @@ -416,7 +419,6 @@ cc_library( "//mediapipe/framework/tool:status_util", "//mediapipe/framework/tool:tag_map", "//mediapipe/framework/tool:validate_name", - "//mediapipe/gpu:graph_support", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -613,7 +615,11 @@ cc_library( hdrs = ["graph_service.h"], visibility = [":mediapipe_internal"], deps = [ + ":packet", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/framework/api2/packet.h b/mediapipe/framework/api2/packet.h index 771cfb83f..80d2307ae 100644 --- a/mediapipe/framework/api2/packet.h +++ b/mediapipe/framework/api2/packet.h @@ -167,7 +167,6 @@ struct IsCompatibleType> template inline Packet PacketBase::As() const { if (!payload_) return Packet().At(timestamp_); - packet_internal::Holder* typed_payload = payload_->As(); internal::CheckCompatibleType(*payload_, internal::Wrap{}); return Packet(payload_).At(timestamp_); } @@ -217,8 +216,8 @@ class Packet : public Packet { const T& operator*() const { return Get(); } const T* operator->() const { return &Get(); } - template - T GetOr(U&& v) const { + template + std::enable_if_t, TT> GetOr(U&& v) const { return IsEmpty() ? static_cast(absl::forward(v)) : **this; } diff --git a/mediapipe/framework/api2/packet_nc.cc b/mediapipe/framework/api2/packet_nc.cc index e59bfac8a..0699f0f2f 100644 --- a/mediapipe/framework/api2/packet_nc.cc +++ b/mediapipe/framework/api2/packet_nc.cc @@ -4,11 +4,15 @@ namespace api2 { namespace { #if defined(TEST_NO_ASSIGN_WRONG_PACKET_TYPE) -void AssignWrongPacketType() { Packet p = MakePacket(1.0); } +int AssignWrongPacketType() { + Packet p = MakePacket(1.0); + return *p; +} #elif defined(TEST_NO_ASSIGN_GENERIC_TO_SPECIFIC) -void AssignWrongPacketType() { +int AssignWrongPacketType() { Packet<> p = MakePacket(1.0); Packet p2 = p; + return *p2; } #endif diff --git a/mediapipe/framework/api2/packet_test.cc b/mediapipe/framework/api2/packet_test.cc index 887ba3c3e..00bc35086 100644 --- a/mediapipe/framework/api2/packet_test.cc +++ b/mediapipe/framework/api2/packet_test.cc @@ -264,6 +264,23 @@ TEST(PacketTest, Polymorphism) { EXPECT_EQ((**mutable_base).name(), "Derived"); } +class AbstractBase { + public: + virtual ~AbstractBase() = default; + virtual absl::string_view name() const = 0; +}; + +class ConcreteDerived : public AbstractBase { + public: + absl::string_view name() const override { return "ConcreteDerived"; } +}; + +TEST(PacketTest, PolymorphismAbstract) { + Packet base = + PacketAdopting(absl::make_unique()); + EXPECT_EQ(base->name(), "ConcreteDerived"); +} + } // namespace } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/framework/api2/port_test.cc b/mediapipe/framework/api2/port_test.cc index 2bbae387d..c09e38452 100644 --- a/mediapipe/framework/api2/port_test.cc +++ b/mediapipe/framework/api2/port_test.cc @@ -40,6 +40,17 @@ TEST(PortTest, DeletedCopyConstructorInput) { EXPECT_EQ(std::string(kSideOutputPort.Tag()), "SIDE_OUTPUT"); } +class AbstractBase { + public: + virtual ~AbstractBase() = default; + virtual absl::string_view name() const = 0; +}; + +TEST(PortTest, Abstract) { + static constexpr Input kInputPort{"INPUT"}; + EXPECT_EQ(std::string(kInputPort.Tag()), "INPUT"); +} + } // namespace } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/framework/calculator_contract.h b/mediapipe/framework/calculator_contract.h index 72f29bc03..fd0507bec 100644 --- a/mediapipe/framework/calculator_contract.h +++ b/mediapipe/framework/calculator_contract.h @@ -21,6 +21,8 @@ #include // TODO: Move protos in another CL after the C++ code migration. +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/graph_service.h" #include "mediapipe/framework/mediapipe_options.pb.h" @@ -147,7 +149,7 @@ class CalculatorContract { bool IsOptional() const { return optional_; } private: - GraphServiceBase service_; + const GraphServiceBase& service_; bool optional_ = false; }; @@ -156,9 +158,12 @@ class CalculatorContract { return it->second; } - const std::map& ServiceRequests() const { - return service_requests_; - } + // A GraphService's key is always a static constant, so we can use string_view + // as the key type without lifetime issues. + using ServiceReqMap = + absl::flat_hash_map; + + const ServiceReqMap& ServiceRequests() const { return service_requests_; } private: template @@ -180,7 +185,7 @@ class CalculatorContract { std::string input_stream_handler_; MediaPipeOptions input_stream_handler_options_; std::string node_name_; - std::map service_requests_; + ServiceReqMap service_requests_; bool process_timestamps_ = false; TimestampDiff timestamp_offset_ = TimestampDiff::Unset(); diff --git a/mediapipe/framework/calculator_graph.cc b/mediapipe/framework/calculator_graph.cc index 43c844bc7..4f6755364 100644 --- a/mediapipe/framework/calculator_graph.cc +++ b/mediapipe/framework/calculator_graph.cc @@ -226,6 +226,16 @@ absl::Status CalculatorGraph::InitializeStreams() { return absl::OkStatus(); } +// Hack for backwards compatibility with ancient GPU calculators. Can it +// be retired yet? +static void MaybeFixupLegacyGpuNodeContract(CalculatorNode& node) { +#if !MEDIAPIPE_DISABLE_GPU + if (node.Contract().InputSidePackets().HasTag(kGpuSharedTagName)) { + const_cast(node.Contract()).UseService(kGpuService); + } +#endif // !MEDIAPIPE_DISABLE_GPU +} + absl::Status CalculatorGraph::InitializeCalculatorNodes() { // Check if the user has specified a maximum queue size for an input stream. max_queue_size_ = validated_graph_->Config().max_queue_size(); @@ -246,6 +256,7 @@ absl::Status CalculatorGraph::InitializeCalculatorNodes() { validated_graph_.get(), node_ref, input_stream_managers_.get(), output_stream_managers_.get(), output_side_packets_.get(), &buffer_size_hint, profiler_); + MaybeFixupLegacyGpuNodeContract(*nodes_.back()); if (buffer_size_hint > 0) { max_queue_size_ = std::max(max_queue_size_, buffer_size_hint); } @@ -283,6 +294,7 @@ absl::Status CalculatorGraph::InitializePacketGeneratorNodes( validated_graph_.get(), node_ref, input_stream_managers_.get(), output_stream_managers_.get(), output_side_packets_.get(), &buffer_size_hint, profiler_); + MaybeFixupLegacyGpuNodeContract(*nodes_.back()); if (!result.ok()) { // Collect as many errors as we can before failing. errors.push_back(result); @@ -495,9 +507,8 @@ absl::StatusOr CalculatorGraph::GetOutputSidePacket( << "\" because it doesn't exist."; } Packet output_packet; - if (scheduler_.IsTerminated()) { - // Side-packets from calculators can be retrieved only after the graph is - // done. + if (!output_side_packets_[side_packet_index].GetPacket().IsEmpty() || + scheduler_.IsTerminated()) { output_packet = output_side_packets_[side_packet_index].GetPacket(); } if (output_packet.IsEmpty()) { @@ -546,6 +557,7 @@ absl::Status CalculatorGraph::StartRun( #if !MEDIAPIPE_DISABLE_GPU absl::Status CalculatorGraph::SetGpuResources( std::shared_ptr<::mediapipe::GpuResources> resources) { + RET_CHECK_NE(resources, nullptr); auto gpu_service = service_manager_.GetServiceObject(kGpuService); RET_CHECK_EQ(gpu_service, nullptr) << "The GPU resources have already been configured."; @@ -557,68 +569,89 @@ std::shared_ptr<::mediapipe::GpuResources> CalculatorGraph::GetGpuResources() return service_manager_.GetServiceObject(kGpuService); } -absl::StatusOr> CalculatorGraph::PrepareGpu( +static Packet GetLegacyGpuSharedSidePacket( const std::map& side_packets) { - std::map additional_side_packets; - bool update_sp = false; - bool uses_gpu = false; - for (const auto& node : nodes_) { - if (node->UsesGpu()) { - uses_gpu = true; - break; - } + auto legacy_sp_iter = side_packets.find(kGpuSharedSidePacketName); + if (legacy_sp_iter == side_packets.end()) return {}; + // Note that, because of b/116875321, the legacy side packet may be set but + // empty. But it's ok, because here we return an empty packet to indicate the + // missing case anyway. + return legacy_sp_iter->second; +} + +absl::Status CalculatorGraph::MaybeSetUpGpuServiceFromLegacySidePacket( + Packet legacy_sp) { + if (legacy_sp.IsEmpty()) return absl::OkStatus(); + auto gpu_resources = service_manager_.GetServiceObject(kGpuService); + if (gpu_resources) { + LOG(WARNING) + << "::mediapipe::GpuSharedData provided as a side packet while the " + << "graph already had one; ignoring side packet"; + return absl::OkStatus(); } - if (uses_gpu) { - auto gpu_resources = service_manager_.GetServiceObject(kGpuService); + gpu_resources = legacy_sp.Get<::mediapipe::GpuSharedData*>()->gpu_resources; + return service_manager_.SetServiceObject(kGpuService, gpu_resources); +} - auto legacy_sp_iter = side_packets.find(kGpuSharedSidePacketName); - // Workaround for b/116875321: CalculatorRunner provides an empty packet, - // instead of just leaving it unset. - bool has_legacy_sp = legacy_sp_iter != side_packets.end() && - !legacy_sp_iter->second.IsEmpty(); - - if (gpu_resources) { - if (has_legacy_sp) { - LOG(WARNING) - << "::mediapipe::GpuSharedData provided as a side packet while the " - << "graph already had one; ignoring side packet"; - } - update_sp = true; - } else { - if (has_legacy_sp) { - gpu_resources = - legacy_sp_iter->second.Get<::mediapipe::GpuSharedData*>() - ->gpu_resources; - } else { - ASSIGN_OR_RETURN(gpu_resources, ::mediapipe::GpuResources::Create()); - update_sp = true; - } - MP_RETURN_IF_ERROR( - service_manager_.SetServiceObject(kGpuService, gpu_resources)); - } - - // Create or replace the legacy side packet if needed. - if (update_sp) { - legacy_gpu_shared_.reset(new ::mediapipe::GpuSharedData(gpu_resources)); - additional_side_packets[kGpuSharedSidePacketName] = - MakePacket<::mediapipe::GpuSharedData*>(legacy_gpu_shared_.get()); - } - - // Set up executors. - for (auto& node : nodes_) { - if (node->UsesGpu()) { - MP_RETURN_IF_ERROR(gpu_resources->PrepareGpuNode(node.get())); - } - } - for (const auto& name_executor : gpu_resources->GetGpuExecutors()) { - MP_RETURN_IF_ERROR( - SetExecutorInternal(name_executor.first, name_executor.second)); - } +std::map CalculatorGraph::MaybeCreateLegacyGpuSidePacket( + Packet legacy_sp) { + std::map additional_side_packets; + auto gpu_resources = service_manager_.GetServiceObject(kGpuService); + if (gpu_resources && + (legacy_sp.IsEmpty() || + legacy_sp.Get<::mediapipe::GpuSharedData*>()->gpu_resources != + gpu_resources)) { + legacy_gpu_shared_ = + absl::make_unique(gpu_resources); + additional_side_packets[kGpuSharedSidePacketName] = + MakePacket<::mediapipe::GpuSharedData*>(legacy_gpu_shared_.get()); } return additional_side_packets; } + +static bool UsesGpu(const CalculatorNode& node) { + return node.Contract().ServiceRequests().contains(kGpuService.key); +} + +absl::Status CalculatorGraph::PrepareGpu() { + auto gpu_resources = service_manager_.GetServiceObject(kGpuService); + if (!gpu_resources) return absl::OkStatus(); + // Set up executors. + for (auto& node : nodes_) { + if (UsesGpu(*node)) { + MP_RETURN_IF_ERROR(gpu_resources->PrepareGpuNode(node.get())); + } + } + for (const auto& name_executor : gpu_resources->GetGpuExecutors()) { + MP_RETURN_IF_ERROR( + SetExecutorInternal(name_executor.first, name_executor.second)); + } + return absl::OkStatus(); +} #endif // !MEDIAPIPE_DISABLE_GPU +absl::Status CalculatorGraph::PrepareServices() { + for (const auto& node : nodes_) { + for (const auto& [key, request] : node->Contract().ServiceRequests()) { + auto packet = service_manager_.GetServicePacket(request.Service()); + if (!packet.IsEmpty()) continue; + auto packet_or = request.Service().CreateDefaultObject(); + if (packet_or.ok()) { + MP_RETURN_IF_ERROR(service_manager_.SetServicePacket( + request.Service(), std::move(packet_or).value())); + } else if (request.IsOptional()) { + continue; + } else { + return absl::InternalError(absl::StrCat( + "Service \"", request.Service().key, "\", required by node ", + node->DebugName(), ", was not provided and cannot be created: ", + std::move(packet_or).status().message())); + } + } + } + return absl::OkStatus(); +} + absl::Status CalculatorGraph::PrepareForRun( const std::map& extra_side_packets, const std::map& stream_headers) { @@ -637,7 +670,13 @@ absl::Status CalculatorGraph::PrepareForRun( std::map additional_side_packets; #if !MEDIAPIPE_DISABLE_GPU - ASSIGN_OR_RETURN(additional_side_packets, PrepareGpu(extra_side_packets)); + auto legacy_sp = GetLegacyGpuSharedSidePacket(extra_side_packets); + MP_RETURN_IF_ERROR(MaybeSetUpGpuServiceFromLegacySidePacket(legacy_sp)); +#endif // !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(PrepareServices()); +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(PrepareGpu()); + additional_side_packets = MaybeCreateLegacyGpuSidePacket(legacy_sp); #endif // !MEDIAPIPE_DISABLE_GPU const std::map* input_side_packets; diff --git a/mediapipe/framework/calculator_graph.h b/mediapipe/framework/calculator_graph.h index 3478375e4..406317fb9 100644 --- a/mediapipe/framework/calculator_graph.h +++ b/mediapipe/framework/calculator_graph.h @@ -165,10 +165,13 @@ class CalculatorGraph { StatusOrPoller AddOutputStreamPoller(const std::string& stream_name, bool observe_timestamp_bounds = false); - // Gets output side packet by name after the graph is done. However, base - // packets (generated by PacketGenerators) can be retrieved before - // graph is done. Returns error if the graph is still running (for non-base - // packets) or the output side packet is not found or empty. + // Gets output side packet by name. The output side packet can be successfully + // retrevied in one of the following situations: + // - The graph is done. + // - The output side packet has been generated by a calculator and the graph + // is currently idle. + // - The side packet is a base packet generated by a PacketGenerator. + // Returns error if the the output side packet is not found or empty. absl::StatusOr GetOutputSidePacket(const std::string& packet_name); // Runs the graph after adding the given extra input side packets. All @@ -367,13 +370,8 @@ class CalculatorGraph { std::shared_ptr GetGpuResources() const; absl::Status SetGpuResources(std::shared_ptr resources); - - // Helper for PrepareForRun. If it returns a non-empty map, those packets - // must be added to the existing side packets, replacing existing values - // that have the same key. - absl::StatusOr> PrepareGpu( - const std::map& side_packets); #endif // !MEDIAPIPE_DISABLE_GPU + template absl::Status SetServiceObject(const GraphService& service, std::shared_ptr object) { @@ -495,6 +493,18 @@ class CalculatorGraph { const std::map& extra_side_packets, const std::map& stream_headers); + absl::Status PrepareServices(); + +#if !MEDIAPIPE_DISABLE_GPU + absl::Status MaybeSetUpGpuServiceFromLegacySidePacket(Packet legacy_sp); + // Helper for PrepareForRun. If it returns a non-empty map, those packets + // must be added to the existing side packets, replacing existing values + // that have the same key. + std::map MaybeCreateLegacyGpuSidePacket( + Packet legacy_sp); + absl::Status PrepareGpu(); +#endif // !MEDIAPIPE_DISABLE_GPU + // Cleans up any remaining state after the run and returns any errors that may // have occurred during the run. Called after the scheduler has terminated. absl::Status FinishRun(); diff --git a/mediapipe/framework/calculator_graph_side_packet_test.cc b/mediapipe/framework/calculator_graph_side_packet_test.cc index e187a1a6d..57fcff866 100644 --- a/mediapipe/framework/calculator_graph_side_packet_test.cc +++ b/mediapipe/framework/calculator_graph_side_packet_test.cc @@ -732,11 +732,12 @@ TEST(CalculatorGraph, GetOutputSidePacket) { status_or_packet = graph.GetOutputSidePacket("unknown"); EXPECT_FALSE(status_or_packet.ok()); EXPECT_EQ(absl::StatusCode::kNotFound, status_or_packet.status().code()); - // Should return UNAVAILABLE before graph is done for valid non-base - // packets. + // Should return the packet after the graph becomes idle. + MP_ASSERT_OK(graph.WaitUntilIdle()); status_or_packet = graph.GetOutputSidePacket("num_of_packets"); - EXPECT_FALSE(status_or_packet.ok()); - EXPECT_EQ(absl::StatusCode::kUnavailable, status_or_packet.status().code()); + MP_ASSERT_OK(status_or_packet); + EXPECT_EQ(max_count, status_or_packet.value().Get()); + EXPECT_EQ(Timestamp::Unset(), status_or_packet.value().Timestamp()); // Should stil return a base even before graph is done. status_or_packet = graph.GetOutputSidePacket("output_uint64"); MP_ASSERT_OK(status_or_packet); @@ -896,5 +897,23 @@ TEST(CalculatorGraph, GeneratorAfterCalculatorProcess) { } } +TEST(CalculatorGraph, GetOutputSidePacketAfterCalculatorIsOpened) { + CalculatorGraph graph; + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "IntegerOutputSidePacketCalculator" + output_side_packet: "offset" + } + )pb"); + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.StartRun({})); + // Must be called to ensure that the calculator is opened. + MP_ASSERT_OK(graph.WaitUntilIdle()); + absl::StatusOr status_or_packet = graph.GetOutputSidePacket("offset"); + MP_ASSERT_OK(status_or_packet); + EXPECT_EQ(1, status_or_packet.value().Get()); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/calculator_node.cc b/mediapipe/framework/calculator_node.cc index 763c5d07b..f6a1c7dbf 100644 --- a/mediapipe/framework/calculator_node.cc +++ b/mediapipe/framework/calculator_node.cc @@ -46,7 +46,6 @@ #include "mediapipe/framework/tool/status_util.h" #include "mediapipe/framework/tool/tag_map.h" #include "mediapipe/framework/tool/validate_name.h" -#include "mediapipe/gpu/graph_support.h" namespace mediapipe { @@ -155,11 +154,6 @@ absl::Status CalculatorNode::Initialize( const CalculatorContract& contract = node_type_info_->Contract(); - uses_gpu_ = - node_type_info_->InputSidePacketTypes().HasTag(kGpuSharedTagName) || - ContainsKey(node_type_info_->Contract().ServiceRequests(), - kGpuService.key); - // TODO Propagate types between calculators when SetAny is used. MP_RETURN_IF_ERROR(InitializeOutputSidePackets( @@ -397,7 +391,7 @@ absl::Status CalculatorNode::PrepareForRun( std::move(schedule_callback), error_callback); output_stream_handler_->PrepareForRun(error_callback); - const auto& contract = node_type_info_->Contract(); + const auto& contract = Contract(); input_side_packet_types_ = RemoveOmittedPacketTypes( contract.InputSidePackets(), all_side_packets, validated_graph_); MP_RETURN_IF_ERROR(input_side_packet_handler_.PrepareForRun( diff --git a/mediapipe/framework/calculator_node.h b/mediapipe/framework/calculator_node.h index 8ecf72cfc..4ba5027b5 100644 --- a/mediapipe/framework/calculator_node.h +++ b/mediapipe/framework/calculator_node.h @@ -195,9 +195,6 @@ class CalculatorNode { // Called by SchedulerQueue when a node is opened. void NodeOpened() ABSL_LOCKS_EXCLUDED(status_mutex_); - // Returns whether this is a GPU calculator node. - bool UsesGpu() const { return uses_gpu_; } - // Returns the scheduler queue the node is assigned to. internal::SchedulerQueue* GetSchedulerQueue() const { return scheduler_queue_; @@ -234,6 +231,12 @@ class CalculatorNode { return *calculator_state_; } + // Returns the node's contract. + // Must not be called before the CalculatorNode is initialized. + const CalculatorContract& Contract() const { + return node_type_info_->Contract(); + } + private: // Sets up the output side packets from the main flat array. absl::Status InitializeOutputSidePackets( @@ -363,9 +366,6 @@ class CalculatorNode { std::unique_ptr output_stream_handler_; - // Whether this is a GPU calculator. - bool uses_gpu_ = false; - // True if CleanupAfterRun() needs to call CloseNode(). bool needs_to_close_ = false; diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index 9deff6542..236269260 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -187,6 +187,21 @@ cc_library( ], ) +config_setting( + name = "opencv", + define_values = { + "use_opencv": "true", + }, +) + +config_setting( + name = "portable_opencv", + define_values = { + "use_portable_opencv": "true", + "use_opencv": "false", + }, +) + cc_library( name = "location", srcs = ["location.cc"], @@ -194,6 +209,8 @@ cc_library( defines = select({ "//conditions:default": [], "//mediapipe:android": ["MEDIAPIPE_ANDROID_OPENCV"], + ":portable_opencv": ["MEDIAPIPE_ANDROID_OPENCV"], + ":opencv": [], }), visibility = ["//visibility:public"], deps = [ diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 453f2c659..adb7dca6e 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -76,7 +76,7 @@ class Tensor { public: // No resources are allocated here. - enum class ElementType { kNone, kFloat16, kFloat32, kUInt8 }; + enum class ElementType { kNone, kFloat16, kFloat32, kUInt8, kInt8 }; struct Shape { Shape() = default; Shape(std::initializer_list dimensions) : dims(dimensions) {} @@ -217,6 +217,8 @@ class Tensor { return sizeof(float); case ElementType::kUInt8: return 1; + case ElementType::kInt8: + return 1; } } int bytes() const { return shape_.num_elements() * element_size(); } diff --git a/mediapipe/framework/graph_service.h b/mediapipe/framework/graph_service.h index 920603929..c4c64a852 100644 --- a/mediapipe/framework/graph_service.h +++ b/mediapipe/framework/graph_service.h @@ -16,6 +16,12 @@ #define MEDIAPIPE_FRAMEWORK_GRAPH_SERVICE_H_ #include +#include +#include + +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/status.h" namespace mediapipe { @@ -27,18 +33,74 @@ namespace mediapipe { // IMPORTANT: this is an experimental API. Get in touch with the MediaPipe team // if you want to use it. In most cases, you should use a side packet instead. -struct GraphServiceBase { +class GraphServiceBase { + public: + // TODO: fix services for which default init is broken, remove + // this setting. + enum DefaultInitSupport { + kAllowDefaultInitialization, + kDisallowDefaultInitialization + }; + constexpr GraphServiceBase(const char* key) : key(key) {} + virtual ~GraphServiceBase() = default; + inline virtual absl::StatusOr CreateDefaultObject() const { + return DefaultInitializationUnsupported(); + } + const char* key; + + protected: + absl::Status DefaultInitializationUnsupported() const { + return absl::UnimplementedError(absl::StrCat( + "Graph service '", key, "' does not support default initialization")); + } }; template -struct GraphService : public GraphServiceBase { +class GraphService : public GraphServiceBase { + public: using type = T; using packet_type = std::shared_ptr; - constexpr GraphService(const char* key) : GraphServiceBase(key) {} + constexpr GraphService(const char* my_key, DefaultInitSupport default_init = + kDisallowDefaultInitialization) + : GraphServiceBase(my_key), default_init_(default_init) {} + + absl::StatusOr CreateDefaultObject() const override { + if (default_init_ != kAllowDefaultInitialization) { + return DefaultInitializationUnsupported(); + } + auto packet_or = CreateDefaultObjectInternal(); + if (packet_or.ok()) { + return MakePacket>(std::move(packet_or).value()); + } else { + return packet_or.status(); + } + } + + private: + absl::StatusOr> CreateDefaultObjectInternal() const { + auto call_create = [](auto x) -> decltype(decltype(x)::type::Create()) { + return decltype(x)::type::Create(); + }; + if constexpr (std::is_invocable_r_v>, + decltype(call_create), type_tag>) { + return T::Create(); + } + if constexpr (std::is_default_constructible_v) { + return std::make_shared(); + } + return DefaultInitializationUnsupported(); + } + + template + struct type_tag { + using type = U; + }; + + DefaultInitSupport default_init_; }; template diff --git a/mediapipe/framework/graph_service_manager.h b/mediapipe/framework/graph_service_manager.h index a8b9cc1fb..301f17cb6 100644 --- a/mediapipe/framework/graph_service_manager.h +++ b/mediapipe/framework/graph_service_manager.h @@ -35,6 +35,8 @@ class GraphServiceManager { Packet GetServicePacket(const GraphServiceBase& service) const; std::map service_packets_; + + friend class CalculatorGraph; }; } // namespace mediapipe diff --git a/mediapipe/framework/graph_service_manager_test.cc b/mediapipe/framework/graph_service_manager_test.cc index f38148006..1895a6f70 100644 --- a/mediapipe/framework/graph_service_manager_test.cc +++ b/mediapipe/framework/graph_service_manager_test.cc @@ -6,11 +6,13 @@ #include "mediapipe/framework/port/status_matchers.h" namespace mediapipe { +namespace { +const GraphService kIntService("mediapipe::IntService"); +} // namespace TEST(GraphServiceManager, SetGetServiceObject) { GraphServiceManager service_manager; - constexpr GraphService kIntService("mediapipe::IntService"); EXPECT_EQ(service_manager.GetServiceObject(kIntService), nullptr); MP_EXPECT_OK(service_manager.SetServiceObject(kIntService, @@ -22,8 +24,6 @@ TEST(GraphServiceManager, SetGetServiceObject) { TEST(GraphServiceManager, SetServicePacket) { GraphServiceManager service_manager; - constexpr GraphService kIntService("mediapipe::IntService"); - MP_EXPECT_OK(service_manager.SetServicePacket( kIntService, mediapipe::MakePacket>(std::make_shared(100)))); @@ -36,8 +36,6 @@ TEST(GraphServiceManager, ServicePackets) { EXPECT_TRUE(service_manager.ServicePackets().empty()); - constexpr GraphService kIntService("mediapipe::IntService"); - MP_EXPECT_OK(service_manager.SetServiceObject(kIntService, std::make_shared(100))); diff --git a/mediapipe/framework/graph_service_test.cc b/mediapipe/framework/graph_service_test.cc index 3226ecbaf..bd9b1af66 100644 --- a/mediapipe/framework/graph_service_test.cc +++ b/mediapipe/framework/graph_service_test.cc @@ -150,5 +150,12 @@ TEST_F(GraphServiceTest, OptionalIsAvailable) { EXPECT_EQ(PacketValues(output_packets_), (std::vector{108})); } +TEST_F(GraphServiceTest, CreateDefault) { + EXPECT_FALSE(kTestService.CreateDefaultObject().ok()); + MP_EXPECT_OK(kAnotherService.CreateDefaultObject()); + EXPECT_FALSE(kNoDefaultService.CreateDefaultObject().ok()); + MP_EXPECT_OK(kNeedsCreateService.CreateDefaultObject()); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/input_stream_handler.cc b/mediapipe/framework/input_stream_handler.cc index 27cba711f..d1dffa414 100644 --- a/mediapipe/framework/input_stream_handler.cc +++ b/mediapipe/framework/input_stream_handler.cc @@ -50,15 +50,18 @@ absl::Status InputStreamHandler::SetupInputShards( return absl::OkStatus(); } -std::vector> +std::vector> InputStreamHandler::GetMonitoringInfo() { - std::vector> monitoring_info_vector; + std::vector> + monitoring_info_vector; for (auto& stream : input_stream_managers_) { if (!stream) { continue; } monitoring_info_vector.emplace_back( - std::pair(stream->Name(), stream->QueueSize())); + std::tuple( + stream->Name(), stream->QueueSize(), stream->NumPacketsAdded(), + stream->MinTimestampOrBound(nullptr))); } return monitoring_info_vector; } diff --git a/mediapipe/framework/input_stream_handler.h b/mediapipe/framework/input_stream_handler.h index 798f89f36..e306a55a8 100644 --- a/mediapipe/framework/input_stream_handler.h +++ b/mediapipe/framework/input_stream_handler.h @@ -94,7 +94,7 @@ class InputStreamHandler { // Returns a vector of pairs of stream name and queue size for monitoring // purpose. - std::vector> GetMonitoringInfo(); + std::vector> GetMonitoringInfo(); // Resets the input stream handler and its underlying input streams for // another run of the graph. diff --git a/mediapipe/framework/input_stream_manager.cc b/mediapipe/framework/input_stream_manager.cc index 5b1917138..f47259877 100644 --- a/mediapipe/framework/input_stream_manager.cc +++ b/mediapipe/framework/input_stream_manager.cc @@ -329,6 +329,11 @@ Packet InputStreamManager::PopQueueHead(bool* stream_is_done) { return packet; } +int InputStreamManager::NumPacketsAdded() const { + absl::MutexLock lock(&stream_mutex_); + return num_packets_added_; +} + int InputStreamManager::QueueSize() const { absl::MutexLock lock(&stream_mutex_); return static_cast(queue_.size()); diff --git a/mediapipe/framework/input_stream_manager.h b/mediapipe/framework/input_stream_manager.h index 042ef8d83..f269e8ed9 100644 --- a/mediapipe/framework/input_stream_manager.h +++ b/mediapipe/framework/input_stream_manager.h @@ -87,12 +87,14 @@ class InputStreamManager { // Timestamp::PostStream(), the packet must be the only packet in the // stream. // Violation of any of these conditions causes an error status. - absl::Status AddPackets(const std::list& container, bool* notify); + absl::Status AddPackets(const std::list& container, bool* notify) + ABSL_LOCKS_EXCLUDED(stream_mutex_); // Move a list of timestamped packets. Sets "notify" to true if the queue // becomes non-empty. Does nothing if the input stream is closed. After the // move, all packets in the container must be empty. - absl::Status MovePackets(std::list* container, bool* notify); + absl::Status MovePackets(std::list* container, bool* notify) + ABSL_LOCKS_EXCLUDED(stream_mutex_); // Closes the input stream. This function can be called multiple times. void Close() ABSL_LOCKS_EXCLUDED(stream_mutex_); @@ -140,6 +142,9 @@ class InputStreamManager { // Timestamp::Done() after the pop. Packet PopQueueHead(bool* stream_is_done) ABSL_LOCKS_EXCLUDED(stream_mutex_); + // Returns the number of packets in the queue. + int NumPacketsAdded() const ABSL_LOCKS_EXCLUDED(stream_mutex_); + // Returns the number of packets in the queue. int QueueSize() const ABSL_LOCKS_EXCLUDED(stream_mutex_); diff --git a/mediapipe/framework/input_stream_manager_test.cc b/mediapipe/framework/input_stream_manager_test.cc index f1c1185f1..db2b3f2a3 100644 --- a/mediapipe/framework/input_stream_manager_test.cc +++ b/mediapipe/framework/input_stream_manager_test.cc @@ -767,6 +767,7 @@ TEST_F(InputStreamManagerTest, QueueSizeTest) { EXPECT_EQ(3, num_packets_dropped_); EXPECT_TRUE(input_stream_manager_->IsEmpty()); EXPECT_FALSE(stream_is_done_); + EXPECT_EQ(3, input_stream_manager_->NumPacketsAdded()); packets.clear(); packets.push_back(MakePacket("packet 4").At(Timestamp(60))); @@ -776,6 +777,7 @@ TEST_F(InputStreamManagerTest, QueueSizeTest) { input_stream_manager_->AddPackets(packets, ¬ify_)); // Notification EXPECT_FALSE(input_stream_manager_->IsEmpty()); EXPECT_TRUE(notify_); + EXPECT_EQ(5, input_stream_manager_->NumPacketsAdded()); expected_queue_becomes_full_count_ = 2; expected_queue_becomes_not_full_count_ = 1; diff --git a/mediapipe/framework/mediapipe_cc_test.bzl b/mediapipe/framework/mediapipe_cc_test.bzl index c953004d9..6ccbebb0c 100644 --- a/mediapipe/framework/mediapipe_cc_test.bzl +++ b/mediapipe/framework/mediapipe_cc_test.bzl @@ -12,6 +12,8 @@ def mediapipe_cc_test( timeout = None, args = [], additional_deps = DEFAULT_ADDITIONAL_TEST_DEPS, + platforms = ["linux", "android", "ios", "wasm"], + exclude_platforms = None, # ios_unit_test arguments ios_minimum_os_version = "9.0", # android_cc_test arguments diff --git a/mediapipe/framework/port/BUILD b/mediapipe/framework/port/BUILD index 9e33052ce..87944d80f 100644 --- a/mediapipe/framework/port/BUILD +++ b/mediapipe/framework/port/BUILD @@ -412,8 +412,7 @@ cc_library( name = "status_matchers", testonly = 1, hdrs = ["status_matchers.h"], - # Use this library through "mediapipe/framework/port:gtest_main". - visibility = ["//mediapipe/framework/port:__pkg__"], + visibility = ["//visibility:private"], deps = [ ":status", "@com_google_googletest//:gtest", diff --git a/mediapipe/framework/test_service.cc b/mediapipe/framework/test_service.cc index 79bbc4340..4bafaf28c 100644 --- a/mediapipe/framework/test_service.cc +++ b/mediapipe/framework/test_service.cc @@ -16,8 +16,14 @@ namespace mediapipe { -const GraphService kTestService("test_service"); -const GraphService kAnotherService("another_service"); +const GraphService kTestService( + "test_service", GraphServiceBase::kDisallowDefaultInitialization); +const GraphService kAnotherService( + "another_service", GraphServiceBase::kAllowDefaultInitialization); +const GraphService kNoDefaultService( + "no_default_service", GraphServiceBase::kAllowDefaultInitialization); +const GraphService kNeedsCreateService( + "needs_create_service", GraphServiceBase::kAllowDefaultInitialization); absl::Status TestServiceCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); diff --git a/mediapipe/framework/test_service.h b/mediapipe/framework/test_service.h index e726f7c15..2ff5a384a 100644 --- a/mediapipe/framework/test_service.h +++ b/mediapipe/framework/test_service.h @@ -16,6 +16,7 @@ #define MEDIAPIPE_FRAMEWORK_TEST_SERVICE_H_ #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/graph_service.h" namespace mediapipe { @@ -24,6 +25,23 @@ using TestServiceObject = std::map; extern const GraphService kTestService; extern const GraphService kAnotherService; +class NoDefaultConstructor { + public: + NoDefaultConstructor() = delete; +}; +extern const GraphService kNoDefaultService; + +class NeedsCreateMethod { + public: + static absl::StatusOr> Create() { + return std::shared_ptr(new NeedsCreateMethod()); + } + + private: + NeedsCreateMethod() = default; +}; +extern const GraphService kNeedsCreateService; + // Use a service. class TestServiceCalculator : public CalculatorBase { public: diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 1d8b6a88c..d44c8fe26 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -134,7 +134,7 @@ cc_library( name = "name_util", srcs = ["name_util.cc"], hdrs = ["name_util.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//visibility:public"], deps = [ ":validate_name", "//mediapipe/framework:calculator_cc_proto", diff --git a/mediapipe/framework/tool/test_util.cc b/mediapipe/framework/tool/test_util.cc index 77e2d6fbd..c77aed377 100644 --- a/mediapipe/framework/tool/test_util.cc +++ b/mediapipe/framework/tool/test_util.cc @@ -225,7 +225,7 @@ std::string GetTestOutputsDir() { return output_dir; } -std::string GetTestDataDir(const std::string& package_base_path) { +std::string GetTestDataDir(absl::string_view package_base_path) { return file::JoinPath(GetTestRootDir(), package_base_path, "testdata/"); } @@ -270,7 +270,7 @@ absl::StatusOr> LoadTestImage( format, width, height, width * output_channels, data, stbi_image_free); } -std::unique_ptr LoadTestPng(const std::string& path, +std::unique_ptr LoadTestPng(absl::string_view path, ImageFormat::Format format) { return nullptr; } diff --git a/mediapipe/framework/tool/test_util.h b/mediapipe/framework/tool/test_util.h index bf6569bb0..ae3de3706 100644 --- a/mediapipe/framework/tool/test_util.h +++ b/mediapipe/framework/tool/test_util.h @@ -63,7 +63,7 @@ std::string GetTestFilePath(absl::string_view relative_path); // directory. // This handles the different paths where test data ends up when using // ion_cc_test on various platforms. -std::string GetTestDataDir(const std::string& package_base_path); +std::string GetTestDataDir(absl::string_view package_base_path); // Loads a binary graph from path. Returns true iff successful. bool LoadTestGraph(CalculatorGraphConfig* proto, const std::string& path); @@ -75,7 +75,7 @@ absl::StatusOr> LoadTestImage( // Loads a PNG image from path using the given ImageFormat. Returns nullptr in // case of failure. std::unique_ptr LoadTestPng( - const std::string& path, ImageFormat::Format format = ImageFormat::SRGBA); + absl::string_view path, ImageFormat::Format format = ImageFormat::SRGBA); // Returns the luminance image of |original_image|. // The format of |original_image| must be sRGB or sRGBA. diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 3782e1eee..8c9c433b0 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -38,14 +38,19 @@ cc_library( srcs = ["gpu_service.cc"], hdrs = ["gpu_service.h"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:graph_service"], + deps = ["//mediapipe/framework:graph_service"] + select({ + "//conditions:default": [ + ":gpu_shared_data_internal", + ], + "//mediapipe/gpu:disable_gpu": [], + }), ) cc_library( name = "graph_support", hdrs = ["graph_support.h"], visibility = ["//visibility:public"], - deps = [":gpu_service"], + deps = ["//mediapipe/framework:graph_service"], ) GL_BASE_LINK_OPTS = select({ @@ -366,7 +371,6 @@ objc_library( hdrs = ["pixel_buffer_pool_util.h"], copts = [ "-Wno-shorten-64-to-32", - "-std=c++17", ], sdk_frameworks = [ "Accelerate", @@ -389,7 +393,6 @@ objc_library( copts = [ "-x objective-c++", "-Wno-shorten-64-to-32", - "-std=c++17", ], features = ["-layering_check"], sdk_frameworks = [ @@ -425,7 +428,6 @@ objc_library( copts = [ "-x objective-c++", "-Wno-shorten-64-to-32", - "-std=c++17", ], sdk_frameworks = [ "CoreVideo", @@ -691,7 +693,6 @@ objc_library( name = "gl_calculator_helper_ios", copts = [ "-Wno-shorten-64-to-32", - "-std=c++17", ], visibility = ["//visibility:public"], deps = [ @@ -707,7 +708,6 @@ objc_library( hdrs = ["MPPMetalHelper.h"], copts = [ "-Wno-shorten-64-to-32", - "-std=c++17", ], features = ["-layering_check"], sdk_frameworks = [ @@ -801,7 +801,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":gl_calculator_helper", - ":gpu_buffer_storage_image_frame", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:status", @@ -927,7 +926,6 @@ mediapipe_cc_proto_library( objc_library( name = "metal_copy_calculator", srcs = ["MetalCopyCalculator.mm"], - copts = ["-std=c++17"], features = ["-layering_check"], sdk_frameworks = [ "CoreVideo", @@ -946,7 +944,6 @@ objc_library( objc_library( name = "metal_rgb_weight_calculator", srcs = ["MetalRgbWeightCalculator.mm"], - copts = ["-std=c++17"], features = ["-layering_check"], sdk_frameworks = [ "CoreVideo", @@ -964,7 +961,6 @@ objc_library( objc_library( name = "metal_sobel_calculator", srcs = ["MetalSobelCalculator.mm"], - copts = ["-std=c++17"], features = ["-layering_check"], sdk_frameworks = [ "CoreVideo", @@ -982,7 +978,6 @@ objc_library( objc_library( name = "metal_sobel_compute_calculator", srcs = ["MetalSobelComputeCalculator.mm"], - copts = ["-std=c++17"], features = ["-layering_check"], sdk_frameworks = [ "CoreVideo", @@ -1018,7 +1013,6 @@ objc_library( objc_library( name = "mps_threshold_calculator", srcs = ["MPSThresholdCalculator.mm"], - copts = ["-std=c++17"], sdk_frameworks = [ "CoreVideo", "Metal", @@ -1053,7 +1047,6 @@ objc_library( ], copts = [ "-Wno-shorten-64-to-32", - "-std=c++17", ], data = [ "//mediapipe/objc:testdata/googlelogo_color_272x92dp.png", diff --git a/mediapipe/gpu/gl_context.cc b/mediapipe/gpu/gl_context.cc index 25e969d2f..179c35150 100644 --- a/mediapipe/gpu/gl_context.cc +++ b/mediapipe/gpu/gl_context.cc @@ -23,6 +23,7 @@ #include "absl/base/dynamic_annotations.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" @@ -358,6 +359,7 @@ absl::Status GlContext::FinishInitialization(bool create_thread) { GlContext::GlContext() {} GlContext::~GlContext() { + destructing_ = true; // Note: on Apple platforms, this object contains Objective-C objects. // The destructor will release them, but ARC must be on. #ifdef __OBJC__ @@ -366,11 +368,16 @@ GlContext::~GlContext() { #endif #endif // __OBJC__ + auto clear_attachments = [this] { + attachments_.clear(); + if (profiling_helper_) { + profiling_helper_->LogAllTimestamps(); + } + }; + if (thread_) { - auto status = thread_->Run([this] { - if (profiling_helper_) { - profiling_helper_->LogAllTimestamps(); - } + auto status = thread_->Run([this, clear_attachments] { + clear_attachments(); return ExitContext(nullptr); }); LOG_IF(ERROR, !status.ok()) @@ -378,6 +385,17 @@ GlContext::~GlContext() { if (thread_->IsCurrentThread()) { thread_.release()->SelfDestruct(); } + } else { + if (IsCurrent()) { + clear_attachments(); + } else { + ContextBinding saved_context; + auto status = SwitchContextAndRun([&clear_attachments] { + clear_attachments(); + return absl::OkStatus(); + }); + LOG_IF(ERROR, !status.ok()) << status; + } } DestroyContext(); } @@ -501,6 +519,14 @@ absl::Status GlContext::SwitchContext(ContextBinding* saved_context, } } +GlContext::ContextBinding GlContext::ThisContextBinding() { + GlContext::ContextBinding result = ThisContextBindingPlatform(); + if (!destructing_) { + result.context_object = shared_from_this(); + } + return result; +} + absl::Status GlContext::EnterContext(ContextBinding* saved_context) { DCHECK(HasContext()); return SwitchContext(saved_context, ThisContextBinding()); diff --git a/mediapipe/gpu/gl_context.h b/mediapipe/gpu/gl_context.h index 9e798f98a..7f1fbbbdc 100644 --- a/mediapipe/gpu/gl_context.h +++ b/mediapipe/gpu/gl_context.h @@ -21,6 +21,7 @@ #include #include +#include "absl/container/flat_hash_map.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/executor.h" #include "mediapipe/framework/mediapipe_profiling.h" @@ -285,6 +286,48 @@ class GlContext : public std::enable_shared_from_this { // Sets default texture filtering parameters. void SetStandardTextureParams(GLenum target, GLint internal_format); + template + using AttachmentPtr = std::unique_ptr>; + + template + static std::enable_if_t::value, AttachmentPtr> + MakeAttachmentPtr(Args&&... args) { + return {new T(std::forward(args)...), + [](void* ptr) { delete static_cast(ptr); }}; + } + + class AttachmentBase {}; + + template + class Attachment : public AttachmentBase { + public: + using FactoryT = std::function(GlContext&)>; + Attachment(FactoryT factory) : factory_(factory) {} + + Attachment(const Attachment&) = delete; + Attachment(Attachment&&) = delete; + Attachment& operator=(const Attachment&) = delete; + Attachment& operator=(Attachment&&) = delete; + + T& Get(GlContext& ctx) const { return ctx.GetCachedAttachment(*this); } + + const FactoryT& factory() const { return factory_; } + + private: + FactoryT factory_; + }; + + // TOOD: const result? + template + T& GetCachedAttachment(const Attachment& attachment) { + DCHECK(IsCurrent()); + AttachmentPtr& entry = attachments_[&attachment]; + if (entry == nullptr) { + entry = attachment.factory()(*this); + } + return *static_cast(entry.get()); + } + // These are used for testing specific SyncToken implementations. Do not use // outside of tests. enum class SyncTokenTypeForTest { @@ -387,6 +430,8 @@ class GlContext : public std::enable_shared_from_this { // A binding that can be used to make this GlContext current. ContextBinding ThisContextBinding(); + // Fill in platform-specific fields. Must _not_ set context_obj. + ContextBinding ThisContextBindingPlatform(); // Fills in a ContextBinding with platform-specific information about which // context is current on this thread. static void GetCurrentContextBinding(ContextBinding* binding); @@ -409,6 +454,8 @@ class GlContext : public std::enable_shared_from_this { // better mechanism? bool can_linear_filter_float_textures_; + absl::flat_hash_map> attachments_; + // Number of glFinish calls completed on the GL thread. // Changes should be guarded by mutex_. However, we use simple atomic // loads for efficiency on the fast path. @@ -428,6 +475,8 @@ class GlContext : public std::enable_shared_from_this { absl::CondVar wait_for_gl_finish_cv_ ABSL_GUARDED_BY(mutex_); std::unique_ptr profiling_helper_ = nullptr; + + bool destructing_ = false; }; // For backward compatibility. TODO: migrate remaining callers. diff --git a/mediapipe/gpu/gl_context_eagl.cc b/mediapipe/gpu/gl_context_eagl.cc index 2811ea0b4..865813c21 100644 --- a/mediapipe/gpu/gl_context_eagl.cc +++ b/mediapipe/gpu/gl_context_eagl.cc @@ -84,9 +84,8 @@ void GlContext::DestroyContext() { } } -GlContext::ContextBinding GlContext::ThisContextBinding() { +GlContext::ContextBinding GlContext::ThisContextBindingPlatform() { GlContext::ContextBinding result; - result.context_object = shared_from_this(); result.context = context_; return result; } diff --git a/mediapipe/gpu/gl_context_egl.cc b/mediapipe/gpu/gl_context_egl.cc index 44ddd9314..75eeeb936 100644 --- a/mediapipe/gpu/gl_context_egl.cc +++ b/mediapipe/gpu/gl_context_egl.cc @@ -269,9 +269,8 @@ void GlContext::DestroyContext() { #endif // __ANDROID__ } -GlContext::ContextBinding GlContext::ThisContextBinding() { +GlContext::ContextBinding GlContext::ThisContextBindingPlatform() { GlContext::ContextBinding result; - result.context_object = shared_from_this(); result.display = display_; result.draw_surface = surface_; result.read_surface = surface_; diff --git a/mediapipe/gpu/gl_context_nsgl.cc b/mediapipe/gpu/gl_context_nsgl.cc index d9a261e5b..dda74f0ce 100644 --- a/mediapipe/gpu/gl_context_nsgl.cc +++ b/mediapipe/gpu/gl_context_nsgl.cc @@ -134,9 +134,8 @@ void GlContext::DestroyContext() { } } -GlContext::ContextBinding GlContext::ThisContextBinding() { +GlContext::ContextBinding GlContext::ThisContextBindingPlatform() { GlContext::ContextBinding result; - result.context_object = shared_from_this(); result.context = context_; return result; } diff --git a/mediapipe/gpu/gl_context_webgl.cc b/mediapipe/gpu/gl_context_webgl.cc index 01fc12d6d..b1f5295c9 100644 --- a/mediapipe/gpu/gl_context_webgl.cc +++ b/mediapipe/gpu/gl_context_webgl.cc @@ -173,9 +173,8 @@ void GlContext::DestroyContext() { } } -GlContext::ContextBinding GlContext::ThisContextBinding() { +GlContext::ContextBinding GlContext::ThisContextBindingPlatform() { GlContext::ContextBinding result; - result.context_object = shared_from_this(); result.context = context_; return result; } diff --git a/mediapipe/gpu/gl_quad_renderer.cc b/mediapipe/gpu/gl_quad_renderer.cc index c25a37e48..309b83238 100644 --- a/mediapipe/gpu/gl_quad_renderer.cc +++ b/mediapipe/gpu/gl_quad_renderer.cc @@ -111,7 +111,7 @@ absl::Status QuadRenderer::GlRender(float frame_width, float frame_height, FrameScaleMode scale_mode, FrameRotation rotation, bool flip_horizontal, bool flip_vertical, - bool flip_texture) { + bool flip_texture) const { RET_CHECK(program_) << "Must setup the program before rendering."; glUseProgram(program_); diff --git a/mediapipe/gpu/gl_quad_renderer.h b/mediapipe/gpu/gl_quad_renderer.h index 4ef6c7669..3d0e9d642 100644 --- a/mediapipe/gpu/gl_quad_renderer.h +++ b/mediapipe/gpu/gl_quad_renderer.h @@ -72,7 +72,7 @@ class QuadRenderer { absl::Status GlRender(float frame_width, float frame_height, float view_width, float view_height, FrameScaleMode scale_mode, FrameRotation rotation, bool flip_horizontal, - bool flip_vertical, bool flip_texture); + bool flip_vertical, bool flip_texture) const; // Deletes the rendering program. Must be called withn the GL context where // it was created. void GlTeardown(); diff --git a/mediapipe/gpu/gpu_buffer_format.cc b/mediapipe/gpu/gpu_buffer_format.cc index 41c98ba43..1dcd58e63 100644 --- a/mediapipe/gpu/gpu_buffer_format.cc +++ b/mediapipe/gpu/gpu_buffer_format.cc @@ -144,7 +144,7 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, }}, {GpuBufferFormat::kRGBAFloat128, { - {GL_RGBA, GL_RGBA, GL_FLOAT, 1}, + {GL_RGBA32F, GL_RGBA, GL_FLOAT, 1}, }}, }}; diff --git a/mediapipe/gpu/gpu_service.cc b/mediapipe/gpu/gpu_service.cc index f280a58d0..53a0e0f47 100644 --- a/mediapipe/gpu/gpu_service.cc +++ b/mediapipe/gpu/gpu_service.cc @@ -16,6 +16,7 @@ namespace mediapipe { -const GraphService kGpuService("kGpuService"); +const GraphService kGpuService( + "kGpuService", GraphServiceBase::kAllowDefaultInitialization); } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_service.h b/mediapipe/gpu/gpu_service.h index a610a275f..aa5990036 100644 --- a/mediapipe/gpu/gpu_service.h +++ b/mediapipe/gpu/gpu_service.h @@ -17,9 +17,18 @@ #include "mediapipe/framework/graph_service.h" +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gpu_shared_data_internal.h" +#endif // !MEDIAPIPE_DISABLE_GPU + namespace mediapipe { -class GpuResources; +#if MEDIAPIPE_DISABLE_GPU +class GpuResources { + GpuResources() = delete; +}; +#endif // !MEDIAPIPE_DISABLE_GPU + extern const GraphService kGpuService; } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index 7aa622d24..a8bf0c3a3 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -105,7 +105,7 @@ GpuResources::~GpuResources() { } absl::Status GpuResources::PrepareGpuNode(CalculatorNode* node) { - CHECK(node->UsesGpu()); + CHECK(ContainsKey(node->Contract().ServiceRequests(), kGpuService.key)); std::string node_id = node->GetCalculatorState().NodeName(); std::string node_type = node->GetCalculatorState().CalculatorType(); std::string context_key; diff --git a/mediapipe/gpu/graph_support.h b/mediapipe/gpu/graph_support.h index e3e531e74..6541771c8 100644 --- a/mediapipe/gpu/graph_support.h +++ b/mediapipe/gpu/graph_support.h @@ -16,10 +16,14 @@ #ifndef MEDIAPIPE_GPU_GRAPH_SUPPORT_H_ #define MEDIAPIPE_GPU_GRAPH_SUPPORT_H_ -#include "mediapipe/gpu/gpu_service.h" +#include "mediapipe/framework/graph_service.h" namespace mediapipe { +// Forward declaration to avoid depending on GpuResources here. +class GpuResources; +extern const GraphService kGpuService; + static constexpr char kGpuSharedTagName[] = "GPU_SHARED"; static constexpr char kGpuSharedSidePacketName[] = "gpu_shared"; static constexpr char kGpuExecutorName[] = "__gpu"; diff --git a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc index 11f3ae58c..2a8331db8 100644 --- a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc +++ b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc @@ -16,7 +16,10 @@ #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/gpu/gl_calculator_helper.h" -#include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" + +#ifdef __APPLE__ +#include "mediapipe/objc/util.h" +#endif namespace mediapipe { @@ -31,7 +34,9 @@ class ImageFrameToGpuBufferCalculator : public CalculatorBase { absl::Status Process(CalculatorContext* cc) override; private: +#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER GlCalculatorHelper helper_; +#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER }; REGISTER_CALCULATOR(ImageFrameToGpuBufferCalculator); @@ -51,25 +56,28 @@ absl::Status ImageFrameToGpuBufferCalculator::Open(CalculatorContext* cc) { // Inform the framework that we always output at the same timestamp // as we receive a packet at. cc->SetOffset(TimestampDiff(0)); +#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER MP_RETURN_IF_ERROR(helper_.Open(cc)); +#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER return absl::OkStatus(); } absl::Status ImageFrameToGpuBufferCalculator::Process(CalculatorContext* cc) { - auto image_frame = std::const_pointer_cast( - mediapipe::SharedPtrWithPacket( - cc->Inputs().Index(0).Value())); - auto gpu_buffer = MakePacket( - std::make_shared( - std::move(image_frame))) - .At(cc->InputTimestamp()); - // Request GPU access to ensure the data is available to the GPU. - // TODO: have a better way to do this, or defer until later. - helper_.RunInGlContext([&gpu_buffer] { - auto view = gpu_buffer.Get().GetReadView(0); +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + CFHolder buffer; + MP_RETURN_IF_ERROR(CreateCVPixelBufferForImageFramePacket( + cc->Inputs().Index(0).Value(), &buffer)); + cc->Outputs().Index(0).Add(new GpuBuffer(buffer), cc->InputTimestamp()); +#else + const auto& input = cc->Inputs().Index(0).Get(); + helper_.RunInGlContext([this, &input, &cc]() { + auto src = helper_.CreateSourceTexture(input); + auto output = src.GetFrame(); + glFlush(); + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + src.Release(); }); - cc->Outputs().Index(0).AddPacket(std::move(gpu_buffer)); - +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER return absl::OkStatus(); } diff --git a/mediapipe/java/com/google/mediapipe/components/BUILD b/mediapipe/java/com/google/mediapipe/components/BUILD index 4471a0c56..a1ec17548 100644 --- a/mediapipe/java/com/google/mediapipe/components/BUILD +++ b/mediapipe/java/com/google/mediapipe/components/BUILD @@ -62,17 +62,27 @@ android_library( ], ) +# Interfaces for common Audio Consumer and Producers in MediaPipe. android_library( - name = "android_microphone_helper", + name = "android_audio_components", srcs = [ "AudioDataConsumer.java", "AudioDataProcessor.java", "AudioDataProducer.java", + ], + visibility = ["//visibility:public"], + deps = ["@maven//:com_google_guava_guava"], +) + +# MicrophoneHelper that provides access to audio data from a microphone +android_library( + name = "android_microphone_helper", + srcs = [ "MicrophoneHelper.java", ], visibility = ["//visibility:public"], deps = [ - "@maven//:com_google_code_findbugs_jsr305", + ":android_audio_components", "@maven//:com_google_guava_guava", ], ) diff --git a/mediapipe/java/com/google/mediapipe/components/CameraXPreviewHelper.java b/mediapipe/java/com/google/mediapipe/components/CameraXPreviewHelper.java index cee0c2770..70234170d 100644 --- a/mediapipe/java/com/google/mediapipe/components/CameraXPreviewHelper.java +++ b/mediapipe/java/com/google/mediapipe/components/CameraXPreviewHelper.java @@ -172,6 +172,25 @@ public class CameraXPreviewHelper extends CameraHelper { startCamera(activity, (LifecycleOwner) activity, cameraFacing, targetSize); } + /** + * Initializes the camera and sets it up for accessing frames. This constructor also enables the + * image capture use case from {@link CameraX}. + * + * @param imageCaptureBuilder Builder for an {@link ImageCapture}, this builder must contain the + * desired configuration options for the image capture being build (e.g. target resolution). + * @param targetSize the preview size to use. If set to {@code null}, the helper will default to + * 1280 * 720. + */ + public void startCamera( + Activity activity, + @Nonnull ImageCapture.Builder imageCaptureBuilder, + CameraFacing cameraFacing, + @Nullable SurfaceTexture surfaceTexture, + @Nullable Size targetSize) { + this.imageCaptureBuilder = imageCaptureBuilder; + startCamera(activity, (LifecycleOwner) activity, cameraFacing, surfaceTexture, targetSize); + } + /** * Initializes the camera and sets it up for accessing frames. * diff --git a/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java b/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java index f4aa330dd..42fcdb4d8 100644 --- a/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java +++ b/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java @@ -115,12 +115,33 @@ public class ExternalTextureConverter implements TextureFrameProducer { * frames when all frames in the pool are in-use, but they are only added back to the pool upon * release if the size allows so. * + * Please note, while this property allows the buffer pool to grow temporarily if needed, there is + * a different bufferPoolMaxSize properly that strictly enforces buffer pool doesn't grow beyond + * size and incoming frames are dropped. + * * @param bufferPoolSize the number of camera frames that can enter processing simultaneously. */ public void setBufferPoolSize(int bufferPoolSize) { thread.setBufferPoolSize(bufferPoolSize); } + /** + * Sets the buffer pool max size. Setting to <= 0 effectively clears this property. + * + * If set (i.e. > 0), the value should be >= bufferPoolSize. While the API allows for setting a + * value lower without throwing an exception, internally the higher of the 2 values is used for + * enforcing buffer pool max size. + * + * When set, no TextureFrames are created beyond the specified size. New incoming + * frames will be dropped. + * + * When un-set (i.e. <= 0), new TextureFrames are temporarily allocated even bufferPoolSize is + * reached. However, they are not added back to the buffer pool upon release. + */ + public void setBufferPoolMaxSize(int bufferPoolMaxSize) { + thread.setBufferPoolMaxSize(bufferPoolMaxSize); + } + /** * Sets vertical flipping of the texture, useful for conversion between coordinate systems with * top-left v.s. bottom-left origins. This should be called before {@link @@ -260,7 +281,8 @@ public class ExternalTextureConverter implements TextureFrameProducer { private final Queue framesAvailable = new ArrayDeque<>(); private int framesInUse = 0; - private int framesToKeep; + private int bufferPoolSize; + private int bufferPoolMaxSize; private ExternalTextureRenderer renderer = null; private long nextFrameTimestampOffset = 0; @@ -291,13 +313,17 @@ public class ExternalTextureConverter implements TextureFrameProducer { public RenderThread(EGLContext parentContext, int numBuffers) { super(parentContext); - framesToKeep = numBuffers; + bufferPoolSize = numBuffers; renderer = new ExternalTextureRenderer(); consumers = new ArrayList<>(); } public void setBufferPoolSize(int bufferPoolSize) { - this.framesToKeep = bufferPoolSize; + this.bufferPoolSize = bufferPoolSize; + } + + public void setBufferPoolMaxSize(int bufferPoolMaxSize) { + this.bufferPoolMaxSize = bufferPoolMaxSize; } public void setFlipY(boolean flip) { @@ -406,11 +432,13 @@ public class ExternalTextureConverter implements TextureFrameProducer { boolean frameUpdated = false; for (TextureFrameConsumer consumer : consumers) { AppTextureFrame outputFrame = nextOutputFrame(); + if (outputFrame == null) { + break; + } // TODO: Switch to ref-counted single copy instead of making additional // copies blitting to separate textures each time. updateOutputFrame(outputFrame); frameUpdated = true; - if (consumer != null) { if (Log.isLoggable(TAG, Log.VERBOSE)) { Log.v( @@ -425,14 +453,10 @@ public class ExternalTextureConverter implements TextureFrameProducer { consumer.onNewFrame(outputFrame); } } - if (!frameUpdated) { // Need to update the frame even if there are no consumers. - AppTextureFrame outputFrame = nextOutputFrame(); - // TODO: Switch to ref-counted single copy instead of making additional - // copies blitting to separate textures each time. - updateOutputFrame(outputFrame); - // Release immediately as this is not sent to a consumer so no release() would be - // called otherwise. - outputFrame.release(); + if (!frameUpdated) { + // Progress the SurfaceTexture BufferQueue even if we didn't update the outputFrame, + // which could be either because there are no consumers or bufferPoolMaxSize is reached. + surfaceTexture.updateTexImage(); } } } finally { @@ -469,6 +493,13 @@ public class ExternalTextureConverter implements TextureFrameProducer { PoolTextureFrame outputFrame; synchronized (this) { outputFrame = framesAvailable.poll(); + // Don't create new frame if bufferPoolMaxSize is set (i.e. > 0) and reached. + if (outputFrame == null && bufferPoolMaxSize > 0 + && framesInUse >= max(bufferPoolMaxSize, bufferPoolSize)) { + Log.d(TAG, "Enforcing buffer pool max Size. FramesInUse: " + + framesInUse + " >= " + bufferPoolMaxSize); + return null; + } framesInUse++; } if (outputFrame == null) { @@ -492,7 +523,7 @@ public class ExternalTextureConverter implements TextureFrameProducer { protected synchronized void poolFrameReleased(PoolTextureFrame frame) { framesAvailable.offer(frame); framesInUse--; - int keep = max(framesToKeep - framesInUse, 0); + int keep = max(bufferPoolSize - framesInUse, 0); while (framesAvailable.size() > keep) { PoolTextureFrame textureFrameToRemove = framesAvailable.remove(); handler.post(() -> teardownFrame(textureFrameToRemove)); diff --git a/mediapipe/java/com/google/mediapipe/components/MicrophoneHelper.java b/mediapipe/java/com/google/mediapipe/components/MicrophoneHelper.java index 4775bd7ee..fc99115c4 100644 --- a/mediapipe/java/com/google/mediapipe/components/MicrophoneHelper.java +++ b/mediapipe/java/com/google/mediapipe/components/MicrophoneHelper.java @@ -14,6 +14,8 @@ package com.google.mediapipe.components; +import static java.lang.Math.max; + import android.media.AudioFormat; import android.media.AudioRecord; import android.media.AudioTimestamp; @@ -24,6 +26,7 @@ import android.util.Log; import com.google.common.base.Preconditions; import java.io.IOException; import java.nio.ByteBuffer; +import java.util.concurrent.CopyOnWriteArraySet; /** Provides access to audio data from a microphone. */ public class MicrophoneHelper implements AudioDataProducer { @@ -88,10 +91,10 @@ public class MicrophoneHelper implements AudioDataProducer { // sent to the listener of this class. private boolean recording = false; - // The consumer is provided with the data read on every AudioRecord.read() call. If the consumer - // called stopRecording() while a call to AudioRecord.read() was blocked, the class will discard + // The consumers are provided with the data read on every AudioRecord.read() call. If the consumer + // called stopMicrophone() while a call to AudioRecord.read() was blocked, the class will discard // the data read after recording stopped. - private AudioDataConsumer consumer; + private final CopyOnWriteArraySet consumers = new CopyOnWriteArraySet<>(); // TODO: Add a constructor that takes an AudioFormat. @@ -105,7 +108,6 @@ public class MicrophoneHelper implements AudioDataProducer { public MicrophoneHelper(int sampleRateInHz, int channelConfig) { this.sampleRateInHz = sampleRateInHz; this.channelConfig = channelConfig; - // Number of channels of audio source, depending on channelConfig. final int numChannels = channelConfig == AudioFormat.CHANNEL_IN_STEREO ? 2 : 1; @@ -140,7 +142,7 @@ public class MicrophoneHelper implements AudioDataProducer { (int) Math.ceil(1.0 * bytesPerFrame * sampleRateInHz * micros / MICROS_PER_SECOND); // The size of the internal buffer should be greater than the size of the audio packet read // and sent to the AudioDataConsumer so that AudioRecord. - audioRecordBufferSize = Math.max(audioPacketBufferSize, minBufferSize) * BUFFER_SIZE_MULTIPLIER; + audioRecordBufferSize = max(audioPacketBufferSize, minBufferSize) * BUFFER_SIZE_MULTIPLIER; } private void setupAudioRecord() { @@ -210,8 +212,10 @@ public class MicrophoneHelper implements AudioDataProducer { // Confirm that the consumer is still interested in receiving audio data and // stopMicrophone() wasn't called. If the consumer called stopMicrophone(), discard // the data read in the latest AudioRecord.read(...) function call. - if (recording && consumer != null) { - consumer.onNewAudioData(audioData, timestampMicros, audioFormat); + if (recording) { + for (AudioDataConsumer consumer : consumers) { + consumer.onNewAudioData(audioData, timestampMicros, audioFormat); + } } } }, @@ -389,8 +393,24 @@ public class MicrophoneHelper implements AudioDataProducer { audioRecord.release(); } + /* + * Clears all the old consumers and sets this as the new sole consumer. + */ @Override public void setAudioConsumer(AudioDataConsumer consumer) { - this.consumer = consumer; + consumers.clear(); + consumers.add(consumer); + } + + public void addAudioConsumer(AudioDataConsumer consumer) { + consumers.add(consumer); + } + + public void removeAudioConsumer(AudioDataConsumer consumer) { + consumers.remove(consumer); + } + + public void removeAllAudioConsumers() { + consumers.clear(); } } diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc index 1e4d74fab..6a67c01cb 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc @@ -449,11 +449,13 @@ absl::Status Graph::StartRunningGraph(JNIEnv* env) { } absl::Status status; #if !MEDIAPIPE_DISABLE_GPU - status = running_graph_->SetGpuResources(gpu_resources_); - if (!status.ok()) { - LOG(ERROR) << status.message(); - running_graph_.reset(nullptr); - return status; + if (gpu_resources_) { + status = running_graph_->SetGpuResources(gpu_resources_); + if (!status.ok()) { + LOG(ERROR) << status.message(); + running_graph_.reset(nullptr); + return status; + } } #endif // !MEDIAPIPE_DISABLE_GPU diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc b/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc index 8ea37d9c5..6797b4b20 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc @@ -273,6 +273,9 @@ void RegisterPacketGetterNatives(JNIEnv *env) { AddJNINativeMethod(&packet_getter_methods, packet_getter, "nativeGetFloat32Vector", "(J)[F", (void *)&PACKET_GETTER_METHOD(nativeGetFloat32Vector)); + AddJNINativeMethod(&packet_getter_methods, packet_getter, + "nativeGetProtoVector", "(J)[[B", + (void *)&PACKET_GETTER_METHOD(nativeGetProtoVector)); RegisterNativesVector(env, packet_getter_class, packet_getter_methods); env->DeleteLocalRef(packet_getter_class); } diff --git a/mediapipe/modules/face_geometry/README.md b/mediapipe/modules/face_geometry/README.md index 8427ea63c..649d0a853 100644 --- a/mediapipe/modules/face_geometry/README.md +++ b/mediapipe/modules/face_geometry/README.md @@ -1,20 +1,20 @@ -# face_geometry +# face_transform Protos|Details :--- | :--- [`face_geometry.Environment`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/protos/environment.proto)| Describes an environment; includes the camera frame origin point location as well as virtual camera parameters. -[`face_geometry.GeometryPipelineMetadata`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/protos/geometry_pipeline_metadata.proto)| Describes metadata needed to estimate face geometry based on the face landmark module result. -[`face_geometry.FaceGeometry`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/protos/face_geometry.proto)| Describes geometry data for a single face; includes a face mesh surface and a face pose in a given environment. -[`face_geometry.Mesh3d`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/protos/mesh_3d.proto)| Describes a 3D mesh surface. +[`face_geometry.GeometryPipelineMetadata`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/protos/geometry_pipeline_metadata.proto)| Describes metadata needed to estimate face 3D transform based on the face landmark module result. +[`face_geometry.FaceGeometry`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/protos/face_geometry.proto)| Describes 3D transform data for a single face; includes a face mesh surface and a face pose in a given environment. +[`face_geometry.Mesh3d`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/protos/mesh_3d.proto)| Describes a 3D mesh triangular surface. Calculators|Details :--- | :--- [`FaceGeometryEnvGeneratorCalculator`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/env_generator_calculator.cc)| Generates an environment that describes a virtual scene. -[`FaceGeometryPipelineCalculator`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/geometry_pipeline_calculator.cc)| Extracts face geometry for multiple faces from a vector of landmark lists. +[`FaceGeometryPipelineCalculator`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/geometry_pipeline_calculator.cc)| Extracts face 3D transform for multiple faces from a vector of landmark lists. [`FaceGeometryEffectRendererCalculator`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/effect_renderer_calculator.cc)| Renders a face effect. Subgraphs|Details :--- | :--- -[`FaceGeometryFromDetection`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/face_geometry_from_detection.pbtxt)| Extracts geometry from face detection for multiple faces. -[`FaceGeometryFromLandmarks`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/face_geometry_from_landmarks.pbtxt)| Extracts geometry from face landmarks for multiple faces. -[`FaceGeometry`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/face_geometry.pbtxt)| Extracts geometry from face landmarks for multiple faces. Deprecated, please use `FaceGeometryFromLandmarks` in the new code. +[`FaceGeometryFromDetection`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/face_geometry_from_detection.pbtxt)| Extracts 3D transform from face detection for multiple faces. +[`FaceGeometryFromLandmarks`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/face_geometry_from_landmarks.pbtxt)| Extracts 3D transform from face landmarks for multiple faces. +[`FaceGeometry`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/face_geometry.pbtxt)| Extracts 3D transform from face landmarks for multiple faces. Deprecated, please use `FaceGeometryFromLandmarks` in the new code. diff --git a/mediapipe/modules/face_geometry/face_geometry.pbtxt b/mediapipe/modules/face_geometry/face_geometry.pbtxt index 76228d4b1..33e31d27e 100644 --- a/mediapipe/modules/face_geometry/face_geometry.pbtxt +++ b/mediapipe/modules/face_geometry/face_geometry.pbtxt @@ -1,4 +1,4 @@ -# MediaPipe graph to extract geometry from face landmarks for multiple faces. +# MediaPipe graph to extract 3D transform from face landmarks for multiple faces. # # It is required that "geometry_pipeline_metadata.binarypb" is available at # "mediapipe/modules/face_geometry/data/geometry_pipeline_metadata.binarypb" @@ -28,11 +28,11 @@ input_stream: "MULTI_FACE_LANDMARKS:multi_face_landmarks" # (face_geometry::Environment) input_side_packet: "ENVIRONMENT:environment" -# A list of geometry data for each detected face. +# A list of 3D transform data for each detected face. # (std::vector) output_stream: "MULTI_FACE_GEOMETRY:multi_face_geometry" -# Extracts face geometry for multiple faces from a vector of face landmark +# Extracts face 3D transform for multiple faces from a vector of face landmark # lists. node { calculator: "FaceGeometryPipelineCalculator" diff --git a/mediapipe/modules/face_geometry/face_geometry_from_detection.pbtxt b/mediapipe/modules/face_geometry/face_geometry_from_detection.pbtxt index f570286aa..b9ee5a49e 100644 --- a/mediapipe/modules/face_geometry/face_geometry_from_detection.pbtxt +++ b/mediapipe/modules/face_geometry/face_geometry_from_detection.pbtxt @@ -1,4 +1,4 @@ -# MediaPipe graph to extract geometry from face detection for multiple faces. +# MediaPipe graph to extract 3D transform from face detection for multiple faces. # # It is required that "geometry_pipeline_metadata_detection.binarypb" is # available at @@ -34,7 +34,7 @@ input_stream: "MULTI_FACE_DETECTION:multi_face_detection" # (face_geometry::Environment) input_side_packet: "ENVIRONMENT:environment" -# A list of geometry data for each detected face. +# A list of 3D transform data for each detected face. # (std::vector) # # NOTE: the triangular topology of the face meshes is only useful when derived @@ -43,7 +43,7 @@ input_side_packet: "ENVIRONMENT:environment" # defined here only to comply with the API. It should be considered as # a placeholder and/or for debugging purposes. # -# Use the face geometry derived from the face detection landmarks +# Use the face 3D transform derived from the face detection landmarks # (keypoints) for the face pose transformation matrix, not the mesh. output_stream: "MULTI_FACE_GEOMETRY:multi_face_geometry" @@ -71,7 +71,7 @@ node { output_stream: "ITERABLE:multi_face_landmarks" } -# Extracts face geometry for multiple faces from a vector of face detection +# Extracts face 3D transform for multiple faces from a vector of face detection # landmark lists. node { calculator: "FaceGeometryPipelineCalculator" diff --git a/mediapipe/modules/face_geometry/face_geometry_from_landmarks.pbtxt b/mediapipe/modules/face_geometry/face_geometry_from_landmarks.pbtxt index 329147663..ffad46365 100644 --- a/mediapipe/modules/face_geometry/face_geometry_from_landmarks.pbtxt +++ b/mediapipe/modules/face_geometry/face_geometry_from_landmarks.pbtxt @@ -1,4 +1,4 @@ -# MediaPipe graph to extract geometry from face landmarks for multiple faces. +# MediaPipe graph to extract 3D transform from face landmarks for multiple faces. # # It is required that "geometry_pipeline_metadata_from_landmark.binarypb" is # available at @@ -34,11 +34,11 @@ input_stream: "MULTI_FACE_LANDMARKS:multi_face_landmarks" # (face_geometry::Environment) input_side_packet: "ENVIRONMENT:environment" -# A list of geometry data for each detected face. +# A list of 3D transform data for each detected face. # (std::vector) output_stream: "MULTI_FACE_GEOMETRY:multi_face_geometry" -# Extracts face geometry for multiple faces from a vector of face landmark +# Extracts face 3D transform for multiple faces from a vector of face landmark # lists. node { calculator: "FaceGeometryPipelineCalculator" diff --git a/mediapipe/modules/face_landmark/face_landmark_cpu.pbtxt b/mediapipe/modules/face_landmark/face_landmark_cpu.pbtxt index 4604fc753..ad38678f5 100644 --- a/mediapipe/modules/face_landmark/face_landmark_cpu.pbtxt +++ b/mediapipe/modules/face_landmark/face_landmark_cpu.pbtxt @@ -71,7 +71,7 @@ node { # supports custom ops needed by the model used in this graph. node { calculator: "TfLiteCustomOpResolverCalculator" - output_side_packet: "op_resolver" + output_side_packet: "OP_RESOLVER:op_resolver" } # Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a @@ -81,7 +81,7 @@ node { calculator: "InferenceCalculator" input_stream: "TENSORS:input_tensors" input_side_packet: "MODEL:model" - input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver" + input_side_packet: "OP_RESOLVER:op_resolver" output_stream: "TENSORS:output_tensors" options: { [mediapipe.InferenceCalculatorOptions.ext] { diff --git a/mediapipe/modules/face_landmark/face_landmark_gpu.pbtxt b/mediapipe/modules/face_landmark/face_landmark_gpu.pbtxt index 854ceaff6..49e597e0c 100644 --- a/mediapipe/modules/face_landmark/face_landmark_gpu.pbtxt +++ b/mediapipe/modules/face_landmark/face_landmark_gpu.pbtxt @@ -72,7 +72,7 @@ node { # supports custom ops needed by the model used in this graph. node { calculator: "TfLiteCustomOpResolverCalculator" - output_side_packet: "op_resolver" + output_side_packet: "OP_RESOLVER:op_resolver" } # Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a @@ -82,7 +82,7 @@ node { calculator: "InferenceCalculator" input_stream: "TENSORS:input_tensors" input_side_packet: "MODEL:model" - input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver" + input_side_packet: "OP_RESOLVER:op_resolver" output_stream: "TENSORS:output_tensors" options: { [mediapipe.InferenceCalculatorOptions.ext] { diff --git a/mediapipe/modules/palm_detection/palm_detection_cpu.pbtxt b/mediapipe/modules/palm_detection/palm_detection_cpu.pbtxt index 32b3927d3..2b1a34ef0 100644 --- a/mediapipe/modules/palm_detection/palm_detection_cpu.pbtxt +++ b/mediapipe/modules/palm_detection/palm_detection_cpu.pbtxt @@ -41,7 +41,7 @@ node { # supports custom ops needed by the model used in this graph. node { calculator: "TfLiteCustomOpResolverCalculator" - output_side_packet: "opresolver" + output_side_packet: "OP_RESOLVER:opresolver" } # Loads the palm detection TF Lite model. @@ -58,7 +58,7 @@ node { calculator: "InferenceCalculator" input_stream: "TENSORS:input_tensor" output_stream: "TENSORS:detection_tensors" - input_side_packet: "CUSTOM_OP_RESOLVER:opresolver" + input_side_packet: "OP_RESOLVER:opresolver" input_side_packet: "MODEL:model" options: { [mediapipe.InferenceCalculatorOptions.ext] { diff --git a/mediapipe/modules/palm_detection/palm_detection_gpu.pbtxt b/mediapipe/modules/palm_detection/palm_detection_gpu.pbtxt index 73e4127f1..c8498b544 100644 --- a/mediapipe/modules/palm_detection/palm_detection_gpu.pbtxt +++ b/mediapipe/modules/palm_detection/palm_detection_gpu.pbtxt @@ -42,7 +42,7 @@ node { # supports custom ops needed by the model used in this graph. node { calculator: "TfLiteCustomOpResolverCalculator" - output_side_packet: "opresolver" + output_side_packet: "OP_RESOLVER:opresolver" options: { [mediapipe.TfLiteCustomOpResolverCalculatorOptions.ext] { use_gpu: true @@ -64,7 +64,7 @@ node { calculator: "InferenceCalculator" input_stream: "TENSORS:input_tensor" output_stream: "TENSORS:detection_tensors" - input_side_packet: "CUSTOM_OP_RESOLVER:opresolver" + input_side_packet: "OP_RESOLVER:opresolver" input_side_packet: "MODEL:model" options: { [mediapipe.InferenceCalculatorOptions.ext] { diff --git a/mediapipe/modules/selfie_segmentation/selfie_segmentation_cpu.pbtxt b/mediapipe/modules/selfie_segmentation/selfie_segmentation_cpu.pbtxt index 591824851..b52e9a34a 100644 --- a/mediapipe/modules/selfie_segmentation/selfie_segmentation_cpu.pbtxt +++ b/mediapipe/modules/selfie_segmentation/selfie_segmentation_cpu.pbtxt @@ -77,7 +77,7 @@ node { # supports custom ops needed by the model used in this graph. node { calculator: "TfLiteCustomOpResolverCalculator" - output_side_packet: "op_resolver" + output_side_packet: "OP_RESOLVER:op_resolver" } # Loads the selfie segmentation TF Lite model. @@ -93,7 +93,7 @@ node { input_stream: "TENSORS:input_tensors" output_stream: "TENSORS:output_tensors" input_side_packet: "MODEL:model" - input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver" + input_side_packet: "OP_RESOLVER:op_resolver" options: { [mediapipe.InferenceCalculatorOptions.ext] { delegate { diff --git a/mediapipe/modules/selfie_segmentation/selfie_segmentation_gpu.pbtxt b/mediapipe/modules/selfie_segmentation/selfie_segmentation_gpu.pbtxt index 5f9e55eb7..740962bc7 100644 --- a/mediapipe/modules/selfie_segmentation/selfie_segmentation_gpu.pbtxt +++ b/mediapipe/modules/selfie_segmentation/selfie_segmentation_gpu.pbtxt @@ -79,7 +79,7 @@ node { # supports custom ops needed by the model used in this graph. node { calculator: "TfLiteCustomOpResolverCalculator" - output_side_packet: "op_resolver" + output_side_packet: "OP_RESOLVER:op_resolver" options: { [mediapipe.TfLiteCustomOpResolverCalculatorOptions.ext] { use_gpu: true @@ -100,7 +100,7 @@ node { input_stream: "TENSORS:input_tensors" output_stream: "TENSORS:output_tensors" input_side_packet: "MODEL:model" - input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver" + input_side_packet: "OP_RESOLVER:op_resolver" } # Retrieves the size of the input image. diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index 75d74f06f..4f90c6712 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -67,7 +67,6 @@ objc_library( hdrs = MEDIAPIPE_IOS_HDRS, copts = [ "-Wno-shorten-64-to-32", - "-std=c++17", ], sdk_frameworks = [ # Needed for OpenCV. @@ -146,7 +145,6 @@ objc_library( ], copts = [ "-Wno-shorten-64-to-32", - "-std=c++17", ], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ @@ -186,7 +184,6 @@ objc_library( ], copts = [ "-Wno-shorten-64-to-32", - "-std=c++17", ], sdk_frameworks = [ "CoreVideo", @@ -209,7 +206,6 @@ objc_library( ], copts = [ "-Wno-shorten-64-to-32", - "-std=c++17", ], sdk_frameworks = [ "Accelerate", @@ -241,7 +237,6 @@ objc_library( ], copts = [ "-Wno-shorten-64-to-32", - "-std=c++17", ], data = [ "testdata/googlelogo_color_272x92dp.png", diff --git a/mediapipe/python/pybind/image.cc b/mediapipe/python/pybind/image.cc index d0c77b7df..651eb2ca6 100644 --- a/mediapipe/python/pybind/image.cc +++ b/mediapipe/python/pybind/image.cc @@ -14,6 +14,8 @@ #include "mediapipe/framework/formats/image.h" +#include + #include "mediapipe/python/pybind/image_frame_util.h" #include "mediapipe/python/pybind/util.h" #include "pybind11/stl.h" @@ -84,8 +86,8 @@ void ImageSubmodule(pybind11::module* module) { "uint8 image data should be one of the GRAY8, " "SRGB, and SRGBA MediaPipe image formats."); } - return Image(std::make_shared( - std::move(*CreateImageFrame(format, data).release()))); + return Image(std::shared_ptr( + CreateImageFrame(format, data))); }), R"doc(For uint8 data type, valid ImageFormat are GRAY8, SGRB, and SRGBA.)doc", py::arg("image_format"), py::arg("data").noconvert()) @@ -100,8 +102,8 @@ void ImageSubmodule(pybind11::module* module) { "uint16 image data should be one of the GRAY16, " "SRGB48, and SRGBA64 MediaPipe image formats."); } - return Image(std::make_shared( - std::move(*CreateImageFrame(format, data).release()))); + return Image(std::shared_ptr( + CreateImageFrame(format, data))); }), R"doc(For uint16 data type, valid ImageFormat are GRAY16, SRGB48, and SRGBA64.)doc", py::arg("image_format"), py::arg("data").noconvert()) @@ -115,8 +117,8 @@ void ImageSubmodule(pybind11::module* module) { "float image data should be either VEC32F1 or VEC32F2 " "MediaPipe image formats."); } - return Image(std::make_shared( - std::move(*CreateImageFrame(format, data).release()))); + return Image(std::shared_ptr( + CreateImageFrame(format, data))); }), R"doc(For float data type, valid ImageFormat are VEC32F1 and VEC32F2.)doc", py::arg("image_format"), py::arg("data").noconvert()); diff --git a/mediapipe/python/pybind/packet_creator.cc b/mediapipe/python/pybind/packet_creator.cc index ef7b70194..bc2767f8f 100644 --- a/mediapipe/python/pybind/packet_creator.cc +++ b/mediapipe/python/pybind/packet_creator.cc @@ -55,17 +55,17 @@ Packet CreateImagePacket(mediapipe::ImageFormat::Format format, if (format == mediapipe::ImageFormat::SRGB || format == mediapipe::ImageFormat::SRGBA || format == mediapipe::ImageFormat::GRAY8) { - return MakePacket(std::make_shared( - std::move(*CreateImageFrame(format, data, copy).release()))); + return MakePacket(std::shared_ptr( + CreateImageFrame(format, data, copy))); } else if (format == mediapipe::ImageFormat::GRAY16 || format == mediapipe::ImageFormat::SRGB48 || format == mediapipe::ImageFormat::SRGBA64) { - return MakePacket(std::make_shared( - std::move(*CreateImageFrame(format, data, copy).release()))); + return MakePacket(std::shared_ptr( + CreateImageFrame(format, data, copy))); } else if (format == mediapipe::ImageFormat::VEC32F1 || format == mediapipe::ImageFormat::VEC32F2) { - return MakePacket(std::make_shared( - std::move(*CreateImageFrame(format, data, copy).release()))); + return MakePacket(std::shared_ptr( + CreateImageFrame(format, data, copy))); } throw RaisePyError(PyExc_RuntimeError, absl::StrCat("Unsupported ImageFormat: ", format).c_str()); @@ -633,8 +633,9 @@ void InternalPacketCreators(pybind11::module* m) { // both GPU and CPU can process it. image_frame_copy->CopyFrom(*image.GetImageFrameSharedPtr(), ImageFrame::kGlDefaultAlignmentBoundary); - return MakePacket(std::make_shared( - std::move(*image_frame_copy.release()))); + std::shared_ptr shared_image_frame = + std::move(image_frame_copy); + return MakePacket(shared_image_frame); }, py::arg("image").noconvert(), py::return_value_policy::move); diff --git a/mediapipe/python/solution_base.py b/mediapipe/python/solution_base.py index d4b9a943a..b33d116ac 100644 --- a/mediapipe/python/solution_base.py +++ b/mediapipe/python/solution_base.py @@ -94,6 +94,7 @@ class PacketDataType(enum.Enum): BOOL = 'bool' BOOL_LIST = 'bool_list' INT = 'int' + INT_LIST = 'int_list' FLOAT = 'float' FLOAT_LIST = 'float_list' AUDIO = 'matrix' @@ -123,6 +124,12 @@ NAME_TO_TYPE: Mapping[str, 'PacketDataType'] = { PacketDataType.BOOL_LIST, 'int': PacketDataType.INT, + '::std::vector': + PacketDataType.INT_LIST, + 'int64': + PacketDataType.INT, + '::std::vector': + PacketDataType.INT_LIST, 'float': PacketDataType.FLOAT, '::std::vector': diff --git a/mediapipe/python/solutions/holistic.py b/mediapipe/python/solutions/holistic.py index b58e3eaa7..c58901fcd 100644 --- a/mediapipe/python/solutions/holistic.py +++ b/mediapipe/python/solutions/holistic.py @@ -158,10 +158,10 @@ class Holistic(SolutionBase): """ results = super().process(input_data={'image': image}) - if results.pose_landmarks: - for landmark in results.pose_landmarks.landmark: + if results.pose_landmarks: # pytype: disable=attribute-error + for landmark in results.pose_landmarks.landmark: # pytype: disable=attribute-error landmark.ClearField('presence') - if results.pose_world_landmarks: - for landmark in results.pose_world_landmarks.landmark: + if results.pose_world_landmarks: # pytype: disable=attribute-error + for landmark in results.pose_world_landmarks.landmark: # pytype: disable=attribute-error landmark.ClearField('presence') return results diff --git a/mediapipe/python/solutions/objectron.py b/mediapipe/python/solutions/objectron.py index 28cc026aa..ea7981f1a 100644 --- a/mediapipe/python/solutions/objectron.py +++ b/mediapipe/python/solutions/objectron.py @@ -258,10 +258,10 @@ class Objectron(SolutionBase): """ results = super().process(input_data={'image': image}) - if results.detected_objects: - results.detected_objects = self._convert_format(results.detected_objects) + if results.detected_objects: # pytype: disable=attribute-error + results.detected_objects = self._convert_format(results.detected_objects) # type: ignore else: - results.detected_objects = None + results.detected_objects = None # pytype: disable=not-writable return results def _convert_format( diff --git a/mediapipe/python/solutions/pose.py b/mediapipe/python/solutions/pose.py index d4b499faa..74a52d611 100644 --- a/mediapipe/python/solutions/pose.py +++ b/mediapipe/python/solutions/pose.py @@ -183,10 +183,10 @@ class Pose(SolutionBase): """ results = super().process(input_data={'image': image}) - if results.pose_landmarks: - for landmark in results.pose_landmarks.landmark: + if results.pose_landmarks: # pytype: disable=attribute-error + for landmark in results.pose_landmarks.landmark: # pytype: disable=attribute-error landmark.ClearField('presence') - if results.pose_world_landmarks: - for landmark in results.pose_world_landmarks.landmark: + if results.pose_world_landmarks: # pytype: disable=attribute-error + for landmark in results.pose_world_landmarks.landmark: # pytype: disable=attribute-error landmark.ClearField('presence') return results diff --git a/mediapipe/util/BUILD b/mediapipe/util/BUILD index b5f52450f..12a34a4f5 100644 --- a/mediapipe/util/BUILD +++ b/mediapipe/util/BUILD @@ -35,6 +35,12 @@ mediapipe_proto_library( visibility = ["//visibility:public"], ) +mediapipe_proto_library( + name = "label_map_proto", + srcs = ["label_map.proto"], + visibility = ["//visibility:public"], +) + mediapipe_proto_library( name = "render_data_proto", srcs = ["render_data.proto"], @@ -124,6 +130,19 @@ cc_library( ], ) +cc_library( + name = "label_map_util", + srcs = ["label_map_util.cc"], + hdrs = ["label_map_util.h"], + visibility = ["//visibility:public"], + deps = [ + ":label_map_cc_proto", + "//mediapipe/framework/port:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + cc_library( name = "annotation_renderer", srcs = ["annotation_renderer.cc"], diff --git a/mediapipe/util/label_map.proto b/mediapipe/util/label_map.proto new file mode 100644 index 000000000..5d1123fb2 --- /dev/null +++ b/mediapipe/util/label_map.proto @@ -0,0 +1,40 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +// Mapping a numerical class index output to a Knowledge Graph entity +// ID or any other string label representing this class. Optionally it is +// possible to specify an additional display name (in a given language) which is +// typically used for display purposes. +message LabelMapItem { + // Label name. + // E.g. name = "/m/02xwb" + optional string name = 1; + + // Display name. + // E.g. display_name = "Fruit" + optional string display_name = 2; + + // Optional list of children (e.g. subcategories) used to represent a + // hierarchy. + repeated string child_name = 3; +} + +// Mapping from index to a label map item. +message LabelMap { + map index_to_item = 1; +} diff --git a/mediapipe/util/label_map_util.cc b/mediapipe/util/label_map_util.cc new file mode 100644 index 000000000..849cf4299 --- /dev/null +++ b/mediapipe/util/label_map_util.cc @@ -0,0 +1,78 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/util/label_map_util.h" + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/port/statusor.h" +#include "mediapipe/util/label_map.pb.h" + +namespace mediapipe { + +absl::StatusOr BuildLabelMapFromFiles( + absl::string_view labels_file_contents, + absl::string_view display_names_file) { + if (labels_file_contents.empty()) { + return absl::InvalidArgumentError("Expected non-empty labels file."); + } + std::vector labels = + absl::StrSplit(labels_file_contents, '\n'); + // In most cases, there is an empty line (i.e. newline character) at the end + // of the file that needs to be ignored. In such a situation, StrSplit() will + // produce a vector with an empty string as final element. Also note that in + // case `labels_file_contents` is entirely empty, StrSplit() will produce a + // vector with one single empty substring, so there's no out-of-range risk + // here. + if (labels[labels.size() - 1].empty()) { + labels.pop_back(); + } + + std::vector label_map_items; + label_map_items.reserve(labels.size()); + for (int i = 0; i < labels.size(); ++i) { + LabelMapItem item; + item.set_name(std::string(labels[i])); + label_map_items.emplace_back(item); + } + + if (!display_names_file.empty()) { + std::vector display_names = + absl::StrSplit(display_names_file, '\n'); + // In most cases, there is an empty line (i.e. newline character) at the end + // of the file that needs to be ignored. See above. + if (display_names[display_names.size() - 1].empty()) { + display_names.pop_back(); + } + if (display_names.size() != labels.size()) { + return absl::InvalidArgumentError(absl::StrFormat( + "Mismatch between number of labels (%d) and display names (%d).", + labels.size(), display_names.size())); + } + for (int i = 0; i < display_names.size(); ++i) { + label_map_items[i].set_display_name(display_names[i]); + } + } + LabelMap label_map; + for (int i = 0; i < label_map_items.size(); ++i) { + (*label_map.mutable_index_to_item())[i] = label_map_items[i]; + } + return label_map; +} + +} // namespace mediapipe diff --git a/mediapipe/util/label_map_util.h b/mediapipe/util/label_map_util.h new file mode 100644 index 000000000..75a5f7e75 --- /dev/null +++ b/mediapipe/util/label_map_util.h @@ -0,0 +1,34 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_UTIL_LABEL_MAP_UTIL_H_ +#define MEDIAPIPE_UTIL_LABEL_MAP_UTIL_H_ + +#include "absl/strings/string_view.h" +#include "mediapipe/framework/port/statusor.h" +#include "mediapipe/util/label_map.pb.h" + +namespace mediapipe { + +// Builds a label map from labels and (optional) display names file contents, +// both expected to contain one label per line. +// Returns an error e.g. if there's a mismatch between the number of labels and +// display names. +absl::StatusOr BuildLabelMapFromFiles( + absl::string_view labels_file_contents, + absl::string_view display_names_file); + +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_LABEL_MAP_UTIL_H_ diff --git a/mediapipe/util/tflite/BUILD b/mediapipe/util/tflite/BUILD index a4852b804..9d37b60a0 100644 --- a/mediapipe/util/tflite/BUILD +++ b/mediapipe/util/tflite/BUILD @@ -50,6 +50,17 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "error_reporter", + srcs = ["error_reporter.cc"], + hdrs = ["error_reporter.h"], + deps = [ + "@org_tensorflow//tensorflow/lite:minimal_logging", + "@org_tensorflow//tensorflow/lite:stateful_error_reporter", + "@org_tensorflow//tensorflow/lite/core/api:error_reporter", + ], +) + cc_library( name = "op_resolver", srcs = ["op_resolver.cc"], @@ -108,6 +119,7 @@ cc_library( name = "tflite_model_loader", srcs = ["tflite_model_loader.cc"], hdrs = ["tflite_model_loader.h"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/api2:packet", "//mediapipe/framework/port:ret_check", diff --git a/mediapipe/util/tflite/error_reporter.cc b/mediapipe/util/tflite/error_reporter.cc new file mode 100644 index 000000000..ce92ae180 --- /dev/null +++ b/mediapipe/util/tflite/error_reporter.cc @@ -0,0 +1,49 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/util/tflite/error_reporter.h" + +#include +#include +#include +#include + +#include "tensorflow/lite/minimal_logging.h" + +namespace mediapipe { +namespace util { +namespace tflite { + +ErrorReporter::ErrorReporter() { + message_[0] = '\0'; + previous_message_[0] = '\0'; +} + +int ErrorReporter::Report(const char* format, va_list args) { + std::strcpy(previous_message_, message_); // NOLINT + message_[0] = '\0'; + int num_characters = vsnprintf(message_, kBufferSize, format, args); + // To mimic tflite::StderrReporter. + ::tflite::logging_internal::MinimalLogger::Log(::tflite::TFLITE_LOG_ERROR, + "%s", message_); + return num_characters; +} + +std::string ErrorReporter::message() { return message_; } + +std::string ErrorReporter::previous_message() { return previous_message_; } + +} // namespace tflite +} // namespace util +} // namespace mediapipe diff --git a/mediapipe/util/tflite/error_reporter.h b/mediapipe/util/tflite/error_reporter.h new file mode 100644 index 000000000..245e5dc2e --- /dev/null +++ b/mediapipe/util/tflite/error_reporter.h @@ -0,0 +1,52 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_UTIL_TFLITE_ERROR_REPORTER_H_ +#define MEDIAPIPE_UTIL_TFLITE_ERROR_REPORTER_H_ + +#include + +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/stateful_error_reporter.h" + +namespace mediapipe { +namespace util { +namespace tflite { + +// An ErrorReporter that logs to stderr and captures the last two messages. +class ErrorReporter : public ::tflite::StatefulErrorReporter { + public: + ErrorReporter(); + + // We declared two functions with name 'Report', so that the variadic Report + // function in tflite::ErrorReporter is hidden. + // See https://isocpp.org/wiki/faq/strange-inheritance#hiding-rule. + using ::tflite::ErrorReporter::Report; + + int Report(const char* format, std::va_list args) override; + + std::string message() override; + std::string previous_message(); + + private: + static constexpr int kBufferSize = 1024; + char message_[kBufferSize]; + char previous_message_[kBufferSize]; +}; + +} // namespace tflite +} // namespace util +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_TFLITE_ERROR_REPORTER_H_ diff --git a/setup.py b/setup.py index a0368d59c..ef7794e92 100644 --- a/setup.py +++ b/setup.py @@ -412,7 +412,8 @@ class RemoveGenerated(setuptools.Command): os.remove(MP_THIRD_PARTY_BUILD) shutil.move(_get_backup_file(MP_THIRD_PARTY_BUILD), MP_THIRD_PARTY_BUILD) for init_py in DIR_INIT_PY_FILES: - os.remove(init_py) + if os.path.exists(init_py): + os.remove(init_py) setuptools.setup( diff --git a/setup_android_sdk_and_ndk.sh b/setup_android_sdk_and_ndk.sh index d09960684..c16021eda 100644 --- a/setup_android_sdk_and_ndk.sh +++ b/setup_android_sdk_and_ndk.sh @@ -89,5 +89,5 @@ fi echo "Set android_ndk_repository and android_sdk_repository in WORKSPACE" workspace_file="$( cd "$(dirname "$0")" ; pwd -P )"/WORKSPACE echo "android_sdk_repository(name = \"androidsdk\", path = \"${android_sdk_path}\")" >> $workspace_file -echo "android_ndk_repository(name = \"androidndk\", path = \"${android_ndk_path}/android-ndk-${ndk_version}\")" >> $workspace_file +echo "android_ndk_repository(name = \"androidndk\", api_level=21, path = \"${android_ndk_path}/android-ndk-${ndk_version}\")" >> $workspace_file echo "Done" diff --git a/third_party/opencv_macos.BUILD b/third_party/opencv_macos.BUILD index 546249754..3c17155d2 100644 --- a/third_party/opencv_macos.BUILD +++ b/third_party/opencv_macos.BUILD @@ -5,22 +5,28 @@ licenses(["notice"]) # BSD license exports_files(["LICENSE"]) +load("@bazel_skylib//lib:paths.bzl", "paths") + +# The path to OpenCV is a combination of the path set for "macos_opencv" +# in the WORKSPACE file and the prefix here. +PREFIX = "opt/opencv@3" + cc_library( name = "opencv", srcs = glob( [ - "lib/libopencv_core.dylib", - "lib/libopencv_calib3d.dylib", - "lib/libopencv_features2d.dylib", - "lib/libopencv_highgui.dylib", - "lib/libopencv_imgcodecs.dylib", - "lib/libopencv_imgproc.dylib", - "lib/libopencv_video.dylib", - "lib/libopencv_videoio.dylib", + paths.join(PREFIX, "lib/libopencv_core.dylib"), + paths.join(PREFIX, "lib/libopencv_calib3d.dylib"), + paths.join(PREFIX, "lib/libopencv_features2d.dylib"), + paths.join(PREFIX, "lib/libopencv_highgui.dylib"), + paths.join(PREFIX, "lib/libopencv_imgcodecs.dylib"), + paths.join(PREFIX, "lib/libopencv_imgproc.dylib"), + paths.join(PREFIX, "lib/libopencv_video.dylib"), + paths.join(PREFIX, "lib/libopencv_videoio.dylib"), ], ), - hdrs = glob(["include/opencv2/**/*.h*"]), - includes = ["include/"], + hdrs = glob([paths.join(PREFIX, "include/opencv2/**/*.h*")]), + includes = [paths.join(PREFIX, "include/")], linkstatic = 1, visibility = ["//visibility:public"], )