diff --git a/BUILD b/BUILD index da185b6e0..f225f24e3 100644 --- a/BUILD +++ b/BUILD @@ -1,4 +1,4 @@ -# Copyright 2019-2020 The MediaPipe Authors. +# Copyright 2019 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/README.md b/README.md index 6e4aa7a2b..66323f988 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,7 @@ run code search using ## Videos -* [YouTube Channel](https://www.youtube.com/channel/UCObqmpuSMx-usADtL_qdMAw) +* [YouTube Channel](https://www.youtube.com/c/MediaPipe) ## Events @@ -123,7 +123,7 @@ run code search using * [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome MediaPipe related frameworks, libraries and software -* [Slack community](https://mediapipe.slack.com) for MediaPipe users +* [Slack community](https://https://mediapipe.page.link/joinslack) for MediaPipe users * [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General community discussion around MediaPipe diff --git a/WORKSPACE b/WORKSPACE index 8b148fd4a..8210d786e 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -37,10 +37,19 @@ http_archive( ) # GoogleTest/GoogleMock framework. Used by most unit-tests. +# Last updated 2020-06-30. http_archive( - name = "com_google_googletest", - urls = ["https://github.com/google/googletest/archive/master.zip"], - strip_prefix = "googletest-master", + name = "com_google_googletest", + urls = ["https://github.com/google/googletest/archive/aee0f9d9b5b87796ee8a0ab26b7587ec30e8858e.zip"], + patches = [ + # fix for https://github.com/google/googletest/issues/2817 + "@//third_party:com_google_googletest_9d580ea80592189e6d44fa35bcf9cdea8bf620d6.diff" + ], + patch_args = [ + "-p1", + ], + strip_prefix = "googletest-aee0f9d9b5b87796ee8a0ab26b7587ec30e8858e", + sha256 = "04a1751f94244307cebe695a69cc945f9387a80b0ef1af21394a490697c5c895", ) # Google Benchmark library. diff --git a/build_ios_examples.sh b/build_ios_examples.sh new file mode 100644 index 000000000..15725acc9 --- /dev/null +++ b/build_ios_examples.sh @@ -0,0 +1,74 @@ +#!/bin/bash +# Copyright 2020 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +# +# Script to build all MediaPipe iOS example apps. +# +# To build all apps and store them in out_dir: +# $ ./build_ios_examples.sh -d out_dir +# Omitting -d and the associated directory saves all generated IPAs in the +# current directory. +# $ ./build_ios_examples.sh -d out_dir --nostrip +# Same as above except that the symnbols are not stripped. + +set -e + +out_dir="." +strip=true +app_dir="mediapipe/examples/ios" +bin_dir="bazel-bin" +declare -a default_bazel_flags=(build -c opt --config=ios_arm64) + +while [[ -n $1 ]]; do + case $1 in + -d) + shift + out_dir=$1 + ;; + --nostrip) + strip=false + ;; + *) + echo "Unsupported input argument $1." + exit 1 + ;; + esac + shift +done + +echo "app_dir: $app_dir" +echo "out_dir: $out_dir" +echo "strip: $strip" + +declare -a bazel_flags + +apps="${app_dir}/*" +for app in ${apps}; do + if [[ -d "${app}" ]]; then + target_name=${app##*/} + target="${app}:${target_name}" + + echo "=== Target: ${target}" + + bazel_flags=("${default_bazel_flags[@]}") + bazel_flags+=(${target}) + if [[ $strip == true ]]; then + bazel_flags+=(--linkopt=-s) + fi + + bazel "${bazel_flags[@]}" + cp -f "${bin_dir}/${app}/"*".ipa" "${out_dir}" + fi +done diff --git a/docs/framework_concepts/gpu.md b/docs/framework_concepts/gpu.md index 06355ac44..77d566e8d 100644 --- a/docs/framework_concepts/gpu.md +++ b/docs/framework_concepts/gpu.md @@ -149,15 +149,15 @@ When possible, these calculators use platform-specific functionality to share da The below diagram shows the data flow in a mobile application that captures video from the camera, runs it through a MediaPipe graph, and renders the output on the screen in real time. The dashed line indicates which parts are inside the MediaPipe graph proper. This application runs a Canny edge-detection filter on the CPU using OpenCV, and overlays it on top of the original video using the GPU. -| ![How GPU calculators interact](../images/gpu_example_graph.png) | -| :--------------------------------------------------------------------------: | -| *Video frames from the camera are fed into the graph as `GpuBuffer` packets. | -: The input stream is accessed by two calculators in parallel. : -: `GpuBufferToImageFrameCalculator` converts the buffer into an `ImageFrame`, : -: which is then sent through a grayscale converter and a canny filter (both : -: based on OpenCV and running on the CPU), whose output is then converted into : -: a `GpuBuffer` again. A multi-input GPU calculator, GlOverlayCalculator, : -: takes as input both the original `GpuBuffer` and the one coming out of the : -: edge detector, and overlays them using a shader. The output is then sent : -: back to the application using a callback calculator, and the application : -: renders the image to the screen using OpenGL.* : +![How GPU calculators interact](../images/gpu_example_graph.png) + +Video frames from the camera are fed into the graph as `GpuBuffer` packets. The +input stream is accessed by two calculators in parallel. +`GpuBufferToImageFrameCalculator` converts the buffer into an `ImageFrame`, +which is then sent through a grayscale converter and a canny filter (both based +on OpenCV and running on the CPU), whose output is then converted into a +`GpuBuffer` again. A multi-input GPU calculator, GlOverlayCalculator, takes as +input both the original `GpuBuffer` and the one coming out of the edge detector, +and overlays them using a shader. The output is then sent back to the +application using a callback calculator, and the application renders the image +to the screen using OpenGL. diff --git a/docs/getting_started/building_examples.md b/docs/getting_started/building_examples.md index 3f818f7b6..2c3b6e77c 100644 --- a/docs/getting_started/building_examples.md +++ b/docs/getting_started/building_examples.md @@ -184,12 +184,8 @@ app: ### Prerequisite -1. Install [Xcode](https://developer.apple.com/xcode/) and the Command Line - Tools. - - Follow Apple's instructions to obtain the required development certificates - and provisioning profiles for your iOS device. Install the Command Line - Tools by +1. Install [Xcode](https://developer.apple.com/xcode/), and additionally + install the Command Line Tools by: ```bash xcode-select --install @@ -209,26 +205,31 @@ app: pip3 install --user six ``` -4. Clone the MediaPipe repository. +4. Follow + [Apple's instructions](https://developer.apple.com/support/certificates/) to + obtain the required development certificates and provisioning profiles for + your iOS device. + + Tip: You can the following command to see the provisioning profiles you have + previously downloaded using Xcode: `open + ~/Library/MobileDevice/"Provisioning Profiles"`. If there are none, generate + and download a profile on + [Apple's developer site](https://developer.apple.com/account/resources/). + +5. Clone the MediaPipe repository. ```bash git clone https://github.com/google/mediapipe.git ``` -5. Symlink or copy your provisioning profile to - `mediapipe/mediapipe/provisioning_profile.mobileprovision`. +6. In the cloned MediaPipe repository, symlink or copy your provisioning profile + to `mediapipe/provisioning_profile.mobileprovision`, e.g., ```bash cd mediapipe ln -s ~/Downloads/MyProvisioningProfile.mobileprovision mediapipe/provisioning_profile.mobileprovision ``` - Tip: You can use this command to see the provisioning profiles you have - previously downloaded using Xcode: `open - ~/Library/MobileDevice/"Provisioning Profiles"`. If there are none, generate - and download a profile on - [Apple's developer site](https://developer.apple.com/account/resources/). - ### Option 1: Build with Bazel in Command Line 1. Modify the `bundle_id` field of the app's `ios_application` build target to @@ -246,6 +247,10 @@ app: You may see a permission request from `codesign` in order to sign the app. + Tip: You can run this + [script](https://github.com/google/mediapipe/blob/master/build_ios_examples.sh) + to build all MediaPipe iOS example apps. + 3. In Xcode, open the `Devices and Simulators` window (command-shift-2). 4. Make sure your device is connected. You will see a list of installed apps. diff --git a/docs/getting_started/install.md b/docs/getting_started/install.md index 0457b7e85..7374e244b 100644 --- a/docs/getting_started/install.md +++ b/docs/getting_started/install.md @@ -44,6 +44,18 @@ apps, see these [instructions](./building_examples.md#ios). [Bazel documentation](https://docs.bazel.build/versions/master/install-ubuntu.html) to install Bazel 2.0 or higher. + For Nvidia Jetson and Raspberry Pi devices with ARM Ubuntu, Bazel needs to + be built from source. + + ```bash + # For Bazel 3.0.0 + wget https://github.com/bazelbuild/bazel/releases/download/3.0.0/bazel-3.0.0-dist.zip + sudo apt-get install build-essential openjdk-8-jdk python zip unzip + unzip bazel-3.0.0-dist.zip + env EXTRA_BAZEL_ARGS="--host_javabase=@local_jdk//:jdk" bash ./compile.sh + sudo cp output/bazel /usr/local/bin/ + ``` + 3. Install OpenCV and FFmpeg. Option 1. Use package manager tool to install the pre-compiled OpenCV @@ -58,6 +70,14 @@ apps, see these [instructions](./building_examples.md#ios). libopencv-imgproc-dev libopencv-video-dev ``` + [`opencv_linux.BUILD`] is configured for x86_64 by default. For Nvidia + Jetson and Raspberry Pi devices with ARM Ubuntu, the lib paths need to be + modified. + + ```bash + sed -i "s/x86_64-linux-gnu/aarch64-linux-gnu/g" third_party/opencv_linux.BUILD + ``` + Option 2. Run [`setup_opencv.sh`] to automatically build OpenCV from source and modify MediaPipe's OpenCV config. @@ -493,14 +513,14 @@ cameras. Alternatively, you use a video file as input. ```bash username@DESKTOP-TMVLBJ1:~$ curl -sLO --retry 5 --retry-max-time 10 \ - https://storage.googleapis.com/bazel/2.0.0/release/bazel-2.0.0-installer-linux-x86_64.sh && \ - sudo mkdir -p /usr/local/bazel/2.0.0 && \ - chmod 755 bazel-2.0.0-installer-linux-x86_64.sh && \ - sudo ./bazel-2.0.0-installer-linux-x86_64.sh --prefix=/usr/local/bazel/2.0.0 && \ - source /usr/local/bazel/2.0.0/lib/bazel/bin/bazel-complete.bash + https://storage.googleapis.com/bazel/3.0.0/release/bazel-3.0.0-installer-linux-x86_64.sh && \ + sudo mkdir -p /usr/local/bazel/3.0.0 && \ + chmod 755 bazel-3.0.0-installer-linux-x86_64.sh && \ + sudo ./bazel-3.0.0-installer-linux-x86_64.sh --prefix=/usr/local/bazel/3.0.0 && \ + source /usr/local/bazel/3.0.0/lib/bazel/bin/bazel-complete.bash - username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/2.0.0/lib/bazel/bin/bazel version && \ - alias bazel='/usr/local/bazel/2.0.0/lib/bazel/bin/bazel' + username@DESKTOP-TMVLBJ1:~$ /usr/local/bazel/3.0.0/lib/bazel/bin/bazel version && \ + alias bazel='/usr/local/bazel/3.0.0/lib/bazel/bin/bazel' ``` 6. Checkout MediaPipe repository. diff --git a/docs/index.md b/docs/index.md index 6d5777e9c..39ea05b42 100644 --- a/docs/index.md +++ b/docs/index.md @@ -101,7 +101,7 @@ run code search using ## Videos -* [YouTube Channel](https://www.youtube.com/channel/UCObqmpuSMx-usADtL_qdMAw) +* [YouTube Channel](https://www.youtube.com/c/MediaPipe) ## Events @@ -123,7 +123,7 @@ run code search using * [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome MediaPipe related frameworks, libraries and software -* [Slack community](https://mediapipe.slack.com) for MediaPipe users +* [Slack community](https://https://mediapipe.page.link/joinslack) for MediaPipe users * [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General community discussion around MediaPipe diff --git a/docs/solutions/hands.md b/docs/solutions/hands.md index 04b6bc695..4ba33f861 100644 --- a/docs/solutions/hands.md +++ b/docs/solutions/hands.md @@ -1,6 +1,6 @@ --- layout: default -title: Hand +title: Hands parent: Solutions nav_order: 3 --- @@ -219,9 +219,13 @@ Please refer to [these instructions](../index.md#mediapipe-on-the-web). ## Resources -* Google AI Blog: [On-Device, Real-Time Hand Tracking with MediaPipe](https://ai.googleblog.com/2019/08/on-device-real-time-hand-tracking-with.html) -* TensorFlow Blog: [Face and hand tracking in the browser with MediaPipe and - TensorFlow.js](https://blog.tensorflow.org/2020/03/face-and-hand-tracking-in-browser-with-mediapipe-and-tensorflowjs.html) +* Google AI Blog: + [On-Device, Real-Time Hand Tracking with MediaPipe](https://ai.googleblog.com/2019/08/on-device-real-time-hand-tracking-with.html) +* TensorFlow Blog: + [Face and hand tracking in the browser with MediaPipe and TensorFlow.js](https://blog.tensorflow.org/2020/03/face-and-hand-tracking-in-browser-with-mediapipe-and-tensorflowjs.html) +* Paper: + [MediaPipe Hands: On-device Real-time Hand Tracking](https://arxiv.org/abs/2006.10214) + ([presentation](https://www.youtube.com/watch?v=I-UOrvxxXEk)) * Palm detection model: [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/palm_detection.tflite), [TF.js model](https://tfhub.dev/mediapipe/handdetector/1) diff --git a/docs/solutions/objectron.md b/docs/solutions/objectron.md index 083c443c3..c142bfdf9 100644 --- a/docs/solutions/objectron.md +++ b/docs/solutions/objectron.md @@ -188,5 +188,8 @@ to visualize its associated subgraphs, please see [Real-Time 3D Object Detection on Mobile Devices with MediaPipe](https://ai.googleblog.com/2020/03/real-time-3d-object-detection-on-mobile.html) * Paper: [MobilePose: Real-Time Pose Estimation for Unseen Objects with Weak Shape Supervision](https://arxiv.org/abs/2003.03522) +* Paper: + [Instant 3D Object Tracking with Applications in Augmented Reality](https://drive.google.com/open?id=1O_zHmlgXIzAdKljp20U_JUkEHOGG52R8) + ([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0)) * [TFLite model for shoes](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_sneakers.tflite) * [TFLite model for chairs](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_chair.tflite) diff --git a/docs/tools/tracing_and_profiling.md b/docs/tools/tracing_and_profiling.md index 5018dfa5c..472e52a7d 100644 --- a/docs/tools/tracing_and_profiling.md +++ b/docs/tools/tracing_and_profiling.md @@ -21,7 +21,16 @@ available on Linux, Android, or iOS. ## Enabling tracing and profiling -To enable tracing/profiling of a mediapipe graph, the `CalculatorGraphConfig` (in +To enable tracing and profiling of a mediapipe graph: + + 1. The profiling library must be linked to the framework. + 2. Tracing and profiling must be enabled in the graph configuration. + +The profiling library is linked to the framework by default. If needed, +the profiling library can be omitted from the framework using the bazel +command line option: `--define MEDIAPIPE_PROFILING=0`. + +To enable tracing and profiling, the `CalculatorGraphConfig` (in [calculator.proto](https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto)) representing the graph must have a `profiler_config` message at its root. Here is a simple setup that turns on a few extra options: diff --git a/mediapipe/calculators/image/bilateral_filter_calculator.cc b/mediapipe/calculators/image/bilateral_filter_calculator.cc index 8d3d26f2d..b366caf7a 100644 --- a/mediapipe/calculators/image/bilateral_filter_calculator.cc +++ b/mediapipe/calculators/image/bilateral_filter_calculator.cc @@ -107,7 +107,7 @@ class BilateralFilterCalculator : public CalculatorBase { GLuint program_ = 0; GLuint vao_; GLuint vbo_[2]; // vertex storage -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU }; REGISTER_CALCULATOR(BilateralFilterCalculator); diff --git a/mediapipe/calculators/image/image_transformation_calculator.cc b/mediapipe/calculators/image/image_transformation_calculator.cc index 75ed96e15..cb5f6419e 100644 --- a/mediapipe/calculators/image/image_transformation_calculator.cc +++ b/mediapipe/calculators/image/image_transformation_calculator.cc @@ -386,45 +386,47 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); const int input_width = input_mat.cols; const int input_height = input_mat.rows; - if (!output_height_ || !output_width_) { - output_height_ = input_height; - output_width_ = input_width; - } + int output_width; + int output_height; + ComputeOutputDimensions(input_width, input_height, &output_width, + &output_height); - cv::Mat scaled_mat; - int output_width = output_width_; - int output_height = output_height_; - if (scale_mode_ == mediapipe::ScaleMode_Mode_STRETCH) { - int scale_flag = - input_mat.cols > output_width_ && input_mat.rows > output_height_ - ? cv::INTER_AREA - : cv::INTER_LINEAR; - cv::resize(input_mat, scaled_mat, cv::Size(output_width_, output_height_), - 0, 0, scale_flag); - } else { - const float scale = - std::min(static_cast(output_width_) / input_width, - static_cast(output_height_) / input_height); - const int target_width = std::round(input_width * scale); - const int target_height = std::round(input_height * scale); - int scale_flag = scale < 1.0f ? cv::INTER_AREA : cv::INTER_LINEAR; - if (scale_mode_ == mediapipe::ScaleMode_Mode_FIT) { - cv::Mat intermediate_mat; - cv::resize(input_mat, intermediate_mat, - cv::Size(target_width, target_height), 0, 0, scale_flag); - const int top = (output_height_ - target_height) / 2; - const int bottom = output_height_ - target_height - top; - const int left = (output_width_ - target_width) / 2; - const int right = output_width_ - target_width - left; - cv::copyMakeBorder(intermediate_mat, scaled_mat, top, bottom, left, right, - options_.constant_padding() ? cv::BORDER_CONSTANT - : cv::BORDER_REPLICATE); - } else { - cv::resize(input_mat, scaled_mat, cv::Size(target_width, target_height), + if (output_width_ > 0 && output_height_ > 0) { + cv::Mat scaled_mat; + if (scale_mode_ == mediapipe::ScaleMode_Mode_STRETCH) { + int scale_flag = + input_mat.cols > output_width_ && input_mat.rows > output_height_ + ? cv::INTER_AREA + : cv::INTER_LINEAR; + cv::resize(input_mat, scaled_mat, cv::Size(output_width_, output_height_), 0, 0, scale_flag); - output_width = target_width; - output_height = target_height; + } else { + const float scale = + std::min(static_cast(output_width_) / input_width, + static_cast(output_height_) / input_height); + const int target_width = std::round(input_width * scale); + const int target_height = std::round(input_height * scale); + int scale_flag = scale < 1.0f ? cv::INTER_AREA : cv::INTER_LINEAR; + if (scale_mode_ == mediapipe::ScaleMode_Mode_FIT) { + cv::Mat intermediate_mat; + cv::resize(input_mat, intermediate_mat, + cv::Size(target_width, target_height), 0, 0, scale_flag); + const int top = (output_height_ - target_height) / 2; + const int bottom = output_height_ - target_height - top; + const int left = (output_width_ - target_width) / 2; + const int right = output_width_ - target_width - left; + cv::copyMakeBorder(intermediate_mat, scaled_mat, top, bottom, left, + right, + options_.constant_padding() ? cv::BORDER_CONSTANT + : cv::BORDER_REPLICATE); + } else { + cv::resize(input_mat, scaled_mat, cv::Size(target_width, target_height), + 0, 0, scale_flag); + output_width = target_width; + output_height = target_height; + } } + input_mat = scaled_mat; } if (cc->Outputs().HasTag("LETTERBOX_PADDING")) { @@ -437,10 +439,33 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); } cv::Mat rotated_mat; - const int angle = RotationModeToDegrees(rotation_); - cv::Point2f src_center(scaled_mat.cols / 2.0, scaled_mat.rows / 2.0); - cv::Mat rotation_mat = cv::getRotationMatrix2D(src_center, angle, 1.0); - cv::warpAffine(scaled_mat, rotated_mat, rotation_mat, scaled_mat.size()); + cv::Size rotated_size(output_width, output_height); + if (input_mat.size() == rotated_size) { + const int angle = RotationModeToDegrees(rotation_); + cv::Point2f src_center(input_mat.cols / 2.0, input_mat.rows / 2.0); + cv::Mat rotation_mat = cv::getRotationMatrix2D(src_center, angle, 1.0); + cv::warpAffine(input_mat, rotated_mat, rotation_mat, rotated_size); + } else { + switch (rotation_) { + case mediapipe::RotationMode_Mode_UNKNOWN: + case mediapipe::RotationMode_Mode_ROTATION_0: + LOG(ERROR) << "Not rotating image."; + rotated_mat = input_mat; + break; + case mediapipe::RotationMode_Mode_ROTATION_90: + LOG(ERROR) << "Rotating image by 90 degrees ccw."; + cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_COUNTERCLOCKWISE); + break; + case mediapipe::RotationMode_Mode_ROTATION_180: + LOG(ERROR) << "Rotating image by 180 degrees."; + cv::rotate(input_mat, rotated_mat, cv::ROTATE_180); + break; + case mediapipe::RotationMode_Mode_ROTATION_270: + LOG(ERROR) << "Rotating image by 90 degrees cw."; + cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_CLOCKWISE); + break; + } + } cv::Mat flipped_mat; if (flip_horizontally_ || flip_vertically_) { @@ -498,7 +523,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); renderer = yuv_renderer_.get(); src1 = gpu_helper_.CreateSourceTexture(input, 0); } else // NOLINT(readability/braces) -#endif // iOS +#endif // iOS { src1 = gpu_helper_.CreateSourceTexture(input); #if defined(TEXTURE_EXTERNAL_OES) @@ -510,7 +535,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator); } renderer = ext_rgb_renderer_.get(); } else // NOLINT(readability/braces) -#endif // TEXTURE_EXTERNAL_OES +#endif // TEXTURE_EXTERNAL_OES { if (!rgb_renderer_) { rgb_renderer_ = absl::make_unique(); diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc index aeb69822b..6e1a29e59 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc @@ -139,7 +139,6 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator { static_cast<::mediapipe::StatusCode>(status.code()), status.ToString()); } - auto session = absl::make_unique(); session->session = std::move(saved_model->session); diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 169ef23f5..f1101a009 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -14,6 +14,7 @@ # load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("@bazel_skylib//lib:selects.bzl", "selects") licenses(["notice"]) # Apache 2.0 @@ -202,6 +203,13 @@ cc_library( alwayslink = 1, ) +selects.config_setting_group( + name = "gpu_inference_disabled", + match_any = [ + "//mediapipe/gpu:disable_gpu", + ], +) + cc_library( name = "tflite_inference_calculator", srcs = ["tflite_inference_calculator.cc"], @@ -226,13 +234,14 @@ cc_library( "@com_google_absl//absl/memory", "//mediapipe/framework:calculator_framework", "//mediapipe/util:resource_util", + "//mediapipe/util/tflite:config", "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", "//mediapipe/framework/port:ret_check", - ] + select({ - "//mediapipe/gpu:disable_gpu": [], + ] + selects.with_or({ + ":gpu_inference_disabled": [], "//mediapipe:ios": [ "//mediapipe/gpu:MPPMetalHelper", "//mediapipe/gpu:MPPMetalUtil", @@ -285,6 +294,7 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ + "//mediapipe/util/tflite:config", ":util", ":tflite_converter_calculator_cc_proto", "//mediapipe/util:resource_util", @@ -295,23 +305,26 @@ cc_library( "//mediapipe/framework/port:ret_check", "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - ] + select({ - "//mediapipe/gpu:disable_gpu": [], + ] + selects.with_or({ + ":gpu_inference_disabled": [], "//mediapipe:ios": [ "//mediapipe/gpu:MPPMetalUtil", - "//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:MPPMetalHelper", "//mediapipe/objc:mediapipe_framework_ios", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", ], "//conditions:default": [ - "//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gl_calculator_helper", "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", ], + }) + select({ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [ + "//mediapipe/gpu:gpu_buffer", + ], }), alwayslink = 1, ) @@ -348,8 +361,8 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/util:resource_util", "@org_tensorflow//tensorflow/lite:framework", - ] + select({ - "//mediapipe/gpu:disable_gpu": [], + ] + selects.with_or({ + ":gpu_inference_disabled": [], "//mediapipe:ios": [], "//conditions:default": [ "//mediapipe/gpu:gl_calculator_helper", @@ -404,6 +417,7 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ + "//mediapipe/util/tflite:config", ":util", ":tflite_tensors_to_detections_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -415,8 +429,8 @@ cc_library( "//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework/port:ret_check", "@org_tensorflow//tensorflow/lite:framework", - ] + select({ - "//mediapipe/gpu:disable_gpu": [], + ] + selects.with_or({ + ":gpu_inference_disabled": [], "//mediapipe:ios": [ "//mediapipe/gpu:MPPMetalUtil", "//mediapipe/gpu:gpu_buffer", @@ -492,6 +506,8 @@ cc_library( alwayslink = 1, ) +# To run this with native GPU on Linux, use: +# bazel test //mediapipe/calculators/tflite:tflite_inference_calculator_test --copt=-DTFLITE_GPU_EXTRA_GLES_DEPS --copt=-DMESA_EGL_NO_X11_HEADERS --copt=-DEGL_NO_X11 --config=grte_v5 --test_strategy=local cc_test( name = "tflite_inference_calculator_test", srcs = ["tflite_inference_calculator_test.cc"], diff --git a/mediapipe/calculators/tflite/tflite_converter_calculator.cc b/mediapipe/calculators/tflite/tflite_converter_calculator.cc index 76bac09e4..6a3011141 100644 --- a/mediapipe/calculators/tflite/tflite_converter_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_converter_calculator.cc @@ -22,19 +22,23 @@ #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/util/resource_util.h" +#include "mediapipe/util/tflite/config.h" #include "tensorflow/lite/error_reporter.h" #include "tensorflow/lite/interpreter.h" -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) -#include "mediapipe/gpu/gl_calculator_helper.h" +#ifndef MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gpu_buffer.h" +#endif // MEDIAPIPE_DISABLE_GPU + +#if MEDIAPIPE_TFLITE_GL_INFERENCE +#include "mediapipe/gpu/gl_calculator_helper.h" #include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" #include "tensorflow/lite/delegates/gpu/gl/gl_program.h" #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE -#if defined(MEDIAPIPE_IOS) +#if MEDIAPIPE_TFLITE_METAL_INFERENCE #import #import #import @@ -43,13 +47,7 @@ #include "mediapipe/gpu/MPPMetalUtil.h" #include "mediapipe/gpu/gpu_buffer.h" #include "tensorflow/lite/delegates/gpu/metal_delegate.h" -#endif // iOS - -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) -typedef ::tflite::gpu::gl::GlBuffer GpuTensor; -#elif defined(MEDIAPIPE_IOS) -typedef id GpuTensor; -#endif +#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE namespace { constexpr int kWorkgroupSize = 8; // Block size for GPU shader. @@ -73,7 +71,7 @@ constexpr char kMatrixTag[] = "MATRIX"; namespace mediapipe { namespace { -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) +#if MEDIAPIPE_TFLITE_GL_INFERENCE using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; using ::tflite::gpu::gl::GlProgram; using ::tflite::gpu::gl::GlShader; @@ -83,13 +81,13 @@ struct GPUData { GlShader shader; GlProgram program; }; -#elif defined(MEDIAPIPE_IOS) +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE struct GPUData { int elements = 1; GpuTensor buffer; id pipeline_state; }; -#endif +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE } // namespace @@ -157,13 +155,13 @@ class TfLiteConverterCalculator : public CalculatorBase { std::unique_ptr interpreter_ = nullptr; -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) +#if MEDIAPIPE_TFLITE_GL_INFERENCE mediapipe::GlCalculatorHelper gpu_helper_; std::unique_ptr gpu_data_out_; -#elif defined(MEDIAPIPE_IOS) +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE MPPMetalHelper* gpu_helper_ = nullptr; std::unique_ptr gpu_data_out_; -#endif +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE bool initialized_ = false; bool use_gpu_ = false; @@ -178,6 +176,18 @@ class TfLiteConverterCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(TfLiteConverterCalculator); +namespace { +template +bool ShouldUseGpu(CC* cc) { +#if MEDIAPIPE_TFLITE_GPU_SUPPORTED + return cc->Inputs().HasTag(kGpuBufferTag) || + cc->Outputs().HasTag(kTensorsGpuTag); +#else + return false; +#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED +} +} // namespace + ::mediapipe::Status TfLiteConverterCalculator::GetContract( CalculatorContract* cc) { // Confirm only one of the input streams is present. @@ -189,37 +199,31 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); RET_CHECK(cc->Outputs().HasTag(kTensorsTag) ^ cc->Outputs().HasTag(kTensorsGpuTag)); - bool use_gpu = false; - if (cc->Inputs().HasTag(kImageFrameTag)) { cc->Inputs().Tag(kImageFrameTag).Set(); } if (cc->Inputs().HasTag(kMatrixTag)) { cc->Inputs().Tag(kMatrixTag).Set(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) +#ifndef MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kGpuBufferTag)) { cc->Inputs().Tag(kGpuBufferTag).Set(); - use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kTensorsTag)) { cc->Outputs().Tag(kTensorsTag).Set>(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) if (cc->Outputs().HasTag(kTensorsGpuTag)) { cc->Outputs().Tag(kTensorsGpuTag).Set>(); - use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU - if (use_gpu) { -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) + if (ShouldUseGpu(cc)) { +#if MEDIAPIPE_TFLITE_GL_INFERENCE MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#elif defined(MEDIAPIPE_IOS) +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); -#endif +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE } // Assign this calculator's default InputStreamHandler. @@ -233,14 +237,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); MP_RETURN_IF_ERROR(LoadOptions(cc)); - if (cc->Inputs().HasTag(kGpuBufferTag) || - cc->Outputs().HasTag(kGpuBufferTag)) { -#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) - use_gpu_ = true; -#else - RET_CHECK_FAIL() << "GPU processing not enabled."; -#endif - } + use_gpu_ = ShouldUseGpu(cc); if (use_gpu_) { // Cannot mix CPU/GPU streams. @@ -248,12 +245,12 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); cc->Outputs().HasTag(kTensorsGpuTag)); // Cannot use quantization. use_quantized_tensors_ = false; -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) +#if MEDIAPIPE_TFLITE_GL_INFERENCE MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); -#elif defined(MEDIAPIPE_IOS) +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; RET_CHECK(gpu_helper_); -#endif +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE } else { interpreter_ = absl::make_unique(); interpreter_->AddTensors(1); @@ -282,12 +279,12 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); } ::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) + interpreter_.reset(); +#if MEDIAPIPE_TFLITE_GL_INFERENCE gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); }); -#endif -#if defined(MEDIAPIPE_IOS) +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE gpu_data_out_.reset(); -#endif +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE return ::mediapipe::OkStatus(); } @@ -318,8 +315,14 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); RET_CHECK(format != mediapipe::ImageFormat::VEC32F1) << "Only 8-bit input images are supported for quantization."; quant.type = kTfLiteAffineQuantization; - quant.params = nullptr; - // Optional: Set 'quant' quantization params here if needed. + auto quant_params = static_cast( + malloc(sizeof(TfLiteAffineQuantization))); + quant_params->scale = TfLiteFloatArrayCreate(1); + quant_params->scale->data[0] = 1.0; + quant_params->zero_point = TfLiteIntArrayCreate(1); + quant_params->zero_point->data[0] = 0; + quant_params->quantized_dimension = 0; + quant.params = quant_params; interpreter_->SetTensorParametersReadWrite(0, kTfLiteUInt8, "", {channels_preserved}, quant); } else { @@ -414,7 +417,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); ::mediapipe::Status TfLiteConverterCalculator::ProcessGPU( CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) +#if MEDIAPIPE_TFLITE_GL_INFERENCE // GpuBuffer to tflite::gpu::GlBuffer conversion. const auto& input = cc->Inputs().Tag(kGpuBufferTag).Get(); @@ -451,7 +454,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); cc->Outputs() .Tag(kTensorsGpuTag) .Add(output_tensors.release(), cc->InputTimestamp()); -#elif defined(MEDIAPIPE_IOS) +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE // GpuBuffer to id conversion. const auto& input = cc->Inputs().Tag(kGpuBufferTag).Get(); @@ -490,13 +493,13 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); .Add(output_tensors.release(), cc->InputTimestamp()); #else RET_CHECK_FAIL() << "GPU processing is not enabled."; -#endif +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE return ::mediapipe::OkStatus(); } ::mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) +#if MEDIAPIPE_TFLITE_GPU_SUPPORTED // Get input image sizes. const auto& input = cc->Inputs().Tag(kGpuBufferTag).Get(); @@ -512,9 +515,9 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); RET_CHECK_FAIL() << "Unsupported GPU input format."; if (include_alpha && (format != mediapipe::ImageFormat::SRGBA)) RET_CHECK_FAIL() << "Num input channels is less than desired output."; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) +#if MEDIAPIPE_TFLITE_GL_INFERENCE MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( [this, &include_alpha, &input, &single_channel]() -> ::mediapipe::Status { // Device memory. @@ -559,7 +562,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); return ::mediapipe::OkStatus(); })); -#elif defined(MEDIAPIPE_IOS) +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE RET_CHECK(include_alpha) << "iOS GPU inference currently accepts only RGBA input."; @@ -616,7 +619,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator); RET_CHECK(gpu_data_out_->pipeline_state != nil) << "Couldn't create pipeline state " << [[error localizedDescription] UTF8String]; -#endif +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.cc b/mediapipe/calculators/tflite/tflite_inference_calculator.cc index 96f48da4d..cd881102d 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.cc @@ -22,6 +22,7 @@ #include "mediapipe/calculators/tflite/util.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/util/tflite/config.h" #if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__) #include "mediapipe/util/cpu_util.h" @@ -33,7 +34,7 @@ #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) +#if MEDIAPIPE_TFLITE_GL_INFERENCE #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/util/tflite/tflite_gpu_runner.h" @@ -42,9 +43,9 @@ #include "tensorflow/lite/delegates/gpu/gl/gl_program.h" #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h" -#endif // !MEDIAPIPE_DISABLE_GL_COMPUTE +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE -#if defined(MEDIAPIPE_IOS) +#if MEDIAPIPE_TFLITE_METAL_INFERENCE #import #import #import @@ -56,7 +57,7 @@ #include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h" #include "tensorflow/lite/delegates/gpu/metal_delegate.h" #include "tensorflow/lite/delegates/gpu/metal_delegate_internal.h" -#endif // iOS +#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE #if !defined(MEDIAPIPE_EDGE_TPU) #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" @@ -71,12 +72,6 @@ int NumGroups(const int size, const int group_size) { // NOLINT return (size + group_size - 1) / group_size; } -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) -typedef ::tflite::gpu::gl::GlBuffer GpuTensor; -#elif defined(MEDIAPIPE_IOS) -typedef id GpuTensor; -#endif - // Round up n to next multiple of m. size_t RoundUp(size_t n, size_t m) { return ((n + m - 1) / m) * m; } // NOLINT @@ -112,13 +107,13 @@ std::unique_ptr BuildEdgeTpuInterpreter( // * Aux namespace mediapipe { -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) +#if MEDIAPIPE_TFLITE_GL_INFERENCE using ::tflite::gpu::gl::CopyBuffer; using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; using ::tflite::gpu::gl::GlBuffer; #endif -#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) +#if MEDIAPIPE_TFLITE_GPU_SUPPORTED namespace { struct GPUData { int elements = 1; @@ -126,7 +121,7 @@ struct GPUData { ::tflite::gpu::BHWC shape; }; } // namespace -#endif +#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED // Returns number of threads to configure XNNPACK delegate with. // (Equal to user provided value if specified. Otherwise, it returns number of @@ -152,7 +147,7 @@ int GetXnnpackNumThreads( // Creates an interpreter with given model and calls invoke(). // Optionally run inference on CPU/GPU. // -// This calculator is designed to be used with the TfLiteConverterCalcualtor, +// This calculator is designed to be used with the TfLiteConverterCalculator, // to get the appropriate inputs. // // When the input tensors are on CPU, gpu inference is optional and can be @@ -183,7 +178,6 @@ int GetXnnpackNumThreads( // options: { // [mediapipe.TfLiteInferenceCalculatorOptions.ext] { // model_path: "modelname.tflite" -// delegate { gpu {} } // } // } // } @@ -192,11 +186,12 @@ int GetXnnpackNumThreads( // // node { // calculator: "TfLiteInferenceCalculator" -// input_stream: "TENSORS:tensor_image" +// input_stream: "TENSORS_GPU:tensor_image" // input_side_packet: "MODEL:model" -// output_stream: "TENSORS:tensors" +// output_stream: "TENSORS_GPU:tensors" // options: { // [mediapipe.TfLiteInferenceCalculatorOptions.ext] { +// model_path: "modelname.tflite" // delegate { gpu {} } // } // } @@ -228,24 +223,45 @@ class TfLiteInferenceCalculator : public CalculatorBase { ::mediapipe::Status LoadModel(CalculatorContext* cc); ::mediapipe::StatusOr GetModelAsPacket(const CalculatorContext& cc); ::mediapipe::Status LoadDelegate(CalculatorContext* cc); - ::mediapipe::Status InitTFLiteGPURunner(); + ::mediapipe::Status InitTFLiteGPURunner(CalculatorContext* cc); + ::mediapipe::Status ProcessInputsCpu( + CalculatorContext* cc, std::vector* output_tensors_cpu); + ::mediapipe::Status ProcessOutputsCpu( + CalculatorContext* cc, + std::unique_ptr> output_tensors_cpu); + ::mediapipe::Status ProcessInputsGpu( + CalculatorContext* cc, std::vector* output_tensors_gpu); + ::mediapipe::Status ProcessOutputsGpu( + CalculatorContext* cc, + std::unique_ptr> output_tensors_cpu, + std::unique_ptr> output_tensors_gpu); + + ::mediapipe::Status RunInContextIfNeeded( + std::function<::mediapipe::Status(void)> f) { + if (gpu_inference_) { +#if MEDIAPIPE_TFLITE_GL_INFERENCE + return gpu_helper_.RunInGlContext(std::move(f)); +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE + } + return f(); + } Packet model_packet_; std::unique_ptr interpreter_; TfLiteDelegatePtr delegate_; -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) +#if MEDIAPIPE_TFLITE_GL_INFERENCE mediapipe::GlCalculatorHelper gpu_helper_; std::vector> gpu_data_in_; std::vector> gpu_data_out_; std::unique_ptr tflite_gpu_runner_; -#elif defined(MEDIAPIPE_IOS) +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE MPPMetalHelper* gpu_helper_ = nullptr; std::vector> gpu_data_in_; std::vector> gpu_data_out_; id fp32_to_fp16_program_; TFLBufferConvert* converter_from_BPHWC4_ = nil; -#endif +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE #if defined(MEDIAPIPE_EDGE_TPU) std::shared_ptr edgetpu_context_ = @@ -263,6 +279,22 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); // Calculator Core Section +namespace { +template +bool ShouldUseGpu(CC* cc) { +#if MEDIAPIPE_TFLITE_GPU_SUPPORTED + const auto& options = + cc->template Options<::mediapipe::TfLiteInferenceCalculatorOptions>(); + return options.use_gpu() || + (options.has_delegate() && options.delegate().has_gpu()) || + cc->Inputs().HasTag(kTensorsGpuTag) || + cc->Outputs().HasTag(kTensorsGpuTag); +#else + return false; +#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED +} +} // namespace + ::mediapipe::Status TfLiteInferenceCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kTensorsTag) ^ @@ -276,32 +308,15 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); cc->InputSidePackets().HasTag("MODEL")) << "Either model as side packet or model path in options is required."; - bool use_gpu = - options.has_delegate() ? options.delegate().has_gpu() : options.use_gpu(); - if (cc->Inputs().HasTag(kTensorsTag)) cc->Inputs().Tag(kTensorsTag).Set>(); -#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) - if (cc->Inputs().HasTag(kTensorsGpuTag)) { - RET_CHECK(!options.has_delegate() || options.delegate().has_gpu()) - << "GPU input is compatible with GPU delegate only."; - - cc->Inputs().Tag(kTensorsGpuTag).Set>(); - use_gpu |= true; - } -#endif // !MEDIAPIPE_DISABLE_GPU - if (cc->Outputs().HasTag(kTensorsTag)) cc->Outputs().Tag(kTensorsTag).Set>(); -#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) - if (cc->Outputs().HasTag(kTensorsGpuTag)) { - RET_CHECK(!options.has_delegate() || options.delegate().has_gpu()) - << "GPU output is compatible with GPU delegate only."; + if (cc->Inputs().HasTag(kTensorsGpuTag)) + cc->Inputs().Tag(kTensorsGpuTag).Set>(); + if (cc->Outputs().HasTag(kTensorsGpuTag)) cc->Outputs().Tag(kTensorsGpuTag).Set>(); - use_gpu |= true; - } -#endif // !MEDIAPIPE_DISABLE_GPU if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { cc->InputSidePackets() @@ -312,10 +327,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); cc->InputSidePackets().Tag("MODEL").Set(); } - if (use_gpu) { -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) + if (ShouldUseGpu(cc)) { +#if MEDIAPIPE_TFLITE_GL_INFERENCE MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#elif defined(MEDIAPIPE_IOS) +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); #endif } @@ -331,123 +346,181 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); const auto& options = cc->Options<::mediapipe::TfLiteInferenceCalculatorOptions>(); - gpu_inference_ = options.use_gpu(); - if (cc->Inputs().HasTag(kTensorsGpuTag)) { -#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) - gpu_input_ = true; - gpu_inference_ = true; // Inference must be on GPU also. -#else - RET_CHECK(!cc->Inputs().HasTag(kTensorsGpuTag)) - << "GPU processing not enabled."; -#endif // !MEDIAPIPE_DISABLE_GPU - } + gpu_inference_ = ShouldUseGpu(cc); + gpu_input_ = cc->Inputs().HasTag(kTensorsGpuTag); + gpu_output_ = cc->Outputs().HasTag(kTensorsGpuTag); - if (cc->Outputs().HasTag(kTensorsGpuTag)) { -#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) - gpu_output_ = true; - RET_CHECK(cc->Inputs().HasTag(kTensorsGpuTag)) - << "GPU output must also have GPU Input."; -#else - RET_CHECK(!cc->Inputs().HasTag(kTensorsGpuTag)) - << "GPU processing not enabled."; -#endif // !MEDIAPIPE_DISABLE_GPU - } - - use_advanced_gpu_api_ = false; - if (use_advanced_gpu_api_ && !(gpu_input_ && gpu_output_)) { - LOG(WARNING) - << "Cannot use advanced GPU APIs, both inputs and outputs must " - "be GPU buffers. Falling back to the default TFLite API."; + use_advanced_gpu_api_ = MEDIAPIPE_TFLITE_GL_INFERENCE && + options.has_delegate() && + options.delegate().has_gpu() && + options.delegate().gpu().use_advanced_gpu_api(); + if (use_advanced_gpu_api_ && !gpu_input_) { + LOG(WARNING) << "Cannot use advanced GPU APIs, input must be GPU buffers." + "Falling back to the default TFLite API."; use_advanced_gpu_api_ = false; } + CHECK(!use_advanced_gpu_api_ || gpu_inference_); MP_RETURN_IF_ERROR(LoadModel(cc)); if (gpu_inference_) { -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) +#if MEDIAPIPE_TFLITE_GL_INFERENCE MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); -#elif defined(MEDIAPIPE_IOS) - gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; - RET_CHECK(gpu_helper_); -#endif - -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) MP_RETURN_IF_ERROR( gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { - return use_advanced_gpu_api_ ? InitTFLiteGPURunner() + return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) : LoadDelegate(cc); })); - if (use_advanced_gpu_api_) return ::mediapipe::OkStatus(); -#else +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE + gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; + RET_CHECK(gpu_helper_); MP_RETURN_IF_ERROR(LoadDelegate(cc)); #endif } else { -#if defined(__EMSCRIPTEN__) || defined(MEDIAPIPE_ANDROID) + // TODO: why only on these platforms? + // It seems that the XNNPACK delegate fails to load on Linux. +#if defined(__EMSCRIPTEN__) || defined(MEDIAPIPE_ANDROID) || \ + defined(MEDIAPIPE_IOS) MP_RETURN_IF_ERROR(LoadDelegate(cc)); -#endif // __EMSCRIPTEN__ || ANDROID +#endif // __EMSCRIPTEN__ || MEDIAPIPE_ANDROID || MEDIAPIPE_IOS } return ::mediapipe::OkStatus(); } ::mediapipe::Status TfLiteInferenceCalculator::Process(CalculatorContext* cc) { - // 0. Declare outputs -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) || defined(MEDIAPIPE_IOS) - auto output_tensors_gpu = absl::make_unique>(); -#endif - auto output_tensors_cpu = absl::make_unique>(); + return RunInContextIfNeeded([this, cc]() -> ::mediapipe::Status { + // 0. Declare outputs + auto output_tensors_gpu = absl::make_unique>(); + auto output_tensors_cpu = absl::make_unique>(); - // 1. Receive pre-processed tensor inputs. - if (use_advanced_gpu_api_ && gpu_output_) { -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) - if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + // 1. Receive pre-processed tensor inputs. + if (gpu_input_) { + MP_RETURN_IF_ERROR(ProcessInputsGpu(cc, output_tensors_gpu.get())); + } else { + MP_RETURN_IF_ERROR(ProcessInputsCpu(cc, output_tensors_cpu.get())); } + + // 2. Run inference. +#if MEDIAPIPE_TFLITE_GL_INFERENCE + if (gpu_inference_ && use_advanced_gpu_api_) { + RET_CHECK(tflite_gpu_runner_->Invoke().ok()); + } else { + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + } +#else + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE + + // 3. Output processed tensors. + if (gpu_output_ || use_advanced_gpu_api_) { + MP_RETURN_IF_ERROR(ProcessOutputsGpu(cc, std::move(output_tensors_cpu), + std::move(output_tensors_gpu))); + } else { + MP_RETURN_IF_ERROR(ProcessOutputsCpu(cc, std::move(output_tensors_cpu))); + } + + return ::mediapipe::OkStatus(); + }); +} + +::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) { + return RunInContextIfNeeded([this]() -> ::mediapipe::Status { + if (delegate_) { + interpreter_ = nullptr; + delegate_ = nullptr; +#if MEDIAPIPE_TFLITE_GPU_SUPPORTED + if (gpu_inference_) { + for (int i = 0; i < gpu_data_in_.size(); ++i) { + gpu_data_in_[i].reset(); + } + for (int i = 0; i < gpu_data_out_.size(); ++i) { + gpu_data_out_[i].reset(); + } + } +#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED + } +#if defined(MEDIAPIPE_EDGE_TPU) + edgetpu_context_.reset(); +#endif + return ::mediapipe::OkStatus(); + }); +} + +// Calculator Auxiliary Section + +::mediapipe::Status TfLiteInferenceCalculator::ProcessInputsCpu( + CalculatorContext* cc, std::vector* output_tensors_cpu) { + if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) { + return ::mediapipe::OkStatus(); + } + // Read CPU input into tensors. + const auto& input_tensors = + cc->Inputs().Tag(kTensorsTag).Get>(); + RET_CHECK_GT(input_tensors.size(), 0); + for (int i = 0; i < input_tensors.size(); ++i) { + const TfLiteTensor* input_tensor = &input_tensors[i]; + RET_CHECK(input_tensor->data.raw); + if (use_quantized_tensors_) { + const uint8* input_tensor_buffer = input_tensor->data.uint8; + uint8* local_tensor_buffer = interpreter_->typed_input_tensor(i); + std::memcpy(local_tensor_buffer, input_tensor_buffer, + input_tensor->bytes); + } else { + const float* input_tensor_buffer = input_tensor->data.f; + float* local_tensor_buffer = interpreter_->typed_input_tensor(i); + std::memcpy(local_tensor_buffer, input_tensor_buffer, + input_tensor->bytes); + } + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteInferenceCalculator::ProcessInputsGpu( + CalculatorContext* cc, std::vector* output_tensors_gpu) { + if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) { + return ::mediapipe::OkStatus(); + } + if (use_advanced_gpu_api_) { +#if MEDIAPIPE_TFLITE_GL_INFERENCE const auto& input_tensors = cc->Inputs().Tag(kTensorsGpuTag).Get>(); RET_CHECK(!input_tensors.empty()); - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, &input_tensors, &output_tensors_gpu]() -> ::mediapipe::Status { - for (int i = 0; i < input_tensors.size(); ++i) { - MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor( - input_tensors[i].id(), i)); - } - // Allocate output tensor. - output_tensors_gpu->resize(gpu_data_out_.size()); - for (int i = 0; i < gpu_data_out_.size(); ++i) { - GpuTensor& tensor = output_tensors_gpu->at(i); - RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( - gpu_data_out_[i]->elements, &tensor)); - MP_RETURN_IF_ERROR( - tflite_gpu_runner_->BindSSBOToOutputTensor(tensor.id(), i)); - } - return ::mediapipe::OkStatus(); - })); -#endif + for (int i = 0; i < input_tensors.size(); ++i) { + MP_RETURN_IF_ERROR( + tflite_gpu_runner_->BindSSBOToInputTensor(input_tensors[i].id(), i)); + } + if (gpu_output_) { + // Allocate new output tensor. + output_tensors_gpu->resize(gpu_data_out_.size()); + for (int i = 0; i < gpu_data_out_.size(); ++i) { + GpuTensor& tensor = output_tensors_gpu->at(i); + RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + gpu_data_out_[i]->elements, &tensor)); + MP_RETURN_IF_ERROR( + tflite_gpu_runner_->BindSSBOToOutputTensor(tensor.id(), i)); + } + } else { + // Re-use internal output tensor. + for (int i = 0; i < gpu_data_out_.size(); ++i) { + MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToOutputTensor( + gpu_data_out_[i]->buffer.id(), i)); + } + } +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE } else if (gpu_input_) { // Read GPU input into SSBO. -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) - if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) { - return ::mediapipe::OkStatus(); - } +#if MEDIAPIPE_TFLITE_GL_INFERENCE const auto& input_tensors = cc->Inputs().Tag(kTensorsGpuTag).Get>(); RET_CHECK_GT(input_tensors.size(), 0); - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, &input_tensors]() -> ::mediapipe::Status { - // Explicit copy input. - gpu_data_in_.resize(input_tensors.size()); - for (int i = 0; i < input_tensors.size(); ++i) { - RET_CHECK_CALL( - CopyBuffer(input_tensors[i], gpu_data_in_[i]->buffer)); - } - - return ::mediapipe::OkStatus(); - })); -#elif defined(MEDIAPIPE_IOS) - if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) { - return ::mediapipe::OkStatus(); + // Explicit copy input. + gpu_data_in_.resize(input_tensors.size()); + for (int i = 0; i < input_tensors.size(); ++i) { + RET_CHECK_CALL(CopyBuffer(input_tensors[i], gpu_data_in_[i]->buffer)); } +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE const auto& input_tensors = cc->Inputs().Tag(kTensorsGpuTag).Get>(); RET_CHECK_GT(input_tensors.size(), 0); @@ -470,79 +543,70 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); } [compute_encoder endEncoding]; [command_buffer commit]; -#else - RET_CHECK_FAIL() << "GPU processing not enabled."; -#endif - } else { - if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) { - return ::mediapipe::OkStatus(); - } - // Read CPU input into tensors. - const auto& input_tensors = - cc->Inputs().Tag(kTensorsTag).Get>(); - RET_CHECK_GT(input_tensors.size(), 0); - for (int i = 0; i < input_tensors.size(); ++i) { - const TfLiteTensor* input_tensor = &input_tensors[i]; - RET_CHECK(input_tensor->data.raw); - if (use_quantized_tensors_) { - const uint8* input_tensor_buffer = input_tensor->data.uint8; - uint8* local_tensor_buffer = interpreter_->typed_input_tensor(i); - std::memcpy(local_tensor_buffer, input_tensor_buffer, - input_tensor->bytes); - } else { - const float* input_tensor_buffer = input_tensor->data.f; - float* local_tensor_buffer = interpreter_->typed_input_tensor(i); - std::memcpy(local_tensor_buffer, input_tensor_buffer, - input_tensor->bytes); - } - } +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE } - // 2. Run inference. - if (gpu_inference_) { -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) - MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status { - if (use_advanced_gpu_api_) { - RET_CHECK(tflite_gpu_runner_->Invoke().ok()); - } else { - RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); - } - return ::mediapipe::OkStatus(); - })); -#elif defined(MEDIAPIPE_IOS) - RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); -#endif - } else { - RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); - } + return ::mediapipe::OkStatus(); +} - // 3. Output processed tensors. +::mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsCpu( + CalculatorContext* cc, + std::unique_ptr> output_tensors_cpu) { + // Output result tensors (CPU). + const auto& tensor_indexes = interpreter_->outputs(); + for (int i = 0; i < tensor_indexes.size(); ++i) { + TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]); + output_tensors_cpu->emplace_back(*tensor); + } + cc->Outputs() + .Tag(kTensorsTag) + .Add(output_tensors_cpu.release(), cc->InputTimestamp()); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsGpu( + CalculatorContext* cc, + std::unique_ptr> output_tensors_cpu, + std::unique_ptr> output_tensors_gpu) { if (use_advanced_gpu_api_) { -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) - cc->Outputs() - .Tag(kTensorsGpuTag) - .Add(output_tensors_gpu.release(), cc->InputTimestamp()); -#endif +#if MEDIAPIPE_TFLITE_GL_INFERENCE + if (gpu_output_) { + // Send out pre-allocated tensors. + cc->Outputs() + .Tag(kTensorsGpuTag) + .Add(output_tensors_gpu.release(), cc->InputTimestamp()); + } else { + // Download to CPU for output. + const auto& tensor_indexes = interpreter_->inputs(); + for (int i = 0; i < tensor_indexes.size(); ++i) { + TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]); + std::vector gpu_data(tensor->bytes / sizeof(float)); + RET_CHECK_CALL(gpu_data_out_[i]->buffer.Read( + absl::MakeSpan(tensor->data.f, tensor->bytes))); + output_tensors_cpu->emplace_back(*tensor); + } + // Output result tensors (CPU). + cc->Outputs() + .Tag(kTensorsTag) + .Add(output_tensors_cpu.release(), cc->InputTimestamp()); + } +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE } else if (gpu_output_) { -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) +#if MEDIAPIPE_TFLITE_GL_INFERENCE // Output result tensors (GPU). - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, &output_tensors_gpu]() -> ::mediapipe::Status { - output_tensors_gpu->resize(gpu_data_out_.size()); - for (int i = 0; i < gpu_data_out_.size(); ++i) { - GpuTensor& tensor = output_tensors_gpu->at(i); - // Allocate output tensor. - RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( - gpu_data_out_[i]->elements, &tensor)); - RET_CHECK_CALL(CopyBuffer(gpu_data_out_[i]->buffer, tensor)); - } - return ::mediapipe::OkStatus(); - })); + output_tensors_gpu->resize(gpu_data_out_.size()); + for (int i = 0; i < gpu_data_out_.size(); ++i) { + GpuTensor& tensor = output_tensors_gpu->at(i); + // Allocate output tensor. + RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer( + gpu_data_out_[i]->elements, &tensor)); + RET_CHECK_CALL(CopyBuffer(gpu_data_out_[i]->buffer, tensor)); + } cc->Outputs() .Tag(kTensorsGpuTag) .Add(output_tensors_gpu.release(), cc->InputTimestamp()); -#elif defined(MEDIAPIPE_IOS) +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE // Output result tensors (GPU). output_tensors_gpu->resize(gpu_data_out_.size()); id device = gpu_helper_.mtlDevice; @@ -566,68 +630,58 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); cc->Outputs() .Tag(kTensorsGpuTag) .Add(output_tensors_gpu.release(), cc->InputTimestamp()); -#else - RET_CHECK_FAIL() << "GPU processing not enabled."; -#endif // !MEDIAPIPE_DISABLE_GPU - } else { - // Output result tensors (CPU). - const auto& tensor_indexes = interpreter_->outputs(); - for (int i = 0; i < tensor_indexes.size(); ++i) { - TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]); - output_tensors_cpu->emplace_back(*tensor); - } - cc->Outputs() - .Tag(kTensorsTag) - .Add(output_tensors_cpu.release(), cc->InputTimestamp()); +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE } return ::mediapipe::OkStatus(); } -::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) { - if (delegate_) { - if (gpu_inference_) { -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status { - interpreter_ = nullptr; - delegate_ = nullptr; - for (int i = 0; i < gpu_data_in_.size(); ++i) { - gpu_data_in_[i].reset(); - } - for (int i = 0; i < gpu_data_out_.size(); ++i) { - gpu_data_out_[i].reset(); - } - return ::mediapipe::OkStatus(); - })); -#elif defined(MEDIAPIPE_IOS) - interpreter_ = nullptr; - delegate_ = nullptr; - for (int i = 0; i < gpu_data_in_.size(); ++i) { - gpu_data_in_[i].reset(); - } - for (int i = 0; i < gpu_data_out_.size(); ++i) { - gpu_data_out_[i].reset(); - } -#endif - } else { - interpreter_ = nullptr; - delegate_ = nullptr; - } +::mediapipe::Status TfLiteInferenceCalculator::InitTFLiteGPURunner( + CalculatorContext* cc) { +#if MEDIAPIPE_TFLITE_GL_INFERENCE + ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc)); + const auto& model = *model_packet_.Get(); + tflite::ops::builtin::BuiltinOpResolver op_resolver; + if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { + op_resolver = cc->InputSidePackets() + .Tag("CUSTOM_OP_RESOLVER") + .Get(); } -#if defined(MEDIAPIPE_EDGE_TPU) - edgetpu_context_.reset(); -#endif - return ::mediapipe::OkStatus(); -} -// Calculator Auxiliary Section + // Create runner + tflite::gpu::InferenceOptions options; + options.priority1 = tflite::gpu::InferencePriority::MIN_LATENCY; + options.priority2 = tflite::gpu::InferencePriority::AUTO; + options.priority3 = tflite::gpu::InferencePriority::AUTO; + options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED; + tflite_gpu_runner_ = std::make_unique(options); + RET_CHECK_CALL(tflite_gpu_runner_->InitializeWithModel(model, op_resolver)); + + // Allocate interpreter memory for cpu output. + if (!gpu_output_) { + interpreter_ = absl::make_unique(); + const int num_outputs = tflite_gpu_runner_->GetOutputShapes().size(); + interpreter_->AddTensors(num_outputs); + std::vector indices(num_outputs); + for (int i = 0; i < num_outputs; ++i) indices[i] = i; + // There is no ResizeOutputTensor(), so we use 'inputs' space instead. + interpreter_->SetInputs(indices); + TfLiteQuantization quant; + quant.type = kTfLiteNoQuantization; + quant.params = nullptr; + for (int i = 0; i < num_outputs; ++i) { + auto shape = tflite_gpu_runner_->GetOutputShapes()[i]; + const int tensor_idx = interpreter_->inputs()[i]; + interpreter_->SetTensorParametersReadWrite(tensor_idx, kTfLiteFloat32, "", + {shape.c}, quant); + CHECK(interpreter_->ResizeInputTensor( + tensor_idx, {shape.h, shape.w, shape.c}) == kTfLiteOk); + } + CHECK(interpreter_->AllocateTensors() == kTfLiteOk); + } -::mediapipe::Status TfLiteInferenceCalculator::InitTFLiteGPURunner() { -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) // Create and bind OpenGL buffers for outputs. - // These buffers are created onve and later their ids are jut passed to the - // calculator outputs. - + // The buffers are created once and their ids are passed to calculator outputs gpu_data_out_.resize(tflite_gpu_runner_->outputs_size()); for (int i = 0; i < tflite_gpu_runner_->outputs_size(); ++i) { gpu_data_out_[i] = absl::make_unique(); @@ -638,15 +692,20 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer)); } RET_CHECK_CALL(tflite_gpu_runner_->Build()); -#endif +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE + return ::mediapipe::OkStatus(); } ::mediapipe::Status TfLiteInferenceCalculator::LoadModel( CalculatorContext* cc) { + if (use_advanced_gpu_api_) { + // Use InitTFLiteGPURunner for everything. + return ::mediapipe::OkStatus(); + } + ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc)); const auto& model = *model_packet_.Get(); - tflite::ops::builtin::BuiltinOpResolver op_resolver; if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { op_resolver = cc->InputSidePackets() @@ -654,19 +713,6 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); .Get(); } -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) - if (use_advanced_gpu_api_) { - tflite::gpu::InferenceOptions options; - options.priority1 = tflite::gpu::InferencePriority::MIN_LATENCY; - options.priority2 = tflite::gpu::InferencePriority::AUTO; - options.priority3 = tflite::gpu::InferencePriority::AUTO; - options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED; - tflite_gpu_runner_ = - std::make_unique(options); - return tflite_gpu_runner_->InitializeWithModel(model, op_resolver); - } -#endif - #if defined(MEDIAPIPE_EDGE_TPU) interpreter_ = BuildEdgeTpuInterpreter(model, &op_resolver, edgetpu_context_.get()); @@ -771,7 +817,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); return ::mediapipe::OkStatus(); } -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) +#if MEDIAPIPE_TFLITE_GL_INFERENCE // Configure and create the delegate. TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault(); options.compile_options.precision_loss_allowed = 1; @@ -832,9 +878,9 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); // Must call this last. RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), kTfLiteOk); -#endif // OpenGL +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE -#if defined(MEDIAPIPE_IOS) +#if MEDIAPIPE_TFLITE_METAL_INFERENCE const int kHalfSize = 2; // sizeof(half) // Configure and create the delegate. TFLGpuDelegateOptions options; @@ -958,7 +1004,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator); "Error initializating output buffer converter"); } } -#endif // iOS +#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.proto b/mediapipe/calculators/tflite/tflite_inference_calculator.proto index d784dc2db..4fc0af932 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.proto +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.proto @@ -45,6 +45,8 @@ message TfLiteInferenceCalculatorOptions { message Gpu { // Experimental, Android/Linux only. Use TFLite GPU delegate API2 for // the NN inference. + // example: + // delegate: { gpu { use_advanced_gpu_api: true } } optional bool use_advanced_gpu_api = 1 [default = false]; } // Android only. diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc index be6796433..412c07125 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc @@ -25,17 +25,18 @@ #include "mediapipe/framework/formats/location.h" #include "mediapipe/framework/formats/object_detection/anchor.pb.h" #include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/util/tflite/config.h" #include "tensorflow/lite/interpreter.h" -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) +#if MEDIAPIPE_TFLITE_GL_INFERENCE #include "mediapipe/gpu/gl_calculator_helper.h" #include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" #include "tensorflow/lite/delegates/gpu/gl/gl_program.h" #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE -#if defined(MEDIAPIPE_IOS) +#if MEDIAPIPE_TFLITE_METAL_INFERENCE #import #import #import @@ -44,7 +45,7 @@ #include "mediapipe/gpu/MPPMetalUtil.h" #include "mediapipe/gpu/gpu_buffer.h" #include "tensorflow/lite/delegates/gpu/metal_delegate.h" -#endif // iOS +#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE namespace { constexpr int kNumInputTensorsWithAnchors = 3; @@ -56,22 +57,17 @@ constexpr char kTensorsGpuTag[] = "TENSORS_GPU"; namespace mediapipe { -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) +#if MEDIAPIPE_TFLITE_GL_INFERENCE using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; using ::tflite::gpu::gl::GlShader; -#endif - -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) -typedef ::tflite::gpu::gl::GlBuffer GpuTensor; typedef ::tflite::gpu::gl::GlProgram GpuProgram; -#elif defined(MEDIAPIPE_IOS) -typedef id GpuTensor; +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE typedef id GpuProgram; -#endif +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE namespace { -#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) +#if MEDIAPIPE_TFLITE_GPU_SUPPORTED struct GPUData { GpuProgram decode_program; GpuProgram score_program; @@ -81,7 +77,7 @@ struct GPUData { GpuTensor scored_boxes_buffer; GpuTensor raw_scores_buffer; }; -#endif +#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes, std::vector* anchors) { @@ -181,13 +177,13 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase { std::vector anchors_; bool side_packet_anchors_{}; -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) +#if MEDIAPIPE_TFLITE_GL_INFERENCE mediapipe::GlCalculatorHelper gpu_helper_; std::unique_ptr gpu_data_; -#elif defined(MEDIAPIPE_IOS) +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE MPPMetalHelper* gpu_helper_ = nullptr; std::unique_ptr gpu_data_; -#endif +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE bool gpu_input_ = false; bool anchors_init_ = false; @@ -205,12 +201,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); cc->Inputs().Tag(kTensorsTag).Set>(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__) if (cc->Inputs().HasTag(kTensorsGpuTag)) { cc->Inputs().Tag(kTensorsGpuTag).Set>(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag("DETECTIONS")) { cc->Outputs().Tag("DETECTIONS").Set>(); @@ -223,11 +217,11 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); } if (use_gpu) { -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) +#if MEDIAPIPE_TFLITE_GL_INFERENCE MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#elif defined(MEDIAPIPE_IOS) +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); -#endif +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE } return ::mediapipe::OkStatus(); @@ -239,12 +233,12 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); if (cc->Inputs().HasTag(kTensorsGpuTag)) { gpu_input_ = true; -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) +#if MEDIAPIPE_TFLITE_GL_INFERENCE MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); -#elif defined(MEDIAPIPE_IOS) +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; RET_CHECK(gpu_helper_); -#endif +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE } MP_RETURN_IF_ERROR(LoadOptions(cc)); @@ -401,7 +395,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); } ::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU( CalculatorContext* cc, std::vector* output_detections) { -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) +#if MEDIAPIPE_TFLITE_GL_INFERENCE const auto& input_tensors = cc->Inputs().Tag(kTensorsGpuTag).Get>(); RET_CHECK_GE(input_tensors.size(), 2); @@ -464,7 +458,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); return ::mediapipe::OkStatus(); })); -#elif defined(MEDIAPIPE_IOS) +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE const auto& input_tensors = cc->Inputs().Tag(kTensorsGpuTag).Get>(); @@ -546,17 +540,17 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); #else LOG(ERROR) << "GPU input on non-Android not supported yet."; -#endif +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE return ::mediapipe::OkStatus(); } ::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Close( CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) +#if MEDIAPIPE_TFLITE_GL_INFERENCE gpu_helper_.RunInGlContext([this] { gpu_data_.reset(); }); -#elif defined(MEDIAPIPE_IOS) +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE gpu_data_.reset(); -#endif +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE return ::mediapipe::OkStatus(); } @@ -705,7 +699,7 @@ Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection( ::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GpuInit( CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) +#if MEDIAPIPE_TFLITE_GL_INFERENCE MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status { gpu_data_ = absl::make_unique(); @@ -918,7 +912,7 @@ void main() { return ::mediapipe::OkStatus(); })); -#elif defined(MEDIAPIPE_IOS) +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE gpu_data_ = absl::make_unique(); id device = gpu_helper_.mtlDevice; @@ -1148,7 +1142,7 @@ kernel void scoreKernel( CHECK_LT(num_classes_, max_wg_size) << "# classes must be <" << max_wg_size; } -#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE return ::mediapipe::OkStatus(); } diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc index 3c41531e1..bc911efb6 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc @@ -217,11 +217,11 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator); for (int i = 0; i < output_landmarks.landmark_size(); ++i) { const Landmark& landmark = output_landmarks.landmark(i); NormalizedLandmark* norm_landmark = output_norm_landmarks.add_landmark(); - norm_landmark->set_x(static_cast(landmark.x()) / - options_.input_image_width()); - norm_landmark->set_y(static_cast(landmark.y()) / - options_.input_image_height()); - norm_landmark->set_z(landmark.z() / options_.normalize_z()); + norm_landmark->set_x(landmark.x() / options_.input_image_width()); + norm_landmark->set_y(landmark.y() / options_.input_image_height()); + // Scale Z coordinate as X + allow additional uniform normalization. + norm_landmark->set_z(landmark.z() / options_.input_image_width() / + options_.normalize_z()); norm_landmark->set_visibility(landmark.visibility()); } cc->Outputs() diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.proto b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.proto index 3b6716c9c..cbf30c181 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.proto +++ b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.proto @@ -29,7 +29,8 @@ message TfLiteTensorsToLandmarksCalculatorOptions { required int32 num_landmarks = 1; // Size of the input image for the model. These options are used only when - // normalized landmarks is needed. + // normalized landmarks are needed. Z coordinate is scaled as X assuming + // a weak perspective projection camera model. optional int32 input_image_width = 2; optional int32 input_image_height = 3; @@ -46,6 +47,8 @@ message TfLiteTensorsToLandmarksCalculatorOptions { // beforehand. optional bool flip_horizontally = 6 [default = false]; - // A value that z values should be divided by. + // A value that Z coordinates should be divided by. This option is used only + // when normalized landmarks are needed. It is applied in addition to Z + // coordinate being re-scaled as X. optional float normalize_z = 5 [default = 1.0]; } diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index b570e4ca2..7223ad44d 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -376,6 +376,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":timed_box_list_id_to_label_calculator_cc_proto", + "@com_google_absl//absl/container:node_hash_map", "//mediapipe/framework/port:status", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", diff --git a/mediapipe/calculators/util/landmark_letterbox_removal_calculator.cc b/mediapipe/calculators/util/landmark_letterbox_removal_calculator.cc index b79d8e4f0..925272230 100644 --- a/mediapipe/calculators/util/landmark_letterbox_removal_calculator.cc +++ b/mediapipe/calculators/util/landmark_letterbox_removal_calculator.cc @@ -122,11 +122,13 @@ class LandmarkLetterboxRemovalCalculator : public CalculatorBase { NormalizedLandmark* new_landmark = output_landmarks.add_landmark(); const float new_x = (landmark.x() - left) / (1.0f - left_and_right); const float new_y = (landmark.y() - top) / (1.0f - top_and_bottom); + const float new_z = + landmark.z() / (1.0f - left_and_right); // Scale Z coordinate as X. new_landmark->set_x(new_x); new_landmark->set_y(new_y); // Keep z-coord as is. - new_landmark->set_z(landmark.z()); + new_landmark->set_z(new_z); // Keep visibility as is. new_landmark->set_visibility(landmark.visibility()); } diff --git a/mediapipe/calculators/util/landmark_projection_calculator.cc b/mediapipe/calculators/util/landmark_projection_calculator.cc index 61986672c..0309c530a 100644 --- a/mediapipe/calculators/util/landmark_projection_calculator.cc +++ b/mediapipe/calculators/util/landmark_projection_calculator.cc @@ -123,11 +123,12 @@ class LandmarkProjectionCalculator : public CalculatorBase { new_x = new_x * input_rect.width() + input_rect.x_center(); new_y = new_y * input_rect.height() + input_rect.y_center(); + const float new_z = + landmark.z() * input_rect.width(); // Scale Z coordinate as X. new_landmark->set_x(new_x); new_landmark->set_y(new_y); - // Keep z-coord as is. - new_landmark->set_z(landmark.z()); + new_landmark->set_z(new_z); // Keep visibility as is. new_landmark->set_visibility(landmark.visibility()); } diff --git a/mediapipe/calculators/util/timed_box_list_id_to_label_calculator.cc b/mediapipe/calculators/util/timed_box_list_id_to_label_calculator.cc index c01327b9b..5d81a7af3 100644 --- a/mediapipe/calculators/util/timed_box_list_id_to_label_calculator.cc +++ b/mediapipe/calculators/util/timed_box_list_id_to_label_calculator.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/container/node_hash_map.h" #include "mediapipe/calculators/util/timed_box_list_id_to_label_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/packet.h" @@ -53,7 +54,7 @@ class TimedBoxListIdToLabelCalculator : public CalculatorBase { ::mediapipe::Status Process(CalculatorContext* cc) override; private: - std::unordered_map label_map_; + absl::node_hash_map label_map_; }; REGISTER_CALCULATOR(TimedBoxListIdToLabelCalculator); diff --git a/mediapipe/examples/android/README.md b/mediapipe/examples/android/README.md index 136d37a3f..8ce927727 100644 --- a/mediapipe/examples/android/README.md +++ b/mediapipe/examples/android/README.md @@ -1,4 +1 @@ -MediaPipe Examples -================== - -This directory contains MediaPipe Android example applications. Please see [src/java/com/google/mediapipe/apps/README.md](src/java/com/google/mediapipe/apps/README.md) for details. +This directory contains MediaPipe example applications for Android. Please see [Solutions](https://solutions.mediapipe.dev)for details. diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/METADATA b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/METADATA deleted file mode 100644 index aee0b0fe7..000000000 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/METADATA +++ /dev/null @@ -1,7 +0,0 @@ -tricorder: { - options: { - builder: { - config: "android_arm64" - } - } -} diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/BUILD index 00a3efcdf..fb0e6835f 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/BUILD @@ -83,7 +83,7 @@ android_binary( manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml", manifest_values = { "applicationId": "com.google.mediapipe.apps.objectdetection3d", - "appName": "Object Detection 3D", + "appName": "Objectron", "mainActivity": ".MainActivity", "cameraFacingFront": "False", "binaryGraphName": "object_detection_3d.binarypb", diff --git a/mediapipe/examples/desktop/README.md b/mediapipe/examples/desktop/README.md index 8e36e42eb..6880098ba 100644 --- a/mediapipe/examples/desktop/README.md +++ b/mediapipe/examples/desktop/README.md @@ -1,113 +1 @@ -**Hello World** - -To build the "Hello World" example, use: - -``` -bazel build -c opt mediapipe/examples/desktop/hello_world:hello_world -``` - -and then run it using: - -``` -export GLOG_logtostderr=1 - -bazel-bin/mediapipe/examples/desktop/hello_world/hello_world -``` - -**TFlite Object Detection** - -To build the object detection demo using a TFLite model on desktop, use: - -``` -bazel build -c opt mediapipe/examples/desktop/object_detection:object_detection_tflite --define MEDIAPIPE_DISABLE_GPU=1 -``` - -and run it using: - -``` -export GLOG_logtostderr=1 - -bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tflite \ - --calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tflite_graph.pbtxt \ - --input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file -``` - -**TensorFlow Object Detection** - -To build the object detection demo using a TensorFlow model on desktop, use: - -``` -export GLOG_logtostderr=1 - -bazel build -c opt mediapipe/examples/desktop/object_detection:object_detection_tensorflow \ - --define MEDIAPIPE_DISABLE_GPU=1 -``` - -and run it using: - -``` -export GLOG_logtostderr=1 - -bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tensorflow \ - --calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_tensorflow_graph.pbtxt \ - --input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file -``` - -**TFlite Hand Detection** - -To build the hand detection demo using a TFLite model on desktop, use: - -``` -bazel build -c opt mediapipe/examples/desktop/hand_tracking:hand_tracking_tflite --define MEDIAPIPE_DISABLE_GPU=1 -``` - -and run it using: - -``` -export GLOG_logtostderr=1 - -bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_tflite \ - --calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_detection_desktop.pbtxt \ - --input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file -``` - -**TFlite Hand Tracking** - -To build the hand tracking demo using a TFLite model on desktop, use: - -``` -bazel build -c opt mediapipe/examples/desktop/hand_tracking:hand_tracking_tflite --define MEDIAPIPE_DISABLE_GPU=1 -``` - -and run it using: - -``` -export GLOG_logtostderr=1 - -bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_tflite \ - --calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop.pbtxt \ - --input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file -``` - -**TFlite Multi-Hand Tracking** - -To build the multi-hand tracking demo using a TFLite model on desktop, use: - -``` -bazel build -c opt mediapipe/examples/desktop/multi_hand_tracking:multi_hand_tracking_tflite --define MEDIAPIPE_DISABLE_GPU=1 -``` - -and run it using: - -``` -export GLOG_logtostderr=1 - -bazel-bin/mediapipe/examples/desktop/multi_hand_tracking/multi_hand_tracking_tflite \ - --calculator_graph_config_file=mediapipe/graphs/hand_tracking/multi_hand_tracking_desktop.pbtxt \ - --input_side_packets=input_video_path=/path/to/input/file,output_video_path=/path/to/output/file -``` - -To change the number of hands to `x` in this application, change: - -1. `min_size:x` in `CollectionHasMinSizeCalculatorOptions` in `mediapipe/graphs/hand_tracking/multi_hand_tracking_desktop.pbtxt`. -2. `max_vec_size:x` in `ClipVectorSizeCalculatorOptions` in `mediapipe/examples/dekstop/hand_tracking/subgraphs/multi_hand_detection_cpu.pbtxt`. +This directory contains MediaPipe example applications for desktop. Please see [Solutions](https://solutions.mediapipe.dev)for details. diff --git a/mediapipe/examples/desktop/autoflip/calculators/BUILD b/mediapipe/examples/desktop/autoflip/calculators/BUILD index 3b1712924..b645dc69f 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/BUILD +++ b/mediapipe/examples/desktop/autoflip/calculators/BUILD @@ -62,8 +62,10 @@ cc_library( "//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto", "//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver", "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image_frame", - "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:location_data_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", ], @@ -126,17 +128,20 @@ cc_test( ":content_zooming_calculator", ":content_zooming_calculator_cc_proto", "//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto", + "//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:location_data_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:benchmark", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc index 38fb72b06..ee403a5d0 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc @@ -19,16 +19,20 @@ #include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h" #include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image_frame.h" -#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/location_data.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_builder.h" constexpr char kVideoFrame[] = "VIDEO"; constexpr char kVideoSize[] = "VIDEO_SIZE"; -constexpr char kDetectionSet[] = "DETECTIONS"; +constexpr char kSalientRegions[] = "SALIENT_REGIONS"; +constexpr char kDetections[] = "DETECTIONS"; constexpr char kDetectedBorders[] = "BORDERS"; +constexpr char kCropRect[] = "CROP_RECT"; // Field-of-view (degrees) of the camera's x-axis (width). // TODO: Parameterize FOV based on camera specs. constexpr float kWidthFieldOfView = 60; @@ -37,12 +41,12 @@ namespace mediapipe { namespace autoflip { // Content zooming calculator zooms in on content when a detection has -// "only_required" set true. It does this by computing the value of top/bottom -// borders to remove from the output and sends these to the -// SceneCroppingCalculator. When more than one detections are received the zoom -// box is calculated as the union of the detections. Typical applications -// include mobile makeover and autofliplive face reframing. Currently only -// supports y-dimension zooming. +// "only_required" set true or any raw detection input. It does this by +// computing the value of top/bottom borders to remove from the output and sends +// these to the SceneCroppingCalculator using BORDERS output or a full rect crop +// using CROP_RECT output. When more than one detections are received the +// zoom box is calculated as the union of the detections. Typical applications +// include mobile makeover and autofliplive face reframing. class ContentZoomingCalculator : public CalculatorBase { public: ContentZoomingCalculator() @@ -56,26 +60,32 @@ class ContentZoomingCalculator : public CalculatorBase { ::mediapipe::Status Process(mediapipe::CalculatorContext* cc) override; private: - // Converts bounds to tilt offset and height. - ::mediapipe::Status ConvertToTiltZoom(float xmin, float xmax, float ymin, - float ymax, int* tilt_offset, - int* height); + // Converts bounds to tilt offset, pan offset and height. + ::mediapipe::Status ConvertToPanTiltZoom(float xmin, float xmax, float ymin, + float ymax, int* tilt_offset, + int* pan_offset, int* height); ContentZoomingCalculatorOptions options_; // Detection frame width/height. int frame_height_; int frame_width_; // Path solver used to smooth top/bottom border crop values. std::unique_ptr path_solver_height_; + std::unique_ptr path_solver_width_; std::unique_ptr path_solver_offset_; // Are parameters initialized. bool initialized_; // Stores the time of the last "only_required" input. int64 last_only_required_detection_; - // Border values of last message with detection. + // Rect values of last message with detection(s). int last_measured_height_; + int last_measured_x_offset_; int last_measured_y_offset_; - // Min border values. - float min_height_value_; + // Target aspect ratio. + float target_aspect_; + // Max size of bounding box. If input/output aspect ratios are the same, + // will be 1.0. Else, will be less than 1.0 to prevent exceeding the size of + // the image in either dimension. + float max_frame_value_; }; REGISTER_CALCULATOR(ContentZoomingCalculator); @@ -92,8 +102,18 @@ REGISTER_CALCULATOR(ContentZoomingCalculator); return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "Input VIDEO or VIDEO_SIZE must be provided."; } - cc->Inputs().Tag(kDetectionSet).Set(); - cc->Outputs().Tag(kDetectedBorders).Set(); + if (cc->Inputs().HasTag(kSalientRegions)) { + cc->Inputs().Tag(kSalientRegions).Set(); + } + if (cc->Inputs().HasTag(kDetections)) { + cc->Inputs().Tag(kDetections).Set>(); + } + if (cc->Outputs().HasTag(kDetectedBorders)) { + cc->Outputs().Tag(kDetectedBorders).Set(); + } + if (cc->Outputs().HasTag(kCropRect)) { + cc->Outputs().Tag(kCropRect).Set(); + } return ::mediapipe::OkStatus(); } @@ -108,29 +128,38 @@ REGISTER_CALCULATOR(ContentZoomingCalculator); if (options_.has_min_motion_to_reframe()) { return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "Deprecated min_motion_to_reframe was set, please set " - "in kinematic_options_zoom and kinematic_options_tilt directly."; + "in kinematic_options_zoom and kinematic_options_tilt " + "directly."; } return ::mediapipe::OkStatus(); } -::mediapipe::Status ContentZoomingCalculator::ConvertToTiltZoom( +::mediapipe::Status ContentZoomingCalculator::ConvertToPanTiltZoom( float xmin, float xmax, float ymin, float ymax, int* tilt_offset, - int* height) { + int* pan_offset, int* height) { // Find center of the y-axis offset (for tilt control). float y_center = ymin + (ymax - ymin) / 2; + // Find center of the x-axis offset (for pan control). + float x_center = xmin + (xmax - xmin) / 2; // Find size and apply scale factor to y-axis. float fit_size = fmax((ymax - ymin) / options_.scale_factor(), xmax - xmin); - // Apply min zoom for cases where the target size is wider than input frame - // size. - fit_size = fmin(min_height_value_, fit_size); + // Apply max frame for cases where the target size is different than input + // frame size. + fit_size = fmin(max_frame_value_, fit_size); // Prevent box from extending beyond the image. if (y_center - fit_size / 2 < 0) { y_center = fit_size / 2; } else if (y_center + fit_size / 2 > 1) { y_center = 1 - fit_size / 2; } + if (x_center - fit_size / 2 < 0) { + x_center = fit_size / 2; + } else if (x_center + fit_size / 2 > 1) { + x_center = 1 - fit_size / 2; + } // Scale to pixel coordinates. *tilt_offset = frame_height_ * y_center; + *pan_offset = frame_width_ * x_center; *height = frame_height_ * fit_size; return ::mediapipe::OkStatus(); } @@ -151,6 +180,20 @@ namespace { return ::mediapipe::OkStatus(); } +::mediapipe::Status UpdateRanges(const mediapipe::Detection& detection, + float* xmin, float* xmax, float* ymin, + float* ymax) { + RET_CHECK(detection.location_data().format() == + mediapipe::LocationData::RELATIVE_BOUNDING_BOX) + << "Face detection input is lacking required relative_bounding_box()"; + const auto& location = detection.location_data().relative_bounding_box(); + *xmin = fmin(*xmin, location.xmin()); + *xmax = fmax(*xmax, location.xmin() + location.width()); + *ymin = fmin(*ymin, location.ymin()); + *ymax = fmax(*ymax, location.ymin() + location.height()); + + return ::mediapipe::OkStatus(); +} void MakeStaticFeatures(const int top_border, const int bottom_border, const int frame_width, const int frame_height, StaticFeatures* static_feature) { @@ -173,10 +216,8 @@ void MakeStaticFeatures(const int top_border, const int bottom_border, ::mediapipe::Status ContentZoomingCalculator::Process( mediapipe::CalculatorContext* cc) { if (cc->Inputs().HasTag(kVideoFrame)) { - cv::Mat frame = mediapipe::formats::MatView( - &cc->Inputs().Tag(kVideoFrame).Get()); - frame_width_ = frame.cols; - frame_height_ = frame.rows; + frame_width_ = cc->Inputs().Tag(kVideoFrame).Get().Width(); + frame_height_ = cc->Inputs().Tag(kVideoFrame).Get().Height(); } else if (cc->Inputs().HasTag(kVideoSize)) { frame_width_ = cc->Inputs().Tag(kVideoSize).Get>().first; @@ -191,10 +232,14 @@ void MakeStaticFeatures(const int top_border, const int bottom_border, path_solver_height_ = std::make_unique( options_.kinematic_options_zoom(), 0, frame_height_, static_cast(frame_width_) / kWidthFieldOfView); + path_solver_width_ = std::make_unique( + options_.kinematic_options_pan(), 0, frame_width_, + static_cast(frame_width_) / kWidthFieldOfView); path_solver_offset_ = std::make_unique( options_.kinematic_options_tilt(), 0, frame_height_, static_cast(frame_width_) / kWidthFieldOfView); - min_height_value_ = 1.0; + max_frame_value_ = 1.0; + target_aspect_ = frame_width_ / static_cast(frame_height_); // If target size is set and wider than input aspect, make sure to always // crop the min required amount. if (options_.has_target_size()) { @@ -203,75 +248,107 @@ void MakeStaticFeatures(const int top_border, const int bottom_border, RET_CHECK_GT(options_.target_size().height(), 0) << "Provided target height not valid."; float input_aspect = frame_width_ / static_cast(frame_height_); - float target_aspect = options_.target_size().width() / - static_cast(options_.target_size().height()); - min_height_value_ = - (input_aspect < target_aspect) ? input_aspect / target_aspect : 1.0; + target_aspect_ = options_.target_size().width() / + static_cast(options_.target_size().height()); + max_frame_value_ = std::min(input_aspect / target_aspect_, + target_aspect_ / input_aspect); } - last_measured_height_ = min_height_value_ * frame_height_; + last_measured_height_ = max_frame_value_ * frame_height_; + last_measured_x_offset_ = target_aspect_ * frame_width_; last_measured_y_offset_ = frame_width_ / 2; initialized_ = true; } - auto detection_set = cc->Inputs().Tag(kDetectionSet).Get(); bool only_required_found = false; // Compute the box that contains all "is_required" detections. float xmin = 1, ymin = 1, xmax = 0, ymax = 0; - for (const auto& region : detection_set.detections()) { - if (!region.only_required()) { - continue; + if (cc->Inputs().HasTag(kSalientRegions)) { + auto detection_set = cc->Inputs().Tag(kSalientRegions).Get(); + for (const auto& region : detection_set.detections()) { + if (!region.only_required()) { + continue; + } + only_required_found = true; + MP_RETURN_IF_ERROR(UpdateRanges(region, &xmin, &xmax, &ymin, &ymax)); + } + } + + if (cc->Inputs().HasTag(kDetections)) { + auto raw_detections = + cc->Inputs().Tag(kDetections).Get>(); + for (const auto& detection : raw_detections) { + only_required_found = true; + MP_RETURN_IF_ERROR(UpdateRanges(detection, &xmin, &xmax, &ymin, &ymax)); } - only_required_found = true; - MP_RETURN_IF_ERROR(UpdateRanges(region, &xmin, &xmax, &ymin, &ymax)); } // Convert bounds to tilt/zoom and in pixel coordinates. - int offset, height; - MP_RETURN_IF_ERROR( - ConvertToTiltZoom(xmin, xmax, ymin, ymax, &offset, &height)); + int offset_y, height, offset_x; + MP_RETURN_IF_ERROR(ConvertToPanTiltZoom(xmin, xmax, ymin, ymax, &offset_y, + &offset_x, &height)); if (only_required_found) { // A only required detection was found. last_only_required_detection_ = cc->InputTimestamp().Microseconds(); last_measured_height_ = height; - last_measured_y_offset_ = offset; + last_measured_x_offset_ = offset_x; + last_measured_y_offset_ = offset_y; } else if (cc->InputTimestamp().Microseconds() - last_only_required_detection_ >= options_.us_before_zoomout()) { - // No only_require detections found within salient regions packets arriving - // since us_before_zoomout duration. - height = min_height_value_ * frame_height_; - offset = frame_height_ / 2; + // No only_require detections found within salient regions packets + // arriving since us_before_zoomout duration. + height = max_frame_value_ * frame_height_; + offset_x = (target_aspect_ * height) / 2; + offset_y = frame_height_ / 2; } else { // No only detection found but using last detection due to // duration_before_zoomout_us setting. height = last_measured_height_; - offset = last_measured_y_offset_; + offset_x = last_measured_x_offset_; + offset_y = last_measured_y_offset_; } // Compute smoothed camera paths. MP_RETURN_IF_ERROR(path_solver_height_->AddObservation( height, cc->InputTimestamp().Microseconds())); + MP_RETURN_IF_ERROR(path_solver_width_->AddObservation( + offset_x, cc->InputTimestamp().Microseconds())); MP_RETURN_IF_ERROR(path_solver_offset_->AddObservation( - offset, cc->InputTimestamp().Microseconds())); - int path_size; - MP_RETURN_IF_ERROR(path_solver_height_->GetState(&path_size)); - int path_offset; - MP_RETURN_IF_ERROR(path_solver_offset_->GetState(&path_offset)); + offset_y, cc->InputTimestamp().Microseconds())); + int path_height; + MP_RETURN_IF_ERROR(path_solver_height_->GetState(&path_height)); + int path_offset_x; + MP_RETURN_IF_ERROR(path_solver_width_->GetState(&path_offset_x)); + int path_offset_y; + MP_RETURN_IF_ERROR(path_solver_offset_->GetState(&path_offset_y)); // Convert to top/bottom borders to remove. - int path_top = path_offset - path_size / 2; - int path_bottom = frame_height_ - (path_offset + path_size / 2); + int path_top = path_offset_y - path_height / 2; + int path_bottom = frame_height_ - (path_offset_y + path_height / 2); - // Transmit result downstream. - std::unique_ptr features = - absl::make_unique(); - MakeStaticFeatures(path_top, path_bottom, frame_width_, frame_height_, - features.get()); - cc->Outputs() - .Tag(kDetectedBorders) - .AddPacket(Adopt(features.release()).At(cc->InputTimestamp())); + // Transmit result downstream to scenecroppingcalculator. + if (cc->Outputs().HasTag(kDetectedBorders)) { + std::unique_ptr features = + absl::make_unique(); + MakeStaticFeatures(path_top, path_bottom, frame_width_, frame_height_, + features.get()); + cc->Outputs() + .Tag(kDetectedBorders) + .AddPacket(Adopt(features.release()).At(cc->InputTimestamp())); + } + + // Transmit downstream to glcroppingcalculator. + if (cc->Outputs().HasTag(kCropRect)) { + auto gpu_rect = absl::make_unique(); + gpu_rect->set_x_center(path_offset_x); + gpu_rect->set_width(path_height * target_aspect_); + gpu_rect->set_y_center(path_offset_y); + gpu_rect->set_height(path_height); + cc->Outputs().Tag(kCropRect).Add(gpu_rect.release(), + Timestamp(cc->InputTimestamp())); + } return ::mediapipe::OkStatus(); } diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto index 78d7c9e93..bf0b8201b 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto @@ -32,6 +32,8 @@ message ContentZoomingCalculatorOptions { optional KinematicOptions kinematic_options_zoom = 6; // Kinematic options for tilt (y-axis reframing.) optional KinematicOptions kinematic_options_tilt = 7; + // Kinematic options for pan (x-axis reframing.) + optional KinematicOptions kinematic_options_pan = 10; // Duration (in MicroSeconds) before returning to fully zoomed out position // when no "only_required" frames are received. optional int64 us_before_zoomout = 9 [default = 1000000]; diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc index e7398d255..a37e09c57 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc @@ -16,10 +16,14 @@ #include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h" #include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h" +#include "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/location_data.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/benchmark.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -36,14 +40,14 @@ namespace { const char kConfigA[] = R"( calculator: "ContentZoomingCalculator" input_stream: "VIDEO:camera_frames" - input_stream: "DETECTIONS:detection_set" + input_stream: "SALIENT_REGIONS:detection_set" output_stream: "BORDERS:borders" )"; const char kConfigB[] = R"( calculator: "ContentZoomingCalculator" input_stream: "VIDEO:camera_frames" - input_stream: "DETECTIONS:detection_set" + input_stream: "SALIENT_REGIONS:detection_set" output_stream: "BORDERS:borders" options: { [mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: { @@ -58,10 +62,17 @@ const char kConfigB[] = R"( const char kConfigC[] = R"( calculator: "ContentZoomingCalculator" input_stream: "VIDEO_SIZE:size" - input_stream: "DETECTIONS:detection_set" + input_stream: "SALIENT_REGIONS:detection_set" output_stream: "BORDERS:borders" )"; +const char kConfigD[] = R"( + calculator: "ContentZoomingCalculator" + input_stream: "VIDEO_SIZE:size" + input_stream: "DETECTIONS:detections" + output_stream: "CROP_RECT:rect" + )"; + void CheckBorder(const StaticFeatures& static_features, int width, int height, int top_border, int bottom_border) { ASSERT_EQ(2, static_features.border().size()); @@ -80,6 +91,43 @@ void CheckBorder(const StaticFeatures& static_features, int width, int height, EXPECT_EQ(Border::BOTTOM, part.relative_position()); } +void AddDetection(const cv::Rect_& position, const int64 time, + CalculatorRunner* runner) { + auto detections = std::make_unique>(); + mediapipe::Detection detection; + detection.mutable_location_data()->set_format( + mediapipe::LocationData::RELATIVE_BOUNDING_BOX); + detection.mutable_location_data() + ->mutable_relative_bounding_box() + ->set_height(position.height); + detection.mutable_location_data()->mutable_relative_bounding_box()->set_width( + position.width); + detection.mutable_location_data()->mutable_relative_bounding_box()->set_xmin( + position.x); + detection.mutable_location_data()->mutable_relative_bounding_box()->set_ymin( + position.y); + detections->push_back(detection); + runner->MutableInputs() + ->Tag("DETECTIONS") + .packets.push_back(Adopt(detections.release()).At(Timestamp(time))); + + auto input_size = ::absl::make_unique>(1000, 1000); + runner->MutableInputs() + ->Tag("VIDEO_SIZE") + .packets.push_back(Adopt(input_size.release()).At(Timestamp(time))); +} + +void CheckCropRect(const int x_center, const int y_center, const int width, + const int height, const int frame_number, + const std::vector& output_packets) { + ASSERT_GT(output_packets.size(), frame_number); + const auto& rect = output_packets[frame_number].Get(); + EXPECT_EQ(rect.x_center(), x_center); + EXPECT_EQ(rect.y_center(), y_center); + EXPECT_EQ(rect.width(), width); + EXPECT_EQ(rect.height(), height); +} + TEST(ContentZoomingCalculatorTest, ZoomTest) { auto runner = ::absl::make_unique( ParseTextProtoOrDie(kConfigA)); @@ -98,7 +146,7 @@ TEST(ContentZoomingCalculatorTest, ZoomTest) { Adopt(input_frame.release()).At(Timestamp(0))); runner->MutableInputs() - ->Tag("DETECTIONS") + ->Tag("SALIENT_REGIONS") .packets.push_back(Adopt(detection_set.release()).At(Timestamp(0))); // Run the calculator. @@ -111,6 +159,66 @@ TEST(ContentZoomingCalculatorTest, ZoomTest) { CheckBorder(static_features, 1000, 1000, 495, 395); } +TEST(ContentZoomingCalculatorTest, ZoomTestFullPTZ) { + auto runner = ::absl::make_unique( + ParseTextProtoOrDie(kConfigD)); + AddDetection(cv::Rect_(.4, .5, .1, .1), 0, runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(450, 550, 111, 111, 0, + runner->Outputs().Tag("CROP_RECT").packets); +} + +TEST(ContentZoomingCalculatorTest, PanConfig) { + auto config = ParseTextProtoOrDie(kConfigD); + auto* options = config.mutable_options()->MutableExtension( + ContentZoomingCalculatorOptions::ext); + options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(0.0); + options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(5.0); + options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(5.0); + auto runner = ::absl::make_unique(config); + AddDetection(cv::Rect_(.4, .5, .1, .1), 0, runner.get()); + AddDetection(cv::Rect_(.45, .55, .15, .15), 1000000, runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(450, 550, 111, 111, 0, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(488, 550, 111, 111, 1, + runner->Outputs().Tag("CROP_RECT").packets); +} + +TEST(ContentZoomingCalculatorTest, TiltConfig) { + auto config = ParseTextProtoOrDie(kConfigD); + auto* options = config.mutable_options()->MutableExtension( + ContentZoomingCalculatorOptions::ext); + options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(5.0); + options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(0.0); + options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(5.0); + auto runner = ::absl::make_unique(config); + AddDetection(cv::Rect_(.4, .5, .1, .1), 0, runner.get()); + AddDetection(cv::Rect_(.45, .55, .15, .15), 1000000, runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(450, 550, 111, 111, 0, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(450, 588, 111, 111, 1, + runner->Outputs().Tag("CROP_RECT").packets); +} + +TEST(ContentZoomingCalculatorTest, ZoomConfig) { + auto config = ParseTextProtoOrDie(kConfigD); + auto* options = config.mutable_options()->MutableExtension( + ContentZoomingCalculatorOptions::ext); + options->mutable_kinematic_options_pan()->set_min_motion_to_reframe(5.0); + options->mutable_kinematic_options_tilt()->set_min_motion_to_reframe(5.0); + options->mutable_kinematic_options_zoom()->set_min_motion_to_reframe(0.0); + auto runner = ::absl::make_unique(config); + AddDetection(cv::Rect_(.4, .5, .1, .1), 0, runner.get()); + AddDetection(cv::Rect_(.45, .55, .15, .15), 1000000, runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(450, 550, 111, 111, 0, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(450, 550, 139, 139, 1, + runner->Outputs().Tag("CROP_RECT").packets); +} + TEST(ContentZoomingCalculatorTest, MinAspectBorderValues) { auto runner = ::absl::make_unique( ParseTextProtoOrDie(kConfigB)); @@ -129,7 +237,7 @@ TEST(ContentZoomingCalculatorTest, MinAspectBorderValues) { Adopt(input_frame.release()).At(Timestamp(0))); runner->MutableInputs() - ->Tag("DETECTIONS") + ->Tag("SALIENT_REGIONS") .packets.push_back(Adopt(detection_set.release()).At(Timestamp(0))); // Run the calculator. @@ -166,7 +274,7 @@ TEST(ContentZoomingCalculatorTest, TwoFacesWide) { Adopt(input_frame.release()).At(Timestamp(0))); runner->MutableInputs() - ->Tag("DETECTIONS") + ->Tag("SALIENT_REGIONS") .packets.push_back(Adopt(detection_set.release()).At(Timestamp(0))); // Run the calculator. @@ -191,7 +299,7 @@ TEST(ContentZoomingCalculatorTest, NoDetectionOnInit) { Adopt(input_frame.release()).At(Timestamp(0))); runner->MutableInputs() - ->Tag("DETECTIONS") + ->Tag("SALIENT_REGIONS") .packets.push_back(Adopt(detection_set.release()).At(Timestamp(0))); // Run the calculator. @@ -223,7 +331,7 @@ TEST(ContentZoomingCalculatorTest, ZoomTestPairSize) { .packets.push_back(Adopt(input_size.release()).At(Timestamp(0))); runner->MutableInputs() - ->Tag("DETECTIONS") + ->Tag("SALIENT_REGIONS") .packets.push_back(Adopt(detection_set.release()).At(Timestamp(0))); // Run the calculator. diff --git a/mediapipe/examples/desktop/autoflip/subgraph/front_face_detection_subgraph.pbtxt b/mediapipe/examples/desktop/autoflip/subgraph/front_face_detection_subgraph.pbtxt index 18d336c80..b88ea0c75 100644 --- a/mediapipe/examples/desktop/autoflip/subgraph/front_face_detection_subgraph.pbtxt +++ b/mediapipe/examples/desktop/autoflip/subgraph/front_face_detection_subgraph.pbtxt @@ -37,7 +37,7 @@ node { output_stream: "TENSORS:detection_tensors" options: { [mediapipe.TfLiteInferenceCalculatorOptions.ext] { - model_path: "face_detection_front.tflite" + model_path: "mediapipe/models/face_detection_front.tflite" } } } @@ -118,7 +118,7 @@ node { output_stream: "labeled_detections" options: { [mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] { - label_map_path: "face_detection_front_labelmap.txt" + label_map_path: "mediapipe/models/face_detection_front_labelmap.txt" } } } diff --git a/mediapipe/examples/ios/README.md b/mediapipe/examples/ios/README.md index 0a3f9b4bf..82813b566 100644 --- a/mediapipe/examples/ios/README.md +++ b/mediapipe/examples/ios/README.md @@ -1,18 +1 @@ -This directory contains example MediaPipe applications on iOS. - -| Use Case | Directory | -|---------------------------------------|:-----------------------------------:| -| Edge Detection on GPU | edgedetection | -| Face Detection on CPU | facedetectioncpu | -| Face Detection on GPU | facedetectiongpu | -| Object Detection on CPU | objectdetectioncpu | -| Object Detection on GPU | objectdetectiongpu | -| Hand Detection on GPU | handdetectiongpu | -| Hand Tracking on GPU | handtrackinggpu | - -For instance, to build an example app for face detection on CPU, run: - -```bash -bazel build -c opt --config=ios_arm64 --xcode_version=$XCODE_VERSION --cxxopt='-std=c++14' mediapipe/examples/ios/facedetectioncpu:FaceDetectionCpuApp -``` -(Note: with your own $XCODE_VERSION) +This directory contains MediaPipe example applications for iOS. Please see [Solutions](https://solutions.mediapipe.dev)for details. diff --git a/mediapipe/examples/ios/edgedetectiongpu/BUILD b/mediapipe/examples/ios/edgedetectiongpu/BUILD index aa5f721c1..66ea1b066 100644 --- a/mediapipe/examples/ios/edgedetectiongpu/BUILD +++ b/mediapipe/examples/ios/edgedetectiongpu/BUILD @@ -21,6 +21,11 @@ load( "ios_application", ) +alias( + name = "edgedetectiongpu", + actual = "EdgeDetectionGpuApp", +) + ios_application( name = "EdgeDetectionGpuApp", bundle_id = "com.google.mediapipe.EdgeDetectionGpu", diff --git a/mediapipe/examples/ios/facedetectioncpu/BUILD b/mediapipe/examples/ios/facedetectioncpu/BUILD index cd97b42d8..1e8488b34 100644 --- a/mediapipe/examples/ios/facedetectioncpu/BUILD +++ b/mediapipe/examples/ios/facedetectioncpu/BUILD @@ -21,6 +21,11 @@ load( "ios_application", ) +alias( + name = "facedetectioncpu", + actual = "FaceDetectionCpuApp", +) + ios_application( name = "FaceDetectionCpuApp", bundle_id = "com.google.mediapipe.FaceDetectionCpu", diff --git a/mediapipe/examples/ios/facedetectiongpu/BUILD b/mediapipe/examples/ios/facedetectiongpu/BUILD index 2e46f86b8..b6fce8791 100644 --- a/mediapipe/examples/ios/facedetectiongpu/BUILD +++ b/mediapipe/examples/ios/facedetectiongpu/BUILD @@ -21,6 +21,11 @@ load( "ios_application", ) +alias( + name = "facedetectiongpu", + actual = "FaceDetectionGpuApp", +) + ios_application( name = "FaceDetectionGpuApp", bundle_id = "com.google.mediapipe.FaceDetectionGpu", diff --git a/mediapipe/examples/ios/facemeshgpu/BUILD b/mediapipe/examples/ios/facemeshgpu/BUILD index 5c9df3feb..a892510ff 100644 --- a/mediapipe/examples/ios/facemeshgpu/BUILD +++ b/mediapipe/examples/ios/facemeshgpu/BUILD @@ -21,6 +21,11 @@ licenses(["notice"]) # Apache 2.0 MIN_IOS_VERSION = "10.0" +alias( + name = "facemeshgpu", + actual = "FaceMeshGpuApp", +) + ios_application( name = "FaceMeshGpuApp", bundle_id = "com.google.mediapipe.FaceMeshGpu", diff --git a/mediapipe/examples/ios/handdetectiongpu/BUILD b/mediapipe/examples/ios/handdetectiongpu/BUILD index 47f1f0ed5..162166a42 100644 --- a/mediapipe/examples/ios/handdetectiongpu/BUILD +++ b/mediapipe/examples/ios/handdetectiongpu/BUILD @@ -21,6 +21,11 @@ load( "ios_application", ) +alias( + name = "handdetectiongpu", + actual = "HandDetectionGpuApp", +) + ios_application( name = "HandDetectionGpuApp", bundle_id = "com.google.mediapipe.HandDetectionGpu", diff --git a/mediapipe/examples/ios/handtrackinggpu/BUILD b/mediapipe/examples/ios/handtrackinggpu/BUILD index 6f3841eb1..72965cef3 100644 --- a/mediapipe/examples/ios/handtrackinggpu/BUILD +++ b/mediapipe/examples/ios/handtrackinggpu/BUILD @@ -21,6 +21,11 @@ licenses(["notice"]) # Apache 2.0 MIN_IOS_VERSION = "10.0" +alias( + name = "handtrackinggpu", + actual = "HandTrackingGpuApp", +) + ios_application( name = "HandTrackingGpuApp", bundle_id = "com.google.mediapipe.HandTrackingGpu", diff --git a/mediapipe/examples/ios/multihandtrackinggpu/BUILD b/mediapipe/examples/ios/multihandtrackinggpu/BUILD index f93854608..be718d3e9 100644 --- a/mediapipe/examples/ios/multihandtrackinggpu/BUILD +++ b/mediapipe/examples/ios/multihandtrackinggpu/BUILD @@ -21,6 +21,11 @@ licenses(["notice"]) # Apache 2.0 MIN_IOS_VERSION = "10.0" +alias( + name = "multihandtrackinggpu", + actual = "MultiHandTrackingGpuApp", +) + ios_application( name = "MultiHandTrackingGpuApp", bundle_id = "com.google.mediapipe.MultiHandTrackingGpu", diff --git a/mediapipe/examples/ios/objectdetectioncpu/BUILD b/mediapipe/examples/ios/objectdetectioncpu/BUILD index 37d316c99..0efb96316 100644 --- a/mediapipe/examples/ios/objectdetectioncpu/BUILD +++ b/mediapipe/examples/ios/objectdetectioncpu/BUILD @@ -21,6 +21,11 @@ load( "ios_application", ) +alias( + name = "objectdetectioncpu", + actual = "ObjectDetectionCpuApp", +) + ios_application( name = "ObjectDetectionCpuApp", bundle_id = "com.google.mediapipe.ObjectDetectionCpu", diff --git a/mediapipe/examples/ios/objectdetectiongpu/AppDelegate.m b/mediapipe/examples/ios/objectdetectiongpu/AppDelegate.m index 9e1b7ff0e..cee668142 100644 --- a/mediapipe/examples/ios/objectdetectiongpu/AppDelegate.m +++ b/mediapipe/examples/ios/objectdetectiongpu/AppDelegate.m @@ -13,6 +13,7 @@ // limitations under the License. #import "AppDelegate.h" +#import "ViewController.h" @interface AppDelegate () @@ -22,7 +23,14 @@ - (BOOL)application:(UIApplication *)application didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { - // Override point for customization after application launch. + ViewController *viewController = (ViewController *)self.window.rootViewController; + NSURL *url = [launchOptions objectForKey:UIApplicationLaunchOptionsURLKey]; + // Unattended testing on Firebase is enabled by custom URL schema. + if ([url.scheme isEqualToString:@"firebase-game-loop"]) { + [viewController setSourceMode:MediaPipeDemoSourceVideo]; + } else { + [viewController setSourceMode:MediaPipeDemoSourceBackCamera]; + } return YES; } diff --git a/mediapipe/examples/ios/objectdetectiongpu/BUILD b/mediapipe/examples/ios/objectdetectiongpu/BUILD index 307bc4a12..288273ac0 100644 --- a/mediapipe/examples/ios/objectdetectiongpu/BUILD +++ b/mediapipe/examples/ios/objectdetectiongpu/BUILD @@ -21,6 +21,11 @@ load( "ios_application", ) +alias( + name = "objectdetectiongpu", + actual = "ObjectDetectionGpuApp", +) + ios_application( name = "ObjectDetectionGpuApp", bundle_id = "com.google.mediapipe.ObjectDetectionGpu", diff --git a/mediapipe/examples/ios/objectdetectiongpu/Info.plist b/mediapipe/examples/ios/objectdetectiongpu/Info.plist index 30db14c62..3a193f784 100644 --- a/mediapipe/examples/ios/objectdetectiongpu/Info.plist +++ b/mediapipe/examples/ios/objectdetectiongpu/Info.plist @@ -38,5 +38,18 @@ UIInterfaceOrientationPortrait + CFBundleURLTypes + + + CFBundleURLName + com.google.firebase + CFBundleTypeRole + Editor + CFBundleURLSchemes + + firebase-game-loop + + + diff --git a/mediapipe/examples/ios/objectdetectiongpu/ViewController.h b/mediapipe/examples/ios/objectdetectiongpu/ViewController.h index e0a5a6367..c768fa0d9 100644 --- a/mediapipe/examples/ios/objectdetectiongpu/ViewController.h +++ b/mediapipe/examples/ios/objectdetectiongpu/ViewController.h @@ -14,6 +14,11 @@ #import -@interface ViewController : UIViewController +typedef NS_ENUM(NSInteger, MediaPipeDemoSourceMode) { + MediaPipeDemoSourceBackCamera, + MediaPipeDemoSourceVideo +}; +@interface ViewController : UIViewController +- (void)setSourceMode:(MediaPipeDemoSourceMode)mode; @end diff --git a/mediapipe/examples/ios/objectdetectiongpu/ViewController.mm b/mediapipe/examples/ios/objectdetectiongpu/ViewController.mm index 236a9a5a0..fc667d9d7 100644 --- a/mediapipe/examples/ios/objectdetectiongpu/ViewController.mm +++ b/mediapipe/examples/ios/objectdetectiongpu/ViewController.mm @@ -17,6 +17,7 @@ #import "mediapipe/objc/MPPGraph.h" #import "mediapipe/objc/MPPCameraInputSource.h" #import "mediapipe/objc/MPPLayerRenderer.h" +#import "mediapipe/objc/MPPPlayerInputSource.h" static NSString* const kGraphName = @"mobile_gpu"; @@ -35,6 +36,8 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue"; @implementation ViewController { /// Handles camera access via AVCaptureSession library. MPPCameraInputSource* _cameraSource; + MPPPlayerInputSource* _videoSource; + MediaPipeDemoSourceMode _sourceMode; /// Inform the user when camera is unavailable. IBOutlet UILabel* _noCameraLabel; @@ -47,6 +50,10 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue"; dispatch_queue_t _videoQueue; } +- (void)setSourceMode:(MediaPipeDemoSourceMode)mode { + _sourceMode = mode; +} + #pragma mark - Cleanup methods - (void)dealloc { @@ -97,13 +104,6 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue"; DISPATCH_QUEUE_SERIAL, QOS_CLASS_USER_INTERACTIVE, /*relative_priority=*/0); _videoQueue = dispatch_queue_create(kVideoQueueLabel, qosAttribute); - _cameraSource = [[MPPCameraInputSource alloc] init]; - [_cameraSource setDelegate:self queue:_videoQueue]; - _cameraSource.sessionPreset = AVCaptureSessionPresetHigh; - _cameraSource.cameraPosition = AVCaptureDevicePositionBack; - // The frame's native format is rotated with respect to the portrait orientation. - _cameraSource.orientation = AVCaptureVideoOrientationPortrait; - self.mediapipeGraph = [[self class] loadGraphFromResource:kGraphName]; self.mediapipeGraph.delegate = self; // Set maxFramesInFlight to a small value to avoid memory contention for real-time processing. @@ -119,27 +119,43 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue"; - (void)viewWillAppear:(BOOL)animated { [super viewWillAppear:animated]; - [_cameraSource requestCameraAccessWithCompletionHandler:^void(BOOL granted) { - if (granted) { - [self startGraphAndCamera]; - dispatch_async(dispatch_get_main_queue(), ^{ - _noCameraLabel.hidden = YES; - }); - } - }]; -} - -- (void)startGraphAndCamera { // Start running self.mediapipeGraph. NSError* error; if (![self.mediapipeGraph startWithError:&error]) { NSLog(@"Failed to start graph: %@", error); } - // Start fetching frames from the camera. - dispatch_async(_videoQueue, ^{ - [_cameraSource start]; - }); + switch (_sourceMode) { + case MediaPipeDemoSourceVideo: { + AVAsset* video = + [AVAsset assetWithURL:[[NSBundle mainBundle] URLForResource:@"object_detection" + withExtension:@"mov"]]; + _videoSource = [[MPPPlayerInputSource alloc] initWithAVAsset:video]; + [_videoSource setDelegate:self queue:_videoQueue]; + dispatch_async(_videoQueue, ^{ + [_videoSource start]; + }); + break; + } + case MediaPipeDemoSourceBackCamera: + _cameraSource = [[MPPCameraInputSource alloc] init]; + [_cameraSource setDelegate:self queue:_videoQueue]; + _cameraSource.sessionPreset = AVCaptureSessionPresetHigh; + _cameraSource.cameraPosition = AVCaptureDevicePositionBack; + // The frame's native format is rotated with respect to the portrait orientation. + _cameraSource.orientation = AVCaptureVideoOrientationPortrait; + [_cameraSource requestCameraAccessWithCompletionHandler:^void(BOOL granted) { + if (granted) { + dispatch_async(_videoQueue, ^{ + [_cameraSource start]; + }); + dispatch_async(dispatch_get_main_queue(), ^{ + _noCameraLabel.hidden = YES; + }); + } + }]; + break; + } } #pragma mark - MPPGraphDelegate methods @@ -164,7 +180,7 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue"; - (void)processVideoFrame:(CVPixelBufferRef)imageBuffer timestamp:(CMTime)timestamp fromSource:(MPPInputSource*)source { - if (source != _cameraSource) { + if (source != _cameraSource && source != _videoSource) { NSLog(@"Unknown source: %@", source); return; } diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index df4e4c553..f5b170ab9 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -36,7 +36,7 @@ exports_files([ mediapipe_proto_library( name = "calculator_proto", srcs = ["calculator.proto"], - visibility = [":mediapipe_internal"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:mediapipe_options_proto", @@ -68,7 +68,7 @@ mediapipe_proto_library( mediapipe_proto_library( name = "calculator_profile_proto", srcs = ["calculator_profile.proto"], - visibility = [":mediapipe_internal"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -830,6 +830,8 @@ cc_library( ":port", ":timestamp", ":type_map", + "//mediapipe/framework/deps:no_destructor", + "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -1524,6 +1526,21 @@ cc_test( ], ) +cc_test( + name = "packet_registration_test", + size = "small", + srcs = ["packet_registration_test.cc"], + deps = [ + ":calculator_framework", + ":packet", + ":packet_test_cc_proto", + ":type_map", + "//mediapipe/framework/port:core_proto", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/strings", + ], +) + cc_test( name = "packet_generator_test", size = "small", diff --git a/mediapipe/framework/calculator_contract.h b/mediapipe/framework/calculator_contract.h index a47632fc9..ba6cdad85 100644 --- a/mediapipe/framework/calculator_contract.h +++ b/mediapipe/framework/calculator_contract.h @@ -113,8 +113,11 @@ class CalculatorContract { // calculations should use SetProcessTimestampBounds. // When true, Process is called for every new timestamp bound, with or without - // new packets. A call to Process with only an input timestamp bound is + // new packets. A call to Process with only an input timestamp bound is // normally used to compute a new output timestamp bound. + // NOTE: Also, when true, Process is called when input streams become done, + // which means, Process needs to handle input streams in "done" state. + // (Usually, by closing calculators' outputs where and when appropriate.) void SetProcessTimestampBounds(bool process_timestamps) { process_timestamps_ = process_timestamps; } diff --git a/mediapipe/framework/calculator_graph.h b/mediapipe/framework/calculator_graph.h index 3826dbda5..7a1935a92 100644 --- a/mediapipe/framework/calculator_graph.h +++ b/mediapipe/framework/calculator_graph.h @@ -91,6 +91,9 @@ typedef ::mediapipe::StatusOr StatusOrPoller; // {{"video_id", mediapipe::MakePacket("Ex-uGhDzue4")}})); // // See mediapipe/framework/graph_runner.h for an interface // // to insert and extract packets from a graph as it runs. +// // Once it is done using the graph, close its streams and wait till done. +// MP_RETURN_IF_ERROR(graph->CloseAllInputStreams()); +// MP_RETURN_IF_ERROR(graph->WaitUntilDone()); class CalculatorGraph { public: // Defines possible modes for adding a packet to a graph input stream. @@ -157,8 +160,9 @@ class CalculatorGraph { std::function<::mediapipe::Status(const Packet&)> packet_callback); // Adds an OutputStreamPoller for a stream. This provides a synchronous, - // polling API for accessing a stream's output. For asynchronous output, use - // ObserveOutputStream. See also the helpers in tool/sink.h. + // polling API for accessing a stream's output. Should only be called before + // Run() or StartRun(). For asynchronous output, use ObserveOutputStream. See + // also the helpers in tool/sink.h. StatusOrPoller AddOutputStreamPoller(const std::string& stream_name); // Gets output side packet by name after the graph is done. However, base @@ -300,6 +304,13 @@ class CalculatorGraph { void RecordError(const ::mediapipe::Status& error) ABSL_LOCKS_EXCLUDED(error_mutex_); + // Combines errors into a status. Returns true if the vector of errors is + // non-empty. + bool GetCombinedErrors(const std::string& error_prefix, + ::mediapipe::Status* error_status); + // Convenience overload which specifies a default error prefix. + bool GetCombinedErrors(::mediapipe::Status* error_status); + // Returns the maximum input stream queue size. int GetMaxInputStreamQueueSize(); @@ -501,13 +512,6 @@ class CalculatorGraph { void CleanupAfterRun(::mediapipe::Status* status) ABSL_LOCKS_EXCLUDED(error_mutex_); - // Combines errors into a status. Returns true if the vector of errors is - // non-empty. - bool GetCombinedErrors(const std::string& error_prefix, - ::mediapipe::Status* error_status); - // Convenience overload which specifies a default error prefix. - bool GetCombinedErrors(::mediapipe::Status* error_status); - // Calls HandlePreRunStatus or HandleStatus on the StatusHandlers. Which one // is called depends on the GraphRunState parameter (PRE_RUN or POST_RUN). // current_run_side_packets_ must be set before this function is called. diff --git a/mediapipe/framework/deps/vector.h b/mediapipe/framework/deps/vector.h index 24f2480cd..0ecf3f2dd 100644 --- a/mediapipe/framework/deps/vector.h +++ b/mediapipe/framework/deps/vector.h @@ -458,8 +458,9 @@ class Vector3 // return the index of the largest component (fabs) int LargestAbsComponent() const { Vector3 temp = Abs(); - return temp[0] > temp[1] ? temp[0] > temp[2] ? 0 : 2 - : temp[1] > temp[2] ? 1 : 2; + return temp[0] > temp[1] ? temp[0] > temp[2] ? 0 : 2 + : temp[1] > temp[2] ? 1 + : 2; } // return the index of the smallest, median ,largest component of the vector diff --git a/mediapipe/framework/input_stream_handler.h b/mediapipe/framework/input_stream_handler.h index aa9f44285..db0a7fc5e 100644 --- a/mediapipe/framework/input_stream_handler.h +++ b/mediapipe/framework/input_stream_handler.h @@ -155,7 +155,7 @@ class InputStreamHandler { // max number of invocations that are allowed to be scheduled is reached. // Returns true if at least one invocation has been scheduled. // The latest minimum timestamp bound of the input streams is returned in - // *input_bound iff the latest readiness of the node is kNotReady when the + // *input_bound if the latest readiness of the node is kNotReady when the // function returns. During batching, this value will be equal to the // timestamp of the first set of inputs in the batch. In other cases, // Timestamp::Unset() is returned. diff --git a/mediapipe/framework/legacy_calculator_support.h b/mediapipe/framework/legacy_calculator_support.h index 6a76101bb..4cd15ce22 100644 --- a/mediapipe/framework/legacy_calculator_support.h +++ b/mediapipe/framework/legacy_calculator_support.h @@ -61,11 +61,25 @@ class LegacyCalculatorSupport { // platforms. #ifndef __APPLE__ ABSL_CONST_INIT -#endif // !__APPLE__ +#endif // !__APPLE__ static thread_local C* current_; // NOLINT }; }; +// We only declare this variable for two specializations of the template because +// it is only meant to be used for these two types. +// Note that, since these variables are members of specific template +// _specializations_, they are not themselves templates, and therefore their +// definitions must be in the .cc file. However, a declaration still needs to be +// included in the header, or some compilers will assume they have no +// definition. +template <> +thread_local CalculatorContext* + LegacyCalculatorSupport::Scoped::current_; +template <> +thread_local CalculatorContract* + LegacyCalculatorSupport::Scoped::current_; + } // namespace mediapipe #endif // MEDIAPIPE_FRAMEWORK_LEGACY_CALCULATOR_SUPPORT_H_ diff --git a/mediapipe/framework/packet.cc b/mediapipe/framework/packet.cc index 1b23e521e..8d9914835 100644 --- a/mediapipe/framework/packet.cc +++ b/mediapipe/framework/packet.cc @@ -51,6 +51,18 @@ const HolderBase* GetHolder(const Packet& packet) { return packet.holder_.get(); } +::mediapipe::StatusOr PacketFromDynamicProto( + const std::string& type_name, const std::string& serialized) { + ASSIGN_OR_RETURN( + auto message_holder, + packet_internal::MessageHolderRegistry::CreateByName(type_name)); + auto* message = + const_cast(message_holder->GetProtoMessageLite()); + RET_CHECK_NE(message, nullptr); + RET_CHECK(message->ParseFromString(serialized)); + return packet_internal::Create(message_holder.release()); +} + } // namespace packet_internal Packet Packet::At(class Timestamp timestamp) const& { diff --git a/mediapipe/framework/packet.h b/mediapipe/framework/packet.h index bee14f702..f6b177454 100644 --- a/mediapipe/framework/packet.h +++ b/mediapipe/framework/packet.h @@ -27,6 +27,8 @@ #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" +#include "mediapipe/framework/deps/no_destructor.h" +#include "mediapipe/framework/deps/registration.h" #include "mediapipe/framework/port.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/logging.h" @@ -51,6 +53,8 @@ Packet Create(HolderBase* holder, Timestamp timestamp); Packet Create(std::shared_ptr holder, Timestamp timestamp); const HolderBase* GetHolder(const Packet& packet); const std::shared_ptr& GetHolderShared(const Packet& packet); +::mediapipe::StatusOr PacketFromDynamicProto( + const std::string& type_name, const std::string& serialized); } // namespace packet_internal // A generic container class which can hold data of any type. The type of @@ -355,112 +359,11 @@ class HolderBase { // Downcasts this to Holder. Returns nullptr if deserialization // failed or if the requested type is not what is stored. template - inline Holder* As( - typename std::enable_if< - (!std::is_base_of::value && - !std::is_base_of::value) || - (std::is_same::value || - std::is_same::value)>::type* = 0) { - if (HolderIsOfType>() || HolderIsOfType>()) { - return static_cast*>(this); - } - // Does not hold a T. - return nullptr; - } - - // For proto Message/MessageLite subclasses. - // When holder data is a concrete proto, the method downcasts this to - // Holder if the requested type is what is stored. - // When holder data is a generic proto Message/MessageLite and a concrete - // proto type T is requested, the method will downcast the HolderBase to - // Holder if the proto data is an instance of T. - template - inline Holder* As( - typename std::enable_if< - (std::is_base_of::value || - std::is_base_of::value) && - (!std::is_same::value && - !std::is_same::value)>::type* = 0) { - // Holder data is an instance of subclass type T. - if (HolderIsOfType>() || HolderIsOfType>()) { - return static_cast*>(this); - } - - // Holder data is a generic proto Message/MessageLite and a subclass type T - // is requested. - if (HolderIsOfType>() || - HolderIsOfType>() || - HolderIsOfType>() || - HolderIsOfType>()) { - // TODO: Holder cannot be - // legally downcast to Holder, even though that downcast works in - // practice. Need to propose a better way to do the downcast. - Holder* holder = static_cast*>(this); - T tmp; - VLOG(2) << "Holder proto data type: " << holder->data().GetTypeName() - << " vs requested proto type: " << tmp.GetTypeName(); - if (tmp.GetTypeName() == holder->data().GetTypeName()) { - return holder; - } - } - - // Does not hold a T. - return nullptr; - } + Holder* As(); // Same as non-const As() function. template - inline const Holder* As( - typename std::enable_if< - (!std::is_base_of::value && - !std::is_base_of::value) || - (std::is_same::value || - std::is_same::value)>::type* = 0) const { - if (HolderIsOfType>() || HolderIsOfType>()) { - return static_cast*>(this); - } - // Does not hold a T. - return nullptr; - } - - // For proto Message/MessageLite subclasses. - // When holder data is a concrete proto, the method downcasts this to - // Holder if the requested type is what is stored. - // When holder data is a generic proto Message/MessageLite and a concrete - // proto type T is requested, the method will downcast the HolderBase to - // Holder if the proto data is an instance of T. - template - inline const Holder* As( - typename std::enable_if< - (std::is_base_of::value || - std::is_base_of::value) && - (!std::is_same::value && - !std::is_same::value)>::type* = 0) const { - if (HolderIsOfType>() || HolderIsOfType>()) { - return static_cast*>(this); - } - - // Holder data is a generic proto Message/MessageLite and a subclass type T - // is requested. - if (HolderIsOfType>() || - HolderIsOfType>() || - HolderIsOfType>() || - HolderIsOfType>()) { - // TODO: Holder cannot be - // legally downcast to Holder, even though that downcast works in - // practice. Need to propose a better way to do the downcast. - Holder* holder = static_cast*>(this); - T tmp; - VLOG(2) << "Holder proto data type: " << holder->data().GetTypeName() - << " vs requested proto type: " << tmp.GetTypeName(); - if (tmp.GetTypeName() == holder->data().GetTypeName()) { - return holder; - } - } - - // Does not hold a T. - return nullptr; - } + const Holder* As() const; // Returns the pointer to MessageLite type for the data in holder, if // underlying object is protocol buffer type, otherwise, nullptr is returned. @@ -520,12 +423,68 @@ ConvertToVectorOfProtoMessageLitePtrs(const T* data, return result; } +// This registry is used to create Holders of the right concrete C++ type given +// a proto type std::string (which is used as the registration key). +class MessageHolderRegistry + : public GlobalFactoryRegistry> {}; + +template +struct is_concrete_proto_t + : public std::integral_constant< + bool, std::is_base_of{} && + !std::is_same{} && + !std::is_same{}> {}; + +// Registers a message type. T must be a non-cv-qualified concrete proto type. +template +struct MessageRegistrationImpl { + static NoDestructor registration; +}; + +// Static members of template classes can be defined in the header. +template +NoDestructor + MessageRegistrationImpl::registration(MessageHolderRegistry::Register( + T{}.GetTypeName(), [] { return absl::make_unique>(new T); })); + +// For non-Message payloads, this does nothing. +template +struct HolderSupport { + static void EnsureStaticInit() {} +}; + +// This template ensures that, for each concrete MessageLite subclass that is +// stored in a Packet, we register a function that allows us to create a +// Holder with the correct payload type from the proto's type name. +template +struct HolderSupport{}>::type> { + // We must use std::remove_cv to ensure we don't try to register Foo twice if + // there are Holder and Holder. TODO: lift this + // up to Holder? + using R = MessageRegistrationImpl::type>; + // For the registration static member to be instantiated, it needs to be + // referenced in a context that requires the definition to exist (see ISO/IEC + // C++ 2003 standard, 14.7.1). Calling this ensures that's the case. + // We need two different call-sites to cover proto types for which packets + // are only ever created (i.e. the protos are only produced by calculators) + // and proto types for which packets are only ever consumed (i.e. the protos + // are only consumed by calculators). + static void EnsureStaticInit() { CHECK(R::registration.get() != nullptr); } +}; + template class Holder : public HolderBase { public: - explicit Holder(const T* ptr) : ptr_(ptr) { SetHolderTypeId(); } + explicit Holder(const T* ptr) : ptr_(ptr) { + HolderSupport::EnsureStaticInit(); + SetHolderTypeId(); + } ~Holder() override { delete_helper(); } - const T& data() const { return *ptr_; } + const T& data() const { + HolderSupport::EnsureStaticInit(); + return *ptr_; + } size_t GetTypeId() const final { return tool::GetTypeHash(); } // Releases the underlying data pointer and transfers the ownership to a // unique pointer. @@ -622,6 +581,24 @@ class ForeignHolder : public Holder { } }; +template +Holder* HolderBase::As() { + if (HolderIsOfType>() || HolderIsOfType>()) { + return static_cast*>(this); + } + // Does not hold a T. + return nullptr; +} + +template +const Holder* HolderBase::As() const { + if (HolderIsOfType>() || HolderIsOfType>()) { + return static_cast*>(this); + } + // Does not hold a T. + return nullptr; +} + } // namespace packet_internal inline Packet::Packet(const Packet& packet) diff --git a/mediapipe/framework/packet_registration_test.cc b/mediapipe/framework/packet_registration_test.cc new file mode 100644 index 000000000..25acc264c --- /dev/null +++ b/mediapipe/framework/packet_registration_test.cc @@ -0,0 +1,57 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/packet_test.pb.h" +#include "mediapipe/framework/port/core_proto_inc.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +namespace test_ns { + +class TestSinkCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Tag("IN").Set(); + cc->Outputs().Tag("OUT").Set(); + return ::mediapipe::OkStatus(); + } + + ::mediapipe::Status Process(CalculatorContext* cc) override { + int x = cc->Inputs().Tag("IN").Get().x(); + cc->Outputs().Tag("OUT").AddPacket( + MakePacket(x).At(cc->InputTimestamp())); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(::mediapipe::test_ns::TestSinkCalculator); + +} // namespace test_ns + +TEST(PacketTest, InputTypeRegistration) { + using testing::Contains; + ASSERT_EQ(mediapipe::InputOnlyProto{}.GetTypeName(), + "mediapipe.InputOnlyProto"); + EXPECT_THAT(packet_internal::MessageHolderRegistry::GetRegisteredNames(), + Contains("mediapipe.InputOnlyProto")); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/packet_test.cc b/mediapipe/framework/packet_test.cc index 44a817b26..039ccedf7 100644 --- a/mediapipe/framework/packet_test.cc +++ b/mediapipe/framework/packet_test.cc @@ -174,54 +174,13 @@ TEST(PacketTest, ReturnGenericProtobufMessage) { .x(0)); } -TEST(PacketTest, ReturnProtobufMessageSubType) { - std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr( - new ::mediapipe::PacketTestProto); - proto_ptr->add_x(123); - Packet packet = Adopt(static_cast(proto_ptr.release())); - EXPECT_EQ(123, packet.Get<::mediapipe::PacketTestProto>().x(0)); - EXPECT_EQ(123, packet.Get().x(0)); -} - TEST(PacketTest, TryWrongProtobufMessageSubType) { - // Packet of PacketTestProto. std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr( new ::mediapipe::PacketTestProto); proto_ptr->add_x(123); Packet packet = Adopt(proto_ptr.release()); EXPECT_FALSE(packet.ValidateAsType<::mediapipe::SimpleProto>().ok()); EXPECT_TRUE(packet.ValidateAsType<::mediapipe::PacketTestProto>().ok()); - - // Packet of proto_ns::Message. - proto_ptr.reset(new ::mediapipe::PacketTestProto); - proto_ptr->add_x(456); - Packet packet2 = Adopt(static_cast(proto_ptr.release())); - EXPECT_FALSE(packet2.ValidateAsType<::mediapipe::SimpleProto>().ok()); - EXPECT_TRUE(packet2.ValidateAsType<::mediapipe::PacketTestProto>().ok()); - EXPECT_EQ(123, packet.Get<::mediapipe::PacketTestProto>().x(0)); -} - -TEST(PacketTest, ReturnProtobufMessageLiteSubType) { - std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr( - new ::mediapipe::PacketTestProto); - proto_ptr->add_x(123); - Packet packet = - Adopt(static_cast(proto_ptr.release())); - EXPECT_EQ(123, packet.Get<::mediapipe::PacketTestProto>().x(0)); - EXPECT_EQ(123, packet.Get().x(0)); -} - -TEST(PacketTest, TryWrongProtobufMessageLiteSubType) { - // Packet of PacketTestProto. - std::unique_ptr<::mediapipe::PacketTestProto> proto_ptr( - new ::mediapipe::PacketTestProto); - // Packet of proto_ns::MessageLite. - proto_ptr->add_x(456); - Packet packet = - Adopt(static_cast(proto_ptr.release())); - EXPECT_FALSE(packet.ValidateAsType<::mediapipe::SimpleProto>().ok()); - EXPECT_TRUE(packet.ValidateAsType<::mediapipe::PacketTestProto>().ok()); - EXPECT_EQ(456, packet.Get<::mediapipe::PacketTestProto>().x(0)); } TEST(PacketTest, GetProtoBase) { @@ -505,5 +464,26 @@ TEST(PacketTest, TestConsumeOrCopyBoundedArray) { EXPECT_TRUE(packet2.IsEmpty()); } +TEST(PacketTest, MessageHolderRegistration) { + using testing::Contains; + Packet packet = MakePacket(); + ASSERT_EQ(mediapipe::SimpleProto{}.GetTypeName(), "mediapipe.SimpleProto"); + EXPECT_THAT(packet_internal::MessageHolderRegistry::GetRegisteredNames(), + Contains("mediapipe.SimpleProto")); +} + +TEST(PacketTest, PacketFromSerializedProto) { + mediapipe::SimpleProto original; + original.add_value("foo"); + std::string serialized = original.SerializeAsString(); + + StatusOr maybe_packet = packet_internal::PacketFromDynamicProto( + "mediapipe.SimpleProto", serialized); + MP_ASSERT_OK(maybe_packet); + Packet packet = maybe_packet.ValueOrDie(); + MP_EXPECT_OK(packet.ValidateAsType<::mediapipe::SimpleProto>()); + EXPECT_FALSE(packet.ValidateAsType<::mediapipe::PacketTestProto>().ok()); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/packet_test.proto b/mediapipe/framework/packet_test.proto index 8749ffcc5..bccfd6b5f 100644 --- a/mediapipe/framework/packet_test.proto +++ b/mediapipe/framework/packet_test.proto @@ -39,3 +39,9 @@ message SerializationProxyProto { repeated float float_value = 2; repeated string string_value = 3; } + +// This proto should be used only as an input to a calculator, to verify that +// that case is covered. +message InputOnlyProto { + optional int32 x = 1; +} diff --git a/mediapipe/framework/port.h b/mediapipe/framework/port.h index fee918a23..bd5639599 100644 --- a/mediapipe/framework/port.h +++ b/mediapipe/framework/port.h @@ -46,7 +46,7 @@ // but may or may not still be able to run other OpenGL code. #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) && \ (defined(__APPLE__) || defined(__EMSCRIPTEN__) || \ - defined(MEDIAPIPE_DISABLE_GPU)) + defined(MEDIAPIPE_DISABLE_GPU) || MEDIAPIPE_USING_SWIFTSHADER) #define MEDIAPIPE_DISABLE_GL_COMPUTE #endif diff --git a/mediapipe/framework/profiler/graph_tracer_test.cc b/mediapipe/framework/profiler/graph_tracer_test.cc index 70ee39873..13b5522bb 100644 --- a/mediapipe/framework/profiler/graph_tracer_test.cc +++ b/mediapipe/framework/profiler/graph_tracer_test.cc @@ -143,29 +143,29 @@ TEST_F(GraphTracerTest, CalculatorTrace) { {{MakePacket("goodbye").At(start_timestamp_)}}); // Validate the GraphTrace data. - EXPECT_THAT(GetTrace(), - EqualsProto(::mediapipe::ParseTextProtoOrDie(R"( - base_time: 1608911100000000 - base_timestamp: 1608911100000000 - stream_name: "" - stream_name: "input_stream" - stream_name: "output_stream" - calculator_trace { - node_id: 0 - input_timestamp: 0 - event_type: PROCESS - start_time: 0 - finish_time: 10000 - thread_id: 0 - input_trace { - finish_time: 0 - packet_timestamp: 0 - stream_id: 1 - event_data: 1 - } - output_trace { packet_timestamp: 0 stream_id: 2 } - } - )"))); + EXPECT_THAT( + GetTrace(), EqualsProto(::mediapipe::ParseTextProtoOrDie(R"( + base_time: 1608911100000000 + base_timestamp: 1608911100000000 + stream_name: "" + stream_name: "input_stream" + stream_name: "output_stream" + calculator_trace { + node_id: 0 + input_timestamp: 0 + event_type: PROCESS + start_time: 0 + finish_time: 10000 + thread_id: 0 + input_trace { + finish_time: 0 + packet_timestamp: 0 + stream_id: 1 + event_data: 1 + } + output_trace { packet_timestamp: 0 stream_id: 2 event_data: 2 } + } + )"))); } TEST_F(GraphTracerTest, GraphTrace) { @@ -205,92 +205,101 @@ TEST_F(GraphTracerTest, GraphTrace) { LogOutputPackets("PCalculator_3", GraphTrace::PROCESS, curr_time, {{MakePacket("out").At(start_timestamp_)}}); curr_time += absl::Microseconds(2000); - ClearCalculatorContext("PCalculator_3"); - LogInputPackets("PCalculator_3", GraphTrace::PROCESS, curr_time, + + // Note: the packet data ID is based on the packet's payload address, which + // means the same ID can be reused if data is allocated in the same location + // as a previously expired packet (b/160212191). This means the generated + // trace can change depending on the allocator. To keep results stable, we + // must keep the packets used in this test alive until the end. Each + // TestContextBuilder happens to keep a reference to all packets for the last + // context, so for now we just create a separate TestContextBuilder instead of + // clearing it. TODO: revise this test. + SetUpCalculatorContext("PCalculator_3a", /*node_id=*/2, {"up_2"}, {"down_2"}); + LogInputPackets("PCalculator_3a", GraphTrace::PROCESS, curr_time, {MakePacket("pup").At(start_timestamp_ + 5)}); curr_time += absl::Microseconds(20000); LogOutputPackets( - "PCalculator_3", GraphTrace::PROCESS, curr_time, + "PCalculator_3a", GraphTrace::PROCESS, curr_time, {{MakePacket("pout").At(start_timestamp_ + 5)}}); curr_time += absl::Microseconds(1000); // Validate the GraphTrace data. - EXPECT_THAT(GetTrace(), - EqualsProto(::mediapipe::ParseTextProtoOrDie(R"( - base_time: 1608911100000000 - base_timestamp: 1608911100000000 - stream_name: "" - stream_name: "input_stream" - stream_name: "up_1" - stream_name: "up_2" - stream_name: "down_1" - stream_name: "down_2" - calculator_trace { - node_id: 0 - input_timestamp: 0 - event_type: PROCESS - start_time: 0 - finish_time: 10000 - thread_id: 0 - input_trace { - finish_time: 0 - packet_timestamp: 0 - stream_id: 1 - event_data: 1 - } - output_trace { packet_timestamp: 0 stream_id: 2 } - output_trace { packet_timestamp: 0 stream_id: 3 } - output_trace { packet_timestamp: 5 stream_id: 3 } - } - calculator_trace { - node_id: 1 - input_timestamp: 0 - event_type: PROCESS - start_time: 11000 - finish_time: 21000 - thread_id: 0 - input_trace { - start_time: 10000 - finish_time: 11000 - packet_timestamp: 0 - stream_id: 2 - event_data: 2 - } - output_trace { packet_timestamp: 0 stream_id: 4 } - } - calculator_trace { - node_id: 2 - input_timestamp: 0 - event_type: PROCESS - start_time: 16000 - finish_time: 36000 - thread_id: 0 - input_trace { - start_time: 10000 - finish_time: 16000 - packet_timestamp: 0 - stream_id: 3 - event_data: 3 - } - output_trace { packet_timestamp: 0 stream_id: 5 } - } - calculator_trace { - node_id: 2 - input_timestamp: 5 - event_type: PROCESS - start_time: 38000 - finish_time: 58000 - thread_id: 0 - input_trace { - start_time: 10000 - finish_time: 38000 - packet_timestamp: 5 - stream_id: 3 - event_data: 4 - } - output_trace { packet_timestamp: 5 stream_id: 5 } - } - )"))); + EXPECT_THAT( + GetTrace(), EqualsProto(::mediapipe::ParseTextProtoOrDie(R"( + base_time: 1608911100000000 + base_timestamp: 1608911100000000 + stream_name: "" + stream_name: "input_stream" + stream_name: "up_1" + stream_name: "up_2" + stream_name: "down_1" + stream_name: "down_2" + calculator_trace { + node_id: 0 + input_timestamp: 0 + event_type: PROCESS + start_time: 0 + finish_time: 10000 + thread_id: 0 + input_trace { + finish_time: 0 + packet_timestamp: 0 + stream_id: 1 + event_data: 1 + } + output_trace { packet_timestamp: 0 stream_id: 2 event_data: 2 } + output_trace { packet_timestamp: 0 stream_id: 3 event_data: 3 } + output_trace { packet_timestamp: 5 stream_id: 3 event_data: 4 } + } + calculator_trace { + node_id: 1 + input_timestamp: 0 + event_type: PROCESS + start_time: 11000 + finish_time: 21000 + thread_id: 0 + input_trace { + start_time: 10000 + finish_time: 11000 + packet_timestamp: 0 + stream_id: 2 + event_data: 5 + } + output_trace { packet_timestamp: 0 stream_id: 4 event_data: 6 } + } + calculator_trace { + node_id: 2 + input_timestamp: 0 + event_type: PROCESS + start_time: 16000 + finish_time: 36000 + thread_id: 0 + input_trace { + start_time: 10000 + finish_time: 16000 + packet_timestamp: 0 + stream_id: 3 + event_data: 7 + } + output_trace { packet_timestamp: 0 stream_id: 5 event_data: 8 } + } + calculator_trace { + node_id: 2 + input_timestamp: 5 + event_type: PROCESS + start_time: 38000 + finish_time: 58000 + thread_id: 0 + input_trace { + start_time: 10000 + finish_time: 38000 + packet_timestamp: 5 + stream_id: 3 + event_data: 9 + } + output_trace { packet_timestamp: 5 stream_id: 5 event_data: 10 } + } + )"))); // No timestamps are completed before start_time_. // One timestamp is completed before start_time_ + 10ms. @@ -1275,37 +1284,39 @@ TEST_F(GraphTracerE2ETest, GpuTaskTrace) { GraphTrace trace_1; builder.CreateTrace(buffer, absl::InfinitePast(), absl::InfiniteFuture(), &trace_1); - EXPECT_THAT(trace_1, EqualsProto(::mediapipe::ParseTextProtoOrDie( - R"( - base_time: 1100 - base_timestamp: 1000 - stream_name: "" - stream_name: "stream_1" - stream_name: "stream_2" - calculator_trace { - node_id: 333 - input_timestamp: 0 - event_type: PROCESS - start_time: 0 - finish_time: 1000 - input_trace { - finish_time: 0 - packet_timestamp: 0 - stream_id: 1 - event_data: 0 - } - output_trace { packet_timestamp: 0 stream_id: 2 } - thread_id: 0 - } - calculator_trace { - node_id: 333 - input_timestamp: 0 - event_type: GPU_TASK - start_time: 100 - finish_time: 2100 - thread_id: 0 - } - )"))); + EXPECT_THAT( + trace_1, + EqualsProto(::mediapipe::ParseTextProtoOrDie( + R"( + base_time: 1100 + base_timestamp: 1000 + stream_name: "" + stream_name: "stream_1" + stream_name: "stream_2" + calculator_trace { + node_id: 333 + input_timestamp: 0 + event_type: PROCESS + start_time: 0 + finish_time: 1000 + input_trace { + finish_time: 0 + packet_timestamp: 0 + stream_id: 1 + event_data: 0 + } + output_trace { packet_timestamp: 0 stream_id: 2 event_data: 0 } + thread_id: 0 + } + calculator_trace { + node_id: 333 + input_timestamp: 0 + event_type: GPU_TASK + start_time: 100 + finish_time: 2100 + thread_id: 0 + } + )"))); GraphTrace trace_2; builder.CreateLog(buffer, absl::InfinitePast(), absl::InfiniteFuture(), diff --git a/mediapipe/framework/profiler/trace_builder.cc b/mediapipe/framework/profiler/trace_builder.cc index ff20f9c91..ce6c6c57c 100644 --- a/mediapipe/framework/profiler/trace_builder.cc +++ b/mediapipe/framework/profiler/trace_builder.cc @@ -330,13 +330,12 @@ class TraceBuilder::Impl { if (trace_event_registry_[event->event_type].is_stream_event()) { auto stream_trace = event->is_finish ? result->add_output_trace() : result->add_input_trace(); - if (event->is_finish) { - // Log only the packet id for each output event. - stream_trace->set_stream_id(stream_id_map_[event->stream_id]); - stream_trace->set_packet_timestamp(LogTimestamp(event->packet_ts)); - } else { - // Log the full stream trace for each input event. - BuildStreamTrace(*event, stream_trace); + BuildStreamTrace(*event, stream_trace); + if (!event->is_finish) { + // Note: is_finish is true for output events, false for input events. + // For input events, we log some additional timing information. The + // finish_time is the start_time of this Process call, the start_time + // is the finish_time of the Process call that output the packet. stream_trace->set_finish_time(LogTime(event->event_time)); const TraceEvent* output_event = FindOutputEvent(*event); if (output_event) { diff --git a/mediapipe/framework/stream_handler/immediate_input_stream_handler.cc b/mediapipe/framework/stream_handler/immediate_input_stream_handler.cc index be21f4180..60d6ceb19 100644 --- a/mediapipe/framework/stream_handler/immediate_input_stream_handler.cc +++ b/mediapipe/framework/stream_handler/immediate_input_stream_handler.cc @@ -116,10 +116,19 @@ NodeReadiness ImmediateInputStreamHandler::GetNodeReadiness( CHECK_EQ(stream_ts, Timestamp::Done()); if (ProcessTimestampBounds()) { // With kReadyForClose, the timestamp-bound Done is returned. - // This bound is processed using the preceding input-timestamp. // TODO: Make all InputStreamHandlers process Done() like this. - ready_timestamps_[i] = stream_ts.PreviousAllowedInStream(); - input_timestamp = std::min(input_timestamp, ready_timestamps_[i]); + static const Timestamp kDonePrecedingTimestamp = + Timestamp::Done().PreviousAllowedInStream(); + if (prev_ts < kDonePrecedingTimestamp) { + // When kReadyForClose is received for the first time for a sync set, + // it is processed using the timestamp preceding Done() to indicate + // input stream is done, but still needs to be processed. + min_bound = std::min(min_bound, kDonePrecedingTimestamp); + input_timestamp = std::min(input_timestamp, kDonePrecedingTimestamp); + ready_timestamps_[i] = kDonePrecedingTimestamp; + } else { + ready_timestamps_[i] = Timestamp::Done(); + } } else if (prev_ts < Timestamp::Done()) { stream_became_done = true; ready_timestamps_[i] = Timestamp::Done(); diff --git a/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc b/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc index 399344384..0c14fcdbe 100644 --- a/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc +++ b/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc @@ -133,6 +133,11 @@ class ImmediateInputStreamHandlerTest : public ::testing::Test { } } + const InputStream& Input(const CollectionItemId& id) { + CHECK(cc_); + return cc_->Inputs().Get(id); + } + PacketType packet_type_; std::function headers_ready_callback_; std::function notification_callback_; @@ -262,6 +267,344 @@ TEST_F(ImmediateInputStreamHandlerTest, ReadyForClose) { EXPECT_TRUE(errors_.empty()); } +TEST_F(ImmediateInputStreamHandlerTest, ProcessTimestampBounds) { + input_stream_handler_->SetProcessTimestampBounds(true); + + Timestamp min_stream_timestamp; + ASSERT_FALSE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(min_stream_timestamp, Timestamp::PreStream()); + + const auto& input_a_id = name_to_id_["input_a"]; + const auto& input_b_id = name_to_id_["input_b"]; + const auto& input_c_id = name_to_id_["input_c"]; + + std::list packets; + packets.push_back(Adopt(new std::string("packet 1")).At(Timestamp(1))); + input_stream_handler_->AddPackets(input_b_id, packets); + input_stream_handler_->SetNextTimestampBound(input_b_id, Timestamp::Done()); + + ASSERT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(min_stream_timestamp, Timestamp::Unset()); + EXPECT_EQ(cc_->InputTimestamp(), Timestamp(1)); + ExpectPackets(cc_->Inputs(), {{"input_b", "packet 1"}}); + EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(1)); + EXPECT_TRUE(Input(input_a_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unstarted()); + EXPECT_TRUE(Input(input_c_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unstarted()); + + // FinalizeInputSet() is a no-op. + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); + + ASSERT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(min_stream_timestamp, Timestamp::Unset()); + EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max()); + EXPECT_TRUE(Input(input_b_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Max()); + EXPECT_TRUE(Input(input_a_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unstarted()); + EXPECT_TRUE(Input(input_c_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unstarted()); + + // FinalizeInputSet() is a no-op. + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); + + EXPECT_TRUE( + input_stream_handler_->GetInputStreamManager(input_b_id)->IsEmpty()); + + input_stream_handler_->SetNextTimestampBound(input_a_id, Timestamp::Done()); + input_stream_handler_->SetNextTimestampBound(input_c_id, Timestamp::Done()); + + ASSERT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(min_stream_timestamp, Timestamp::Unset()); + EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max()); + EXPECT_TRUE(Input(input_b_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Max()); + EXPECT_TRUE(Input(input_a_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Max()); + EXPECT_TRUE(Input(input_c_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Max()); + + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); + EXPECT_TRUE(errors_.empty()); + + // Schedule invocation for Close. + ASSERT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(min_stream_timestamp, Timestamp::Unset()); + EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Done()); + EXPECT_TRUE(Input(input_b_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset()); + EXPECT_TRUE(Input(input_a_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset()); + EXPECT_TRUE(Input(input_c_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset()); + + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); + EXPECT_TRUE(errors_.empty()); +} + +TEST_F(ImmediateInputStreamHandlerTest, + ProcessTimestampBoundsNoOpScheduleInvocations) { + input_stream_handler_->SetProcessTimestampBounds(true); + + const auto& input_a_id = name_to_id_["input_a"]; + const auto& input_b_id = name_to_id_["input_b"]; + const auto& input_c_id = name_to_id_["input_c"]; + + Timestamp min_stream_timestamp; + std::list packets; + packets.push_back(Adopt(new std::string("packet 1")).At(Timestamp(1))); + input_stream_handler_->AddPackets(input_b_id, packets); + ASSERT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(min_stream_timestamp, Timestamp::Unset()); + EXPECT_EQ(cc_->InputTimestamp(), Timestamp(1)); + ExpectPackets(cc_->Inputs(), {{"input_b", "packet 1"}}); + EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(1)); + EXPECT_TRUE(Input(input_a_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unstarted()); + EXPECT_TRUE(Input(input_c_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unstarted()); + + // FinalizeInputSet() is a no-op. + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); + + input_stream_handler_->SetNextTimestampBound(input_a_id, Timestamp::Done()); + input_stream_handler_->SetNextTimestampBound(input_c_id, Timestamp::Done()); + + ASSERT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(min_stream_timestamp, Timestamp::Unset()); + EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max()); + EXPECT_TRUE(Input(input_b_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(1)); + EXPECT_TRUE(Input(input_a_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Max()); + EXPECT_TRUE(Input(input_c_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Max()); + + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); + EXPECT_TRUE(errors_.empty()); + + // Try to schedule invocations several times again. Considering nothing + // changed since last invocation nothing should be scheduled. + for (int i = 0; i < 3; ++i) { + ASSERT_FALSE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(min_stream_timestamp, Timestamp(2)); + EXPECT_TRUE(Input(input_b_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset()); + EXPECT_TRUE(Input(input_a_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset()); + EXPECT_TRUE(Input(input_c_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset()); + } + + input_stream_handler_->SetNextTimestampBound(input_b_id, Timestamp::Done()); + + ASSERT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(min_stream_timestamp, Timestamp::Unset()); + EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max()); + EXPECT_TRUE(Input(input_b_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Max()); + EXPECT_TRUE(Input(input_a_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Max()); + EXPECT_TRUE(Input(input_c_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Max()); + + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); + EXPECT_TRUE(errors_.empty()); + + // Schedule invocation for Close. + ASSERT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(min_stream_timestamp, Timestamp::Unset()); + EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Done()); + EXPECT_TRUE(Input(input_b_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset()); + EXPECT_TRUE(Input(input_a_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset()); + EXPECT_TRUE(Input(input_c_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset()); + + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); + EXPECT_TRUE(errors_.empty()); + + // Try to schedule invocations several times again. Considering nothing + // changed since last invocation nothing should be scheduled. + for (int i = 0; i < 3; ++i) { + ASSERT_FALSE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(min_stream_timestamp, Timestamp::Unset()); + EXPECT_TRUE(Input(input_b_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset()); + EXPECT_TRUE(Input(input_a_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset()); + EXPECT_TRUE(Input(input_c_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset()); + } +} + +// Due to some temporary changes in ImmediateInputStreamHandler some packets +// - were queued but never released +// - were released in incorrect order +// As other test cases were passing, this test case is designed to ensure that. +TEST_F(ImmediateInputStreamHandlerTest, VerifyPacketsReleaseOrder) { + input_stream_handler_->SetProcessTimestampBounds(true); + + const auto& input_a_id = name_to_id_["input_a"]; + const auto& input_b_id = name_to_id_["input_b"]; + const auto& input_c_id = name_to_id_["input_c"]; + + Packet packet_a = Adopt(new std::string("packet a")); + Packet packet_b = Adopt(new std::string("packet b")); + Packet packet_c = Adopt(new std::string("packet c")); + input_stream_handler_->AddPackets(input_a_id, {packet_a.At(Timestamp(1))}); + input_stream_handler_->AddPackets(input_b_id, {packet_b.At(Timestamp(2))}); + input_stream_handler_->AddPackets(input_c_id, {packet_c.At(Timestamp(3))}); + + Timestamp min_stream_timestamp; + ASSERT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(min_stream_timestamp, Timestamp::Unset()); + EXPECT_EQ(cc_->InputTimestamp(), Timestamp(1)); + ASSERT_FALSE(Input(input_a_id).IsEmpty()); + EXPECT_EQ(Input(input_a_id).Get(), "packet a"); + EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp(1)); + EXPECT_TRUE(Input(input_b_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(1)); + EXPECT_TRUE(Input(input_c_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp(2)); + + // FinalizeInputSet() is a no-op. + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); + + input_stream_handler_->AddPackets(input_a_id, {packet_a.At(Timestamp(5))}); + input_stream_handler_->AddPackets(input_b_id, {packet_b.At(Timestamp(5))}); + input_stream_handler_->AddPackets(input_c_id, {packet_c.At(Timestamp(5))}); + + ASSERT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(min_stream_timestamp, Timestamp::Unset()); + EXPECT_EQ(cc_->InputTimestamp(), Timestamp(2)); + EXPECT_TRUE(Input(input_a_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp(4)); + ASSERT_FALSE(Input(input_b_id).IsEmpty()); + EXPECT_EQ(Input(input_b_id).Get(), "packet b"); + EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(2)); + EXPECT_TRUE(Input(input_c_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp(2)); + + // FinalizeInputSet() is a no-op. + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); + + ASSERT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(min_stream_timestamp, Timestamp::Unset()); + EXPECT_EQ(cc_->InputTimestamp(), Timestamp(3)); + EXPECT_TRUE(Input(input_a_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp(4)); + EXPECT_TRUE(Input(input_b_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(4)); + ASSERT_FALSE(Input(input_c_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_c_id).Get(), "packet c"); + EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp(3)); + + // FinalizeInputSet() is a no-op. + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); + + ASSERT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(min_stream_timestamp, Timestamp::Unset()); + EXPECT_EQ(cc_->InputTimestamp(), Timestamp(5)); + ASSERT_FALSE(Input(input_a_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_a_id).Get(), "packet a"); + EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp(5)); + ASSERT_FALSE(Input(input_b_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_b_id).Get(), "packet b"); + EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp(5)); + ASSERT_FALSE(Input(input_c_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_c_id).Get(), "packet c"); + EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp(5)); + + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); + + input_stream_handler_->SetNextTimestampBound(input_a_id, Timestamp::Done()); + input_stream_handler_->SetNextTimestampBound(input_b_id, Timestamp::Done()); + input_stream_handler_->SetNextTimestampBound(input_c_id, Timestamp::Done()); + + ASSERT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(min_stream_timestamp, Timestamp::Unset()); + EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Max()); + EXPECT_TRUE(Input(input_b_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Max()); + EXPECT_TRUE(Input(input_a_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Max()); + EXPECT_TRUE(Input(input_c_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Max()); + + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); + + // Schedule invocation for Close. + ASSERT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(min_stream_timestamp, Timestamp::Unset()); + EXPECT_EQ(cc_->InputTimestamp(), Timestamp::Done()); + EXPECT_TRUE(Input(input_b_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset()); + EXPECT_TRUE(Input(input_a_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset()); + EXPECT_TRUE(Input(input_c_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset()); + + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); + + ASSERT_FALSE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(min_stream_timestamp, Timestamp::Unset()); + EXPECT_TRUE(Input(input_b_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_b_id).Value().Timestamp(), Timestamp::Unset()); + EXPECT_TRUE(Input(input_a_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_a_id).Value().Timestamp(), Timestamp::Unset()); + EXPECT_TRUE(Input(input_c_id).Value().IsEmpty()); + EXPECT_EQ(Input(input_c_id).Value().Timestamp(), Timestamp::Unset()); +} + // This test simulates how CalculatorNode::ProcessNode() uses an input // stream handler and the associated input streams. TEST_F(ImmediateInputStreamHandlerTest, SimulateProcessNode) { diff --git a/mediapipe/framework/test_calculators.cc b/mediapipe/framework/test_calculators.cc index 8bcf59baf..000a6301f 100644 --- a/mediapipe/framework/test_calculators.cc +++ b/mediapipe/framework/test_calculators.cc @@ -641,4 +641,61 @@ class DummyTestCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(DummyTestCalculator); +// A Calculator that passes the input value to the output after sleeping for +// a set number of microseconds. +class PassThroughWithSleepCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + cc->InputSidePackets().Tag("SLEEP_MICROS").Set(); + cc->InputSidePackets().Tag("CLOCK").Set>(); + return ::mediapipe::OkStatus(); + } + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->SetOffset(TimestampDiff(0)); + sleep_micros_ = cc->InputSidePackets().Tag("SLEEP_MICROS").Get(); + if (sleep_micros_ < 0) { + return ::mediapipe::InternalError("SLEEP_MICROS should be >= 0"); + } + clock_ = cc->InputSidePackets().Tag("CLOCK").Get>(); + return ::mediapipe::OkStatus(); + } + ::mediapipe::Status Process(CalculatorContext* cc) final { + clock_->Sleep(absl::Microseconds(sleep_micros_)); + int value = cc->Inputs().Index(0).Value().Get(); + cc->Outputs().Index(0).Add(new int(value), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); + } + + private: + int sleep_micros_ = 0; + std::shared_ptr clock_; +}; +REGISTER_CALCULATOR(PassThroughWithSleepCalculator); + +// A Calculator that multiples two input values. +class MultiplyIntCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Inputs().Index(1).SetSameAs(&cc->Inputs().Index(0)); + // cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + RET_CHECK(cc->Outputs().HasTag("OUT")); + cc->Outputs().Tag("OUT").SetSameAs(&cc->Inputs().Index(0)); + return ::mediapipe::OkStatus(); + } + ::mediapipe::Status Open(CalculatorContext* cc) final { + cc->SetOffset(TimestampDiff(0)); + return ::mediapipe::OkStatus(); + } + ::mediapipe::Status Process(CalculatorContext* cc) final { + int x = cc->Inputs().Index(0).Value().Get(); + int y = cc->Inputs().Index(1).Value().Get(); + cc->Outputs().Tag("OUT").Add(new int(x * y), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); + } +}; +REGISTER_CALCULATOR(MultiplyIntCalculator); + } // namespace mediapipe diff --git a/mediapipe/framework/tool/name_util.cc b/mediapipe/framework/tool/name_util.cc index 0f7cceea8..7aa4965ee 100644 --- a/mediapipe/framework/tool/name_util.cc +++ b/mediapipe/framework/tool/name_util.cc @@ -101,6 +101,13 @@ std::string ParseNameFromStream(const std::string& stream) { return name; } +std::pair ParseTagIndex(const std::string& tag_index) { + std::string tag; + int index; + MEDIAPIPE_CHECK_OK(tool::ParseTagIndex(tag_index, &tag, &index)); + return {tag, index}; +} + std::pair ParseTagIndexFromStream(const std::string& stream) { std::string tag, name; int index; diff --git a/mediapipe/framework/tool/name_util.h b/mediapipe/framework/tool/name_util.h index dbab0ed29..69885ae47 100644 --- a/mediapipe/framework/tool/name_util.h +++ b/mediapipe/framework/tool/name_util.h @@ -76,6 +76,9 @@ std::string CanonicalNodeName(const CalculatorGraphConfig& graph_config, // Parses the name from a "tag:index:name". std::string ParseNameFromStream(const std::string& stream); +// Parses the TagIndex from a "tag:index". +std::pair ParseTagIndex(const std::string& tag_index); + // Parses the TagIndex from a "tag:index:name". std::pair ParseTagIndexFromStream(const std::string& stream); diff --git a/mediapipe/framework/tool/testdata/BUILD b/mediapipe/framework/tool/testdata/BUILD index fd2c8ae5c..4227b11f9 100644 --- a/mediapipe/framework/tool/testdata/BUILD +++ b/mediapipe/framework/tool/testdata/BUILD @@ -13,15 +13,15 @@ # limitations under the License. # -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//mediapipe:__subpackages__"]) - load( "//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_simple_subgraph", ) +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//mediapipe:__subpackages__"]) + filegroup( name = "test_graph", srcs = ["test.pbtxt"], @@ -31,6 +31,8 @@ exports_files([ "test.pbtxt", "dub_quad_test_subgraph.pbtxt", "nested_test_subgraph.pbtxt", + "single_flow_container_test.pbtxt", + "dual_flow_container_test.pbtxt", ]) mediapipe_simple_subgraph( diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index fcff29e08..b95c9e0c4 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//visibility:public"]) - load("//mediapipe/gpu:metal.bzl", "metal_library") load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + # Disabling GPU support is sometimes useful on desktop Linux because SwiftShader can # interfere with desktop GL. b/73494271 config_setting( diff --git a/mediapipe/gpu/gl_scaler_calculator.cc b/mediapipe/gpu/gl_scaler_calculator.cc index 8729f4d8e..d8d90a524 100644 --- a/mediapipe/gpu/gl_scaler_calculator.cc +++ b/mediapipe/gpu/gl_scaler_calculator.cc @@ -39,6 +39,7 @@ namespace mediapipe { // ROTATION: the counterclockwise rotation angle in degrees. This allows // user to specify different rotation angles for different frames. If this // stream is provided, it will override the ROTATION input side packet. +// OUTPUT_DIMENSIONS: the output width and height in pixels. // Additional output streams: // TOP_BOTTOM_PADDING: If use FIT scale mode, this stream outputs the padding // size of the input image in normalized value [0, 1] for top and bottom @@ -103,6 +104,9 @@ REGISTER_CALCULATOR(GlScalerCalculator); if (cc->Inputs().HasTag("ROTATION")) { cc->Inputs().Tag("ROTATION").Set(); } + if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS")) { + cc->Inputs().Tag("OUTPUT_DIMENSIONS").Set(); + } MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); if (cc->InputSidePackets().HasTag("OPTIONS")) { @@ -181,6 +185,18 @@ REGISTER_CALCULATOR(GlScalerCalculator); } ::mediapipe::Status GlScalerCalculator::Process(CalculatorContext* cc) { + if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS")) { + if (cc->Inputs().Tag("OUTPUT_DIMENSIONS").IsEmpty()) { + // OUTPUT_DIMENSIONS input stream is specified, but value is missing. + return ::mediapipe::OkStatus(); + } + + const auto& dimensions = + cc->Inputs().Tag("OUTPUT_DIMENSIONS").Get(); + dst_width_ = dimensions[0]; + dst_height_ = dimensions[1]; + } + return helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { const auto& input = TagOrIndex(cc->Inputs(), "VIDEO", 0).Get(); QuadRenderer* renderer = nullptr; @@ -199,7 +215,7 @@ REGISTER_CALCULATOR(GlScalerCalculator); src1 = helper_.CreateSourceTexture(input, 0); src2 = helper_.CreateSourceTexture(input, 1); } else // NOLINT(readability/braces) -#endif // __APPLE__ +#endif // __APPLE__ { src1 = helper_.CreateSourceTexture(input); #ifdef __ANDROID__ @@ -211,7 +227,7 @@ REGISTER_CALCULATOR(GlScalerCalculator); } renderer = ext_rgb_renderer_.get(); } else // NOLINT(readability/braces) -#endif // __ANDROID__ +#endif // __ANDROID__ { if (!rgb_renderer_) { rgb_renderer_ = absl::make_unique(); diff --git a/mediapipe/graphs/hand_tracking/subgraphs/hand_landmark_cpu.pbtxt b/mediapipe/graphs/hand_tracking/subgraphs/hand_landmark_cpu.pbtxt index 9fa417510..50881e8a7 100644 --- a/mediapipe/graphs/hand_tracking/subgraphs/hand_landmark_cpu.pbtxt +++ b/mediapipe/graphs/hand_tracking/subgraphs/hand_landmark_cpu.pbtxt @@ -140,6 +140,9 @@ node { num_landmarks: 21 input_image_width: 256 input_image_height: 256 + # The additional scaling factor is used to account for the Z coordinate + # distribution in the training data. + normalize_z: 0.4 } } } diff --git a/mediapipe/graphs/hand_tracking/subgraphs/hand_landmark_gpu.pbtxt b/mediapipe/graphs/hand_tracking/subgraphs/hand_landmark_gpu.pbtxt index 690f250f3..fa0a00f2c 100644 --- a/mediapipe/graphs/hand_tracking/subgraphs/hand_landmark_gpu.pbtxt +++ b/mediapipe/graphs/hand_tracking/subgraphs/hand_landmark_gpu.pbtxt @@ -144,6 +144,9 @@ node { num_landmarks: 21 input_image_width: 256 input_image_height: 256 + # The additional scaling factor is used to account for the Z coordinate + # distribution in the training data. + normalize_z: 0.4 } } } diff --git a/mediapipe/java/com/google/mediapipe/components/BUILD b/mediapipe/java/com/google/mediapipe/components/BUILD index a3eca9b3a..2e0c32df8 100644 --- a/mediapipe/java/com/google/mediapipe/components/BUILD +++ b/mediapipe/java/com/google/mediapipe/components/BUILD @@ -25,6 +25,7 @@ android_library( ), visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/glutil", "//third_party:androidx_appcompat", diff --git a/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java b/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java index ce155b754..38cc5e91a 100644 --- a/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java +++ b/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java @@ -14,17 +14,21 @@ package com.google.mediapipe.components; +import static java.lang.Math.max; + import android.graphics.SurfaceTexture; import android.opengl.GLES11Ext; import android.opengl.GLES20; import android.util.Log; import com.google.mediapipe.framework.AppTextureFrame; +import com.google.mediapipe.framework.GlSyncToken; import com.google.mediapipe.glutil.ExternalTextureRenderer; import com.google.mediapipe.glutil.GlThread; import com.google.mediapipe.glutil.ShaderUtil; +import java.util.ArrayDeque; import java.util.ArrayList; -import java.util.Collections; import java.util.List; +import java.util.Queue; import javax.microedition.khronos.egl.EGLContext; /** @@ -204,8 +208,11 @@ public class ExternalTextureConverter implements TextureFrameProducer { private static final long NANOS_PER_MICRO = 1000; // Nanoseconds in one microsecond. private volatile SurfaceTexture surfaceTexture = null; private final List consumers; - private List outputFrames = null; - private int outputFrameIndex = -1; + + private final Queue framesAvailable = new ArrayDeque<>(); + private int framesInUse = 0; + private final int framesToKeep; + private ExternalTextureRenderer renderer = null; private long nextFrameTimestampOffset = 0; private long timestampOffsetNanos = 0; @@ -215,10 +222,27 @@ public class ExternalTextureConverter implements TextureFrameProducer { protected int destinationWidth = 0; protected int destinationHeight = 0; + private class PoolTextureFrame extends AppTextureFrame { + public PoolTextureFrame(int textureName, int width, int height) { + super(textureName, width, height); + } + + @Override + public void release(GlSyncToken syncToken) { + super.release(syncToken); + poolFrameReleased(this); + } + + @Override + public void release() { + super.release(); + poolFrameReleased(this); + } + } + public RenderThread(EGLContext parentContext, int numBuffers) { super(parentContext); - outputFrames = new ArrayList<>(); - outputFrames.addAll(Collections.nCopies(numBuffers, null)); + framesToKeep = numBuffers; renderer = new ExternalTextureRenderer(); consumers = new ArrayList<>(); } @@ -283,8 +307,8 @@ public class ExternalTextureConverter implements TextureFrameProducer { @Override public void releaseGl() { setSurfaceTexture(null, 0, 0); - for (int i = 0; i < outputFrames.size(); ++i) { - teardownDestination(i); + while (!framesAvailable.isEmpty()) { + teardownFrame(framesAvailable.remove()); } renderer.release(); super.releaseGl(); // This releases the EGL context, so must do it after any GL calls. @@ -337,16 +361,11 @@ public class ExternalTextureConverter implements TextureFrameProducer { } } - private void teardownDestination(int index) { - if (outputFrames.get(index) != null) { - waitUntilReleased(outputFrames.get(index)); - GLES20.glDeleteTextures(1, new int[] {outputFrames.get(index).getTextureName()}, 0); - outputFrames.set(index, null); - } + private static void teardownFrame(AppTextureFrame frame) { + GLES20.glDeleteTextures(1, new int[] {frame.getTextureName()}, 0); } - private void setupDestination(int index) { - teardownDestination(index); + private PoolTextureFrame createFrame() { int destinationTextureId = ShaderUtil.createRgbaTexture(destinationWidth, destinationHeight); Log.d( TAG, @@ -354,11 +373,9 @@ public class ExternalTextureConverter implements TextureFrameProducer { "Created output texture: %d width: %d height: %d", destinationTextureId, destinationWidth, destinationHeight)); bindFramebuffer(destinationTextureId, destinationWidth, destinationHeight); - outputFrames.set( - index, new AppTextureFrame(destinationTextureId, destinationWidth, destinationHeight)); + return new PoolTextureFrame(destinationTextureId, destinationWidth, destinationHeight); } - /** * Gets next available frame or creates new one if next frame is not initialized * or cannot be used with current surface texture. @@ -371,20 +388,38 @@ public class ExternalTextureConverter implements TextureFrameProducer { * NOTE: must be invoked on GL thread */ private AppTextureFrame nextOutputFrame() { - outputFrameIndex = (outputFrameIndex + 1) % outputFrames.size(); - AppTextureFrame outputFrame = outputFrames.get(outputFrameIndex); - // Check if the size has changed. - if (outputFrame == null - || outputFrame.getWidth() != destinationWidth - || outputFrame.getHeight() != destinationHeight) { - // setupDestination will wait for the frame to be released before reallocating it. - setupDestination(outputFrameIndex); - outputFrame = outputFrames.get(outputFrameIndex); + PoolTextureFrame outputFrame; + synchronized (this) { + outputFrame = framesAvailable.poll(); + framesInUse++; + } + if (outputFrame == null) { + outputFrame = createFrame(); + } else if (outputFrame.getWidth() != destinationWidth + || outputFrame.getHeight() != destinationHeight) { + // Create anew if size has changed. + // TODO: waiting for the consumer sync here may not be necessary. + waitUntilReleased(outputFrame); + teardownFrame(outputFrame); + outputFrame = createFrame(); + } else { + // Note: waitUntilReleased does two things: waits for the frame to be released by the CPU, + // and syncs with the GPU sync token provided by the consumer. The first part is redundant + // here (and completes immediately), but the second part is still needed. + waitUntilReleased(outputFrame); } - waitUntilReleased(outputFrame); return outputFrame; } + protected synchronized void poolFrameReleased(PoolTextureFrame frame) { + framesAvailable.offer(frame); + framesInUse--; + int keep = max(framesToKeep - framesInUse, 0); + while (framesAvailable.size() > keep) { + teardownFrame(framesAvailable.remove()); + } + } + /** * Updates output frame with current pixels of surface texture and corresponding timestamp. * @@ -417,16 +452,22 @@ public class ExternalTextureConverter implements TextureFrameProducer { Log.v( TAG, String.format( - "Waiting for tex: %d width: %d height: %d", - frame.getTextureName(), frame.getWidth(), frame.getHeight())); + "Waiting for tex: %d width: %d height: %d timestamp: %d", + frame.getTextureName(), + frame.getWidth(), + frame.getHeight(), + frame.getTimestamp())); } frame.waitUntilReleased(); if (Log.isLoggable(TAG, Log.VERBOSE)) { Log.v( TAG, String.format( - "Finished waiting for tex: %d width: %d height: %d", - frame.getTextureName(), frame.getWidth(), frame.getHeight())); + "Finished waiting for tex: %d width: %d height: %d timestamp: %d", + frame.getTextureName(), + frame.getWidth(), + frame.getHeight(), + frame.getTimestamp())); } } catch (InterruptedException ie) { // Someone interrupted our thread. This is not supposed to happen: we own diff --git a/mediapipe/java/com/google/mediapipe/components/FrameProcessor.java b/mediapipe/java/com/google/mediapipe/components/FrameProcessor.java index d9245e768..2728eb1c7 100644 --- a/mediapipe/java/com/google/mediapipe/components/FrameProcessor.java +++ b/mediapipe/java/com/google/mediapipe/components/FrameProcessor.java @@ -20,6 +20,7 @@ import android.media.AudioFormat; import android.os.Handler; import android.util.Log; import com.google.common.base.Preconditions; +import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig; import com.google.mediapipe.framework.AndroidAssetUtil; import com.google.mediapipe.framework.AndroidPacketCreator; import com.google.mediapipe.framework.Graph; @@ -32,10 +33,12 @@ import com.google.mediapipe.framework.SurfaceOutput; import com.google.mediapipe.framework.TextureFrame; import java.io.File; import java.nio.ByteBuffer; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Queue; import java.util.concurrent.atomic.AtomicBoolean; import javax.annotation.Nullable; @@ -106,6 +109,15 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor initializeGraphAndPacketCreator(context, graphName); } + /** + * Constructor. + * + * @param graphConfig the proto object representation of the graph. + */ + public FrameProcessor(CalculatorGraphConfig graphConfig) { + initializeGraphAndPacketCreator(graphConfig); + } + /** * Initializes a graph for processing data in real time. * @@ -123,6 +135,17 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor packetCreator = new AndroidPacketCreator(mediapipeGraph); } + /** + * Initializes a graph for processing data in real time. + * + * @param graphConfig the proto object representation of the graph. + */ + private void initializeGraphAndPacketCreator(CalculatorGraphConfig graphConfig) { + mediapipeGraph = new Graph(); + mediapipeGraph.loadBinaryGraph(graphConfig); + packetCreator = new AndroidPacketCreator(mediapipeGraph); + } + /** Callback for errors occurring during processing in the graph. */ public interface ErrorListener { void onError(RuntimeException error); @@ -186,6 +209,8 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor currentConsumers = videoConsumers; } for (TextureFrameConsumer consumer : currentConsumers) { + // Note: each consumer will release its TextureFrame, so each gets a separate object + // (though they all reference the same data). TextureFrame frame = PacketGetter.getTextureFrame(packet); if (Log.isLoggable(TAG, Log.VERBOSE)) { Log.v( @@ -373,9 +398,10 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor /** * Returns true if the MediaPipe graph can accept one more input frame. + * * @throws MediaPipeException for any error status. */ - private boolean maybeAcceptNewFrame() { + private boolean maybeAcceptNewFrame(long timestamp) { if (!started.getAndSet(true)) { startGraph(); } @@ -395,7 +421,7 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor frame.getTextureName(), frame.getWidth(), frame.getHeight())); } - if (!maybeAcceptNewFrame()) { + if (!maybeAcceptNewFrame(frame.getTimestamp())) { return; } @@ -451,7 +477,7 @@ public class FrameProcessor implements TextureFrameProcessor, AudioDataProcessor public void onNewFrame(final Bitmap bitmap, long timestamp) { Packet packet = null; try { - if (!maybeAcceptNewFrame()) { + if (!maybeAcceptNewFrame(timestamp)) { return; } diff --git a/mediapipe/java/com/google/mediapipe/components/PermissionHelper.java b/mediapipe/java/com/google/mediapipe/components/PermissionHelper.java index 2b81029b1..2884a25b7 100644 --- a/mediapipe/java/com/google/mediapipe/components/PermissionHelper.java +++ b/mediapipe/java/com/google/mediapipe/components/PermissionHelper.java @@ -17,8 +17,8 @@ package com.google.mediapipe.components; import android.Manifest; import android.app.Activity; import android.content.pm.PackageManager; -import androidx.core.app.ActivityCompat; import android.util.Log; +import androidx.core.app.ActivityCompat; import androidx.core.content.ContextCompat; /** Manages camera permission request and handling. */ diff --git a/mediapipe/java/com/google/mediapipe/components/TextureFrameConsumer.java b/mediapipe/java/com/google/mediapipe/components/TextureFrameConsumer.java index 4c62ebbcb..498ca6076 100644 --- a/mediapipe/java/com/google/mediapipe/components/TextureFrameConsumer.java +++ b/mediapipe/java/com/google/mediapipe/components/TextureFrameConsumer.java @@ -18,6 +18,10 @@ import com.google.mediapipe.framework.TextureFrame; /** Lightweight abstraction for an object that can receive video frames. */ public interface TextureFrameConsumer { - /** Called when a new {@link TextureFrame} is available. */ + /** + * Called when a new {@link TextureFrame} is available. + * + * Important: implementations of this method should call frame.release(). + **/ public abstract void onNewFrame(TextureFrame frame); } diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java index d87bc8945..aae0adc6d 100644 --- a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java +++ b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java @@ -272,6 +272,10 @@ public final class PacketGetter { *

Note: in order for the application to be able to use the texture, its GL context must be * linked with MediaPipe's. This is ensured by calling {@link Graph#createGlRunner(String,long)} * with the native handle to the application's GL context as the second argument. + * + *

The returned GraphTextureFrame must be released by the caller. If this method is called + * multiple times, each returned GraphTextureFrame is an independent reference to the underlying + * texture data, and must be released individually. */ public static GraphTextureFrame getTextureFrame(final Packet packet) { return new GraphTextureFrame( diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/METADATA b/mediapipe/java/com/google/mediapipe/framework/jni/METADATA deleted file mode 100644 index cbc57f510..000000000 --- a/mediapipe/java/com/google/mediapipe/framework/jni/METADATA +++ /dev/null @@ -1,7 +0,0 @@ -tricorder: { - options: { - builder: { - config: "android_arm" - } - } -} diff --git a/mediapipe/objc/MPPGraph.h b/mediapipe/objc/MPPGraph.h index 6823aad18..06bf1552b 100644 --- a/mediapipe/objc/MPPGraph.h +++ b/mediapipe/objc/MPPGraph.h @@ -33,22 +33,22 @@ struct GpuSharedData; /// Provides the delegate with a new video frame. @optional -- (void)mediapipeGraph:(MPPGraph*)graph +- (void)mediapipeGraph:(MPPGraph *)graph didOutputPixelBuffer:(CVPixelBufferRef)pixelBuffer - fromStream:(const std::string&)streamName; + fromStream:(const std::string &)streamName; /// Provides the delegate with a new video frame and time stamp. @optional -- (void)mediapipeGraph:(MPPGraph*)graph +- (void)mediapipeGraph:(MPPGraph *)graph didOutputPixelBuffer:(CVPixelBufferRef)pixelBuffer - fromStream:(const std::string&)streamName - timestamp:(const mediapipe::Timestamp&)timestamp; + fromStream:(const std::string &)streamName + timestamp:(const mediapipe::Timestamp &)timestamp; /// Provides the delegate with a raw packet. @optional -- (void)mediapipeGraph:(MPPGraph*)graph - didOutputPacket:(const mediapipe::Packet&)packet - fromStream:(const std::string&)streamName; +- (void)mediapipeGraph:(MPPGraph *)graph + didOutputPacket:(const mediapipe::Packet &)packet + fromStream:(const std::string &)streamName; @end @@ -100,34 +100,34 @@ typedef NS_ENUM(int, MPPPacketType) { /// Copies the config and initializes the graph. /// @param config The configuration describing the graph. -- (instancetype)initWithGraphConfig:(const mediapipe::CalculatorGraphConfig&)config +- (instancetype)initWithGraphConfig:(const mediapipe::CalculatorGraphConfig &)config NS_DESIGNATED_INITIALIZER; -- (mediapipe::ProfilingContext*)getProfiler; +- (mediapipe::ProfilingContext *)getProfiler; /// Sets a stream header. If the header was already set, it is overwritten. /// @param packet The header. /// @param streamName The name of the stream. -- (void)setHeaderPacket:(const mediapipe::Packet&)packet forStream:(const std::string&)streamName; +- (void)setHeaderPacket:(const mediapipe::Packet &)packet forStream:(const std::string &)streamName; /// Sets a side packet. If it was already set, it is overwritten. /// Must be called before the graph is started. /// @param packet The packet to be associated with the input side packet. /// @param name The name of the input side packet. -- (void)setSidePacket:(const mediapipe::Packet&)packet named:(const std::string&)name; +- (void)setSidePacket:(const mediapipe::Packet &)packet named:(const std::string &)name; /// Sets a service packet. If it was already set, it is overwritten. /// Must be called before the graph is started. /// @param packet The packet to be associated with the service. /// @param service. -- (void)setServicePacket:(mediapipe::Packet&)packet - forService:(const mediapipe::GraphServiceBase&)service; +- (void)setServicePacket:(mediapipe::Packet &)packet + forService:(const mediapipe::GraphServiceBase &)service; /// Adds input side packets from a map. Any inputs that were already set are /// left unchanged. /// Must be called before the graph is started. /// @param extraInputSidePackets The input side packets to be added. -- (void)addSidePackets:(const std::map&)extraSidePackets; +- (void)addSidePackets:(const std::map &)extraSidePackets; // TODO: rename to addDelegateOutputStream:packetType: /// Add an output stream in the graph from which the delegate wants to receive @@ -135,30 +135,30 @@ typedef NS_ENUM(int, MPPPacketType) { /// @param outputStreamName The name of the output stream from which /// the delegate will receive frames. /// @param packetType The type of packet provided by the output streams. -- (void)addFrameOutputStream:(const std::string&)outputStreamName +- (void)addFrameOutputStream:(const std::string &)outputStreamName outputPacketType:(MPPPacketType)packetType; /// Starts running the graph. /// @return YES if successful. -- (BOOL)startWithError:(NSError**)error; +- (BOOL)startWithError:(NSError **)error; /// Sends a generic packet into a graph input stream. /// The graph must have been started before calling this. /// Returns YES if the packet was successfully sent. -- (BOOL)sendPacket:(const mediapipe::Packet&)packet - intoStream:(const std::string&)streamName - error:(NSError**)error; +- (BOOL)sendPacket:(const mediapipe::Packet &)packet + intoStream:(const std::string &)streamName + error:(NSError **)error; -- (BOOL)movePacket:(mediapipe::Packet&&)packet - intoStream:(const std::string&)streamName - error:(NSError**)error; +- (BOOL)movePacket:(mediapipe::Packet &&)packet + intoStream:(const std::string &)streamName + error:(NSError **)error; /// Sets the maximum queue size for a stream. Experimental feature, currently /// only supported for graph input streams. Should be called before starting the /// graph. - (BOOL)setMaxQueueSize:(int)maxQueueSize - forStream:(const std::string&)streamName - error:(NSError**)error; + forStream:(const std::string &)streamName + error:(NSError **)error; /// Creates a MediaPipe packet wrapping the given pixelBuffer; - (mediapipe::Packet)packetWithPixelBuffer:(CVPixelBufferRef)pixelBuffer @@ -170,9 +170,9 @@ typedef NS_ENUM(int, MPPPacketType) { /// allows MediaPipe to overwrite the packet contents on successful sending for /// possibly increased efficiency. Returns YES if the packet was successfully sent. - (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer - intoStream:(const std::string&)inputName + intoStream:(const std::string &)inputName packetType:(MPPPacketType)packetType - timestamp:(const mediapipe::Timestamp&)timestamp + timestamp:(const mediapipe::Timestamp &)timestamp allowOverwrite:(BOOL)allowOverwrite; /// Sends a pixel buffer into a graph input stream, using the specified packet @@ -180,9 +180,23 @@ typedef NS_ENUM(int, MPPPacketType) { /// returns NO if maxFramesInFlight is exceeded. Returns YES if the packet was /// successfully sent. - (BOOL)sendPixelBuffer:(CVPixelBufferRef)pixelBuffer - intoStream:(const std::string&)inputName + intoStream:(const std::string &)inputName packetType:(MPPPacketType)packetType - timestamp:(const mediapipe::Timestamp&)timestamp; + timestamp:(const mediapipe::Timestamp &)timestamp; + +/// Sends a pixel buffer into a graph input stream, using the specified packet +/// type. The graph must have been started before calling this. Drops frames and +/// returns NO if maxFramesInFlight is exceeded. If allowOverwrite is set to YES, +/// allows MediaPipe to overwrite the packet contents on successful sending for +/// possibly increased efficiency. Returns YES if the packet was successfully sent. +/// Sets error to a non-nil value if an error occurs in the graph when sending the +/// packet. +- (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer + intoStream:(const std::string &)inputName + packetType:(MPPPacketType)packetType + timestamp:(const mediapipe::Timestamp &)timestamp + allowOverwrite:(BOOL)allowOverwrite + error:(NSError **)error; /// Sends a pixel buffer into a graph input stream, using the specified packet /// type. The graph must have been started before calling this. The timestamp is @@ -190,32 +204,32 @@ typedef NS_ENUM(int, MPPPacketType) { /// frames and returns NO if maxFramesInFlight is exceeded. Returns YES if the /// packet was successfully sent. - (BOOL)sendPixelBuffer:(CVPixelBufferRef)pixelBuffer - intoStream:(const std::string&)inputName + intoStream:(const std::string &)inputName packetType:(MPPPacketType)packetType; /// Cancels a graph run. You must still call waitUntilDoneWithError: after this. - (void)cancel; /// Check if the graph contains this input stream -- (BOOL)hasInputStream:(const std::string&)inputName; +- (BOOL)hasInputStream:(const std::string &)inputName; /// Closes an input stream. /// You must close all graph input streams before stopping the graph. /// @return YES if successful. -- (BOOL)closeInputStream:(const std::string&)inputName error:(NSError**)error; +- (BOOL)closeInputStream:(const std::string &)inputName error:(NSError **)error; /// Closes all graph input streams. /// @return YES if successful. -- (BOOL)closeAllInputStreamsWithError:(NSError**)error; +- (BOOL)closeAllInputStreamsWithError:(NSError **)error; /// Stops running the graph. /// Call this before releasing this object. All input streams must have been /// closed. This call does not time out, so you should not call it from the main /// thread. /// @return YES if successful. -- (BOOL)waitUntilDoneWithError:(NSError**)error; +- (BOOL)waitUntilDoneWithError:(NSError **)error; /// Waits for the graph to become idle. -- (BOOL)waitUntilIdleWithError:(NSError**)error; +- (BOOL)waitUntilIdleWithError:(NSError **)error; @end diff --git a/mediapipe/objc/MPPGraph.mm b/mediapipe/objc/MPPGraph.mm index cccdb945a..dec76047e 100644 --- a/mediapipe/objc/MPPGraph.mm +++ b/mediapipe/objc/MPPGraph.mm @@ -327,22 +327,35 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, packetType:(MPPPacketType)packetType timestamp:(const mediapipe::Timestamp&)timestamp allowOverwrite:(BOOL)allowOverwrite { + NSError* error; + bool success = [self sendPixelBuffer:imageBuffer + intoStream:inputName + packetType:packetType + timestamp:timestamp + allowOverwrite:allowOverwrite + error:&error]; + if (error) { + _GTMDevLog(@"failed to send packet: %@", error); + } + return success; +} + +- (BOOL)sendPixelBuffer:(CVPixelBufferRef)imageBuffer + intoStream:(const std::string&)inputName + packetType:(MPPPacketType)packetType + timestamp:(const mediapipe::Timestamp&)timestamp + allowOverwrite:(BOOL)allowOverwrite + error:(NSError**)error { if (_maxFramesInFlight && _framesInFlight >= _maxFramesInFlight) return NO; mediapipe::Packet packet = [self packetWithPixelBuffer:imageBuffer packetType:packetType]; - NSError* error; BOOL success; if (allowOverwrite) { packet = std::move(packet).At(timestamp); - success = [self movePacket:std::move(packet) - intoStream:inputName - error:&error]; + success = [self movePacket:std::move(packet) intoStream:inputName error:error]; } else { - success = [self sendPacket:packet.At(timestamp) - intoStream:inputName - error:&error]; + success = [self sendPacket:packet.At(timestamp) intoStream:inputName error:error]; } if (success) _framesInFlight++; - else _GTMDevLog(@"failed to send packet: %@", error); return success; } diff --git a/mediapipe/util/sequence/README.md b/mediapipe/util/sequence/README.md index eda4c702a..4af003092 100644 --- a/mediapipe/util/sequence/README.md +++ b/mediapipe/util/sequence/README.md @@ -423,6 +423,10 @@ tasks and tracking (or class) fields for tracking information. |`region/point/y`|feature list float list|`add_bbox_point_y` / `AddBBoxPointY`|A list of normalized y values for points in a frame.| |`region/point/\*`| *special* |`add_bbox_point` / `AddBBoxPoint`|Operates on point/x,point/y with a single call.| |`region/point/radius`|feature list float list|`add_bbox_point_radius` / `AddBBoxPointRadius`|A list of radii for points in a frame.| +|`region/3d_point/x`|feature list float list|`add_bbox_3d_point_x` / `AddBBox3dPointX`|A list of normalized x values for points in a frame.| +|`region/3d_point/y`|feature list float list|`add_bbox_3d_point_y` / `AddBBox3dPointY`|A list of normalized y values for points in a frame.| +|`region/3d_point/z`|feature list float list|`add_bbox_3d_point_z` / `AddBBox3dPointZ`|A list of normalized z values for points in a frame.| +|`region/3d_point/\*`| *special* |`add_bbox_3d_point` / `AddBBox3dPoint`|Operates on 3d_point/{x,y,z} with a single call.| |`region/timestamp`|feature list int|`add_bbox_timestamp` / `AddBBoxTimestamp`|The timestamp in microseconds for the region annotations.| |`region/num_regions`|feature list int|`add_bbox_num_regions` / `AddBBoxNumRegions`|The number of boxes or other regions in a frame. Should be 0 for unannotated frames.| |`region/is_annotated`|feature list int|`add_bbox_is_annotated` / `AddBBoxIsAnnotated`|1 if this timestep is annotated. 0 otherwise. Distinguishes empty from unannotated frames.| diff --git a/mediapipe/util/sequence/media_sequence.cc b/mediapipe/util/sequence/media_sequence.cc index 9fa78723c..88f73771e 100644 --- a/mediapipe/util/sequence/media_sequence.cc +++ b/mediapipe/util/sequence/media_sequence.cc @@ -229,6 +229,18 @@ float TimestampsToRate(int64 first_timestamp, int64 second_timestamp) { sequence); } } + if (Get3dPointSize(prefix, *sequence) > 0) { + std::string x_key = merge_prefix(prefix, kRegion3dPointXKey); + auto* region_feature_list = MutableFeatureList(x_key, sequence); + RET_CHECK_EQ(num_bboxes, region_feature_list->feature_size()) + << "Expected number of BBox timestamps and boxes to match."; + ClearBBoxNumRegions(prefix, sequence); + for (int i = 0; i < num_bboxes; ++i) { + AddBBoxNumRegions( + prefix, region_feature_list->feature(i).float_list().value_size(), + sequence); + } + } // Collect which timestamps currently match to which indices in timestamps. // skip empty timestamps. // Requires sorted indices. @@ -453,6 +465,47 @@ void ClearPoint(const std::string& prefix, ClearBBoxPointX(prefix, sequence); } +int Get3dPointSize(const std::string& prefix, + const tensorflow::SequenceExample& sequence) { + return GetBBox3dPointXSize(prefix, sequence); +} + +std::vector<::std::tuple> Get3dPointAt( + const std::string& prefix, const tensorflow::SequenceExample& sequence, + int index) { + const auto& xs = GetBBox3dPointXAt(prefix, sequence, index); + const auto& ys = GetBBox3dPointYAt(prefix, sequence, index); + const auto& zs = GetBBox3dPointZAt(prefix, sequence, index); + std::vector<::std::tuple> points(ys.size()); + for (int i = 0; i < xs.size(); ++i) { + points[i] = std::make_tuple(xs[i], ys[i], zs[i]); + } + return points; +} + +void Add3dPoint(const std::string& prefix, + const std::vector<::std::tuple>& points, + tensorflow::SequenceExample* sequence) { + ::std::vector xs; + ::std::vector ys; + ::std::vector zs; + for (auto& point : points) { + xs.push_back(std::get<0>(point)); + ys.push_back(std::get<1>(point)); + zs.push_back(std::get<2>(point)); + } + AddBBox3dPointX(prefix, xs, sequence); + AddBBox3dPointY(prefix, ys, sequence); + AddBBox3dPointZ(prefix, zs, sequence); +} + +void Clear3dPoint(const std::string& prefix, + tensorflow::SequenceExample* sequence) { + ClearBBox3dPointX(prefix, sequence); + ClearBBox3dPointY(prefix, sequence); + ClearBBox3dPointZ(prefix, sequence); +} + std::unique_ptr GetAudioFromFeatureAt( const std::string& prefix, const tensorflow::SequenceExample& sequence, int index) { diff --git a/mediapipe/util/sequence/media_sequence.h b/mediapipe/util/sequence/media_sequence.h index 51934a509..81e18656e 100644 --- a/mediapipe/util/sequence/media_sequence.h +++ b/mediapipe/util/sequence/media_sequence.h @@ -268,6 +268,10 @@ const char kRegionBBoxXMaxKey[] = "region/bbox/xmax"; const char kRegionPointXKey[] = "region/point/x"; const char kRegionPointYKey[] = "region/point/y"; const char kRegionRadiusKey[] = "region/radius"; +// The 3d point can denote keypoints. +const char kRegion3dPointXKey[] = "region/3d_point/x"; +const char kRegion3dPointYKey[] = "region/3d_point/y"; +const char kRegion3dPointZKey[] = "region/3d_point/z"; // The number of regions at that timestep. const char kRegionNumRegionsKey[] = "region/num_regions"; // Whether that timestep is annotated for bounding regions. @@ -333,61 +337,111 @@ void AddPoint(const std::string& prefix, void ClearPoint(const std::string& prefix, tensorflow::SequenceExample* sequence); -#define FIXED_PREFIX_BBOX_ACCESSORS(identifier, prefix) \ - inline int CONCAT_STR3(Get, identifier, \ - Size)(const tensorflow::SequenceExample& sequence) { \ - return GetBBoxSize(prefix, sequence); \ - } \ - inline std::vector<::mediapipe::Location> CONCAT_STR3(Get, identifier, At)( \ - const tensorflow::SequenceExample& sequence, int index) { \ - return GetBBoxAt(prefix, sequence, index); \ - } \ - inline void CONCAT_STR2(Add, identifier)( \ - const std::vector<::mediapipe::Location>& bboxes, \ - tensorflow::SequenceExample* sequence) { \ - return AddBBox(prefix, bboxes, sequence); \ - } \ - inline void CONCAT_STR2( \ - Clear, identifier)(tensorflow::SequenceExample * sequence) { \ - return ClearBBox(prefix, sequence); \ - } \ - inline int CONCAT_STR3(Get, identifier, PointSize)( \ - const tensorflow::SequenceExample& sequence) { \ - return GetPointSize(prefix, sequence); \ - } \ - inline int CONCAT_STR3(Get, identifier, PointSize)( \ - const std::string& name, const tensorflow::SequenceExample& sequence) { \ - return GetPointSize(name, sequence); \ - } \ - inline std::vector> CONCAT_STR3( \ - Get, identifier, PointAt)(const tensorflow::SequenceExample& sequence, \ - int index) { \ - return GetPointAt(prefix, sequence, index); \ - } \ - inline std::vector> CONCAT_STR3( \ - Get, identifier, PointAt)(const std::string& name, \ - const tensorflow::SequenceExample& sequence, \ - int index) { \ - return GetPointAt(name, sequence, index); \ - } \ - inline void CONCAT_STR3(Add, identifier, Point)( \ - const std::vector>& points, \ - tensorflow::SequenceExample* sequence) { \ - return AddPoint(prefix, points, sequence); \ - } \ - inline void CONCAT_STR3(Add, identifier, Point)( \ - const std::string& name, \ - const std::vector>& points, \ - tensorflow::SequenceExample* sequence) { \ - return AddPoint(name, points, sequence); \ - } \ - inline void CONCAT_STR3(Clear, identifier, \ - Point)(tensorflow::SequenceExample * sequence) { \ - return ClearPoint(prefix, sequence); \ - } \ - inline void CONCAT_STR3(Clear, identifier, Point)( \ - std::string name, tensorflow::SequenceExample * sequence) { \ - return ClearPoint(name, sequence); \ +// The input and output format is a pair of coordinates to match the +// order of bounding box coordinates. +int Get3dPointSize(const std::string& prefix, + const tensorflow::SequenceExample& sequence); +std::vector> Get3dPointAt( + const std::string& prefix, const tensorflow::SequenceExample& sequence, + int index); +void Add3dPoint(const std::string& prefix, + const std::vector>& points, + tensorflow::SequenceExample* sequence); +void Clear3dPoint(const std::string& prefix, + tensorflow::SequenceExample* sequence); +#define FIXED_PREFIX_BBOX_ACCESSORS(identifier, prefix) \ + inline int CONCAT_STR3(Get, identifier, \ + Size)(const tensorflow::SequenceExample& sequence) { \ + return GetBBoxSize(prefix, sequence); \ + } \ + inline std::vector<::mediapipe::Location> CONCAT_STR3(Get, identifier, At)( \ + const tensorflow::SequenceExample& sequence, int index) { \ + return GetBBoxAt(prefix, sequence, index); \ + } \ + inline void CONCAT_STR2(Add, identifier)( \ + const std::vector<::mediapipe::Location>& bboxes, \ + tensorflow::SequenceExample* sequence) { \ + return AddBBox(prefix, bboxes, sequence); \ + } \ + inline void CONCAT_STR2( \ + Clear, identifier)(tensorflow::SequenceExample * sequence) { \ + return ClearBBox(prefix, sequence); \ + } \ + inline int CONCAT_STR3(Get, identifier, PointSize)( \ + const tensorflow::SequenceExample& sequence) { \ + return GetPointSize(prefix, sequence); \ + } \ + inline int CONCAT_STR3(Get, identifier, PointSize)( \ + const std::string& name, const tensorflow::SequenceExample& sequence) { \ + return GetPointSize(name, sequence); \ + } \ + inline std::vector> CONCAT_STR3( \ + Get, identifier, PointAt)(const tensorflow::SequenceExample& sequence, \ + int index) { \ + return GetPointAt(prefix, sequence, index); \ + } \ + inline std::vector> CONCAT_STR3( \ + Get, identifier, PointAt)(const std::string& name, \ + const tensorflow::SequenceExample& sequence, \ + int index) { \ + return GetPointAt(name, sequence, index); \ + } \ + inline void CONCAT_STR3(Add, identifier, Point)( \ + const std::vector>& points, \ + tensorflow::SequenceExample* sequence) { \ + return AddPoint(prefix, points, sequence); \ + } \ + inline void CONCAT_STR3(Add, identifier, Point)( \ + const std::string& name, \ + const std::vector>& points, \ + tensorflow::SequenceExample* sequence) { \ + return AddPoint(name, points, sequence); \ + } \ + inline void CONCAT_STR3(Clear, identifier, \ + Point)(tensorflow::SequenceExample * sequence) { \ + return ClearPoint(prefix, sequence); \ + } \ + inline void CONCAT_STR3(Clear, identifier, Point)( \ + std::string name, tensorflow::SequenceExample * sequence) { \ + return ClearPoint(name, sequence); \ + } \ + inline int CONCAT_STR3(Get, identifier, 3dPointSize)( \ + const tensorflow::SequenceExample& sequence) { \ + return Get3dPointSize(prefix, sequence); \ + } \ + inline int CONCAT_STR3(Get, identifier, 3dPointSize)( \ + const std::string& name, const tensorflow::SequenceExample& sequence) { \ + return Get3dPointSize(name, sequence); \ + } \ + inline std::vector> CONCAT_STR3( \ + Get, identifier, 3dPointAt)(const tensorflow::SequenceExample& sequence, \ + int index) { \ + return Get3dPointAt(prefix, sequence, index); \ + } \ + inline std::vector> CONCAT_STR3( \ + Get, identifier, 3dPointAt)(const std::string& name, \ + const tensorflow::SequenceExample& sequence, \ + int index) { \ + return Get3dPointAt(name, sequence, index); \ + } \ + inline void CONCAT_STR3(Add, identifier, 3dPoint)( \ + const std::vector>& points, \ + tensorflow::SequenceExample* sequence) { \ + return Add3dPoint(prefix, points, sequence); \ + } \ + inline void CONCAT_STR3(Add, identifier, 3dPoint)( \ + const std::string& name, \ + const std::vector>& points, \ + tensorflow::SequenceExample* sequence) { \ + return Add3dPoint(name, points, sequence); \ + } \ + inline void CONCAT_STR3(Clear, identifier, \ + 3dPoint)(tensorflow::SequenceExample * sequence) { \ + return Clear3dPoint(prefix, sequence); \ + } \ + inline void CONCAT_STR3(Clear, identifier, 3dPoint)( \ + std::string name, tensorflow::SequenceExample * sequence) { \ + return Clear3dPoint(name, sequence); \ } #define PREFIXED_BBOX(identifier, prefix) \ @@ -435,6 +489,12 @@ void ClearPoint(const std::string& prefix, kRegionPointYKey, prefix) \ FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, Radius), \ kRegionRadiusKey, prefix) \ + FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, 3dPointX), \ + kRegion3dPointXKey, prefix) \ + FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, 3dPointY), \ + kRegion3dPointYKey, prefix) \ + FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST(CONCAT_STR2(identifier, 3dPointZ), \ + kRegion3dPointZKey, prefix) \ FIXED_PREFIX_VECTOR_FLOAT_FEATURE_LIST( \ CONCAT_STR2(identifier, EmbeddingFloats), kRegionEmbeddingFloatKey, \ prefix) \ diff --git a/mediapipe/util/sequence/media_sequence.py b/mediapipe/util/sequence/media_sequence.py index 75a0dcff7..07289576f 100644 --- a/mediapipe/util/sequence/media_sequence.py +++ b/mediapipe/util/sequence/media_sequence.py @@ -262,6 +262,10 @@ REGION_BBOX_XMAX_KEY = "region/bbox/xmax" REGION_POINT_X_KEY = "region/point/x" REGION_POINT_Y_KEY = "region/point/y" REGION_RADIUS_KEY = "region/radius" +# The 3D point can denote keypoints. +REGION_3D_POINT_X_KEY = "region/3d_point/x" +REGION_3D_POINT_Y_KEY = "region/3d_point/y" +REGION_3D_POINT_Z_KEY = "region/3d_point/z" # The number of regions at that timestep. REGION_NUM_REGIONS_KEY = "region/num_regions" # Whether that timestep is annotated for regions. @@ -365,6 +369,15 @@ def _create_region_with_prefix(name, prefix): prefix=prefix, module_dict=globals()) msu.create_float_list_feature_list(name + "_point_y", REGION_POINT_Y_KEY, prefix=prefix, module_dict=globals()) + msu.create_float_list_feature_list( + name + "_3d_point_x", REGION_3D_POINT_X_KEY, + prefix=prefix, module_dict=globals()) + msu.create_float_list_feature_list( + name + "_3d_point_y", REGION_3D_POINT_Y_KEY, + prefix=prefix, module_dict=globals()) + msu.create_float_list_feature_list( + name + "_3d_point_z", REGION_3D_POINT_Z_KEY, + prefix=prefix, module_dict=globals()) msu.create_bytes_list_context_feature(name + "_parts", REGION_PARTS_KEY, prefix=prefix, module_dict=globals()) @@ -406,6 +419,39 @@ def _create_region_with_prefix(name, prefix): clear_bbox_xmin(sequence_example, prefix=prefix) clear_bbox_ymax(sequence_example, prefix=prefix) clear_bbox_xmax(sequence_example, prefix=prefix) + def get_prefixed_point_at(index, sequence_example, prefix): + return np.stack(( + get_bbox_point_y_at(index, sequence_example, prefix=prefix), + get_bbox_point_x_at(index, sequence_example, prefix=prefix)), + 1) + def add_prefixed_point(values, sequence_example, prefix): + add_bbox_point_y(values[:, 0], sequence_example, prefix=prefix) + add_bbox_point_x(values[:, 1], sequence_example, prefix=prefix) + def get_prefixed_point_size(sequence_example, prefix): + return get_bbox_point_y_size(sequence_example, prefix=prefix) + def has_prefixed_point(sequence_example, prefix): + return has_bbox_point_y(sequence_example, prefix=prefix) + def clear_prefixed_point(sequence_example, prefix): + clear_bbox_point_y(sequence_example, prefix=prefix) + clear_bbox_point_x(sequence_example, prefix=prefix) + def get_prefixed_3d_point_at(index, sequence_example, prefix): + return np.stack(( + get_bbox_3d_point_x_at(index, sequence_example, prefix=prefix), + get_bbox_3d_point_y_at(index, sequence_example, prefix=prefix), + get_bbox_3d_point_z_at(index, sequence_example, prefix=prefix)), + 1) + def add_prefixed_3d_point(values, sequence_example, prefix): + add_bbox_3d_point_x(values[:, 0], sequence_example, prefix=prefix) + add_bbox_3d_point_y(values[:, 1], sequence_example, prefix=prefix) + add_bbox_3d_point_z(values[:, 2], sequence_example, prefix=prefix) + def get_prefixed_3d_point_size(sequence_example, prefix): + return get_bbox_3d_point_x_size(sequence_example, prefix=prefix) + def has_prefixed_3d_point(sequence_example, prefix): + return has_bbox_3d_point_x(sequence_example, prefix=prefix) + def clear_prefixed_3d_point(sequence_example, prefix): + clear_bbox_3d_point_x(sequence_example, prefix=prefix) + clear_bbox_3d_point_y(sequence_example, prefix=prefix) + clear_bbox_3d_point_z(sequence_example, prefix=prefix) # pylint: enable=undefined-variable msu.add_functions_to_module({ "get_" + name + "_at": @@ -419,6 +465,30 @@ def _create_region_with_prefix(name, prefix): "clear_" + name: functools.partial(clear_prefixed_bbox, prefix=prefix), }, module_dict=globals()) + msu.add_functions_to_module({ + "get_" + name + "_point_at": + functools.partial(get_prefixed_point_at, prefix=prefix), + "add_" + name + "_point": + functools.partial(add_prefixed_point, prefix=prefix), + "get_" + name + "_point_size": + functools.partial(get_prefixed_point_size, prefix=prefix), + "has_" + name + "_point": + functools.partial(has_prefixed_point, prefix=prefix), + "clear_" + name + "_point": + functools.partial(clear_prefixed_point, prefix=prefix), + }, module_dict=globals()) + msu.add_functions_to_module({ + "get_" + name + "_3d_point_at": + functools.partial(get_prefixed_3d_point_at, prefix=prefix), + "add_" + name + "_3d_point": + functools.partial(add_prefixed_3d_point, prefix=prefix), + "get_" + name + "_3d_point_size": + functools.partial(get_prefixed_3d_point_size, prefix=prefix), + "has_" + name + "_3d_point": + functools.partial(has_prefixed_3d_point, prefix=prefix), + "clear_" + name + "_3d_point": + functools.partial(clear_prefixed_3d_point, prefix=prefix), + }, module_dict=globals()) PREDICTED_PREFIX = "PREDICTED" diff --git a/mediapipe/util/sequence/media_sequence_test.cc b/mediapipe/util/sequence/media_sequence_test.cc index 0b860d8ea..d22128747 100644 --- a/mediapipe/util/sequence/media_sequence_test.cc +++ b/mediapipe/util/sequence/media_sequence_test.cc @@ -436,6 +436,21 @@ TEST(MediaSequenceTest, RoundTripBBoxPointPrefixed) { } } +TEST(MediaSequenceTest, RoundTripBBox3dPoint) { + tensorflow::SequenceExample sequence; + std::vector>> points = { + {std::make_tuple(0.3, 0.5, 0.1), std::make_tuple(0.4, 0.7, 0.2)}, + {std::make_tuple(0.7, 0.5, 0.3), std::make_tuple(0.3, 0.4, 0.4)}}; + for (int i = 0; i < points.size(); ++i) { + AddBBox3dPoint(points[i], &sequence); + ASSERT_EQ(GetBBox3dPointSize(sequence), i + 1); + const auto& sequence_points = GetBBox3dPointAt(sequence, i); + for (int j = 0; j < sequence_points.size(); ++j) { + EXPECT_EQ(sequence_points[j], points[i][j]); + } + } +} + TEST(MediaSequenceTest, RoundTripRegionParts) { tensorflow::SequenceExample sequence; std::vector parts = {"HEAD", "FEET"}; diff --git a/mediapipe/util/sequence/media_sequence_test.py b/mediapipe/util/sequence/media_sequence_test.py index bae56b600..2f30c554e 100644 --- a/mediapipe/util/sequence/media_sequence_test.py +++ b/mediapipe/util/sequence/media_sequence_test.py @@ -89,6 +89,9 @@ class MediaSequenceTest(tf.test.TestCase): ms.add_bbox_xmax((0.47, 0.49), example) ms.add_bbox_point_x((0.47, 0.49), example) ms.add_bbox_point_y((0.47, 0.49), example) + ms.add_bbox_3d_point_x((0.47, 0.49), example) + ms.add_bbox_3d_point_y((0.47, 0.49), example) + ms.add_bbox_3d_point_z((0.47, 0.49), example) ms.add_predicted_bbox_ymin((0.47, 0.49), example) ms.add_predicted_bbox_xmin((0.47, 0.49), example) ms.add_predicted_bbox_ymax((0.47, 0.49), example) @@ -133,6 +136,30 @@ class MediaSequenceTest(tf.test.TestCase): ms.clear_bbox(example) self.assertEqual(0, ms.get_bbox_size(example)) + def test_point_round_trip(self): + example = tf.train.SequenceExample() + points = np.array([[0.1, 0.2], + [0.5, 0.6]]) + ms.add_bbox_point(points, example) + ms.add_bbox_point(points, example) + self.assertEqual(2, ms.get_bbox_point_size(example)) + self.assertAllClose(points, ms.get_bbox_point_at(0, example)) + self.assertTrue(ms.has_bbox_point(example)) + ms.clear_bbox_point(example) + self.assertEqual(0, ms.get_bbox_point_size(example)) + + def test_3d_point_round_trip(self): + example = tf.train.SequenceExample() + points = np.array([[0.1, 0.2, 0.3], + [0.5, 0.6, 0.7]]) + ms.add_bbox_3d_point(points, example) + ms.add_bbox_3d_point(points, example) + self.assertEqual(2, ms.get_bbox_3d_point_size(example)) + self.assertAllClose(points, ms.get_bbox_3d_point_at(0, example)) + self.assertTrue(ms.has_bbox_3d_point(example)) + ms.clear_bbox_3d_point(example) + self.assertEqual(0, ms.get_bbox_3d_point_size(example)) + def test_predicted_bbox_round_trip(self): example = tf.train.SequenceExample() boxes = np.array([[0.1, 0.2, 0.3, 0.4], diff --git a/mediapipe/util/tflite/BUILD b/mediapipe/util/tflite/BUILD index f5ab6c1f1..9d8a6f3db 100644 --- a/mediapipe/util/tflite/BUILD +++ b/mediapipe/util/tflite/BUILD @@ -19,6 +19,14 @@ package(default_visibility = [ "//mediapipe:__subpackages__", ]) +cc_library( + name = "config", + hdrs = ["config.h"], + deps = [ + "//mediapipe/framework:calculator_framework", + ], +) + cc_library( name = "cpu_op_resolver", srcs = ["cpu_op_resolver.cc"], @@ -69,6 +77,7 @@ cc_test( srcs = ["tensor_buffer_test.cc"], deps = [ ":tensor_buffer", + ":config", "//mediapipe/framework/port:gtest_main", ] + select({ "//mediapipe/gpu:disable_gpu": [], @@ -99,6 +108,7 @@ cc_library( "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite/delegates/gpu:api", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model", + "@org_tensorflow//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2", ], "//mediapipe:android": [ @@ -108,7 +118,9 @@ cc_library( "//mediapipe/framework/port:statusor", "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite/delegates/gpu:api", + "@org_tensorflow//tensorflow/lite/delegates/gpu/cl:api", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model", + "@org_tensorflow//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2", ], }) + ["@org_tensorflow//tensorflow/lite/core/api"], diff --git a/mediapipe/util/tflite/config.h b/mediapipe/util/tflite/config.h new file mode 100644 index 000000000..dbc499691 --- /dev/null +++ b/mediapipe/util/tflite/config.h @@ -0,0 +1,59 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_UTIL_TFLITE_CONFIG_H_ +#define MEDIAPIPE_UTIL_TFLITE_CONFIG_H_ + +#include "mediapipe/framework/calculator_framework.h" + +// MediaPipe code should use the following defines to determine whether TFLite +// GPU support is available, and whether GL or Metal inference is available. + +#ifdef MEDIAPIPE_DISABLE_GL_COMPUTE +#define MEDIAPIPE_TFLITE_GL_INFERENCE 0 +#else +#define MEDIAPIPE_TFLITE_GL_INFERENCE 1 +#endif // MEDIAPIPE_DISABLE_GL_COMPUTE + +#ifdef MEDIAPIPE_IOS +#define MEDIAPIPE_TFLITE_METAL_INFERENCE 1 +#else +#define MEDIAPIPE_TFLITE_METAL_INFERENCE 0 +#endif // MEDIAPIPE_IOS + +#define MEDIAPIPE_TFLITE_GPU_SUPPORTED \ + ((MEDIAPIPE_TFLITE_GL_INFERENCE) || (MEDIAPIPE_TFLITE_METAL_INFERENCE)) + +#if MEDIAPIPE_TFLITE_GL_INFERENCE +#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE + +#if MEDIAPIPE_TFLITE_METAL_INFERENCE +#import +#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE + +namespace mediapipe { + +#if MEDIAPIPE_TFLITE_GL_INFERENCE +typedef ::tflite::gpu::gl::GlBuffer GpuTensor; +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE +typedef id GpuTensor; +#else +struct DummyGpuTensor {}; +typedef DummyGpuTensor GpuTensor; // Dummy define for less #ifdefs +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE + +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_TFLITE_CONFIG_H_ diff --git a/mediapipe/util/tflite/operations/max_pool_argmax.cc b/mediapipe/util/tflite/operations/max_pool_argmax.cc index 478322ca5..e87c8dd96 100644 --- a/mediapipe/util/tflite/operations/max_pool_argmax.cc +++ b/mediapipe/util/tflite/operations/max_pool_argmax.cc @@ -130,11 +130,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { auto padding = params->padding; auto compute_out_size = [padding](int image_size, int filter_size, int stride) -> int { - return padding == kTfLitePaddingSame - ? (image_size + stride - 1) / stride - : padding == kTfLitePaddingValid - ? (image_size - filter_size + stride) / stride - : 0; + return padding == kTfLitePaddingSame ? (image_size + stride - 1) / stride + : padding == kTfLitePaddingValid + ? (image_size - filter_size + stride) / stride + : 0; }; int out_width = diff --git a/mediapipe/util/tflite/tensor_buffer_test.cc b/mediapipe/util/tflite/tensor_buffer_test.cc index 197a60f79..42ba583d0 100644 --- a/mediapipe/util/tflite/tensor_buffer_test.cc +++ b/mediapipe/util/tflite/tensor_buffer_test.cc @@ -2,6 +2,7 @@ #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" +#include "mediapipe/util/tflite/config.h" namespace mediapipe { @@ -12,7 +13,7 @@ TEST(Cpu, BasicTest) { EXPECT_FALSE(tb.UsesGpu()); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if MEDIAPIPE_TFLITE_GL_INFERENCE TEST(Gpu, BasicTest) { TensorBuffer tb; std::shared_ptr tfg_tb = @@ -20,7 +21,7 @@ TEST(Gpu, BasicTest) { tb = TensorBuffer(tfg_tb); EXPECT_TRUE(tb.UsesGpu()); } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_TFLITE_GL_INFERENCE } // namespace mediapipe diff --git a/mediapipe/util/tflite/tflite_gpu_runner.cc b/mediapipe/util/tflite/tflite_gpu_runner.cc index 510b291e2..f0624c76a 100644 --- a/mediapipe/util/tflite/tflite_gpu_runner.cc +++ b/mediapipe/util/tflite/tflite_gpu_runner.cc @@ -30,6 +30,13 @@ #include "tensorflow/lite/delegates/gpu/gl/api2.h" #include "tensorflow/lite/model.h" +// This code should be enabled as soon as TensorFlow version, which mediapipe +// uses, will include this module. +#ifdef __ANDROID__ +#include "tensorflow/lite/delegates/gpu/cl/api.h" +#endif +#include "tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.h" + namespace tflite { namespace gpu { namespace { @@ -51,6 +58,19 @@ ObjectDef GetSSBOObjectDef(int channels) { mediapipe::Status TFLiteGPURunner::InitializeWithModel( const tflite::FlatBufferModel& flatbuffer, const tflite::OpResolver& op_resolver) { + // GraphFloat32 is created twice because, when OpenCL and OpenGL backends are + // initialized, different backend-specific graph transformations happen + // in-place. As GraphFloat32 is not copyable by design, we keep two copies of + // the graph until inference is built. This decision doesn't affect the amount + // of run time memory used, because both graph_gl_ and graph_cl_ are deleted + // in the end of the initialization stage. + graph_gl_ = std::make_unique(); + graph_cl_ = std::make_unique(); + MP_RETURN_IF_ERROR( + BuildFromFlatBuffer(flatbuffer, op_resolver, graph_gl_.get())); + MP_RETURN_IF_ERROR( + BuildFromFlatBuffer(flatbuffer, op_resolver, graph_cl_.get())); + for (const auto& input : graph_gl_->inputs()) { input_shapes_.push_back(input->tensor.shape); } @@ -140,6 +160,19 @@ mediapipe::Status TFLiteGPURunner::InitializeOpenGL( absl::Status TFLiteGPURunner::InitializeOpenCL( std::unique_ptr* builder) { +#ifdef __ANDROID__ + cl::InferenceEnvironmentOptions env_options; + cl::InferenceEnvironmentProperties properties; + cl::InferenceOptions cl_options; + cl_options.priority1 = options_.priority1; + cl_options.priority2 = options_.priority2; + cl_options.priority3 = options_.priority3; + cl_options.usage = options_.usage; + MP_RETURN_IF_ERROR( + cl::NewInferenceEnvironment(env_options, &cl_environment_, &properties)); + MP_RETURN_IF_ERROR(cl_environment_->NewInferenceBuilder( + cl_options, std::move(*graph_cl_), builder)); +#endif return absl::OkStatus(); } diff --git a/mediapipe/util/tflite/tflite_gpu_runner.h b/mediapipe/util/tflite/tflite_gpu_runner.h index fef2a818f..c842c9dd6 100644 --- a/mediapipe/util/tflite/tflite_gpu_runner.h +++ b/mediapipe/util/tflite/tflite_gpu_runner.h @@ -27,6 +27,10 @@ #include "tensorflow/lite/delegates/gpu/gl/api2.h" #include "tensorflow/lite/model.h" +#ifdef __ANDROID__ +#include "tensorflow/lite/delegates/gpu/cl/api.h" +#endif + namespace tflite { namespace gpu { @@ -64,6 +68,9 @@ class TFLiteGPURunner { mediapipe::Status Build(); mediapipe::Status Invoke(); + std::vector GetInputShapes() { return input_shapes_; } + std::vector GetOutputShapes() { return output_shapes_; } + private: mediapipe::Status InitializeOpenGL( std::unique_ptr* builder); @@ -73,6 +80,10 @@ class TFLiteGPURunner { InferenceOptions options_; std::unique_ptr gl_environment_; +#ifdef __ANDROID__ + std::unique_ptr cl_environment_; +#endif + // graph_ is maintained temporarily and becomes invalid after runner_ is ready std::unique_ptr graph_gl_; std::unique_ptr graph_cl_; diff --git a/mediapipe/util/tracking/parallel_invoker.h b/mediapipe/util/tracking/parallel_invoker.h index 823522310..cc2f6600c 100644 --- a/mediapipe/util/tracking/parallel_invoker.h +++ b/mediapipe/util/tracking/parallel_invoker.h @@ -236,7 +236,7 @@ inline void CheckAndSetInvokerOptions() { LOG(WARNING) << "Unsupported invoker mode selected on Android. " << "OpenMP linkage detected, so falling back to OpenMP"; flags_parallel_invoker_mode = PARALLEL_INVOKER_OPENMP; -#else // _OPENMP +#else // _OPENMP // Fallback mode for active parallel invoker without OpenMP is ThreadPool. LOG(WARNING) << "Unsupported invoker mode selected on Android. " << "Falling back to ThreadPool"; @@ -273,7 +273,7 @@ inline void CheckAndSetInvokerOptions() { #endif // _OPENMP } -#else // PARALLEL_INVOKER_ACTIVE +#else // PARALLEL_INVOKER_ACTIVE if (flags_parallel_invoker_mode != PARALLEL_INVOKER_NONE) { LOG(ERROR) << "Parallel execution requested but PARALLEL_INVOKER_ACTIVE " << "compile flag is not set. Falling back to single threaded " diff --git a/mediapipe/util/tracking/region_flow_computation.cc b/mediapipe/util/tracking/region_flow_computation.cc index 403440f94..9fc7ae81d 100644 --- a/mediapipe/util/tracking/region_flow_computation.cc +++ b/mediapipe/util/tracking/region_flow_computation.cc @@ -2082,8 +2082,8 @@ void RegionFlowComputation::WideBaselineMatchFeatures( !defined(CV_WRAPPER_3X) LOG(FATAL) << "Supported on only with OpenCV 3.0. " << "Use bazel build flag : --define CV_WRAPPER=3X"; -#else // (defined(__ANDROID__) || defined(__APPLE__) || - // defined(__EMSCRIPTEN__)) && !defined(CV_WRAPPER_3X) +#else // (defined(__ANDROID__) || defined(__APPLE__) || + // defined(__EMSCRIPTEN__)) && !defined(CV_WRAPPER_3X) results->clear(); const auto& frame1 = from_data_ptr->frame; diff --git a/third_party/com_google_googletest_9d580ea80592189e6d44fa35bcf9cdea8bf620d6.diff b/third_party/com_google_googletest_9d580ea80592189e6d44fa35bcf9cdea8bf620d6.diff new file mode 100644 index 000000000..496126e9c --- /dev/null +++ b/third_party/com_google_googletest_9d580ea80592189e6d44fa35bcf9cdea8bf620d6.diff @@ -0,0 +1,50 @@ +diff --git a/googletest/include/gtest/internal/gtest-internal.h b/googletest/include/gtest/internal/gtest-internal.h +index 7f1a5b00e..c36029ee1 100644 +--- a/googletest/include/gtest/internal/gtest-internal.h ++++ b/googletest/include/gtest/internal/gtest-internal.h +@@ -94,6 +94,12 @@ namespace proto2 { + class MessageLite; + } + ++namespace google { ++namespace protobuf { ++class MessageLite; ++} ++} ++ + namespace testing { + + // Forward declarations. +@@ -881,10 +887,15 @@ class GTEST_API_ Random { + typename std::remove_const::type>::type + + // IsAProtocolMessage::value is a compile-time bool constant that's +-// true if and only if T is type proto2::MessageLite or a subclass of it. ++// true if and only if T is type proto2::MessageLite or ++// google::protobuf::MessageLite or a subclass of one of them. + template + struct IsAProtocolMessage +- : public std::is_convertible {}; ++ : public std::integral_constant< ++ bool, ++ std::is_convertible::value || ++ std::is_convertible< ++ const T*, const ::google::protobuf::MessageLite*>::value> {}; + + // When the compiler sees expression IsContainerTest(0), if C is an + // STL-style container class, the first overload of IsContainerTest +diff --git a/googletest/test/gtest_unittest.cc b/googletest/test/gtest_unittest.cc +index 005a2d40d..631180e3d 100644 +--- a/googletest/test/gtest_unittest.cc ++++ b/googletest/test/gtest_unittest.cc +@@ -7115,6 +7115,10 @@ TEST(IsAProtocolMessageTest, ValueIsTrueWhenTypeIsAProtocolMessage) { + EXPECT_TRUE(IsAProtocolMessage<::proto2::MessageLite>::value); + } + ++TEST(IsAProtocolMessageTest, ValueIsTrueWhenTypeIsAnOpenSourceProtocolMessage) { ++ EXPECT_TRUE(IsAProtocolMessage<::google::protobuf::MessageLite>::value); ++} ++ + // Tests that IsAProtocolMessage::value is false when T is neither + // ::proto2::Message nor a sub-class of it. + TEST(IsAProtocolMessageTest, ValueIsFalseWhenTypeIsNotAProtocolMessage) {