diff --git a/.bazelversion b/.bazelversion index fae6e3d04..0062ac971 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -4.2.1 +5.0.0 diff --git a/.github/ISSUE_TEMPLATE/50-other-issues.md b/.github/ISSUE_TEMPLATE/50-other-issues.md index e51add916..c590f3f47 100644 --- a/.github/ISSUE_TEMPLATE/50-other-issues.md +++ b/.github/ISSUE_TEMPLATE/50-other-issues.md @@ -10,5 +10,3 @@ For questions on how to work with MediaPipe, or support for problems that are no If you are reporting a vulnerability, please use the [dedicated reporting process](https://github.com/google/mediapipe/security). -For high-level discussions about MediaPipe, please post to discuss@mediapipe.org, for questions about the development or internal workings of MediaPipe, or if you would like to know how to contribute to MediaPipe, please post to developers@mediapipe.org. - diff --git a/Dockerfile b/Dockerfile index 79f08cc92..9da695ef9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -56,7 +56,7 @@ RUN pip3 install tf_slim RUN ln -s /usr/bin/python3 /usr/bin/python # Install bazel -ARG BAZEL_VERSION=4.2.1 +ARG BAZEL_VERSION=5.0.0 RUN mkdir /bazel && \ wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/b\ azel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ diff --git a/README.md b/README.md index 2d9550d37..9c81095c5 100644 --- a/README.md +++ b/README.md @@ -136,8 +136,8 @@ run code search using ## Community -* [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome - MediaPipe related frameworks, libraries and software +* [Awesome MediaPipe](https://mediapipe.page.link/awesome-mediapipe) - A + curated list of awesome MediaPipe related frameworks, libraries and software * [Slack community](https://mediapipe.page.link/joinslack) for MediaPipe users * [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General community discussion around MediaPipe diff --git a/WORKSPACE b/WORKSPACE index 633169032..7bdc114ff 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -61,11 +61,12 @@ http_archive( sha256 = "de682ea824bfffba05b4e33b67431c247397d6175962534305136aa06f92e049", ) -# Google Benchmark library. +# Google Benchmark library v1.6.1 released on 2022-01-10. http_archive( name = "com_google_benchmark", - urls = ["https://github.com/google/benchmark/archive/main.zip"], - strip_prefix = "benchmark-main", + urls = ["https://github.com/google/benchmark/archive/refs/tags/v1.6.1.tar.gz"], + strip_prefix = "benchmark-1.6.1", + sha256 = "6132883bc8c9b0df5375b16ab520fac1a85dc9e4cf5be59480448ece74b278d4", build_file = "@//third_party:benchmark.BUILD", ) @@ -373,9 +374,9 @@ http_archive( ) # Tensorflow repo should always go after the other external dependencies. -# 2021-12-02 -_TENSORFLOW_GIT_COMMIT = "18a1dc0ba806dc023808531f0373d9ec068e64bf" -_TENSORFLOW_SHA256 = "85b90416f7a11339327777bccd634de00ca0de2cf334f5f0727edcb11ff9289a" +# 2022-02-15 +_TENSORFLOW_GIT_COMMIT = "a3419acc751dfc19caf4d34a1594e1f76810ec58" +_TENSORFLOW_SHA256 = "b95b2a83632d4055742ae1a2dcc96b45da6c12a339462dbc76c8bca505308e3a" http_archive( name = "org_tensorflow", urls = [ @@ -383,7 +384,6 @@ http_archive( ], patches = [ "@//third_party:org_tensorflow_compatibility_fixes.diff", - "@//third_party:org_tensorflow_objc_cxx17.diff", # Diff is generated with a script, don't update it manually. "@//third_party:org_tensorflow_custom_ops.diff", ], diff --git a/build_android_examples.sh b/build_android_examples.sh index 75ec54199..6dbdd6671 100644 --- a/build_android_examples.sh +++ b/build_android_examples.sh @@ -109,7 +109,7 @@ for app in ${apps}; do if [[ ${category} != "shoe" ]]; then bazel_flags_extended+=(--define ${category}=true) fi - bazel "${bazel_flags_extended[@]}" + bazelisk "${bazel_flags_extended[@]}" cp -f "${bin}" "${apk}" fi apks+=(${apk}) @@ -120,7 +120,7 @@ for app in ${apps}; do if [[ ${app_name} == "templatematchingcpu" ]]; then switch_to_opencv_4 fi - bazel "${bazel_flags[@]}" + bazelisk "${bazel_flags[@]}" cp -f "${bin}" "${apk}" if [[ ${app_name} == "templatematchingcpu" ]]; then switch_to_opencv_3 diff --git a/build_desktop_examples.sh b/build_desktop_examples.sh index 7ff8db29c..5bc687fc8 100644 --- a/build_desktop_examples.sh +++ b/build_desktop_examples.sh @@ -83,7 +83,7 @@ for app in ${apps}; do bazel_flags=("${default_bazel_flags[@]}") bazel_flags+=(${target}) - bazel "${bazel_flags[@]}" + bazelisk "${bazel_flags[@]}" cp -f "${bin_dir}/${app}/"*"_cpu" "${out_dir}" fi if [[ $build_only == false ]]; then diff --git a/build_ios_examples.sh b/build_ios_examples.sh index 93b97fc4e..e6a2271b2 100644 --- a/build_ios_examples.sh +++ b/build_ios_examples.sh @@ -71,7 +71,7 @@ for app in ${apps}; do bazel_flags+=(--linkopt=-s) fi - bazel "${bazel_flags[@]}" + bazelisk "${bazel_flags[@]}" cp -f "${bin_dir}/${app}/"*".ipa" "${out_dir}" fi done diff --git a/docs/framework_concepts/synchronization.md b/docs/framework_concepts/synchronization.md index 5482aeb76..be92130b3 100644 --- a/docs/framework_concepts/synchronization.md +++ b/docs/framework_concepts/synchronization.md @@ -169,7 +169,7 @@ behavior depending on resource constraints. [`CalculatorBase`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator_base.h [`DefaultInputStreamHandler`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/stream_handler/default_input_stream_handler.h -[`SyncSetInputStreamHandler`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/stream_handler/sync_set_input_stream_handler.h -[`ImmediateInputStreamHandler`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/stream_handler/immediate_input_stream_handler.h +[`SyncSetInputStreamHandler`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/stream_handler/sync_set_input_stream_handler.cc +[`ImmediateInputStreamHandler`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/stream_handler/immediate_input_stream_handler.cc [`CalculatorGraphConfig::max_queue_size`]: https://github.com/google/mediapipe/tree/master/mediapipe/framework/calculator.proto [`FlowLimiterCalculator`]: https://github.com/google/mediapipe/tree/master/mediapipe/calculators/core/flow_limiter_calculator.cc diff --git a/docs/getting_started/android.md b/docs/getting_started/android.md index 73e730679..b3f6c5df4 100644 --- a/docs/getting_started/android.md +++ b/docs/getting_started/android.md @@ -30,7 +30,7 @@ APIs (currently in alpha) that are now available in * Install MediaPipe following these [instructions](./install.md). * Setup Java Runtime. * Setup Android SDK release 30.0.0 and above. -* Setup Android NDK version 18 and above. +* Setup Android NDK version between 18 and 21. MediaPipe recommends setting up Android SDK and NDK via Android Studio (and see below for Android Studio setup). However, if you prefer using MediaPipe without diff --git a/docs/getting_started/android_archive_library.md b/docs/getting_started/android_archive_library.md index 4c526527c..a5752c6d5 100644 --- a/docs/getting_started/android_archive_library.md +++ b/docs/getting_started/android_archive_library.md @@ -48,6 +48,16 @@ each project. bazel build -c opt --strip=ALWAYS \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ --fat_apk_cpu=arm64-v8a,armeabi-v7a \ + --legacy_whole_archive=0 \ + --features=-legacy_whole_archive \ + --copt=-fvisibility=hidden \ + --copt=-ffunction-sections \ + --copt=-fdata-sections \ + --copt=-fstack-protector \ + --copt=-Oz \ + --copt=-fomit-frame-pointer \ + --copt=-DABSL_MIN_LOG_LEVEL=2 \ + --linkopt=-Wl,--gc-sections,--strip-all \ //path/to/the/aar/build/file:aar_name.aar ``` @@ -57,6 +67,16 @@ each project. bazel build -c opt --strip=ALWAYS \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ --fat_apk_cpu=arm64-v8a,armeabi-v7a \ + --legacy_whole_archive=0 \ + --features=-legacy_whole_archive \ + --copt=-fvisibility=hidden \ + --copt=-ffunction-sections \ + --copt=-fdata-sections \ + --copt=-fstack-protector \ + --copt=-Oz \ + --copt=-fomit-frame-pointer \ + --copt=-DABSL_MIN_LOG_LEVEL=2 \ + --linkopt=-Wl,--gc-sections,--strip-all \ //mediapipe/examples/android/src/java/com/google/mediapipe/apps/aar_example:mediapipe_face_detection.aar # It should print: diff --git a/docs/getting_started/install.md b/docs/getting_started/install.md index 430287aeb..0b23787f3 100644 --- a/docs/getting_started/install.md +++ b/docs/getting_started/install.md @@ -569,7 +569,7 @@ next section. Option 1. Follow [the official Bazel documentation](https://docs.bazel.build/versions/master/install-windows.html) - to install Bazel 4.2.1 or higher. + to install Bazel 5.0.0 or higher. Option 2. Follow the official [Bazel documentation](https://docs.bazel.build/versions/master/install-bazelisk.html) diff --git a/docs/getting_started/python_framework.md b/docs/getting_started/python_framework.md index d80db5ab9..33b71be54 100644 --- a/docs/getting_started/python_framework.md +++ b/docs/getting_started/python_framework.md @@ -126,6 +126,7 @@ following steps: } return packet.Get(); }); + } } // namespace mediapipe ``` diff --git a/docs/index.md b/docs/index.md index 86d6ddc5e..1532e10cc 100644 --- a/docs/index.md +++ b/docs/index.md @@ -136,8 +136,8 @@ run code search using ## Community -* [Awesome MediaPipe](https://mediapipe.org) - A curated list of awesome - MediaPipe related frameworks, libraries and software +* [Awesome MediaPipe](https://mediapipe.page.link/awesome-mediapipe) - A + curated list of awesome MediaPipe related frameworks, libraries and software * [Slack community](https://mediapipe.page.link/joinslack) for MediaPipe users * [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General community discussion around MediaPipe diff --git a/docs/solutions/face_detection.md b/docs/solutions/face_detection.md index 04d429987..4eccf17f5 100644 --- a/docs/solutions/face_detection.md +++ b/docs/solutions/face_detection.md @@ -26,7 +26,7 @@ MediaPipe Face Detection is an ultrafast face detection solution that comes with face detector tailored for mobile GPU inference. The detector's super-realtime performance enables it to be applied to any live viewfinder experience that requires an accurate facial region of interest as an input for other -task-specific models, such as 3D facial keypoint or geometry estimation (e.g., +task-specific models, such as 3D facial keypoint estimation (e.g., [MediaPipe Face Mesh](./face_mesh.md)), facial features or expression classification, and face region segmentation. BlazeFace uses a lightweight feature extraction network inspired by, but distinct from diff --git a/docs/solutions/face_mesh.md b/docs/solutions/face_mesh.md index 57bf4de5b..ec43fb4ef 100644 --- a/docs/solutions/face_mesh.md +++ b/docs/solutions/face_mesh.md @@ -20,34 +20,34 @@ nav_order: 2 ## Overview -MediaPipe Face Mesh is a face geometry solution that estimates 468 3D face -landmarks in real-time even on mobile devices. It employs machine learning (ML) -to infer the 3D surface geometry, requiring only a single camera input without -the need for a dedicated depth sensor. Utilizing lightweight model architectures -together with GPU acceleration throughout the pipeline, the solution delivers -real-time performance critical for live experiences. +MediaPipe Face Mesh is a solution that estimates 468 3D face landmarks in +real-time even on mobile devices. It employs machine learning (ML) to infer the +3D facial surface, requiring only a single camera input without the need for a +dedicated depth sensor. Utilizing lightweight model architectures together with +GPU acceleration throughout the pipeline, the solution delivers real-time +performance critical for live experiences. -Additionally, the solution is bundled with the Face Geometry module that bridges -the gap between the face landmark estimation and useful real-time augmented -reality (AR) applications. It establishes a metric 3D space and uses the face -landmark screen positions to estimate face geometry within that space. The face -geometry data consists of common 3D geometry primitives, including a face pose -transformation matrix and a triangular face mesh. Under the hood, a lightweight -statistical analysis method called +Additionally, the solution is bundled with the Face Transform module that +bridges the gap between the face landmark estimation and useful real-time +augmented reality (AR) applications. It establishes a metric 3D space and uses +the face landmark screen positions to estimate a face transform within that +space. The face transform data consists of common 3D primitives, including a +face pose transformation matrix and a triangular face mesh. Under the hood, a +lightweight statistical analysis method called [Procrustes Analysis](https://en.wikipedia.org/wiki/Procrustes_analysis) is employed to drive a robust, performant and portable logic. The analysis runs on CPU and has a minimal speed/memory footprint on top of the ML model inference. ![face_mesh_ar_effects.gif](../images/face_mesh_ar_effects.gif) | :-------------------------------------------------------------: | -*Fig 1. AR effects utilizing facial surface geometry.* | +*Fig 1. AR effects utilizing the 3D facial surface.* | ## ML Pipeline Our ML pipeline consists of two real-time deep neural network models that work together: A detector that operates on the full image and computes face locations and a 3D face landmark model that operates on those locations and predicts the -approximate surface geometry via regression. Having the face accurately cropped +approximate 3D surface via regression. Having the face accurately cropped drastically reduces the need for common data augmentations like affine transformations consisting of rotations, translation and scale changes. Instead it allows the network to dedicate most of its capacity towards coordinate @@ -55,8 +55,8 @@ prediction accuracy. In addition, in our pipeline the crops can also be generated based on the face landmarks identified in the previous frame, and only when the landmark model could no longer identify face presence is the face detector invoked to relocalize the face. This strategy is similar to that -employed in our [MediaPipe Hands](./hands.md) solution, which uses a palm detector -together with a hand landmark model. +employed in our [MediaPipe Hands](./hands.md) solution, which uses a palm +detector together with a hand landmark model. The pipeline is implemented as a MediaPipe [graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/face_mesh/face_mesh_mobile.pbtxt) @@ -128,7 +128,7 @@ about the model in this [paper](https://arxiv.org/abs/2006.10962). :---------------------------------------------------------------------------: | *Fig 3. Attention Mesh: Overview of model architecture.* | -## Face Geometry Module +## Face Transform Module The [Face Landmark Model](#face-landmark-model) performs a single-camera face landmark detection in the screen coordinate space: the X- and Y- coordinates are @@ -140,7 +140,7 @@ enable the full spectrum of augmented reality (AR) features like aligning a virtual 3D object with a detected face. The -[Face Geometry module](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry) +[Face Transform module](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry) moves away from the screen coordinate space towards a metric 3D space and provides necessary primitives to handle a detected face as a regular 3D object. By design, you'll be able to use a perspective camera to project the final 3D @@ -151,7 +151,7 @@ landmark positions are not changed. #### Metric 3D Space -The **Metric 3D space** established within the Face Geometry module is a +The **Metric 3D space** established within the Face Transform module is a right-handed orthonormal metric 3D coordinate space. Within the space, there is a **virtual perspective camera** located at the space origin and pointed in the negative direction of the Z-axis. In the current pipeline, it is assumed that @@ -184,11 +184,11 @@ functions: ### Components -#### Geometry Pipeline +#### Transform Pipeline -The **Geometry Pipeline** is a key component, which is responsible for -estimating face geometry objects within the Metric 3D space. On each frame, the -following steps are executed in the given order: +The **Transform Pipeline** is a key component, which is responsible for +estimating the face transform objects within the Metric 3D space. On each frame, +the following steps are executed in the given order: - Face landmark screen coordinates are converted into the Metric 3D space coordinates; @@ -199,12 +199,12 @@ following steps are executed in the given order: positions (XYZ), while both the vertex texture coordinates (UV) and the triangular topology are inherited from the canonical face model. -The geometry pipeline is implemented as a MediaPipe +The transform pipeline is implemented as a MediaPipe [calculator](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/geometry_pipeline_calculator.cc). -For your convenience, the face geometry pipeline calculator is bundled together -with corresponding metadata into a unified MediaPipe +For your convenience, this calculator is bundled together with corresponding +metadata into a unified MediaPipe [subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/face_geometry_from_landmarks.pbtxt). -The face geometry format is defined as a Protocol Buffer +The face transform format is defined as a Protocol Buffer [message](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/protos/face_geometry.proto). #### Effect Renderer @@ -227,7 +227,7 @@ The effect renderer is implemented as a MediaPipe | ![face_geometry_renderer.gif](../images/face_geometry_renderer.gif) | | :---------------------------------------------------------------------: | -| *Fig 5. An example of face effects rendered by the Face Geometry Effect Renderer.* | +| *Fig 5. An example of face effects rendered by the Face Transform Effect Renderer.* | ## Solution APIs diff --git a/docs/solutions/object_detection.md b/docs/solutions/object_detection.md index d7cc2cec1..c60e44921 100644 --- a/docs/solutions/object_detection.md +++ b/docs/solutions/object_detection.md @@ -116,7 +116,7 @@ on how to build MediaPipe examples. Note: The following runs TensorFlow inference on CPU. If you would like to run inference on GPU (Linux only), please follow - [TensorFlow CUDA Support and Setup on Linux Desktop](gpu.md#tensorflow-cuda-support-and-setup-on-linux-desktop) + [TensorFlow CUDA Support and Setup on Linux Desktop](../getting_started/gpu_support.md#tensorflow-cuda-support-and-setup-on-linux-desktop) instead. To build the TensorFlow CPU inference example on desktop, run: diff --git a/docs/solutions/objectron.md b/docs/solutions/objectron.md index 25259d678..23cf7c179 100644 --- a/docs/solutions/objectron.md +++ b/docs/solutions/objectron.md @@ -384,7 +384,7 @@ Supported configuration options: - + diff --git a/docs/solutions/pose.md b/docs/solutions/pose.md index 2ec1b4f4e..1c9d6a669 100644 --- a/docs/solutions/pose.md +++ b/docs/solutions/pose.md @@ -359,7 +359,7 @@ Supported configuration options: - + diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 61d907dab..3cb0dd018 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -117,6 +117,7 @@ mediapipe_proto_library( "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/framework/formats:classification_proto", + "//mediapipe/framework/formats:landmark_proto", ], ) @@ -309,8 +310,8 @@ cc_library( ) cc_library( - name = "concatenate_normalized_landmark_list_calculator", - srcs = ["concatenate_normalized_landmark_list_calculator.cc"], + name = "concatenate_proto_list_calculator", + srcs = ["concatenate_proto_list_calculator.cc"], visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator_cc_proto", @@ -324,10 +325,10 @@ cc_library( ) cc_test( - name = "concatenate_normalized_landmark_list_calculator_test", - srcs = ["concatenate_normalized_landmark_list_calculator_test.cc"], + name = "concatenate_proto_list_calculator_test", + srcs = ["concatenate_proto_list_calculator_test.cc"], deps = [ - ":concatenate_normalized_landmark_list_calculator", + ":concatenate_proto_list_calculator", ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", @@ -964,8 +965,8 @@ cc_test( ) cc_library( - name = "split_landmarks_calculator", - srcs = ["split_landmarks_calculator.cc"], + name = "split_proto_list_calculator", + srcs = ["split_proto_list_calculator.cc"], visibility = ["//visibility:public"], deps = [ ":split_vector_calculator_cc_proto", @@ -979,10 +980,10 @@ cc_library( ) cc_test( - name = "split_landmarks_calculator_test", - srcs = ["split_landmarks_calculator_test.cc"], + name = "split_proto_list_calculator_test", + srcs = ["split_proto_list_calculator_test.cc"], deps = [ - ":split_landmarks_calculator", + ":split_proto_list_calculator", ":split_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", @@ -1195,6 +1196,7 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", diff --git a/mediapipe/calculators/core/concatenate_normalized_landmark_list_calculator.cc b/mediapipe/calculators/core/concatenate_normalized_landmark_list_calculator.cc deleted file mode 100644 index f0a4043a7..000000000 --- a/mediapipe/calculators/core/concatenate_normalized_landmark_list_calculator.cc +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_NORMALIZED_LIST_CALCULATOR_H_ // NOLINT -#define MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_NORMALIZED_LIST_CALCULATOR_H_ // NOLINT - -#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h" -#include "mediapipe/framework/api2/node.h" -#include "mediapipe/framework/calculator_framework.h" -#include "mediapipe/framework/formats/landmark.pb.h" -#include "mediapipe/framework/port/canonical_errors.h" -#include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/framework/port/status.h" - -namespace mediapipe { -namespace api2 { - -// Concatenates several NormalizedLandmarkList protos following stream index -// order. This class assumes that every input stream contains a -// NormalizedLandmarkList proto object. -class ConcatenateNormalizedLandmarkListCalculator : public Node { - public: - static constexpr Input::Multiple kIn{""}; - static constexpr Output kOut{""}; - - MEDIAPIPE_NODE_CONTRACT(kIn, kOut); - - static absl::Status UpdateContract(CalculatorContract* cc) { - RET_CHECK_GE(kIn(cc).Count(), 1); - return absl::OkStatus(); - } - - absl::Status Open(CalculatorContext* cc) override { - only_emit_if_all_present_ = - cc->Options<::mediapipe::ConcatenateVectorCalculatorOptions>() - .only_emit_if_all_present(); - return absl::OkStatus(); - } - - absl::Status Process(CalculatorContext* cc) override { - if (only_emit_if_all_present_) { - for (const auto& input : kIn(cc)) { - if (input.IsEmpty()) return absl::OkStatus(); - } - } - - NormalizedLandmarkList output; - for (const auto& input : kIn(cc)) { - if (input.IsEmpty()) continue; - const NormalizedLandmarkList& list = *input; - for (int j = 0; j < list.landmark_size(); ++j) { - *output.add_landmark() = list.landmark(j); - } - } - kOut(cc).Send(std::move(output)); - return absl::OkStatus(); - } - - private: - bool only_emit_if_all_present_; -}; -MEDIAPIPE_REGISTER_NODE(ConcatenateNormalizedLandmarkListCalculator); - -} // namespace api2 -} // namespace mediapipe - -// NOLINTNEXTLINE -#endif // MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_NORMALIZED_LIST_CALCULATOR_H_ diff --git a/mediapipe/calculators/core/concatenate_proto_list_calculator.cc b/mediapipe/calculators/core/concatenate_proto_list_calculator.cc new file mode 100644 index 000000000..9dd0dfd99 --- /dev/null +++ b/mediapipe/calculators/core/concatenate_proto_list_calculator.cc @@ -0,0 +1,118 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_PROTO_LIST_CALCULATOR_H_ // NOLINT +#define MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_PROTO_LIST_CALCULATOR_H_ // NOLINT + +#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { +namespace api2 { + +// Concatenate several input packets of ListType with a repeated field of +// ItemType into a single output packet of ListType following stream index +// order. +template +class ConcatenateListsCalculator : public Node { + public: + static constexpr typename Input::Multiple kIn{""}; + static constexpr Output kOut{""}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + + static absl::Status UpdateContract(CalculatorContract* cc) { + RET_CHECK_GE(kIn(cc).Count(), 1); + return absl::OkStatus(); + } + + absl::Status Open(CalculatorContext* cc) override { + only_emit_if_all_present_ = + cc->Options<::mediapipe::ConcatenateVectorCalculatorOptions>() + .only_emit_if_all_present(); + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + if (only_emit_if_all_present_) { + for (const auto& input : kIn(cc)) { + if (input.IsEmpty()) return absl::OkStatus(); + } + } + + ListType output; + for (const auto& input : kIn(cc)) { + if (input.IsEmpty()) continue; + const ListType& list = *input; + for (int j = 0; j < ListSize(list); ++j) { + *AddItem(output) = GetItem(list, j); + } + } + kOut(cc).Send(std::move(output)); + return absl::OkStatus(); + } + + protected: + virtual int ListSize(const ListType& list) const = 0; + virtual const ItemType GetItem(const ListType& list, int idx) const = 0; + virtual ItemType* AddItem(ListType& list) const = 0; + + private: + bool only_emit_if_all_present_; +}; + +// TODO: Move calculators to separate *.cc files + +class ConcatenateNormalizedLandmarkListCalculator + : public ConcatenateListsCalculator { + protected: + int ListSize(const NormalizedLandmarkList& list) const override { + return list.landmark_size(); + } + const NormalizedLandmark GetItem(const NormalizedLandmarkList& list, + int idx) const override { + return list.landmark(idx); + } + NormalizedLandmark* AddItem(NormalizedLandmarkList& list) const override { + return list.add_landmark(); + } +}; +MEDIAPIPE_REGISTER_NODE(ConcatenateNormalizedLandmarkListCalculator); + +class ConcatenateLandmarkListCalculator + : public ConcatenateListsCalculator { + protected: + int ListSize(const LandmarkList& list) const override { + return list.landmark_size(); + } + const Landmark GetItem(const LandmarkList& list, int idx) const override { + return list.landmark(idx); + } + Landmark* AddItem(LandmarkList& list) const override { + return list.add_landmark(); + } +}; +MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkListCalculator); + +} // namespace api2 +} // namespace mediapipe + +// NOLINTNEXTLINE +#endif // MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_PROTO_LIST_CALCULATOR_H_ diff --git a/mediapipe/calculators/core/concatenate_normalized_landmark_list_calculator_test.cc b/mediapipe/calculators/core/concatenate_proto_list_calculator_test.cc similarity index 100% rename from mediapipe/calculators/core/concatenate_normalized_landmark_list_calculator_test.cc rename to mediapipe/calculators/core/concatenate_proto_list_calculator_test.cc diff --git a/mediapipe/calculators/core/constant_side_packet_calculator.cc b/mediapipe/calculators/core/constant_side_packet_calculator.cc index ff328377e..b2eb3ef34 100644 --- a/mediapipe/calculators/core/constant_side_packet_calculator.cc +++ b/mediapipe/calculators/core/constant_side_packet_calculator.cc @@ -18,6 +18,7 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/ret_check.h" @@ -79,6 +80,8 @@ class ConstantSidePacketCalculator : public CalculatorBase { packet.Set(); } else if (packet_options.has_classification_list_value()) { packet.Set(); + } else if (packet_options.has_landmark_list_value()) { + packet.Set(); } else { return absl::InvalidArgumentError( "None of supported values were specified in options."); @@ -108,6 +111,9 @@ class ConstantSidePacketCalculator : public CalculatorBase { } else if (packet_options.has_classification_list_value()) { packet.Set(MakePacket( packet_options.classification_list_value())); + } else if (packet_options.has_landmark_list_value()) { + packet.Set( + MakePacket(packet_options.landmark_list_value())); } else { return absl::InvalidArgumentError( "None of supported values were specified in options."); diff --git a/mediapipe/calculators/core/constant_side_packet_calculator.proto b/mediapipe/calculators/core/constant_side_packet_calculator.proto index 57f5dc545..bc192ffb4 100644 --- a/mediapipe/calculators/core/constant_side_packet_calculator.proto +++ b/mediapipe/calculators/core/constant_side_packet_calculator.proto @@ -18,6 +18,7 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/formats/classification.proto"; +import "mediapipe/framework/formats/landmark.proto"; option objc_class_prefix = "MediaPipe"; @@ -34,6 +35,7 @@ message ConstantSidePacketCalculatorOptions { string string_value = 4; uint64 uint64_value = 5; ClassificationList classification_list_value = 6; + LandmarkList landmark_list_value = 7; } } diff --git a/mediapipe/calculators/core/graph_profile_calculator.cc b/mediapipe/calculators/core/graph_profile_calculator.cc index 9b9aa3bb7..f973efbb7 100644 --- a/mediapipe/calculators/core/graph_profile_calculator.cc +++ b/mediapipe/calculators/core/graph_profile_calculator.cc @@ -29,6 +29,11 @@ namespace api2 { // This calculator periodically copies the GraphProfile from // mediapipe::GraphProfiler::CaptureProfile to the "PROFILE" output stream. // +// Similarly to the log files saved by GraphProfiler::WriteProfile when trace +// logging is enabled, the first captured profile contains the full +// canonicalized graph config and, if tracing is enabled, calculator names in +// graph traces. Subsequent profiles omit this information. +// // Example config: // node { // calculator: "GraphProfileCalculator" @@ -50,11 +55,14 @@ class GraphProfileCalculator : public Node { absl::Status Process(CalculatorContext* cc) final { auto options = cc->Options<::mediapipe::GraphProfileCalculatorOptions>(); - if (prev_profile_ts_ == Timestamp::Unset() || + bool first_profile = prev_profile_ts_ == Timestamp::Unset(); + if (first_profile || cc->InputTimestamp() - prev_profile_ts_ >= options.profile_interval()) { prev_profile_ts_ = cc->InputTimestamp(); GraphProfile result; - MP_RETURN_IF_ERROR(cc->GetProfilingContext()->CaptureProfile(&result)); + MP_RETURN_IF_ERROR(cc->GetProfilingContext()->CaptureProfile( + &result, first_profile ? PopulateGraphConfig::kFull + : PopulateGraphConfig::kNo)); kProfileOut(cc).Send(result); } return absl::OkStatus(); diff --git a/mediapipe/calculators/core/graph_profile_calculator_test.cc b/mediapipe/calculators/core/graph_profile_calculator_test.cc index 5d8f17404..6696b802f 100644 --- a/mediapipe/calculators/core/graph_profile_calculator_test.cc +++ b/mediapipe/calculators/core/graph_profile_calculator_test.cc @@ -202,6 +202,8 @@ TEST_F(GraphProfileCalculatorTest, GraphProfile) { } })pb"); + ASSERT_EQ(output_packets.size(), 2); + EXPECT_TRUE(output_packets[0].Get().has_config()); EXPECT_THAT(output_packets[1].Get(), mediapipe::EqualsProto(expected_profile)); } diff --git a/mediapipe/calculators/core/quantize_float_vector_calculator.cc b/mediapipe/calculators/core/quantize_float_vector_calculator.cc index e1df66c1a..16a3171dc 100644 --- a/mediapipe/calculators/core/quantize_float_vector_calculator.cc +++ b/mediapipe/calculators/core/quantize_float_vector_calculator.cc @@ -23,8 +23,8 @@ #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/status.h" -// Quantizes a vector of floats to a std::string so that each float becomes a -// byte in the [0, 255] range. Any value above max_quantized_value or below +// Quantizes a vector of floats to a string so that each float becomes a byte +// in the [0, 255] range. Any value above max_quantized_value or below // min_quantized_value will be saturated to '/xFF' or '/0'. // // Example config: diff --git a/mediapipe/calculators/core/split_landmarks_calculator.cc b/mediapipe/calculators/core/split_proto_list_calculator.cc similarity index 64% rename from mediapipe/calculators/core/split_landmarks_calculator.cc rename to mediapipe/calculators/core/split_proto_list_calculator.cc index 5bc876bf6..df6156e7f 100644 --- a/mediapipe/calculators/core/split_landmarks_calculator.cc +++ b/mediapipe/calculators/core/split_proto_list_calculator.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef MEDIAPIPE_CALCULATORS_CORE_SPLIT_LANDMARKS_CALCULATOR_H_ // NOLINT -#define MEDIAPIPE_CALCULATORS_CORE_SPLIT_LANDMARKS_CALCULATOR_H_ // NOLINT +#ifndef MEDIAPIPE_CALCULATORS_CORE_SPLIT_PROTO_LIST_CALCULATOR_H_ // NOLINT +#define MEDIAPIPE_CALCULATORS_CORE_SPLIT_PROTO_LIST_CALCULATOR_H_ // NOLINT #include "mediapipe/calculators/core/split_vector_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -24,30 +24,30 @@ namespace mediapipe { -// Splits an input packet with LandmarkListType into -// multiple LandmarkListType output packets using the [begin, end) ranges +// Splits an input packet of ListType with a repeated field of ItemType +// into multiple ListType output packets using the [begin, end) ranges // specified in SplitVectorCalculatorOptions. If the option "element_only" is // set to true, all ranges should be of size 1 and all outputs will be elements -// of type LandmarkType. If "element_only" is false, ranges can be -// non-zero in size and all outputs will be of type LandmarkListType. +// of type ItemType. If "element_only" is false, ranges can be +// non-zero in size and all outputs will be of type ListType. // If the option "combine_outputs" is set to true, only one output stream can be // specified and all ranges of elements will be combined into one -// LandmarkListType. -template -class SplitLandmarksCalculator : public CalculatorBase { +// ListType. +template +class SplitListsCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().NumEntries() == 1); RET_CHECK(cc->Outputs().NumEntries() != 0); - cc->Inputs().Index(0).Set(); + cc->Inputs().Index(0).Set(); const auto& options = cc->Options<::mediapipe::SplitVectorCalculatorOptions>(); if (options.combine_outputs()) { RET_CHECK_EQ(cc->Outputs().NumEntries(), 1); - cc->Outputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); for (int i = 0; i < options.ranges_size() - 1; ++i) { for (int j = i + 1; j < options.ranges_size(); ++j) { const auto& range_0 = options.ranges(i); @@ -82,9 +82,9 @@ class SplitLandmarksCalculator : public CalculatorBase { return absl::InvalidArgumentError( "Since element_only is true, all ranges should be of size 1."); } - cc->Outputs().Index(i).Set(); + cc->Outputs().Index(i).Set(); } else { - cc->Outputs().Index(i).Set(); + cc->Outputs().Index(i).Set(); } } } @@ -111,39 +111,38 @@ class SplitLandmarksCalculator : public CalculatorBase { } absl::Status Process(CalculatorContext* cc) override { - const LandmarkListType& input = - cc->Inputs().Index(0).Get(); - RET_CHECK_GE(input.landmark_size(), max_range_end_) - << "Max range end " << max_range_end_ << " exceeds landmarks size " - << input.landmark_size(); + const ListType& input = cc->Inputs().Index(0).Get(); + RET_CHECK_GE(ListSize(input), max_range_end_) + << "Max range end " << max_range_end_ << " exceeds list size " + << ListSize(input); if (combine_outputs_) { - LandmarkListType output; + ListType output; for (int i = 0; i < ranges_.size(); ++i) { for (int j = ranges_[i].first; j < ranges_[i].second; ++j) { - const LandmarkType& input_landmark = input.landmark(j); - *output.add_landmark() = input_landmark; + const ItemType& input_item = GetItem(input, j); + *AddItem(output) = input_item; } } - RET_CHECK_EQ(output.landmark_size(), total_elements_); + RET_CHECK_EQ(ListSize(output), total_elements_); cc->Outputs().Index(0).AddPacket( - MakePacket(output).At(cc->InputTimestamp())); + MakePacket(output).At(cc->InputTimestamp())); } else { if (element_only_) { for (int i = 0; i < ranges_.size(); ++i) { cc->Outputs().Index(i).AddPacket( - MakePacket(input.landmark(ranges_[i].first)) + MakePacket(GetItem(input, ranges_[i].first)) .At(cc->InputTimestamp())); } } else { for (int i = 0; i < ranges_.size(); ++i) { - LandmarkListType output; + ListType output; for (int j = ranges_[i].first; j < ranges_[i].second; ++j) { - const LandmarkType& input_landmark = input.landmark(j); - *output.add_landmark() = input_landmark; + const ItemType& input_item = GetItem(input, j); + *AddItem(output) = input_item; } cc->Outputs().Index(i).AddPacket( - MakePacket(output).At(cc->InputTimestamp())); + MakePacket(output).At(cc->InputTimestamp())); } } } @@ -151,6 +150,11 @@ class SplitLandmarksCalculator : public CalculatorBase { return absl::OkStatus(); } + protected: + virtual int ListSize(const ListType& list) const = 0; + virtual const ItemType GetItem(const ListType& list, int idx) const = 0; + virtual ItemType* AddItem(ListType& list) const = 0; + private: std::vector> ranges_; int32 max_range_end_ = -1; @@ -159,15 +163,40 @@ class SplitLandmarksCalculator : public CalculatorBase { bool combine_outputs_ = false; }; -typedef SplitLandmarksCalculator - SplitNormalizedLandmarkListCalculator; +// TODO: Move calculators to separate *.cc files + +class SplitNormalizedLandmarkListCalculator + : public SplitListsCalculator { + protected: + int ListSize(const NormalizedLandmarkList& list) const override { + return list.landmark_size(); + } + const NormalizedLandmark GetItem(const NormalizedLandmarkList& list, + int idx) const override { + return list.landmark(idx); + } + NormalizedLandmark* AddItem(NormalizedLandmarkList& list) const override { + return list.add_landmark(); + } +}; REGISTER_CALCULATOR(SplitNormalizedLandmarkListCalculator); -typedef SplitLandmarksCalculator - SplitLandmarkListCalculator; +class SplitLandmarkListCalculator + : public SplitListsCalculator { + protected: + int ListSize(const LandmarkList& list) const override { + return list.landmark_size(); + } + const Landmark GetItem(const LandmarkList& list, int idx) const override { + return list.landmark(idx); + } + Landmark* AddItem(LandmarkList& list) const override { + return list.add_landmark(); + } +}; REGISTER_CALCULATOR(SplitLandmarkListCalculator); } // namespace mediapipe // NOLINTNEXTLINE -#endif // MEDIAPIPE_CALCULATORS_CORE_SPLIT_LANDMARKS_CALCULATOR_H_ +#endif // MEDIAPIPE_CALCULATORS_CORE_SPLIT_PROTO_LIST_CALCULATOR_H_ diff --git a/mediapipe/calculators/core/split_landmarks_calculator_test.cc b/mediapipe/calculators/core/split_proto_list_calculator_test.cc similarity index 100% rename from mediapipe/calculators/core/split_landmarks_calculator_test.cc rename to mediapipe/calculators/core/split_proto_list_calculator_test.cc diff --git a/mediapipe/calculators/core/string_to_int_calculator.cc b/mediapipe/calculators/core/string_to_int_calculator.cc index 13a9a29e0..ecd55afb6 100644 --- a/mediapipe/calculators/core/string_to_int_calculator.cc +++ b/mediapipe/calculators/core/string_to_int_calculator.cc @@ -24,7 +24,7 @@ namespace mediapipe { -// Calculator that converts a std::string into an integer type, or fails if the +// Calculator that converts a string into an integer type, or fails if the // conversion is not possible. // // Example config: @@ -47,7 +47,7 @@ class StringToIntCalculatorTemplate : public CalculatorBase { if (!absl::SimpleAtoi(cc->InputSidePackets().Index(0).Get(), &number)) { return absl::InvalidArgumentError( - "The std::string could not be parsed as an integer."); + "The string could not be parsed as an integer."); } cc->OutputSidePackets().Index(0).Set(MakePacket(number)); return absl::OkStatus(); diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 0bbfadd05..5428f98fd 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -239,10 +239,13 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":image_transformation_calculator_cc_proto", + "//mediapipe/framework:packet", + "//mediapipe/framework:timestamp", "//mediapipe/gpu:scale_mode_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:ret_check", diff --git a/mediapipe/calculators/image/image_file_properties_calculator.cc b/mediapipe/calculators/image/image_file_properties_calculator.cc index 9c6d8caca..97478f9f8 100644 --- a/mediapipe/calculators/image/image_file_properties_calculator.cc +++ b/mediapipe/calculators/image/image_file_properties_calculator.cc @@ -105,7 +105,7 @@ absl::StatusOr GetImageFileProperites( } // namespace // Calculator to extract EXIF information from an image file. The input is -// a std::string containing raw byte data from a file, and the output is an +// a string containing raw byte data from a file, and the output is an // ImageFileProperties proto object with the relevant fields filled in. // The calculator accepts the input as a stream or a side packet, and can output // the result as a stream or a side packet. The calculator checks that if an diff --git a/mediapipe/calculators/image/image_transformation_calculator.cc b/mediapipe/calculators/image/image_transformation_calculator.cc index f017eba79..bc7fd8df7 100644 --- a/mediapipe/calculators/image/image_transformation_calculator.cc +++ b/mediapipe/calculators/image/image_transformation_calculator.cc @@ -16,10 +16,13 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/video_stream_header.h" +#include "mediapipe/framework/packet.h" #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/timestamp.h" #include "mediapipe/gpu/scale_mode.pb.h" #if !MEDIAPIPE_DISABLE_GPU @@ -52,6 +55,7 @@ namespace mediapipe { namespace { constexpr char kImageFrameTag[] = "IMAGE"; constexpr char kGpuBufferTag[] = "IMAGE_GPU"; +constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM"; int RotationModeToDegrees(mediapipe::RotationMode_Mode rotation) { switch (rotation) { @@ -122,6 +126,12 @@ mediapipe::ScaleMode_Mode ParseScaleMode( // provided, it overrides the FLIP_VERTICALLY input side packet and/or // corresponding field in the calculator options. // +// VIDEO_PRESTREAM (optional): VideoHeader for the input ImageFrames, if +// rotating or scaling the frames, the header width and height will be updated +// appropriately. Note the header is updated only based on dimensions and +// rotations specified as side packets or options, input_stream +// transformations will not update the header. +// // Output: // One of the following tags: // IMAGE - ImageFrame representing the output image. @@ -242,6 +252,21 @@ absl::Status ImageTransformationCalculator::GetContract( cc->Inputs().Tag("FLIP_VERTICALLY").Set(); } + RET_CHECK(cc->Inputs().HasTag(kVideoPrestreamTag) == + cc->Outputs().HasTag(kVideoPrestreamTag)) + << "If VIDEO_PRESTREAM is provided, it must be provided both as an " + "inputs and output stream."; + if (cc->Inputs().HasTag(kVideoPrestreamTag)) { + RET_CHECK(!(cc->Inputs().HasTag("OUTPUT_DIMENSIONS") || + cc->Inputs().HasTag("ROTATION_DEGREES"))) + << "If specifying VIDEO_PRESTREAM, the transformations that affect the " + "dimensions of the frames (OUTPUT_DIMENSIONS and ROTATION_DEGREES) " + "need to be constant for every frame, meaning they can only be " + "provided in the calculator options or side packets."; + cc->Inputs().Tag(kVideoPrestreamTag).Set(); + cc->Outputs().Tag(kVideoPrestreamTag).Set(); + } + if (cc->InputSidePackets().HasTag("OUTPUT_DIMENSIONS")) { cc->InputSidePackets().Tag("OUTPUT_DIMENSIONS").Set(); } @@ -326,6 +351,24 @@ absl::Status ImageTransformationCalculator::Open(CalculatorContext* cc) { } absl::Status ImageTransformationCalculator::Process(CalculatorContext* cc) { + // First update the video header if it is given, based on the rotation and + // dimensions specified as side packets or options. This will only be done + // once, so streaming transformation changes will not be reflected in + // the header. + if (cc->Inputs().HasTag(kVideoPrestreamTag) && + !cc->Inputs().Tag(kVideoPrestreamTag).IsEmpty() && + cc->Outputs().HasTag(kVideoPrestreamTag)) { + mediapipe::VideoHeader header = + cc->Inputs().Tag(kVideoPrestreamTag).Get(); + // Update the header's width and height if needed. + ComputeOutputDimensions(header.width, header.height, &header.width, + &header.height); + cc->Outputs() + .Tag(kVideoPrestreamTag) + .AddPacket(mediapipe::MakePacket(header).At( + mediapipe::Timestamp::PreStream())); + } + // Override values if specified so. if (cc->Inputs().HasTag("ROTATION_DEGREES") && !cc->Inputs().Tag("ROTATION_DEGREES").IsEmpty()) { diff --git a/mediapipe/calculators/image/opencv_encoded_image_to_image_frame_calculator.cc b/mediapipe/calculators/image/opencv_encoded_image_to_image_frame_calculator.cc index 21bc587f3..8c909f613 100644 --- a/mediapipe/calculators/image/opencv_encoded_image_to_image_frame_calculator.cc +++ b/mediapipe/calculators/image/opencv_encoded_image_to_image_frame_calculator.cc @@ -22,9 +22,9 @@ namespace mediapipe { -// Takes in an encoded image std::string, decodes it by OpenCV, and converts to -// an ImageFrame. Note that this calculator only supports grayscale and RGB -// images for now. +// Takes in an encoded image string, decodes it by OpenCV, and converts to an +// ImageFrame. Note that this calculator only supports grayscale and RGB images +// for now. // // Example config: // node { diff --git a/mediapipe/calculators/image/opencv_put_text_calculator.cc b/mediapipe/calculators/image/opencv_put_text_calculator.cc index 82a4b3a53..241af58b4 100644 --- a/mediapipe/calculators/image/opencv_put_text_calculator.cc +++ b/mediapipe/calculators/image/opencv_put_text_calculator.cc @@ -20,8 +20,8 @@ namespace mediapipe { -// Takes in a std::string, draws the text std::string by cv::putText(), and -// outputs an ImageFrame. +// Takes in a string, draws the text string by cv::putText(), and outputs an +// ImageFrame. // // Example config: // node { diff --git a/mediapipe/calculators/image/scale_image_calculator.cc b/mediapipe/calculators/image/scale_image_calculator.cc index 0669f5322..f6596b3fd 100644 --- a/mediapipe/calculators/image/scale_image_calculator.cc +++ b/mediapipe/calculators/image/scale_image_calculator.cc @@ -553,7 +553,6 @@ absl::Status ScaleImageCalculator::Process(CalculatorContext* cc) { } } - cc->GetCounter("Inputs")->Increment(); const ImageFrame* image_frame; ImageFrame converted_image_frame; if (input_format_ == ImageFormat::YCBCR420P) { diff --git a/mediapipe/calculators/image/segmentation_smoothing_calculator.cc b/mediapipe/calculators/image/segmentation_smoothing_calculator.cc index db339b754..62d3b0d28 100644 --- a/mediapipe/calculators/image/segmentation_smoothing_calculator.cc +++ b/mediapipe/calculators/image/segmentation_smoothing_calculator.cc @@ -183,22 +183,22 @@ absl::Status SegmentationSmoothingCalculator::Close(CalculatorContext* cc) { absl::Status SegmentationSmoothingCalculator::RenderCpu(CalculatorContext* cc) { // Setup source images. const auto& current_frame = cc->Inputs().Tag(kCurrentMaskTag).Get(); - const cv::Mat current_mat = mediapipe::formats::MatView(¤t_frame); - RET_CHECK_EQ(current_mat.type(), CV_32FC1) + auto current_mat = mediapipe::formats::MatView(¤t_frame); + RET_CHECK_EQ(current_mat->type(), CV_32FC1) << "Only 1-channel float input image is supported."; const auto& previous_frame = cc->Inputs().Tag(kPreviousMaskTag).Get(); - const cv::Mat previous_mat = mediapipe::formats::MatView(&previous_frame); - RET_CHECK_EQ(previous_mat.type(), current_mat.type()) - << "Warning: mixing input format types: " << previous_mat.type() - << " != " << previous_mat.type(); + auto previous_mat = mediapipe::formats::MatView(&previous_frame); + RET_CHECK_EQ(previous_mat->type(), current_mat->type()) + << "Warning: mixing input format types: " << previous_mat->type() + << " != " << previous_mat->type(); - RET_CHECK_EQ(current_mat.rows, previous_mat.rows); - RET_CHECK_EQ(current_mat.cols, previous_mat.cols); + RET_CHECK_EQ(current_mat->rows, previous_mat->rows); + RET_CHECK_EQ(current_mat->cols, previous_mat->cols); // Setup destination image. auto output_frame = std::make_shared( - current_frame.image_format(), current_mat.cols, current_mat.rows); + current_frame.image_format(), current_mat->cols, current_mat->rows); cv::Mat output_mat = mediapipe::formats::MatView(output_frame.get()); output_mat.setTo(cv::Scalar(0)); @@ -233,8 +233,8 @@ absl::Status SegmentationSmoothingCalculator::RenderCpu(CalculatorContext* cc) { // Write directly to the first channel of output. for (int i = 0; i < output_mat.rows; ++i) { float* out_ptr = output_mat.ptr(i); - const float* curr_ptr = current_mat.ptr(i); - const float* prev_ptr = previous_mat.ptr(i); + const float* curr_ptr = current_mat->ptr(i); + const float* prev_ptr = previous_mat->ptr(i); for (int j = 0; j < output_mat.cols; ++j) { const float new_mask_value = curr_ptr[j]; const float prev_mask_value = prev_ptr[j]; diff --git a/mediapipe/calculators/image/segmentation_smoothing_calculator_test.cc b/mediapipe/calculators/image/segmentation_smoothing_calculator_test.cc index 100d7de8a..eeb812cb7 100644 --- a/mediapipe/calculators/image/segmentation_smoothing_calculator_test.cc +++ b/mediapipe/calculators/image/segmentation_smoothing_calculator_test.cc @@ -116,8 +116,8 @@ void RunGraph(Packet curr_packet, Packet prev_packet, bool use_gpu, float ratio, ASSERT_EQ(1, output_packets.size()); Image result_image = output_packets[0].Get(); - cv::Mat result_mat = formats::MatView(&result_image); - result_mat.copyTo(*result); + auto result_mat = formats::MatView(&result_image); + result_mat->copyTo(*result); // Fully close graph at end, otherwise calculator+Images are destroyed // after calling WaitUntilDone(). @@ -135,10 +135,10 @@ void RunTest(bool use_gpu, float mix_ratio, cv::Mat& test_result) { Packet curr_packet = MakePacket(std::make_unique( ImageFormat::VEC32F1, curr_mat.size().width, curr_mat.size().height)); - curr_mat.copyTo(formats::MatView(&(curr_packet.Get()))); + curr_mat.copyTo(*formats::MatView(&(curr_packet.Get()))); Packet prev_packet = MakePacket(std::make_unique( ImageFormat::VEC32F1, prev_mat.size().width, prev_mat.size().height)); - prev_mat.copyTo(formats::MatView(&(prev_packet.Get()))); + prev_mat.copyTo(*formats::MatView(&(prev_packet.Get()))); cv::Mat result; RunGraph(curr_packet, prev_packet, use_gpu, mix_ratio, &result); diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 72c2f5181..d41fa2c63 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -84,14 +84,15 @@ cc_library( tags = ["nomac"], # config problem with cpuinfo via TF deps = [ "inference_calculator_interface", + "//mediapipe/framework/deps:file_path", "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gpu_buffer", + "//mediapipe/util/tflite:config", "//mediapipe/util/tflite:tflite_gpu_runner", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", - "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", - "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", - "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", ], alwayslink = 1, ) @@ -154,7 +155,7 @@ cc_library( cc_library( name = "inference_calculator_gl_if_compute_shader_available", - deps = select({ + deps = selects.with_or({ ":compute_shader_unavailable": [], "//conditions:default": [":inference_calculator_gl"], }), @@ -303,7 +304,7 @@ cc_library( "//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework/port:ret_check", - ] + select({ + ] + selects.with_or({ ":compute_shader_unavailable": [], "//conditions:default": [":tensors_to_detections_calculator_gpu_deps"], }), @@ -560,7 +561,7 @@ cc_library( cc_library( name = "image_to_tensor_calculator_gpu_deps", - deps = select({ + deps = selects.with_or({ "//mediapipe:android": [ ":image_to_tensor_converter_gl_buffer", "//mediapipe/gpu:gl_calculator_helper", @@ -684,7 +685,7 @@ cc_library( name = "image_to_tensor_converter_gl_buffer", srcs = ["image_to_tensor_converter_gl_buffer.cc"], hdrs = ["image_to_tensor_converter_gl_buffer.h"], - deps = ["//mediapipe/framework:port"] + select({ + deps = ["//mediapipe/framework:port"] + selects.with_or({ "//mediapipe:apple": [], "//conditions:default": [ ":image_to_tensor_converter", diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc index b579f0474..9900610e5 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc @@ -49,7 +49,6 @@ #include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h" #include "mediapipe/gpu/gl_calculator_helper.h" #endif // MEDIAPIPE_METAL_ENABLED - #endif // !MEDIAPIPE_DISABLE_GPU namespace mediapipe { @@ -142,11 +141,24 @@ class ImageToTensorCalculator : public Node { const auto& options = cc->Options(); - RET_CHECK(options.has_output_tensor_float_range()) + RET_CHECK(options.has_output_tensor_float_range() || + options.has_output_tensor_int_range()) << "Output tensor range is required."; - RET_CHECK_LT(options.output_tensor_float_range().min(), - options.output_tensor_float_range().max()) - << "Valid output tensor range is required."; + if (options.has_output_tensor_float_range()) { + RET_CHECK_LT(options.output_tensor_float_range().min(), + options.output_tensor_float_range().max()) + << "Valid output float tensor range is required."; + } + if (options.has_output_tensor_int_range()) { + RET_CHECK_LT(options.output_tensor_int_range().min(), + options.output_tensor_int_range().max()) + << "Valid output int tensor range is required."; + RET_CHECK_GE(options.output_tensor_int_range().min(), 0) + << "The minimum of the output int tensor range must be non-negative."; + RET_CHECK_LE(options.output_tensor_int_range().max(), 255) + << "The maximum of the output int tensor range must be less than or " + "equal to 255."; + } RET_CHECK_GT(options.output_tensor_width(), 0) << "Valid output tensor width is required."; RET_CHECK_GT(options.output_tensor_height(), 0) @@ -175,9 +187,15 @@ class ImageToTensorCalculator : public Node { options_ = cc->Options(); output_width_ = options_.output_tensor_width(); output_height_ = options_.output_tensor_height(); - range_min_ = options_.output_tensor_float_range().min(); - range_max_ = options_.output_tensor_float_range().max(); - + is_int_output_ = options_.has_output_tensor_int_range(); + range_min_ = + is_int_output_ + ? static_cast(options_.output_tensor_int_range().min()) + : options_.output_tensor_float_range().min(); + range_max_ = + is_int_output_ + ? static_cast(options_.output_tensor_int_range().max()) + : options_.output_tensor_float_range().max(); return absl::OkStatus(); } @@ -225,7 +243,7 @@ class ImageToTensorCalculator : public Node { } // Lazy initialization of the GPU or CPU converter. - MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, image->UsesGpu())); + MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, *image.get())); ASSIGN_OR_RETURN(Tensor tensor, (image->UsesGpu() ? gpu_converter_ : cpu_converter_) @@ -283,9 +301,15 @@ class ImageToTensorCalculator : public Node { } } - absl::Status InitConverterIfNecessary(CalculatorContext* cc, bool use_gpu) { + absl::Status InitConverterIfNecessary(CalculatorContext* cc, + const Image& image) { // Lazy initialization of the GPU or CPU converter. - if (use_gpu) { + if (image.UsesGpu()) { + if (is_int_output_) { + return absl::UnimplementedError( + "ImageToTensorConverter for the input GPU image currently doesn't " + "support quantization."); + } if (!gpu_converter_) { #if !MEDIAPIPE_DISABLE_GPU #if MEDIAPIPE_METAL_ENABLED @@ -296,9 +320,17 @@ class ImageToTensorCalculator : public Node { CreateImageToGlBufferTensorConverter( cc, DoesGpuInputStartAtBottom(), GetBorderMode())); #else - ASSIGN_OR_RETURN(gpu_converter_, - CreateImageToGlTextureTensorConverter( - cc, DoesGpuInputStartAtBottom(), GetBorderMode())); + // Check whether the underlying storage object is a GL texture. + if (image.GetGpuBuffer() + .internal_storage()) { + ASSIGN_OR_RETURN( + gpu_converter_, + CreateImageToGlTextureTensorConverter( + cc, DoesGpuInputStartAtBottom(), GetBorderMode())); + } else { + return absl::UnimplementedError( + "ImageToTensorConverter for the input GPU image is unavailable."); + } #endif // MEDIAPIPE_METAL_ENABLED #endif // !MEDIAPIPE_DISABLE_GPU } @@ -306,7 +338,10 @@ class ImageToTensorCalculator : public Node { if (!cpu_converter_) { #if !MEDIAPIPE_DISABLE_OPENCV ASSIGN_OR_RETURN(cpu_converter_, - CreateOpenCvConverter(cc, GetBorderMode())); + CreateOpenCvConverter( + cc, GetBorderMode(), + is_int_output_ ? Tensor::ElementType::kUInt8 + : Tensor::ElementType::kFloat32)); #else LOG(FATAL) << "Cannot create image to tensor opencv converter since " "MEDIAPIPE_DISABLE_OPENCV is defined."; @@ -321,6 +356,7 @@ class ImageToTensorCalculator : public Node { mediapipe::ImageToTensorCalculatorOptions options_; int output_width_ = 0; int output_height_ = 0; + bool is_int_output_ = false; float range_min_ = 0.0f; float range_max_ = 1.0f; }; diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator.proto b/mediapipe/calculators/tensor/image_to_tensor_calculator.proto index 0451dc51f..bf8ba160d 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator.proto +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator.proto @@ -31,6 +31,14 @@ message ImageToTensorCalculatorOptions { optional float max = 2; } + // Range of int values [min, max]. + // min, must be strictly less than max. + // Please note that IntRange is supported for CPU tensors only. + message IntRange { + optional int64 min = 1; + optional int64 max = 2; + } + // Pixel extrapolation methods. See @border_mode. enum BorderMode { BORDER_UNSPECIFIED = 0; @@ -49,6 +57,7 @@ message ImageToTensorCalculatorOptions { // Output tensor element range/type image pixels are converted to. oneof range { FloatRange output_tensor_float_range = 4; + IntRange output_tensor_int_range = 7; } // For CONVENTIONAL mode for OpenGL, input image starts at bottom and needs diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc index 275c33559..4e35e3be6 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc @@ -61,7 +61,8 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet, float range_max, int tensor_width, int tensor_height, bool keep_aspect, absl::optional border_mode, - const mediapipe::NormalizedRect& roi) { + const mediapipe::NormalizedRect& roi, + bool output_int_tensor) { std::string border_mode_str; if (border_mode) { switch (*border_mode) { @@ -73,6 +74,21 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet, break; } } + std::string output_tensor_range; + if (output_int_tensor) { + output_tensor_range = absl::Substitute(R"(output_tensor_int_range { + min: $0 + max: $1 + })", + static_cast(range_min), + static_cast(range_max)); + } else { + output_tensor_range = absl::Substitute(R"(output_tensor_float_range { + min: $0 + max: $1 + })", + range_min, range_max); + } auto graph_config = mediapipe::ParseTextProtoOrDie( absl::Substitute(R"( input_stream: "input_image" @@ -86,22 +102,18 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet, [mediapipe.ImageToTensorCalculatorOptions.ext] { output_tensor_width: $0 output_tensor_height: $1 - keep_aspect_ratio: $4 - output_tensor_float_range { - min: $2 - max: $3 - } - $5 # border mode + keep_aspect_ratio: $2 + $3 # output range + $4 # border mode } } } )", /*$0=*/tensor_width, /*$1=*/tensor_height, - /*$2=*/range_min, - /*$3=*/range_max, - /*$4=*/keep_aspect ? "true" : "false", - /*$5=*/border_mode_str)); + /*$2=*/keep_aspect ? "true" : "false", + /*$3=*/output_tensor_range, + /*$4=*/border_mode_str)); std::vector output_packets; tool::AddVectorSink("tensor", &graph_config, &output_packets); @@ -126,11 +138,18 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet, ASSERT_THAT(tensor_vec, testing::SizeIs(1)); const Tensor& tensor = tensor_vec[0]; - EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kFloat32); - auto view = tensor.GetCpuReadView(); - cv::Mat tensor_mat(tensor_height, tensor_width, CV_32FC3, - const_cast(view.buffer())); + cv::Mat tensor_mat; + if (output_int_tensor) { + EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kUInt8); + tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8UC3, + const_cast(view.buffer())); + } else { + EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kFloat32); + tensor_mat = cv::Mat(tensor_height, tensor_width, CV_32FC3, + const_cast(view.buffer())); + } + cv::Mat result_rgb; auto transformation = GetValueRangeTransformation(range_min, range_max, 0.0f, 255.0f).value(); @@ -170,16 +189,26 @@ enum class InputType { kImageFrame, kImage }; const std::vector kInputTypesToTest = {InputType::kImageFrame, InputType::kImage}; -void RunTest(cv::Mat input, cv::Mat expected_result, float range_min, - float range_max, int tensor_width, int tensor_height, - bool keep_aspect, absl::optional border_mode, +void RunTest(cv::Mat input, cv::Mat expected_result, + std::vector float_range, std::vector int_range, + int tensor_width, int tensor_height, bool keep_aspect, + absl::optional border_mode, const mediapipe::NormalizedRect& roi) { + ASSERT_EQ(2, float_range.size()); + ASSERT_EQ(2, int_range.size()); for (auto input_type : kInputTypesToTest) { RunTestWithInputImagePacket( input_type == InputType::kImageFrame ? MakeImageFramePacket(input) : MakeImagePacket(input), - expected_result, range_min, range_max, tensor_width, tensor_height, - keep_aspect, border_mode, roi); + expected_result, float_range[0], float_range[1], tensor_width, + tensor_height, keep_aspect, border_mode, roi, + /*output_int_tensor=*/false); + RunTestWithInputImagePacket( + input_type == InputType::kImageFrame ? MakeImageFramePacket(input) + : MakeImagePacket(input), + expected_result, int_range[0], int_range[1], tensor_width, + tensor_height, keep_aspect, border_mode, roi, + /*output_int_tensor=*/true); } } @@ -195,8 +224,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspect) { "tensor/testdata/image_to_tensor/input.jpg"), GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect.png"), - /*range_min=*/0.0f, - /*range_max=*/1.0f, + /*float_range=*/{0.0f, 1.0f}, + /*int_range=*/{0, 255}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, /*border mode*/ {}, roi); } @@ -213,8 +242,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectBorderZero) { GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/" "medium_sub_rect_keep_aspect_border_zero.png"), - /*range_min=*/0.0f, - /*range_max=*/1.0f, + /*float_range=*/{0.0f, 1.0f}, + /*int_range=*/{0, 255}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, BorderMode::kZero, roi); } @@ -231,7 +260,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectWithRotation) { GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/" "medium_sub_rect_keep_aspect_with_rotation.png"), - /*range_min=*/0.0f, /*range_max=*/1.0f, + /*float_range=*/{0.0f, 1.0f}, + /*int_range=*/{0, 255}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, BorderMode::kReplicate, roi); } @@ -249,7 +279,8 @@ TEST(ImageToTensorCalculatorTest, GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/" "medium_sub_rect_keep_aspect_with_rotation_border_zero.png"), - /*range_min=*/0.0f, /*range_max=*/1.0f, + /*float_range=*/{0.0f, 1.0f}, + /*int_range=*/{0, 255}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, BorderMode::kZero, roi); } @@ -267,8 +298,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotation) { GetRgb( "/mediapipe/calculators/" "tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation.png"), - /*range_min=*/-1.0f, - /*range_max=*/1.0f, + /*float_range=*/{-1.0f, 1.0f}, + /*int_range=*/{0, 255}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, BorderMode::kReplicate, roi); } @@ -285,8 +316,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotationBorderZero) { GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/" "medium_sub_rect_with_rotation_border_zero.png"), - /*range_min=*/-1.0f, - /*range_max=*/1.0f, + /*float_range=*/{-1.0f, 1.0f}, + /*int_range=*/{0, 255}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, BorderMode::kZero, roi); } @@ -302,8 +333,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRect) { "tensor/testdata/image_to_tensor/input.jpg"), GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/large_sub_rect.png"), - /*range_min=*/0.0f, - /*range_max=*/1.0f, + /*float_range=*/{0.0f, 1.0f}, + /*int_range=*/{0, 255}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, BorderMode::kReplicate, roi); } @@ -320,8 +351,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectBorderZero) { "tensor/testdata/image_to_tensor/input.jpg"), GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/large_sub_rect_border_zero.png"), - /*range_min=*/0.0f, - /*range_max=*/1.0f, + /*float_range=*/{0.0f, 1.0f}, + /*int_range=*/{0, 255}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, BorderMode::kZero, roi); } @@ -338,8 +369,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspect) { "tensor/testdata/image_to_tensor/input.jpg"), GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect.png"), - /*range_min=*/0.0f, - /*range_max=*/1.0f, + /*float_range=*/{0.0f, 1.0f}, + /*int_range=*/{0, 255}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, BorderMode::kReplicate, roi); } @@ -356,8 +387,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectBorderZero) { GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/" "large_sub_rect_keep_aspect_border_zero.png"), - /*range_min=*/0.0f, - /*range_max=*/1.0f, + /*float_range=*/{0.0f, 1.0f}, + /*int_range=*/{0, 255}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, BorderMode::kZero, roi); } @@ -374,8 +405,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotation) { GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/" "large_sub_rect_keep_aspect_with_rotation.png"), - /*range_min=*/0.0f, - /*range_max=*/1.0f, + /*float_range=*/{0.0f, 1.0f}, + /*int_range=*/{0, 255}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, /*border_mode=*/{}, roi); } @@ -393,8 +424,8 @@ TEST(ImageToTensorCalculatorTest, GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/" "large_sub_rect_keep_aspect_with_rotation_border_zero.png"), - /*range_min=*/0.0f, - /*range_max=*/1.0f, + /*float_range=*/{0.0f, 1.0f}, + /*int_range=*/{0, 255}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, /*border_mode=*/BorderMode::kZero, roi); } @@ -410,8 +441,8 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRange) { "tensor/testdata/image_to_tensor/input.jpg"), GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/noop_except_range.png"), - /*range_min=*/0.0f, - /*range_max=*/1.0f, + /*float_range=*/{0.0f, 1.0f}, + /*int_range=*/{0, 255}, /*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true, BorderMode::kReplicate, roi); } @@ -427,8 +458,8 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRangeBorderZero) { "tensor/testdata/image_to_tensor/input.jpg"), GetRgb("/mediapipe/calculators/" "tensor/testdata/image_to_tensor/noop_except_range.png"), - /*range_min=*/0.0f, - /*range_max=*/1.0f, + /*float_range=*/{0.0f, 1.0f}, + /*int_range=*/{0, 255}, /*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true, BorderMode::kZero, roi); } diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc index eb9681521..e81621b76 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc @@ -16,7 +16,7 @@ #include "mediapipe/framework/port.h" -#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_20 +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #include #include @@ -339,4 +339,4 @@ CreateImageToGlTextureTensorConverter(CalculatorContext* cc, } // namespace mediapipe -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_20 +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h index 269abf141..dda1c347f 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h @@ -17,7 +17,7 @@ #include "mediapipe/framework/port.h" -#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_20 +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #include @@ -37,6 +37,6 @@ CreateImageToGlTextureTensorConverter(CalculatorContext* cc, } // namespace mediapipe -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_20 +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #endif // MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_GL_TEXTURE_H_ diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.cc index 6fb39e0c3..ac95de8d7 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.cc @@ -2,7 +2,7 @@ #include "mediapipe/framework/port.h" -#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_20 +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #include #include @@ -85,4 +85,4 @@ bool IsGlClampToBorderSupported(const mediapipe::GlContext& gl_context) { } // namespace mediapipe -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_20 +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.h b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.h index 3105cfef1..5f77ba06c 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.h +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.h @@ -3,7 +3,7 @@ #include "mediapipe/framework/port.h" -#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_20 +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #include #include @@ -40,6 +40,6 @@ bool IsGlClampToBorderSupported(const mediapipe::GlContext& gl_context); } // namespace mediapipe -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_20 +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #endif // MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_GL_UTILS_H_ diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils_test.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils_test.cc index 9482cfc2a..5500f43ed 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils_test.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils_test.cc @@ -1,6 +1,6 @@ #include "mediapipe/framework/port.h" -#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_20 +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.h" #include "mediapipe/framework/port/gtest.h" @@ -46,4 +46,4 @@ TEST(ImageToTensorConverterGlUtilsTest, GlTexParameteriOverrider) { } // namespace } // namespace mediapipe -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_20 +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc index 22131a7e7..45e027439 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc @@ -35,7 +35,8 @@ namespace { class OpenCvProcessor : public ImageToTensorConverter { public: - OpenCvProcessor(BorderMode border_mode) { + OpenCvProcessor(BorderMode border_mode, Tensor::ElementType tensor_type) + : tensor_type_(tensor_type) { switch (border_mode) { case BorderMode::kReplicate: border_mode_ = cv::BORDER_REPLICATE; @@ -44,6 +45,7 @@ class OpenCvProcessor : public ImageToTensorConverter { border_mode_ = cv::BORDER_CONSTANT; break; } + mat_type_ = tensor_type == Tensor::ElementType::kUInt8 ? CV_8UC3 : CV_32FC3; } absl::StatusOr Convert(const mediapipe::Image& input, @@ -56,15 +58,20 @@ class OpenCvProcessor : public ImageToTensorConverter { absl::StrCat("Only RGBA/RGB formats are supported, passed format: ", static_cast(input.image_format()))); } - cv::Mat src = mediapipe::formats::MatView(&input); + auto src = mediapipe::formats::MatView(&input); constexpr int kNumChannels = 3; - Tensor tensor( - Tensor::ElementType::kFloat32, - Tensor::Shape{1, output_dims.height, output_dims.width, kNumChannels}); + Tensor tensor(tensor_type_, Tensor::Shape{1, output_dims.height, + output_dims.width, kNumChannels}); auto buffer_view = tensor.GetCpuWriteView(); - cv::Mat dst(output_dims.height, output_dims.width, CV_32FC3, - buffer_view.buffer()); + cv::Mat dst; + if (tensor_type_ == Tensor::ElementType::kUInt8) { + dst = cv::Mat(output_dims.height, output_dims.width, mat_type_, + buffer_view.buffer()); + } else { + dst = cv::Mat(output_dims.height, output_dims.width, mat_type_, + buffer_view.buffer()); + } const cv::RotatedRect rotated_rect(cv::Point2f(roi.center_x, roi.center_y), cv::Size2f(roi.width, roi.height), @@ -85,7 +92,7 @@ class OpenCvProcessor : public ImageToTensorConverter { cv::Mat projection_matrix = cv::getPerspectiveTransform(src_points, dst_points); cv::Mat transformed; - cv::warpPerspective(src, transformed, projection_matrix, + cv::warpPerspective(*src, transformed, projection_matrix, cv::Size(dst_width, dst_height), /*flags=*/cv::INTER_LINEAR, /*borderMode=*/border_mode_); @@ -102,19 +109,22 @@ class OpenCvProcessor : public ImageToTensorConverter { auto transform, GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax, range_min, range_max)); - transformed.convertTo(dst, CV_32FC3, transform.scale, transform.offset); + transformed.convertTo(dst, mat_type_, transform.scale, transform.offset); return tensor; } private: enum cv::BorderTypes border_mode_; + Tensor::ElementType tensor_type_; + int mat_type_; }; } // namespace absl::StatusOr> CreateOpenCvConverter( - CalculatorContext* cc, BorderMode border_mode) { - return absl::make_unique(border_mode); + CalculatorContext* cc, BorderMode border_mode, + Tensor::ElementType tensor_type) { + return absl::make_unique(border_mode, tensor_type); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.h b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.h index 3ccecc557..74e0030e6 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.h +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.h @@ -25,7 +25,8 @@ namespace mediapipe { // Creates OpenCV image-to-tensor converter. absl::StatusOr> CreateOpenCvConverter( - CalculatorContext* cc, BorderMode border_mode); + CalculatorContext* cc, BorderMode border_mode, + Tensor::ElementType tensor_type); } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_calculator.cc b/mediapipe/calculators/tensor/inference_calculator.cc index 46e0f928c..0311612ff 100644 --- a/mediapipe/calculators/tensor/inference_calculator.cc +++ b/mediapipe/calculators/tensor/inference_calculator.cc @@ -36,6 +36,7 @@ class InferenceCalculatorSelectorImpl Subgraph::GetOptions( subgraph_node); std::vector impls; + const bool should_use_gpu = !options.has_delegate() || // Use GPU delegate if not specified (options.has_delegate() && options.delegate().has_gpu()); diff --git a/mediapipe/calculators/tensor/inference_calculator_cpu.cc b/mediapipe/calculators/tensor/inference_calculator_cpu.cc index 7d695ad9b..d4f2224c5 100644 --- a/mediapipe/calculators/tensor/inference_calculator_cpu.cc +++ b/mediapipe/calculators/tensor/inference_calculator_cpu.cc @@ -81,6 +81,7 @@ class InferenceCalculatorCpuImpl Packet model_packet_; std::unique_ptr interpreter_; TfLiteDelegatePtr delegate_; + bool has_quantized_input_; }; absl::Status InferenceCalculatorCpuImpl::UpdateContract( @@ -109,10 +110,18 @@ absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) { for (int i = 0; i < input_tensors.size(); ++i) { const Tensor* input_tensor = &input_tensors[i]; auto input_tensor_view = input_tensor->GetCpuReadView(); - auto input_tensor_buffer = input_tensor_view.buffer(); - float* local_tensor_buffer = interpreter_->typed_input_tensor(i); - std::memcpy(local_tensor_buffer, input_tensor_buffer, - input_tensor->bytes()); + if (has_quantized_input_) { + // TODO: Support more quantized tensor types. + auto input_tensor_buffer = input_tensor_view.buffer(); + uint8* local_tensor_buffer = interpreter_->typed_input_tensor(i); + std::memcpy(local_tensor_buffer, input_tensor_buffer, + input_tensor->bytes()); + } else { + auto input_tensor_buffer = input_tensor_view.buffer(); + float* local_tensor_buffer = interpreter_->typed_input_tensor(i); + std::memcpy(local_tensor_buffer, input_tensor_buffer, + input_tensor->bytes()); + } } // Run inference. @@ -167,10 +176,9 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegateAndAllocateTensors( // AllocateTensors() can be called only after ModifyGraphWithDelegate. RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); - // TODO: Support quantized tensors. - RET_CHECK_NE( - interpreter_->tensor(interpreter_->inputs()[0])->quantization.type, - kTfLiteAffineQuantization); + has_quantized_input_ = + interpreter_->tensor(interpreter_->inputs()[0])->quantization.type == + kTfLiteAffineQuantization; return absl::OkStatus(); } @@ -226,7 +234,7 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) { #endif // defined(__EMSCRIPTEN__) if (use_xnnpack) { - TfLiteXNNPackDelegateOptions xnnpack_opts{}; + auto xnnpack_opts = TfLiteXNNPackDelegateOptionsDefault(); xnnpack_opts.num_threads = GetXnnpackNumThreads(opts_has_delegate, opts_delegate); delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts), diff --git a/mediapipe/calculators/tensor/inference_calculator_face_detection_test.cc b/mediapipe/calculators/tensor/inference_calculator_face_detection_test.cc index b6de30e3a..bb383af71 100644 --- a/mediapipe/calculators/tensor/inference_calculator_face_detection_test.cc +++ b/mediapipe/calculators/tensor/inference_calculator_face_detection_test.cc @@ -154,8 +154,9 @@ TEST_P(InferenceCalculatorTest, TestFaceDetection) { detection_packets[0].Get>(); #if !defined(MEDIAPIPE_PROTO_LITE) // Approximately is not available with lite protos (b/178137094). - EXPECT_THAT(dets, - ElementsAre(Approximately(EqualsProto(expected_detection)))); + constexpr float kEpison = 0.001; + EXPECT_THAT(dets, ElementsAre(Approximately(EqualsProto(expected_detection), + kEpison))); #endif } diff --git a/mediapipe/calculators/tensor/inference_calculator_gl.cc b/mediapipe/calculators/tensor/inference_calculator_gl.cc index dda9a1fa1..8b998d665 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl.cc @@ -59,8 +59,6 @@ class InferenceCalculatorGlImpl // TfLite requires us to keep the model alive as long as the interpreter is. Packet model_packet_; - std::unique_ptr interpreter_; - TfLiteDelegatePtr delegate_; #if MEDIAPIPE_TFLITE_GL_INFERENCE mediapipe::GlCalculatorHelper gpu_helper_; @@ -72,6 +70,9 @@ class InferenceCalculatorGlImpl tflite_gpu_runner_usage_; #endif // MEDIAPIPE_TFLITE_GL_INFERENCE + TfLiteDelegatePtr delegate_; + std::unique_ptr interpreter_; + #if MEDIAPIPE_TFLITE_GPU_SUPPORTED std::vector output_shapes_; std::vector> gpu_buffers_in_; @@ -252,12 +253,17 @@ absl::Status InferenceCalculatorGlImpl::Close(CalculatorContext* cc) { MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status { gpu_buffers_in_.clear(); gpu_buffers_out_.clear(); + // Delegate must outlive the interpreter, hence the order is important. + interpreter_ = nullptr; + delegate_ = nullptr; return absl::OkStatus(); })); + } else { + // Delegate must outlive the interpreter, hence the order is important. + interpreter_ = nullptr; + delegate_ = nullptr; } - interpreter_ = nullptr; - delegate_ = nullptr; return absl::OkStatus(); } diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc index 498036c12..c4e941f12 100644 --- a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc @@ -266,6 +266,7 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU( auto raw_box_tensor = &input_tensors[0]; RET_CHECK_EQ(raw_box_tensor->shape().dims.size(), 3); RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1); + RET_CHECK_GT(num_boxes_, 0) << "Please set num_boxes in calculator options"; RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_); RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_); auto raw_score_tensor = &input_tensors[1]; @@ -385,6 +386,7 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU( CalculatorContext* cc, std::vector* output_detections) { const auto& input_tensors = *kInTensors(cc); RET_CHECK_GE(input_tensors.size(), 2); + RET_CHECK_GT(num_boxes_, 0) << "Please set num_boxes in calculator options"; #ifndef MEDIAPIPE_DISABLE_GL_COMPUTE MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, &input_tensors, &cc, @@ -563,7 +565,6 @@ absl::Status TensorsToDetectionsCalculator::LoadOptions(CalculatorContext* cc) { // Get calculator options specified in the graph. options_ = cc->Options<::mediapipe::TensorsToDetectionsCalculatorOptions>(); RET_CHECK(options_.has_num_classes()); - RET_CHECK(options_.has_num_boxes()); RET_CHECK(options_.has_num_coords()); num_classes_ = options_.num_classes(); diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc index ffc96b2e4..a03a60189 100644 --- a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc @@ -355,9 +355,10 @@ absl::Status TensorsToSegmentationCalculator::ProcessCpu( std::shared_ptr mask_frame = std::make_shared( ImageFormat::VEC32F1, output_width, output_height); std::unique_ptr output_mask = absl::make_unique(mask_frame); - cv::Mat output_mat = formats::MatView(output_mask.get()); + auto output_mat = formats::MatView(output_mask.get()); // Upsample small mask into output. - cv::resize(small_mask_mat, output_mat, cv::Size(output_width, output_height)); + cv::resize(small_mask_mat, *output_mat, + cv::Size(output_width, output_height)); cc->Outputs().Tag(kMaskTag).Add(output_mask.release(), cc->InputTimestamp()); return absl::OkStatus(); diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index ac058610a..f7e4260cc 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -334,6 +334,7 @@ cc_library( ":image_frame_to_tensor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", ] + select({ diff --git a/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.cc b/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.cc index 0db193bcc..cbc9d2aa2 100644 --- a/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.cc @@ -17,6 +17,7 @@ #include "mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/port/proto_ns.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_macros.h" @@ -32,7 +33,10 @@ namespace { // Convert the ImageFrame into Tensor with floating point value type. // The value will be normalized based on mean and stddev. std::unique_ptr ImageFrameToNormalizedTensor( - const ImageFrame& image_frame, float mean, float stddev) { + // const ImageFrame& image_frame, float mean, float stddev) { + const ImageFrame& image_frame, + const mediapipe::proto_ns::RepeatedField& mean, + const mediapipe::proto_ns::RepeatedField& stddev) { const int cols = image_frame.Width(); const int rows = image_frame.Height(); const int channels = image_frame.NumberOfChannels(); @@ -45,7 +49,20 @@ std::unique_ptr ImageFrameToNormalizedTensor( for (int row = 0; row < rows; ++row) { for (int col = 0; col < cols; ++col) { for (int channel = 0; channel < channels; ++channel) { - tensor_data(row, col, channel) = (pixel[channel] - mean) / stddev; + float mean_value = 0; + if (mean.size() > 1) { + mean_value = mean[channel]; + } else if (!mean.empty()) { + mean_value = mean[0]; + } + float stddev_value = 1; + if (stddev.size() > 1) { + stddev_value = stddev[channel]; + } else if (!stddev.empty()) { + stddev_value = stddev[0]; + } + tensor_data(row, col, channel) = + (pixel[channel] - mean_value) / stddev_value; } pixel += channels; } @@ -126,7 +143,18 @@ absl::Status ImageFrameToTensorCalculator::Process(CalculatorContext* cc) { const tf::DataType data_type = options_.data_type(); RET_CHECK_EQ(data_type, tf::DT_FLOAT) << "Unsupported data type " << data_type; - RET_CHECK_GT(options_.stddev(), 0.0f); + RET_CHECK_GT(options_.stddev().size(), 0) << "You must set a stddev."; + RET_CHECK_GT(options_.stddev()[0], 0.0f) << "The stddev cannot be zero."; + if (options_.stddev().size() > 1) { + RET_CHECK_EQ(options_.stddev().size(), video_frame.NumberOfChannels()) + << "If specifying multiple stddev normalization values, " + << "the number must match the number of image channels."; + } + if (options_.mean().size() > 1) { + RET_CHECK_EQ(options_.mean().size(), video_frame.NumberOfChannels()) + << "If specifying multiple mean normalization values, " + << "the number must match the number of image channels."; + } tensor = ImageFrameToNormalizedTensor(video_frame, options_.mean(), options_.stddev()); } else { diff --git a/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.proto b/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.proto index 0e5a47716..c48e6a869 100644 --- a/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.proto +++ b/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.proto @@ -32,6 +32,6 @@ message ImageFrameToTensorCalculatorOptions { // If set, the output tensor T is equal to (F - mean * J) / stddev, where F // and J are the input image frame and the all-ones matrix of the same size, // respectively. Otherwise, T is equal to F. - optional float mean = 2; - optional float stddev = 3; + repeated float mean = 2; + repeated float stddev = 3; } diff --git a/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator_test.cc b/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator_test.cc index 86b12d6f9..5acfadd47 100644 --- a/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator_test.cc @@ -454,4 +454,32 @@ TEST_F(ImageFrameToTensorCalculatorTest, FixedRGBFrameWithMeanAndStddev) { EXPECT_EQ(actual[2], 127.0f / 128.0f); // (255 - 128) / 128 } +TEST_F(ImageFrameToTensorCalculatorTest, FixedRGBFrameWithRepeatMeanAndStddev) { + runner_ = ::absl::make_unique( + "ImageFrameToTensorCalculator", + "[mediapipe.ImageFrameToTensorCalculatorOptions.ext]" + "{data_type:DT_FLOAT mean:128.0 mean:128.0 mean:128.0 " + " stddev:128.0 stddev:128.0 stddev:128.0}", + 1, 1, 0); + + // Create a single pixel image of fixed color #0080ff. + auto image_frame = ::absl::make_unique(ImageFormat::SRGB, 1, 1); + const uint8 color[] = {0, 128, 255}; + SetToColor(color, image_frame.get()); + + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(image_frame.release()).At(Timestamp(0))); + MP_ASSERT_OK(runner_->Run()); + + const auto& tensor = runner_->Outputs().Index(0).packets[0].Get(); + EXPECT_EQ(tensor.dtype(), tf::DT_FLOAT); + ASSERT_EQ(tensor.dims(), 3); + EXPECT_EQ(tensor.shape().dim_size(0), 1); + EXPECT_EQ(tensor.shape().dim_size(1), 1); + EXPECT_EQ(tensor.shape().dim_size(2), 3); + const float* actual = tensor.flat().data(); + EXPECT_EQ(actual[0], -1.0f); // ( 0 - 128) / 128 + EXPECT_EQ(actual[1], 0.0f); // (128 - 128) / 128 + EXPECT_EQ(actual[2], 127.0f / 128.0f); // (255 - 128) / 128 +} } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.cc b/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.cc index a8abe10d9..85ebde97a 100644 --- a/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.cc @@ -70,10 +70,10 @@ const int kNumCoordsPerBox = 4; // image/understanding/object_detection/export_inference_graph.py // // By default, the output Detections store label ids (integers) for each -// detection. Optionally, a label map (of the form std::map +// detection. Optionally, a label map (of the form std::map // mapping label ids to label names as strings) can be made available as an // input side packet, in which case the output Detections store -// labels as their associated std::string provided by the label map. +// labels as their associated string provided by the label map. // // Usage example: // node { diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc index 3991f645d..7e58329f0 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc @@ -59,7 +59,7 @@ namespace mpms = mediapipe::mediasequence; // bounding boxes from vector, and streams with the // "FLOAT_FEATURE_${NAME}" pattern, which stores the values from vector's // associated with the name ${NAME}. "KEYPOINTS" stores a map of 2D keypoints -// from flat_hash_map>>. "IMAGE_${NAME}", +// from flat_hash_map>>. "IMAGE_${NAME}", // "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store prefixed versions of // each stream, which allows for multiple image streams to be included. However, // the default names are suppored by more tools. diff --git a/mediapipe/calculators/tensorflow/string_to_sequence_example_calculator.cc b/mediapipe/calculators/tensorflow/string_to_sequence_example_calculator.cc index da85bed94..f86d34adf 100644 --- a/mediapipe/calculators/tensorflow/string_to_sequence_example_calculator.cc +++ b/mediapipe/calculators/tensorflow/string_to_sequence_example_calculator.cc @@ -28,7 +28,7 @@ // output_side_packet: "SEQUENCE_EXAMPLE:sequence_example" // } // -// Example converting to std::string in Close(): +// Example converting to string in Close(): // node { // calculator: "StringToSequenceExampleCalculator" // input_side_packet: "SEQUENCE_EXAMPLE:sequence_example" diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc index a8ecb847d..1db886a36 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc @@ -302,10 +302,9 @@ class TensorFlowInferenceCalculator : public CalculatorBase { << "To use recurrent_tag_pairs, batch_size must be 1."; for (const auto& tag_pair : options_.recurrent_tag_pair()) { const std::vector tags = absl::StrSplit(tag_pair, ':'); - RET_CHECK_EQ(tags.size(), 2) - << "recurrent_tag_pair must be a colon " - "separated std::string with two components: " - << tag_pair; + RET_CHECK_EQ(tags.size(), 2) << "recurrent_tag_pair must be a colon " + "separated string with two components: " + << tag_pair; RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[0])) << "Can't find tag '" << tags[0] << "' in signature " << options_.signature_name(); diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc index 794a8a732..1d6b9417b 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc @@ -86,7 +86,7 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase { cc->InputSidePackets() .Tag(kStringModelFilePathTag) .Set( - // Filename of std::string model. + // Filename of string model. ); } cc->OutputSidePackets() diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc index 09985bcf3..5afeeae28 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc @@ -84,7 +84,7 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator { } else if (input_side_packets->HasTag(kStringModelFilePathTag)) { input_side_packets->Tag(kStringModelFilePathTag) .Set( - // Filename of std::string model. + // Filename of string model. ); } output_side_packets->Tag(kSessionTag) diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc index de600de31..922eb9d50 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc @@ -69,6 +69,8 @@ const std::string MaybeConvertSignatureToTag( [](unsigned char c) { return std::toupper(c); }); output = absl::StrReplaceAll(output, {{"/", "_"}}); output = absl::StrReplaceAll(output, {{"-", "_"}}); + output = absl::StrReplaceAll(output, {{".", "_"}}); + LOG(INFO) << "Renamed TAG from: " << name << " to " << output; return output; } else { return name; diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc index 9b2e16a88..97c675920 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc @@ -71,6 +71,8 @@ const std::string MaybeConvertSignatureToTag( [](unsigned char c) { return std::toupper(c); }); output = absl::StrReplaceAll(output, {{"/", "_"}}); output = absl::StrReplaceAll(output, {{"-", "_"}}); + output = absl::StrReplaceAll(output, {{".", "_"}}); + LOG(INFO) << "Renamed TAG from: " << name << " to " << output; return output; } else { return name; diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.cc b/mediapipe/calculators/tflite/tflite_inference_calculator.cc index 8e83f3e44..027e4446d 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.cc @@ -939,7 +939,7 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) { #if !defined(MEDIAPIPE_EDGE_TPU) if (use_xnnpack) { - TfLiteXNNPackDelegateOptions xnnpack_opts{}; + auto xnnpack_opts = TfLiteXNNPackDelegateOptionsDefault(); xnnpack_opts.num_threads = GetXnnpackNumThreads(calculator_opts); delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts), &TfLiteXNNPackDelegateDelete); diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc index 22a9a8d70..ff1d1aaf8 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include "absl/strings/str_format.h" @@ -558,7 +559,7 @@ uniform ivec2 out_size; const int output_layer_index = int($1); const float combine_with_previous_ratio = float($2); -// Will be replaced with either '#define READ_PREVIOUS' or empty std::string +// Will be replaced with either '#define READ_PREVIOUS' or empty string $3 //DEFINE_READ_PREVIOUS void main() { diff --git a/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc b/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc index 9cd460114..a9bc51f66 100644 --- a/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc +++ b/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc @@ -51,6 +51,7 @@ class LocalFilePatternContentsCalculator : public CalculatorBase { cc->InputSidePackets().Tag(kFileDirectoryTag).Get(), cc->InputSidePackets().Tag(kFileSuffixTag).Get(), &filenames_)); + std::sort(filenames_.begin(), filenames_.end()); return absl::OkStatus(); } diff --git a/mediapipe/calculators/util/packet_frequency_calculator_test.cc b/mediapipe/calculators/util/packet_frequency_calculator_test.cc index f8e7c0236..0f1a38f4d 100644 --- a/mediapipe/calculators/util/packet_frequency_calculator_test.cc +++ b/mediapipe/calculators/util/packet_frequency_calculator_test.cc @@ -129,8 +129,8 @@ TEST(PacketFrequencyCalculatorTest, MultiPacketTest) { // Tests packet frequency with multiple input/output streams. TEST(PacketFrequencyCalculatorTest, MultiStreamTest) { // Setup the calculator runner and provide strings as input on all streams - // (note that it doesn't have to be std::string; the calculator can take any - // type as input). + // (note that it doesn't have to be string; the calculator can take any type + // as input). CalculatorRunner runner(GetNodeWithMultipleStreams()); // Packet 1 on stream 1. diff --git a/mediapipe/calculators/util/rect_to_render_data_calculator.cc b/mediapipe/calculators/util/rect_to_render_data_calculator.cc index 3b395818f..400be277d 100644 --- a/mediapipe/calculators/util/rect_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/rect_to_render_data_calculator.cc @@ -37,6 +37,13 @@ RenderAnnotation::Rectangle* NewRect( annotation->mutable_color()->set_b(options.color().b()); annotation->set_thickness(options.thickness()); + if (options.has_top_left_thickness()) { + CHECK(!options.oval()); + CHECK(!options.filled()); + annotation->mutable_rectangle()->set_top_left_thickness( + options.top_left_thickness()); + } + return options.oval() ? options.filled() ? annotation->mutable_filled_oval() ->mutable_oval() @@ -136,6 +143,11 @@ absl::Status RectToRenderDataCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); + if (options_.has_top_left_thickness()) { + // Filled and oval don't support top_left_thickness. + RET_CHECK(!options_.filled()); + RET_CHECK(!options_.oval()); + } return absl::OkStatus(); } diff --git a/mediapipe/calculators/util/rect_to_render_data_calculator.proto b/mediapipe/calculators/util/rect_to_render_data_calculator.proto index 9b6d5e6ee..7611e3e04 100644 --- a/mediapipe/calculators/util/rect_to_render_data_calculator.proto +++ b/mediapipe/calculators/util/rect_to_render_data_calculator.proto @@ -35,4 +35,8 @@ message RectToRenderDataCalculatorOptions { // Whether the rendered rectangle should be an oval. optional bool oval = 4 [default = false]; + + // Radius of top left corner circle. Only supported for oval=false, + // filled=false. + optional double top_left_thickness = 5; } diff --git a/mediapipe/calculators/util/top_k_scores_calculator.cc b/mediapipe/calculators/util/top_k_scores_calculator.cc index 42ec5715e..b6bdf2f85 100644 --- a/mediapipe/calculators/util/top_k_scores_calculator.cc +++ b/mediapipe/calculators/util/top_k_scores_calculator.cc @@ -48,8 +48,8 @@ constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES"; constexpr char kScoresTag[] = "SCORES"; // A calculator that takes a vector of scores and returns the indexes, scores, -// labels of the top k elements, classification protos, and summary std::string -// (in csv format). +// labels of the top k elements, classification protos, and summary string (in +// csv format). // // Usage example: // node { diff --git a/mediapipe/calculators/video/box_detector_calculator.cc b/mediapipe/calculators/video/box_detector_calculator.cc index 55b5c458b..14ac12e5e 100644 --- a/mediapipe/calculators/video/box_detector_calculator.cc +++ b/mediapipe/calculators/video/box_detector_calculator.cc @@ -76,7 +76,7 @@ constexpr char kTrackingTag[] = "TRACKING"; // IMAGE_SIZE: Input image dimension. // TRACKED_BOXES : input box tracking result (proto TimedBoxProtoList) from // BoxTrackerCalculator. -// ADD_INDEX: Optional std::string containing binary format proto of type +// ADD_INDEX: Optional string containing binary format proto of type // BoxDetectorIndex. Used for adding target index to the detector // search index during runtime. // CANCEL_OBJECT_ID: Optional id of box to be removed. This is recommended @@ -91,8 +91,7 @@ constexpr char kTrackingTag[] = "TRACKING"; // BOXES: Optional output stream of type TimedBoxProtoList for each lost box. // // Imput side packets: -// INDEX_PROTO_STRING: Optional std::string containing binary format proto of -// type +// INDEX_PROTO_STRING: Optional string containing binary format proto of type // BoxDetectorIndex. Used for initializing box_detector // with predefined template images. // FRAME_ALIGNMENT: Optional integer to indicate alignment_boundary for diff --git a/mediapipe/calculators/video/box_tracker_calculator.cc b/mediapipe/calculators/video/box_tracker_calculator.cc index d3acc322a..fb1ad3951 100644 --- a/mediapipe/calculators/video/box_tracker_calculator.cc +++ b/mediapipe/calculators/video/box_tracker_calculator.cc @@ -15,6 +15,7 @@ #include #include +#include #include #include @@ -78,7 +79,7 @@ const char kOptionsTag[] = "OPTIONS"; // TrackingData and added to current set of tracked boxes. // This is recommended to be used with SyncSetInputStreamHandler. // START_POS_PROTO_STRING: Same as START_POS, but is in the form of serialized -// protobuffer std::string. When both START_POS and +// protobuffer string. When both START_POS and // START_POS_PROTO_STRING are present, START_POS is used. Suggest // to specify only one of them. // RESTART_POS: Same as START_POS, but exclusively for receiving detection @@ -99,7 +100,7 @@ const char kOptionsTag[] = "OPTIONS"; // can be in arbitrary order. // Use with SyncSetInputStreamHandler in streaming mode only. // RA_TRACK_PROTO_STRING: Same as RA_TRACK, but is in the form of serialized -// protobuffer std::string. When both RA_TRACK and +// protobuffer string. When both RA_TRACK and // RA_TRACK_PROTO_STRING are present, RA_TRACK is used. Suggest // to specify only one of them. // diff --git a/mediapipe/calculators/video/motion_analysis_calculator.cc b/mediapipe/calculators/video/motion_analysis_calculator.cc index 6217d3be9..544439ae8 100644 --- a/mediapipe/calculators/video/motion_analysis_calculator.cc +++ b/mediapipe/calculators/video/motion_analysis_calculator.cc @@ -15,6 +15,7 @@ #include #include #include +#include #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" diff --git a/mediapipe/calculators/video/opencv_video_decoder_calculator.cc b/mediapipe/calculators/video/opencv_video_decoder_calculator.cc index 94ddbb836..52905d837 100644 --- a/mediapipe/calculators/video/opencv_video_decoder_calculator.cc +++ b/mediapipe/calculators/video/opencv_video_decoder_calculator.cc @@ -79,7 +79,7 @@ ImageFormat::Format GetImageFormat(int num_channels) { // to be saved, specify an output side packet with tag "SAVED_AUDIO_PATH". // The calculator will call FFmpeg binary to save audio tracks as an aac file. // If the audio tracks can't be extracted by FFmpeg, the output side packet -// will contain an empty std::string. +// will contain an empty string. // // Example config: // node { diff --git a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties index 442d9132e..41dfb8790 100644 --- a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties +++ b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-6.8.3-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-7.4-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic/AndroidManifest.xml b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic/AndroidManifest.xml index f7218c97c..528a03a3a 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic/AndroidManifest.xml +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic/AndroidManifest.xml @@ -10,6 +10,9 @@ + + + #include +#include "absl/status/status.h" #include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h" #include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.pb.h" #include "mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_state.h" @@ -41,6 +42,7 @@ constexpr char kFirstCropRect[] = "FIRST_CROP_RECT"; // Can be used to control whether an animated zoom should actually performed // (configured through option us_to_first_rect). If provided, a non-zero integer // will allow the animated zoom to be used when the first detections arrive. +// Applies to first detection only. constexpr char kAnimateZoom[] = "ANIMATE_ZOOM"; // Can be used to control the maximum zoom; note that it is re-evaluated only // upon change of input resolution. A value of 100 disables zooming and is the @@ -112,6 +114,16 @@ class ContentZoomingCalculator : public CalculatorBase { int* pan_offset, int* height); // Sets max_frame_value_ and target_aspect_ absl::Status UpdateAspectAndMax(); + // Smooth camera path + absl::Status SmoothAndClampPath(int target_width, int target_height, + float path_width, float path_height, + float* path_offset_x, float* path_offset_y); + // Compute box containing all detections. + absl::Status GetDetectionsBox(mediapipe::CalculatorContext* cc, float* xmin, + float* xmax, float* ymin, float* ymax, + bool* only_required_found, + bool* has_detections); + ContentZoomingCalculatorOptions options_; // Detection frame width/height. int frame_height_; @@ -537,68 +549,13 @@ absl::Status ContentZoomingCalculator::Process( UpdateForResolutionChange(cc, frame_width, frame_height)); } - bool only_required_found = false; - // Compute the box that contains all "is_required" detections. float xmin = 1, ymin = 1, xmax = 0, ymax = 0; - if (cc->Inputs().HasTag(kSalientRegions)) { - auto detection_set = cc->Inputs().Tag(kSalientRegions).Get(); - for (const auto& region : detection_set.detections()) { - if (!region.only_required()) { - continue; - } - only_required_found = true; - MP_RETURN_IF_ERROR(UpdateRanges( - region, options_.detection_shift_vertical(), - options_.detection_shift_horizontal(), &xmin, &xmax, &ymin, &ymax)); - } - } - - if (cc->Inputs().HasTag(kDetections)) { - if (cc->Inputs().Tag(kDetections).IsEmpty()) { - if (last_only_required_detection_ == 0) { - // If no detections are available and we never had any, - // simply return the full-image rectangle as crop-rect. - if (cc->Outputs().HasTag(kCropRect)) { - auto default_rect = absl::make_unique(); - default_rect->set_x_center(frame_width_ / 2); - default_rect->set_y_center(frame_height_ / 2); - default_rect->set_width(frame_width_); - default_rect->set_height(frame_height_); - cc->Outputs().Tag(kCropRect).Add(default_rect.release(), - Timestamp(cc->InputTimestamp())); - } - if (cc->Outputs().HasTag(kNormalizedCropRect)) { - auto default_rect = absl::make_unique(); - default_rect->set_x_center(0.5); - default_rect->set_y_center(0.5); - default_rect->set_width(1.0); - default_rect->set_height(1.0); - cc->Outputs() - .Tag(kNormalizedCropRect) - .Add(default_rect.release(), Timestamp(cc->InputTimestamp())); - } - // Also provide a first crop rect: in this case a zero-sized one. - if (cc->Outputs().HasTag(kFirstCropRect)) { - cc->Outputs() - .Tag(kFirstCropRect) - .Add(new mediapipe::NormalizedRect(), - Timestamp(cc->InputTimestamp())); - } - return absl::OkStatus(); - } - } else { - auto raw_detections = cc->Inputs() - .Tag(kDetections) - .Get>(); - for (const auto& detection : raw_detections) { - only_required_found = true; - MP_RETURN_IF_ERROR(UpdateRanges( - detection, options_.detection_shift_vertical(), - options_.detection_shift_horizontal(), &xmin, &xmax, &ymin, &ymax)); - } - } - } + bool only_required_found = false; + bool has_detections = true; + MP_RETURN_IF_ERROR(GetDetectionsBox(cc, &xmin, &xmax, &ymin, &ymax, + &only_required_found, &has_detections)); + if (!has_detections) return absl::OkStatus(); const bool may_start_animation = (options_.us_to_first_rect() != 0) && (!cc->Inputs().HasTag(kAnimateZoom) || @@ -656,7 +613,8 @@ absl::Status ContentZoomingCalculator::Process( path_solver_zoom_->ClearHistory(); } const bool camera_active = - is_animating || pan_state || tilt_state || zoom_state; + is_animating || ((pan_state || tilt_state || zoom_state) && + !options_.disable_animations()); // Waiting for first rect before setting any value of the camera active flag // so we avoid setting it to false during initialization. if (cc->Outputs().HasTag(kCameraActive) && @@ -666,17 +624,26 @@ absl::Status ContentZoomingCalculator::Process( .AddPacket(MakePacket(camera_active).At(cc->InputTimestamp())); } + // Skip the path solvers to the final destination if not animating. + const bool disable_animations = + options_.disable_animations() && path_solver_zoom_->IsInitialized(); + if (disable_animations) { + MP_RETURN_IF_ERROR(path_solver_zoom_->SetState(height)); + MP_RETURN_IF_ERROR(path_solver_tilt_->SetState(offset_y)); + MP_RETURN_IF_ERROR(path_solver_pan_->SetState(offset_x)); + } + // Compute smoothed zoom camera path. MP_RETURN_IF_ERROR(path_solver_zoom_->AddObservation( height, cc->InputTimestamp().Microseconds())); float path_height; MP_RETURN_IF_ERROR(path_solver_zoom_->GetState(&path_height)); - float path_width = path_height * target_aspect_; + const float path_width = path_height * target_aspect_; // Update pixel-per-degree value for pan/tilt. int target_height; MP_RETURN_IF_ERROR(path_solver_zoom_->GetTargetPosition(&target_height)); - int target_width = target_height * target_aspect_; + const int target_width = target_height * target_aspect_; MP_RETURN_IF_ERROR(path_solver_pan_->UpdatePixelsPerDegree( static_cast(target_width) / kFieldOfView)); MP_RETURN_IF_ERROR(path_solver_tilt_->UpdatePixelsPerDegree( @@ -692,66 +659,16 @@ absl::Status ContentZoomingCalculator::Process( float path_offset_y; MP_RETURN_IF_ERROR(path_solver_tilt_->GetState(&path_offset_y)); - float delta_height; - MP_RETURN_IF_ERROR(path_solver_zoom_->GetDeltaState(&delta_height)); - int delta_width = delta_height * target_aspect_; - - // Smooth centering when zooming out. - float remaining_width = target_width - path_width; - int width_space = frame_width_ - target_width; - if (abs(path_offset_x - frame_width_ / 2) > - width_space / 2 + kPixelTolerance && - remaining_width > kPixelTolerance) { - float required_width = - abs(path_offset_x - frame_width_ / 2) - width_space / 2; - if (path_offset_x < frame_width_ / 2) { - path_offset_x += delta_width * (required_width / remaining_width); - MP_RETURN_IF_ERROR(path_solver_pan_->SetState(path_offset_x)); - } else { - path_offset_x -= delta_width * (required_width / remaining_width); - MP_RETURN_IF_ERROR(path_solver_pan_->SetState(path_offset_x)); - } - } - - float remaining_height = target_height - path_height; - int height_space = frame_height_ - target_height; - if (abs(path_offset_y - frame_height_ / 2) > - height_space / 2 + kPixelTolerance && - remaining_height > kPixelTolerance) { - float required_height = - abs(path_offset_y - frame_height_ / 2) - height_space / 2; - if (path_offset_y < frame_height_ / 2) { - path_offset_y += delta_height * (required_height / remaining_height); - MP_RETURN_IF_ERROR(path_solver_tilt_->SetState(path_offset_y)); - } else { - path_offset_y -= delta_height * (required_height / remaining_height); - MP_RETURN_IF_ERROR(path_solver_tilt_->SetState(path_offset_y)); - } - } - - // Prevent box from extending beyond the image after camera smoothing. - if (path_offset_y - ceil(path_height / 2.0) < 0) { - path_offset_y = ceil(path_height / 2.0); - MP_RETURN_IF_ERROR(path_solver_tilt_->SetState(path_offset_y)); - } else if (path_offset_y + ceil(path_height / 2.0) > frame_height_) { - path_offset_y = frame_height_ - ceil(path_height / 2.0); - MP_RETURN_IF_ERROR(path_solver_tilt_->SetState(path_offset_y)); - } - - if (path_offset_x - ceil(path_width / 2.0) < 0) { - path_offset_x = ceil(path_width / 2.0); - MP_RETURN_IF_ERROR(path_solver_pan_->SetState(path_offset_x)); - } else if (path_offset_x + ceil(path_width / 2.0) > frame_width_) { - path_offset_x = frame_width_ - ceil(path_width / 2.0); - MP_RETURN_IF_ERROR(path_solver_pan_->SetState(path_offset_x)); - } - - // Convert to top/bottom borders to remove. - int path_top = path_offset_y - path_height / 2; - int path_bottom = frame_height_ - (path_offset_y + path_height / 2); + // Update path. + MP_RETURN_IF_ERROR(SmoothAndClampPath(target_width, target_height, path_width, + path_height, &path_offset_x, + &path_offset_y)); // Transmit result downstream to scenecroppingcalculator. if (cc->Outputs().HasTag(kDetectedBorders)) { + // Convert to top/bottom borders to remove. + const int path_top = path_offset_y - path_height / 2; + const int path_bottom = frame_height_ - (path_offset_y + path_height / 2); std::unique_ptr features = absl::make_unique(); MakeStaticFeatures(path_top, path_bottom, frame_width_, frame_height_, @@ -798,8 +715,8 @@ absl::Status ContentZoomingCalculator::Process( if (cc->Outputs().HasTag(kNormalizedCropRect)) { std::unique_ptr gpu_rect = absl::make_unique(); - float float_frame_width = static_cast(frame_width_); - float float_frame_height = static_cast(frame_height_); + const float float_frame_width = static_cast(frame_width_); + const float float_frame_height = static_cast(frame_height_); if (is_animating) { auto rect = GetAnimationRect(frame_width, frame_height, cc->InputTimestamp()); @@ -829,5 +746,130 @@ absl::Status ContentZoomingCalculator::Process( return absl::OkStatus(); } +absl::Status ContentZoomingCalculator::SmoothAndClampPath( + int target_width, int target_height, float path_width, float path_height, + float* path_offset_x, float* path_offset_y) { + float delta_height; + MP_RETURN_IF_ERROR(path_solver_zoom_->GetDeltaState(&delta_height)); + const int delta_width = delta_height * target_aspect_; + + // Smooth centering when zooming out. + const float remaining_width = target_width - path_width; + const int width_space = frame_width_ - target_width; + if (abs(*path_offset_x - frame_width_ / 2) > + width_space / 2 + kPixelTolerance && + remaining_width > kPixelTolerance) { + const float required_width = + abs(*path_offset_x - frame_width_ / 2) - width_space / 2; + if (*path_offset_x < frame_width_ / 2) { + *path_offset_x += delta_width * (required_width / remaining_width); + MP_RETURN_IF_ERROR(path_solver_pan_->SetState(*path_offset_x)); + } else { + *path_offset_x -= delta_width * (required_width / remaining_width); + MP_RETURN_IF_ERROR(path_solver_pan_->SetState(*path_offset_x)); + } + } + + const float remaining_height = target_height - path_height; + const int height_space = frame_height_ - target_height; + if (abs(*path_offset_y - frame_height_ / 2) > + height_space / 2 + kPixelTolerance && + remaining_height > kPixelTolerance) { + const float required_height = + abs(*path_offset_y - frame_height_ / 2) - height_space / 2; + if (*path_offset_y < frame_height_ / 2) { + *path_offset_y += delta_height * (required_height / remaining_height); + MP_RETURN_IF_ERROR(path_solver_tilt_->SetState(*path_offset_y)); + } else { + *path_offset_y -= delta_height * (required_height / remaining_height); + MP_RETURN_IF_ERROR(path_solver_tilt_->SetState(*path_offset_y)); + } + } + + // Prevent box from extending beyond the image after camera smoothing. + if (*path_offset_y - ceil(path_height / 2.0) < 0) { + *path_offset_y = ceil(path_height / 2.0); + MP_RETURN_IF_ERROR(path_solver_tilt_->SetState(*path_offset_y)); + } else if (*path_offset_y + ceil(path_height / 2.0) > frame_height_) { + *path_offset_y = frame_height_ - ceil(path_height / 2.0); + MP_RETURN_IF_ERROR(path_solver_tilt_->SetState(*path_offset_y)); + } + + if (*path_offset_x - ceil(path_width / 2.0) < 0) { + *path_offset_x = ceil(path_width / 2.0); + MP_RETURN_IF_ERROR(path_solver_pan_->SetState(*path_offset_x)); + } else if (*path_offset_x + ceil(path_width / 2.0) > frame_width_) { + *path_offset_x = frame_width_ - ceil(path_width / 2.0); + MP_RETURN_IF_ERROR(path_solver_pan_->SetState(*path_offset_x)); + } + + return absl::OkStatus(); +} + +absl::Status ContentZoomingCalculator::GetDetectionsBox( + mediapipe::CalculatorContext* cc, float* xmin, float* xmax, float* ymin, + float* ymax, bool* only_required_found, bool* has_detections) { + if (cc->Inputs().HasTag(kSalientRegions)) { + auto detection_set = cc->Inputs().Tag(kSalientRegions).Get(); + for (const auto& region : detection_set.detections()) { + if (!region.only_required()) { + continue; + } + *only_required_found = true; + MP_RETURN_IF_ERROR(UpdateRanges( + region, options_.detection_shift_vertical(), + options_.detection_shift_horizontal(), xmin, xmax, ymin, ymax)); + } + } + + if (cc->Inputs().HasTag(kDetections)) { + if (cc->Inputs().Tag(kDetections).IsEmpty()) { + if (last_only_required_detection_ == 0) { + // If no detections are available and we never had any, + // simply return the full-image rectangle as crop-rect. + if (cc->Outputs().HasTag(kCropRect)) { + auto default_rect = absl::make_unique(); + default_rect->set_x_center(frame_width_ / 2); + default_rect->set_y_center(frame_height_ / 2); + default_rect->set_width(frame_width_); + default_rect->set_height(frame_height_); + cc->Outputs().Tag(kCropRect).Add(default_rect.release(), + Timestamp(cc->InputTimestamp())); + } + if (cc->Outputs().HasTag(kNormalizedCropRect)) { + auto default_rect = absl::make_unique(); + default_rect->set_x_center(0.5); + default_rect->set_y_center(0.5); + default_rect->set_width(1.0); + default_rect->set_height(1.0); + cc->Outputs() + .Tag(kNormalizedCropRect) + .Add(default_rect.release(), Timestamp(cc->InputTimestamp())); + } + // Also provide a first crop rect: in this case a zero-sized one. + if (cc->Outputs().HasTag(kFirstCropRect)) { + cc->Outputs() + .Tag(kFirstCropRect) + .Add(new mediapipe::NormalizedRect(), + Timestamp(cc->InputTimestamp())); + } + *has_detections = false; + return absl::OkStatus(); + } + } else { + auto raw_detections = cc->Inputs() + .Tag(kDetections) + .Get>(); + for (const auto& detection : raw_detections) { + *only_required_found = true; + MP_RETURN_IF_ERROR(UpdateRanges( + detection, options_.detection_shift_vertical(), + options_.detection_shift_horizontal(), xmin, xmax, ymin, ymax)); + } + } + } + return absl::OkStatus(); +} + } // namespace autoflip } // namespace mediapipe diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto index 6516ed21f..124551304 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto @@ -19,7 +19,7 @@ package mediapipe.autoflip; import "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto"; import "mediapipe/framework/calculator.proto"; -// NextTag: 18 +// NextTag: 19 message ContentZoomingCalculatorOptions { extend mediapipe.CalculatorOptions { optional ContentZoomingCalculatorOptions ext = 313091992; @@ -71,6 +71,12 @@ message ContentZoomingCalculatorOptions { // us_to_first_rect time budget. optional int64 us_to_first_rect_delay = 16 [default = 0]; + // When true, this flag disables animating camera motions, + // and cuts directly to final target position. + // Does not apply to the first instance (first detection will still animate). + // Use "ANIMATE_ZOOM" input stream to control the first animation. + optional bool disable_animations = 18; + // Deprecated parameters optional KinematicOptions kinematic_options = 2 [deprecated = true]; optional int64 min_motion_to_reframe = 4 [deprecated = true]; diff --git a/mediapipe/examples/desktop/autoflip/calculators/localization_to_region_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/localization_to_region_calculator.cc index 4e3d11cb2..dea12736f 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/localization_to_region_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/localization_to_region_calculator.cc @@ -56,7 +56,7 @@ constexpr char kRegionsTag[] = "REGIONS"; constexpr char kDetectionsTag[] = "DETECTIONS"; // Converts an object detection to a autoflip SignalType. Returns true if the -// std::string label has a autoflip label. +// string label has a autoflip label. bool MatchType(const std::string& label, SignalType* type) { if (label == "person") { type->set_standard(SignalType::HUMAN); diff --git a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc index 885753d63..89170dc6a 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc @@ -182,7 +182,7 @@ namespace { absl::Status ParseAspectRatioString(const std::string& aspect_ratio_string, double* aspect_ratio) { std::string error_msg = - "Aspect ratio std::string must be in the format of 'width:height', e.g. " + "Aspect ratio string must be in the format of 'width:height', e.g. " "'1:1' or '5:4', your input was " + aspect_ratio_string; auto pos = aspect_ratio_string.find(':'); diff --git a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.cc b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.cc index 9a3f53352..5d3c59aab 100644 --- a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.cc +++ b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.cc @@ -4,6 +4,7 @@ constexpr float kMinVelocity = 0.5; namespace mediapipe { namespace autoflip { + namespace { int Median(const std::deque>& positions_raw) { std::deque positions; @@ -16,6 +17,7 @@ int Median(const std::deque>& positions_raw) { return positions[n]; } } // namespace + bool KinematicPathSolver::IsMotionTooSmall(double delta_degs) { if (options_.has_min_motion_to_reframe()) { return abs(delta_degs) < options_.min_motion_to_reframe(); @@ -25,7 +27,9 @@ bool KinematicPathSolver::IsMotionTooSmall(double delta_degs) { return abs(delta_degs) < options_.min_motion_to_reframe_lower(); } } + void KinematicPathSolver::ClearHistory() { raw_positions_at_time_.clear(); } + absl::Status KinematicPathSolver::PredictMotionState(int position, const uint64 time_us, bool* state) { @@ -48,6 +52,9 @@ absl::Status KinematicPathSolver::PredictMotionState(int position, } int filtered_position = Median(raw_positions_at_time_copy); + filtered_position = + std::clamp(filtered_position, min_location_, max_location_); + double delta_degs = (filtered_position - current_position_px_) / pixels_per_degree_; @@ -59,6 +66,9 @@ absl::Status KinematicPathSolver::PredictMotionState(int position, // If the motion is smaller than the reframe_window and camera is moving, // don't use the update. *state = false; + } else if (prior_position_px_ == current_position_px_ && motion_state_) { + // Camera isn't actually moving. Likely face is past bounds. + *state = false; } else { // Apply new position, plus the reframe window size. *state = true; @@ -66,6 +76,7 @@ absl::Status KinematicPathSolver::PredictMotionState(int position, return absl::OkStatus(); } + absl::Status KinematicPathSolver::AddObservation(int position, const uint64 time_us) { if (!initialized_) { @@ -181,18 +192,22 @@ absl::Status KinematicPathSolver::AddObservation(int position, } // Time and position updates. - double delta_t = (time_us - current_time_) / 1000000.0; + double delta_t_sec = (time_us - current_time_) / 1000000.0; + if (options_.max_delta_time_sec() > 0) { + // If updates are very infrequent, then limit the max time difference. + delta_t_sec = fmin(delta_t_sec, options_.max_delta_time_sec()); + } // Time since last state/prediction update, smoothed by // mean_period_update_rate. if (mean_delta_t_ < 0) { - mean_delta_t_ = delta_t; + mean_delta_t_ = delta_t_sec; } else { mean_delta_t_ = mean_delta_t_ * (1 - options_.mean_period_update_rate()) + - delta_t * options_.mean_period_update_rate(); + delta_t_sec * options_.mean_period_update_rate(); } - // Observed velocity and then weighted update of this velocity. - double observed_velocity = delta_degs / delta_t; + // Observed velocity and then weighted update of this velocity (deg/sec). + double observed_velocity = delta_degs / delta_t_sec; double update_rate = std::min(mean_delta_t_ / options_.update_rate_seconds(), options_.max_update_rate()); double updated_velocity = current_velocity_deg_per_s_ * (1 - update_rate) + @@ -253,7 +268,8 @@ absl::Status KinematicPathSolver::GetDeltaState(float* delta_position) { absl::Status KinematicPathSolver::SetState(const float position) { RET_CHECK(initialized_) << "SetState called before first observation added."; - current_position_px_ = position; + current_position_px_ = std::clamp(position, static_cast(min_location_), + static_cast(max_location_)); return absl::OkStatus(); } diff --git a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h index cc9e04cf4..94d19ff80 100644 --- a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h +++ b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h @@ -71,6 +71,8 @@ class KinematicPathSolver { // Provides the change in position from last state. absl::Status GetDeltaState(float* delta_position); + bool IsInitialized() { return initialized_; } + private: // Tuning options. KinematicOptions options_; diff --git a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto index 2f97affa1..61b48b620 100644 --- a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto +++ b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto @@ -31,6 +31,9 @@ message KinematicOptions { optional int64 filtering_time_window_us = 7 [default = 0]; // Weighted update of average period, used for motion updates. optional float mean_period_update_rate = 8 [default = 0.25]; + // When set, caps the maximum time difference (seconds) calculated between new + // updates/observations. Useful when updates come very infrequently. + optional double max_delta_time_sec = 13; // Scale factor for max velocity, to be multiplied by the distance from center // in degrees. Cannot be used with max_velocity and must be used with // max_velocity_shift. diff --git a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver_test.cc b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver_test.cc index 2606cca0a..cb7ec5c94 100644 --- a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver_test.cc +++ b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver_test.cc @@ -419,6 +419,13 @@ TEST(KinematicPathSolverTest, PassSetPosition) { MP_ASSERT_OK(solver.SetState(400)); MP_ASSERT_OK(solver.GetState(&state)); EXPECT_FLOAT_EQ(state, 400); + // Expect to stay in bounds. + MP_ASSERT_OK(solver.SetState(600)); + MP_ASSERT_OK(solver.GetState(&state)); + EXPECT_FLOAT_EQ(state, 500); + MP_ASSERT_OK(solver.SetState(-100)); + MP_ASSERT_OK(solver.GetState(&state)); + EXPECT_FLOAT_EQ(state, 0); } TEST(KinematicPathSolverTest, PassBorderTest) { KinematicOptions options; diff --git a/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.cc b/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.cc index dd30566c2..0c5d221b8 100644 --- a/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.cc +++ b/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.cc @@ -83,7 +83,7 @@ void PolynomialRegressionPathSolver::AddCostFunctionToProblem( const double in, const double out, Problem* problem, double* a, double* b, double* c, double* d, double* k) { // Creating a cost function, with 1D residual and 5 1D parameter blocks. This - // is what the "1, 1, 1, 1, 1, 1" std::string below means. + // is what the "1, 1, 1, 1, 1, 1" string below means. CostFunction* cost_function = new AutoDiffCostFunction( new PolynomialResidual(in, out)); diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.h b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.h index 8688e16ed..d7f06a021 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.h +++ b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.h @@ -55,7 +55,8 @@ class SceneCameraMotionAnalyzer { scene_camera_motion_analyzer_options) : options_(scene_camera_motion_analyzer_options), time_since_last_salient_region_us_(0), - has_solid_color_background_(false) {} + has_solid_color_background_(false), + total_scene_frames_(0) {} ~SceneCameraMotionAnalyzer() {} diff --git a/mediapipe/examples/desktop/hello_world/hello_world.cc b/mediapipe/examples/desktop/hello_world/hello_world.cc index d7416e12a..fde821b51 100644 --- a/mediapipe/examples/desktop/hello_world/hello_world.cc +++ b/mediapipe/examples/desktop/hello_world/hello_world.cc @@ -44,7 +44,7 @@ absl::Status PrintHelloWorld() { ASSIGN_OR_RETURN(OutputStreamPoller poller, graph.AddOutputStreamPoller("out")); MP_RETURN_IF_ERROR(graph.StartRun({})); - // Give 10 input packets that contains the same std::string "Hello World!". + // Give 10 input packets that contains the same string "Hello World!". for (int i = 0; i < 10; ++i) { MP_RETURN_IF_ERROR(graph.AddPacketToInputStream( "in", MakePacket("Hello World!").At(Timestamp(i)))); @@ -52,7 +52,7 @@ absl::Status PrintHelloWorld() { // Close the input stream "in". MP_RETURN_IF_ERROR(graph.CloseInputStream("in")); mediapipe::Packet packet; - // Get the output packets std::string. + // Get the output packets string. while (poller.Next(&packet)) { LOG(INFO) << packet.Get(); } diff --git a/mediapipe/examples/ios/faceeffect/BUILD b/mediapipe/examples/ios/faceeffect/BUILD index 5dafa93e4..87b59901c 100644 --- a/mediapipe/examples/ios/faceeffect/BUILD +++ b/mediapipe/examples/ios/faceeffect/BUILD @@ -72,6 +72,7 @@ objc_library( "//mediapipe/modules/face_geometry/data:geometry_pipeline_metadata_landmarks.binarypb", "//mediapipe/modules/face_landmark:face_landmark.tflite", ], + features = ["-layering_check"], sdk_frameworks = [ "AVFoundation", "CoreGraphics", diff --git a/mediapipe/examples/ios/holistictrackinggpu/BUILD b/mediapipe/examples/ios/holistictrackinggpu/BUILD index f0753ab66..6d72282ed 100644 --- a/mediapipe/examples/ios/holistictrackinggpu/BUILD +++ b/mediapipe/examples/ios/holistictrackinggpu/BUILD @@ -58,6 +58,7 @@ objc_library( "//mediapipe/modules/face_detection:face_detection_short_range.tflite", "//mediapipe/modules/face_landmark:face_landmark.tflite", "//mediapipe/modules/hand_landmark:hand_landmark_full.tflite", + "//mediapipe/modules/hand_landmark:hand_landmark_lite.tflite", "//mediapipe/modules/hand_landmark:handedness.txt", "//mediapipe/modules/holistic_landmark:hand_recrop.tflite", "//mediapipe/modules/pose_detection:pose_detection.tflite", diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index ef32c6c81..d592f9c9c 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -150,6 +150,13 @@ mediapipe_proto_library( deps = ["//mediapipe/framework:mediapipe_options_proto"], ) +config_setting( + name = "android_no_jni", + define_values = {"MEDIAPIPE_NO_JNI": "1"}, + values = {"crosstool_top": "//external:android/crosstool"}, + visibility = ["//visibility:public"], +) + cc_library( name = "calculator_base", srcs = ["calculator_base.cc"], @@ -712,6 +719,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", ], ) @@ -916,15 +924,19 @@ cc_library( ":packet", ":packet_set", ":type_map", + "//mediapipe/framework/deps:no_destructor", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:map_util", - "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:source_location", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:status_util", + "//mediapipe/framework/tool:type_util", "//mediapipe/framework/tool:validate_name", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", ], ) diff --git a/mediapipe/framework/api2/BUILD b/mediapipe/framework/api2/BUILD index 20022fb89..73a3e5e5d 100644 --- a/mediapipe/framework/api2/BUILD +++ b/mediapipe/framework/api2/BUILD @@ -134,6 +134,7 @@ cc_test( deps = [ ":packet", "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 5408a7add..f238b653c 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -313,8 +313,8 @@ template class Node : public NodeBase { public: Node() : NodeBase(Calc::kCalculatorName) {} - // Overrides the built-in calculator type std::string with the provided - // argument. Can be used to create nodes from pure interfaces. + // Overrides the built-in calculator type string with the provided argument. + // Can be used to create nodes from pure interfaces. // TODO: only use this for pure interfaces Node(const std::string& type_override) : NodeBase(type_override) {} @@ -377,6 +377,29 @@ class PacketGenerator { return *options_.MutableExtension(T::ext); } + template + auto operator[](const PortCommon& port) { + using PayloadT = + typename PortCommon::PayloadT; + if constexpr (std::is_same_v) { + auto* base = &out_sides_[port.Tag()]; + if constexpr (kIsMultiple) { + return MultiSideSource(base); + } else { + return SideSource(base); + } + } else if constexpr (std::is_same_v) { + auto* base = &in_sides_[port.Tag()]; + if constexpr (kIsMultiple) { + return MultiSideDestination(base); + } else { + return SideDestination(base); + } + } else { + static_assert(dependent_false::value, "Type not supported."); + } + } + private: std::string type_; TagIndexMap in_sides_; @@ -402,7 +425,7 @@ class Graph { } // Creates a node of a specific type. Should be used for pure interfaces, - // which do not have a built-in type std::string. + // which do not have a built-in type string. template Node& AddNode(const std::string& type) { auto node = std::make_unique>(type); diff --git a/mediapipe/framework/api2/const_str.h b/mediapipe/framework/api2/const_str.h index ff5645e3d..5f9d60e91 100644 --- a/mediapipe/framework/api2/const_str.h +++ b/mediapipe/framework/api2/const_str.h @@ -6,8 +6,8 @@ namespace mediapipe { namespace api2 { -// This class stores a constant std::string that can be inspected at compile -// time in constexpr code. +// This class stores a constant string that can be inspected at compile time +// in constexpr code. class const_str { public: constexpr const_str(std::size_t size, const char* data) diff --git a/mediapipe/framework/api2/packet.h b/mediapipe/framework/api2/packet.h index 426d4701d..771cfb83f 100644 --- a/mediapipe/framework/api2/packet.h +++ b/mediapipe/framework/api2/packet.h @@ -215,6 +215,7 @@ class Packet : public Packet { return typed_payload->data(); } const T& operator*() const { return Get(); } + const T* operator->() const { return &Get(); } template T GetOr(U&& v) const { diff --git a/mediapipe/framework/api2/packet_test.cc b/mediapipe/framework/api2/packet_test.cc index 6d8fd0015..887ba3c3e 100644 --- a/mediapipe/framework/api2/packet_test.cc +++ b/mediapipe/framework/api2/packet_test.cc @@ -1,5 +1,6 @@ #include "mediapipe/framework/api2/packet.h" +#include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -18,6 +19,17 @@ class LiveCheck { bool& alive_; }; +class Base { + public: + virtual ~Base() = default; + virtual absl::string_view name() const { return "Base"; } +}; + +class Derived : public Base { + public: + absl::string_view name() const override { return "Derived"; } +}; + TEST(PacketTest, PacketBaseDefault) { PacketBase p; EXPECT_TRUE(p.IsEmpty()); @@ -242,6 +254,16 @@ TEST(PacketTest, OneOfConsume) { EXPECT_TRUE(p.IsEmpty()); } +TEST(PacketTest, Polymorphism) { + Packet base = PacketAdopting(absl::make_unique()); + EXPECT_EQ(base->name(), "Derived"); + // Since packet contents are implicitly immutable, if you need mutability the + // current recommendation is still to wrap the contents in a unique_ptr. + Packet> mutable_base = + MakePacket>(absl::make_unique()); + EXPECT_EQ((**mutable_base).name(), "Derived"); +} + } // namespace } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/framework/api2/port.h b/mediapipe/framework/api2/port.h index dbd15cc68..fc74ba609 100644 --- a/mediapipe/framework/api2/port.h +++ b/mediapipe/framework/api2/port.h @@ -172,9 +172,14 @@ inline void SetType(CalculatorContract* cc, PacketType& pt) { pt.SetNone(); } +template +inline void SetTypeOneOf(OneOf, CalculatorContract* cc, PacketType& pt) { + pt.SetOneOf(); +} + template {}, int>::type = 0> inline void SetType(CalculatorContract* cc, PacketType& pt) { - pt.SetAny(); + SetTypeOneOf(T{}, cc, pt); } template @@ -294,14 +299,26 @@ struct SideBase { using type = SideInputBase; }; +// TODO: maybe return a PacketBase instead of a Packet? +template +struct ActualPayloadType { + using type = T; +}; + +template +struct ActualPayloadType< + T, std::enable_if_t{}, void>> { + using type = internal::Generic; +}; + } // namespace internal -// TODO: maybe return a PacketBase instead of a Packet? -template {}, int>::type = 0> -auto ActualValueT(T) -> T; +// Maps special port value types, such as AnyType, to internal::Generic. +template +using ActualPayloadT = typename internal::ActualPayloadType::type; -auto ActualValueT(DynamicType) -> internal::Generic; +static_assert(std::is_same_v, int>, ""); +static_assert(std::is_same_v, internal::Generic>, ""); template @@ -325,7 +342,7 @@ class PortCommon : public Base { explicit constexpr PortCommon(const char (&tag)[N]) : Base(N, tag, &get_type_hash, IsOptionalV, IsMultipleV) {} - using PayloadT = decltype(ActualValueT(std::declval())); + using PayloadT = ActualPayloadT; auto operator()(CalculatorContext* cc) const { return internal::AccessPort( @@ -385,7 +402,7 @@ class SideFallbackT : public Base { static constexpr bool kOptional = IsOptionalV; static constexpr bool kMultiple = IsMultipleV; using Optional = SideFallbackT; - using PayloadT = decltype(ActualValueT(std::declval())); + using PayloadT = ActualPayloadT; const char* Tag() const { return stream_port.Tag(); } @@ -499,6 +516,10 @@ class OutputShardAccess : public OutputShardAccessBase { Send(std::move(payload), context_.InputTimestamp()); } + void SetHeader(const PacketBase& header) { + if (output_) output_->SetHeader(ToOldPacket(header)); + } + private: OutputShardAccess(const CalculatorContext& cc, OutputStreamShard* output) : OutputShardAccessBase(cc, output) {} diff --git a/mediapipe/framework/api2/port_test.cc b/mediapipe/framework/api2/port_test.cc index 9a07061df..2bbae387d 100644 --- a/mediapipe/framework/api2/port_test.cc +++ b/mediapipe/framework/api2/port_test.cc @@ -21,6 +21,25 @@ TEST(PortTest, Tag) { EXPECT_EQ(std::string(port.Tag()), "FOO"); } +struct DeletedCopyType { + DeletedCopyType(const DeletedCopyType&) = delete; + DeletedCopyType& operator=(const DeletedCopyType&) = delete; +}; + +TEST(PortTest, DeletedCopyConstructorInput) { + static constexpr Input kInputPort{"INPUT"}; + EXPECT_EQ(std::string(kInputPort.Tag()), "INPUT"); + + static constexpr Output kOutputPort{"OUTPUT"}; + EXPECT_EQ(std::string(kOutputPort.Tag()), "OUTPUT"); + + static constexpr SideInput kSideInputPort{"SIDE_INPUT"}; + EXPECT_EQ(std::string(kSideInputPort.Tag()), "SIDE_INPUT"); + + static constexpr SideOutput kSideOutputPort{"SIDE_OUTPUT"}; + EXPECT_EQ(std::string(kSideOutputPort.Tag()), "SIDE_OUTPUT"); +} + } // namespace } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/framework/api2/tag_test.cc b/mediapipe/framework/api2/tag_test.cc index 26ac632b7..0c0750cb8 100644 --- a/mediapipe/framework/api2/tag_test.cc +++ b/mediapipe/framework/api2/tag_test.cc @@ -26,8 +26,8 @@ TEST(TagTest, String) { EXPECT_EQ(kBAR.str(), "BAR"); } -// Separate invocations of MPP_TAG with the same std::string produce objects of -// the same type. +// Separate invocations of MPP_TAG with the same string produce objects of the +// same type. TEST(TagTest, SameType) { EXPECT_TRUE(same_type(kFOO, kFOO2)); } // Different tags have different types. diff --git a/mediapipe/framework/calculator_contract.h b/mediapipe/framework/calculator_contract.h index 6bf7b163b..72f29bc03 100644 --- a/mediapipe/framework/calculator_contract.h +++ b/mediapipe/framework/calculator_contract.h @@ -95,8 +95,8 @@ class CalculatorContract { input_stream_handler_options_ = options; } - // Returns the name of this Nodes's InputStreamHandler, or empty std::string - // if none is set. + // Returns the name of this Nodes's InputStreamHandler, or empty string if + // none is set. std::string GetInputStreamHandler() const { return input_stream_handler_; } // Returns the MediaPipeOptions of this Node's InputStreamHandler, or empty diff --git a/mediapipe/framework/calculator_graph.h b/mediapipe/framework/calculator_graph.h index 0e6d53b6a..3478375e4 100644 --- a/mediapipe/framework/calculator_graph.h +++ b/mediapipe/framework/calculator_graph.h @@ -54,15 +54,13 @@ #include "mediapipe/framework/scheduler.h" #include "mediapipe/framework/thread_pool_executor.pb.h" -#if !MEDIAPIPE_DISABLE_GPU namespace mediapipe { + +#if !MEDIAPIPE_DISABLE_GPU class GpuResources; struct GpuSharedData; -} // namespace mediapipe #endif // !MEDIAPIPE_DISABLE_GPU -namespace mediapipe { - typedef absl::StatusOr StatusOrPoller; // The class representing a DAG of calculator nodes. @@ -366,10 +364,9 @@ class CalculatorGraph { #if !MEDIAPIPE_DISABLE_GPU // Returns a pointer to the GpuResources in use, if any. // Only meant for internal use. - std::shared_ptr<::mediapipe::GpuResources> GetGpuResources() const; + std::shared_ptr GetGpuResources() const; - absl::Status SetGpuResources( - std::shared_ptr<::mediapipe::GpuResources> resources); + absl::Status SetGpuResources(std::shared_ptr resources); // Helper for PrepareForRun. If it returns a non-empty map, those packets // must be added to the existing side packets, replacing existing values @@ -532,7 +529,7 @@ class CalculatorGraph { #if !MEDIAPIPE_DISABLE_GPU // Owns the legacy GpuSharedData if we need to create one for backwards // compatibility. - std::unique_ptr<::mediapipe::GpuSharedData> legacy_gpu_shared_; + std::unique_ptr legacy_gpu_shared_; #endif // !MEDIAPIPE_DISABLE_GPU // True if the graph was initialized. @@ -598,7 +595,7 @@ class CalculatorGraph { std::unique_ptr counter_factory_; // Executors for the scheduler, keyed by the executor's name. The default - // executor's name is the empty std::string. + // executor's name is the empty string. std::map> executors_; // The processed input side packet map for this run. diff --git a/mediapipe/framework/calculator_graph_test.cc b/mediapipe/framework/calculator_graph_test.cc index 26d2f484c..af3655c22 100644 --- a/mediapipe/framework/calculator_graph_test.cc +++ b/mediapipe/framework/calculator_graph_test.cc @@ -768,7 +768,7 @@ typedef TypedStatusHandler Uint32StatusHandler; REGISTER_STATUS_HANDLER(StringStatusHandler); REGISTER_STATUS_HANDLER(Uint32StatusHandler); -// A std::string generator that will succeed. +// A string generator that will succeed. class StaticCounterStringGenerator : public PacketGenerator { public: static absl::Status FillExpectations( @@ -1767,15 +1767,14 @@ TEST(CalculatorGraph, StatusHandlerInputVerification) { EXPECT_FALSE(graph->Run({{"a_uint64", a_uint64}}).ok()); // Should fail verification when the type of an already created packet is - // wrong. Here we give the uint64 packet instead of the std::string packet to - // the StringStatusHandler. + // wrong. Here we give the uint64 packet instead of the string packet to the + // StringStatusHandler. EXPECT_FALSE( graph->Run({{"extra_string", a_uint64}, {"a_uint64", a_uint64}}).ok()); // Should fail verification when the type of a packet generated by a base // packet factory is wrong. Everything is correct except we add a status - // handler expecting a uint32 but give it the std::string from the packet - // factory. + // handler expecting a uint32 but give it the string from the packet factory. auto* invalid_handler = config.add_status_handler(); invalid_handler->set_status_handler("Uint32StatusHandler"); invalid_handler->add_input_side_packet("created_by_factory"); @@ -1792,8 +1791,8 @@ TEST(CalculatorGraph, StatusHandlerInputVerification) { MediaPipeTypeStringOrDemangled()))); // Should fail verification when the type of a to-be-generated packet is - // wrong. The added handler now expects a std::string but will receive the - // uint32 generated by the existing generator. + // wrong. The added handler now expects a string but will receive the uint32 + // generated by the existing generator. invalid_handler->set_status_handler("StringStatusHandler"); invalid_handler->set_input_side_packet(0, "generated_by_generator"); graph.reset(new CalculatorGraph()); diff --git a/mediapipe/framework/calculator_node.h b/mediapipe/framework/calculator_node.h index 368e0a557..8ecf72cfc 100644 --- a/mediapipe/framework/calculator_node.h +++ b/mediapipe/framework/calculator_node.h @@ -79,10 +79,9 @@ class CalculatorNode { // running first. If a node is not a source, this method is not called. Timestamp SourceProcessOrder(const CalculatorContext* cc) const; - // Retrieves a std::string name for the node. If the node's name was set in - // the calculator graph config, it will be returned. Otherwise, a - // human-readable std::string that uniquely identifies the node is returned, - // e.g. + // Retrieves a string name for the node. If the node's name was set in the + // calculator graph config, it will be returned. Otherwise, a human-readable + // string that uniquely identifies the node is returned, e.g. // "[FooBarCalculator with first output stream \"foo_bar_output\"]" for // non-sink nodes and "[FooBarCalculator with node ID: 42 and input streams: // \"foo_bar_input\"]" for sink nodes. This name should be used in error @@ -278,7 +277,7 @@ class CalculatorNode { void CloseInputStreams() ABSL_LOCKS_EXCLUDED(status_mutex_); void CloseOutputStreams(OutputStreamShardSet* outputs) ABSL_LOCKS_EXCLUDED(status_mutex_); - // Get a std::string describing the input streams. + // Get a string describing the input streams. std::string DebugInputStreamNames() const; // Returns true if all outputs will be identical to the previous graph run. diff --git a/mediapipe/framework/calculator_runner.h b/mediapipe/framework/calculator_runner.h index b680c3604..fb1020de1 100644 --- a/mediapipe/framework/calculator_runner.h +++ b/mediapipe/framework/calculator_runner.h @@ -62,7 +62,7 @@ class CalculatorRunner { // )"); explicit CalculatorRunner(const CalculatorGraphConfig::Node& node_config); #if !defined(MEDIAPIPE_PROTO_LITE) - // Convenience constructor which takes a node_config std::string directly. + // Convenience constructor which takes a node_config string directly. explicit CalculatorRunner(const std::string& node_config_string); // Convenience constructor to initialize a calculator which uses indexes // (not tags) for all its fields. diff --git a/mediapipe/framework/deps/file_path.h b/mediapipe/framework/deps/file_path.h index 4c1c15153..40cf223a3 100644 --- a/mediapipe/framework/deps/file_path.h +++ b/mediapipe/framework/deps/file_path.h @@ -51,7 +51,8 @@ std::string JoinPathImpl(bool honor_abs, // // Usage: // std::string path = file::JoinPath("/cns", dirname, filename); -// std::string path = file::JoinPath("./", filename); +// std::string path = file::JoinPath("./", +// filename); // // 0, 1, 2-path specializations exist to optimize common cases. inline std::string JoinPath() { return std::string(); } @@ -69,7 +70,7 @@ inline std::string JoinPath(absl::string_view path1, absl::string_view path2, // * If there is a single leading "/" in the path, the result will be the // leading "/". // * If there is no "/" in the path, the result is the empty prefix of the -// input std::string. +// input string. absl::string_view Dirname(absl::string_view path); // Return the parts of the path, split on the final "/". If there is no @@ -83,7 +84,7 @@ std::pair SplitPath( // "/" in the path, the result is the same as the input. // Note that this function's behavior differs from the Unix basename // command if path ends with "/". For such paths, this function returns the -// empty std::string. +// empty string. absl::string_view Basename(absl::string_view path); // Returns the part of the basename of path after the final ".". If diff --git a/mediapipe/framework/deps/numbers.h b/mediapipe/framework/deps/numbers.h index b19055582..e199000ee 100644 --- a/mediapipe/framework/deps/numbers.h +++ b/mediapipe/framework/deps/numbers.h @@ -15,6 +15,8 @@ #ifndef MEDIAPIPE_DEPS_NUMBERS_H_ #define MEDIAPIPE_DEPS_NUMBERS_H_ +#include + #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "mediapipe/framework/port/integral_types.h" diff --git a/mediapipe/framework/deps/rectangle.h b/mediapipe/framework/deps/rectangle.h index 9ca9d7ad1..0f8f4766f 100644 --- a/mediapipe/framework/deps/rectangle.h +++ b/mediapipe/framework/deps/rectangle.h @@ -145,7 +145,7 @@ class Rectangle { void AddBorder(const T& border_size); // Debug printing. - friend std::ostream& operator<<(std::ostream&, const Rectangle&); + friend std::ostream& operator<< (std::ostream&, const Rectangle&); private: Point2 min_; diff --git a/mediapipe/framework/deps/registration.h b/mediapipe/framework/deps/registration.h index 07c4a44a4..fbfe2caef 100644 --- a/mediapipe/framework/deps/registration.h +++ b/mediapipe/framework/deps/registration.h @@ -370,7 +370,7 @@ class GlobalFactoryRegistry { GlobalFactoryRegistry() = delete; }; -// Two levels of macros are required to convert __LINE__ into a std::string +// Two levels of macros are required to convert __LINE__ into a string // containing the line number. #define REGISTRY_STATIC_VAR_INNER(var_name, line) var_name##_##line##__ #define REGISTRY_STATIC_VAR(var_name, line) \ diff --git a/mediapipe/framework/deps/singleton.h b/mediapipe/framework/deps/singleton.h index 86dcd78df..9599b24ae 100644 --- a/mediapipe/framework/deps/singleton.h +++ b/mediapipe/framework/deps/singleton.h @@ -25,7 +25,7 @@ class Singleton { public: // Returns the pointer to the singleton of type |T|. // This method is thread-safe. - static T *get() LOCKS_EXCLUDED(mu_) { + static T *get() ABSL_LOCKS_EXCLUDED(mu_) { absl::MutexLock lock(&mu_); if (instance_) { return instance_; @@ -46,7 +46,7 @@ class Singleton { // cannot be recreated. However, the callers of this method responsible for // making sure that no other threads are accessing (or plan to access) the // singleton any longer. - static void Destruct() LOCKS_EXCLUDED(mu_) { + static void Destruct() ABSL_LOCKS_EXCLUDED(mu_) { absl::MutexLock lock(&mu_); T *tmp_ptr = instance_; instance_ = nullptr; @@ -55,8 +55,8 @@ class Singleton { } private: - static T *instance_ GUARDED_BY(mu_); - static bool destroyed_ GUARDED_BY(mu_); + static T *instance_ ABSL_GUARDED_BY(mu_); + static bool destroyed_ ABSL_GUARDED_BY(mu_); static absl::Mutex mu_; }; diff --git a/mediapipe/framework/deps/source_location.h b/mediapipe/framework/deps/source_location.h index 59218b236..e436547a6 100644 --- a/mediapipe/framework/deps/source_location.h +++ b/mediapipe/framework/deps/source_location.h @@ -47,7 +47,7 @@ class source_location { // MEDIAPIPE_LOC macro below. // // file_name must outlive all copies of the source_location - // object, so in practice it should be a std::string literal. + // object, so in practice it should be a string literal. constexpr source_location(std::uint_least32_t line, const char* file_name) : line_(line), file_name_(file_name) {} diff --git a/mediapipe/framework/deps/status.cc b/mediapipe/framework/deps/status.cc index c6e7b68b5..61e21935c 100644 --- a/mediapipe/framework/deps/status.cc +++ b/mediapipe/framework/deps/status.cc @@ -29,7 +29,7 @@ std::string* MediaPipeCheckOpHelperOutOfLine(const absl::Status& v, r += msg; r += " status: "; r += v.ToString(); - // Leaks std::string but this is only to be used in a fatal error message + // Leaks string but this is only to be used in a fatal error message return new std::string(r); } diff --git a/mediapipe/framework/deps/status_builder.cc b/mediapipe/framework/deps/status_builder.cc index de46a5d8e..8358ea01a 100644 --- a/mediapipe/framework/deps/status_builder.cc +++ b/mediapipe/framework/deps/status_builder.cc @@ -15,6 +15,7 @@ #include "mediapipe/framework/deps/status_builder.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" namespace mediapipe { @@ -23,7 +24,9 @@ StatusBuilder::StatusBuilder(const StatusBuilder& sb) { file_ = sb.file_; line_ = sb.line_; no_logging_ = sb.no_logging_; - stream_ = absl::make_unique(sb.stream_->str()); + stream_ = sb.stream_ + ? absl::make_unique(sb.stream_->str()) + : nullptr; join_style_ = sb.join_style_; } @@ -32,43 +35,58 @@ StatusBuilder& StatusBuilder::operator=(const StatusBuilder& sb) { file_ = sb.file_; line_ = sb.line_; no_logging_ = sb.no_logging_; - stream_ = absl::make_unique(sb.stream_->str()); + stream_ = sb.stream_ + ? absl::make_unique(sb.stream_->str()) + : nullptr; join_style_ = sb.join_style_; return *this; } -StatusBuilder& StatusBuilder::SetAppend() { +StatusBuilder& StatusBuilder::SetAppend() & { if (status_.ok()) return *this; join_style_ = MessageJoinStyle::kAppend; return *this; } -StatusBuilder& StatusBuilder::SetPrepend() { +StatusBuilder&& StatusBuilder::SetAppend() && { return std::move(SetAppend()); } + +StatusBuilder& StatusBuilder::SetPrepend() & { if (status_.ok()) return *this; join_style_ = MessageJoinStyle::kPrepend; return *this; } -StatusBuilder& StatusBuilder::SetNoLogging() { +StatusBuilder&& StatusBuilder::SetPrepend() && { + return std::move(SetPrepend()); +} + +StatusBuilder& StatusBuilder::SetNoLogging() & { no_logging_ = true; return *this; } +StatusBuilder&& StatusBuilder::SetNoLogging() && { + return std::move(SetNoLogging()); +} + StatusBuilder::operator Status() const& { - if (stream_->str().empty() || no_logging_) { + if (!stream_ || stream_->str().empty() || no_logging_) { return status_; } return StatusBuilder(*this).JoinMessageToStatus(); } StatusBuilder::operator Status() && { - if (stream_->str().empty() || no_logging_) { + if (!stream_ || stream_->str().empty() || no_logging_) { return status_; } return JoinMessageToStatus(); } absl::Status StatusBuilder::JoinMessageToStatus() { + if (!stream_) { + return absl::OkStatus(); + } std::string message; if (join_style_ == MessageJoinStyle::kAnnotate) { if (!status_.ok()) { diff --git a/mediapipe/framework/deps/status_builder.h b/mediapipe/framework/deps/status_builder.h index dad49f11e..c9111c603 100644 --- a/mediapipe/framework/deps/status_builder.h +++ b/mediapipe/framework/deps/status_builder.h @@ -15,7 +15,13 @@ #ifndef MEDIAPIPE_DEPS_STATUS_BUILDER_H_ #define MEDIAPIPE_DEPS_STATUS_BUILDER_H_ +#include +#include +#include + #include "absl/base/attributes.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/deps/source_location.h" @@ -27,6 +33,10 @@ class ABSL_MUST_USE_RESULT StatusBuilder { public: StatusBuilder(const StatusBuilder& sb); StatusBuilder& operator=(const StatusBuilder& sb); + + StatusBuilder(StatusBuilder&&) = default; + StatusBuilder& operator=(StatusBuilder&&) = default; + // Creates a `StatusBuilder` based on an original status. If logging is // enabled, it will use `location` as the location from which the log message // occurs. A typical user will call this with `MEDIAPIPE_LOC`. @@ -35,14 +45,14 @@ class ABSL_MUST_USE_RESULT StatusBuilder { : status_(original_status), line_(location.line()), file_(location.file_name()), - stream_(new std::ostringstream) {} + stream_(InitStream(status_)) {} StatusBuilder(absl::Status&& original_status, mediapipe::source_location location) : status_(std::move(original_status)), line_(location.line()), file_(location.file_name()), - stream_(new std::ostringstream) {} + stream_(InitStream(status_)) {} // Creates a `StatusBuilder` from a mediapipe status code. If logging is // enabled, it will use `location` as the location from which the log message @@ -51,29 +61,37 @@ class ABSL_MUST_USE_RESULT StatusBuilder { : status_(code, ""), line_(location.line()), file_(location.file_name()), - stream_(new std::ostringstream) {} + stream_(InitStream(status_)) {} StatusBuilder(const absl::Status& original_status, const char* file, int line) : status_(original_status), line_(line), file_(file), - stream_(new std::ostringstream) {} + stream_(InitStream(status_)) {} bool ok() const { return status_.ok(); } - StatusBuilder& SetAppend(); + StatusBuilder& SetAppend() &; + StatusBuilder&& SetAppend() &&; - StatusBuilder& SetPrepend(); + StatusBuilder& SetPrepend() &; + StatusBuilder&& SetPrepend() &&; - StatusBuilder& SetNoLogging(); + StatusBuilder& SetNoLogging() &; + StatusBuilder&& SetNoLogging() &&; template - StatusBuilder& operator<<(const T& msg) { - if (status_.ok()) return *this; + StatusBuilder& operator<<(const T& msg) & { + if (!stream_) return *this; *stream_ << msg; return *this; } + template + StatusBuilder&& operator<<(const T& msg) && { + return std::move(*this << msg); + } + operator Status() const&; operator Status() &&; @@ -88,6 +106,15 @@ class ABSL_MUST_USE_RESULT StatusBuilder { kPrepend, }; + // Conditionally creates an ostringstream if the status is not ok. + static std::unique_ptr InitStream( + const absl::Status status) { + if (status.ok()) { + return nullptr; + } + return absl::make_unique(); + } + // The status that the result will be based on. absl::Status status_; // The line to record if this file is logged. @@ -95,7 +122,8 @@ class ABSL_MUST_USE_RESULT StatusBuilder { // Not-owned: The file to record if this status is logged. const char* file_; bool no_logging_ = false; - // The additional messages added with `<<`. + // The additional messages added with `<<`. This is nullptr when status_ is + // ok. std::unique_ptr stream_; // Specifies how to join the message in `status_` and `stream_`. MessageJoinStyle join_style_ = MessageJoinStyle::kAnnotate; diff --git a/mediapipe/framework/deps/status_builder_test.cc b/mediapipe/framework/deps/status_builder_test.cc index 63166a106..f517bb909 100644 --- a/mediapipe/framework/deps/status_builder_test.cc +++ b/mediapipe/framework/deps/status_builder_test.cc @@ -18,6 +18,21 @@ namespace mediapipe { +TEST(StatusBuilder, OkStatusLvalue) { + StatusBuilder builder(absl::OkStatus(), MEDIAPIPE_LOC); + builder << "annotated message1 " + << "annotated message2"; + absl::Status status = builder; + ASSERT_EQ(status, absl::OkStatus()); +} + +TEST(StatusBuilder, OkStatusRvalue) { + absl::Status status = StatusBuilder(absl::OkStatus(), MEDIAPIPE_LOC) + << "annotated message1 " + << "annotated message2"; + ASSERT_EQ(status, absl::OkStatus()); +} + TEST(StatusBuilder, AnnotateMode) { absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound, "original message"), @@ -30,7 +45,12 @@ TEST(StatusBuilder, AnnotateMode) { "original message; annotated message1 annotated message2"); } -TEST(StatusBuilder, PrependMode) { +TEST(StatusBuilder, PrependModeLvalue) { + StatusBuilder builder( + absl::Status(absl::StatusCode::kInvalidArgument, "original message"), + MEDIAPIPE_LOC); + builder.SetPrepend() << "prepended message1 " + << "prepended message2 "; absl::Status status = StatusBuilder( absl::Status(absl::StatusCode::kInvalidArgument, "original message"), @@ -44,7 +64,33 @@ TEST(StatusBuilder, PrependMode) { "prepended message1 prepended message2 original message"); } -TEST(StatusBuilder, AppendMode) { +TEST(StatusBuilder, PrependModeRvalue) { + absl::Status status = + StatusBuilder( + absl::Status(absl::StatusCode::kInvalidArgument, "original message"), + MEDIAPIPE_LOC) + .SetPrepend() + << "prepended message1 " + << "prepended message2 "; + ASSERT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(status.message(), + "prepended message1 prepended message2 original message"); +} + +TEST(StatusBuilder, AppendModeLvalue) { + StatusBuilder builder( + absl::Status(absl::StatusCode::kInternal, "original message"), + MEDIAPIPE_LOC); + builder.SetAppend() << " extra message1" + << " extra message2"; + absl::Status status = builder; + ASSERT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_EQ(status.message(), "original message extra message1 extra message2"); +} + +TEST(StatusBuilder, AppendModeRvalue) { absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kInternal, "original message"), MEDIAPIPE_LOC) @@ -56,7 +102,18 @@ TEST(StatusBuilder, AppendMode) { EXPECT_EQ(status.message(), "original message extra message1 extra message2"); } -TEST(StatusBuilder, NoLoggingMode) { +TEST(StatusBuilder, NoLoggingModeLvalue) { + StatusBuilder builder( + absl::Status(absl::StatusCode::kUnavailable, "original message"), + MEDIAPIPE_LOC); + builder.SetNoLogging() << " extra message"; + absl::Status status = builder; + ASSERT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kUnavailable); + EXPECT_EQ(status.message(), "original message"); +} + +TEST(StatusBuilder, NoLoggingModeRvalue) { absl::Status status = StatusBuilder( absl::Status(absl::StatusCode::kUnavailable, "original message"), diff --git a/mediapipe/framework/deps/status_macros.h b/mediapipe/framework/deps/status_macros.h index d31c81c2d..757d99392 100644 --- a/mediapipe/framework/deps/status_macros.h +++ b/mediapipe/framework/deps/status_macros.h @@ -150,21 +150,28 @@ #define STATUS_MACROS_IMPL_GET_VARIADIC_(args) \ STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_ args -#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_(lhs, rexpr) \ - STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, std::move(_)) +#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_(lhs, rexpr) \ + STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \ + STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr, \ + return mediapipe::StatusBuilder( \ + std::move(STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__)) \ + .status(), \ + __FILE__, __LINE__)) #define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, error_expression) \ STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \ STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr, \ - error_expression) -#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \ - error_expression) \ - auto statusor = (rexpr); \ - if (ABSL_PREDICT_FALSE(!statusor.ok())) { \ - mediapipe::StatusBuilder _(std::move(statusor).status(), __FILE__, \ - __LINE__); \ - (void)_; /* error_expression is allowed to not use this variable */ \ - return (error_expression); \ - } \ + mediapipe::StatusBuilder _( \ + std::move(STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__)) \ + .status(), \ + __FILE__, __LINE__); \ + (void)_; /* error_expression is allowed to not use this variable */ \ + return (error_expression)) +#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \ + error_expression) \ + auto statusor = (rexpr); \ + if (ABSL_PREDICT_FALSE(!statusor.ok())) { \ + error_expression; \ + } \ lhs = std::move(statusor).value() // Internal helper for concatenating macro values. diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index f79e6aa43..9deff6542 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -332,15 +332,13 @@ cc_library( "//mediapipe/framework:port", "//mediapipe/framework:type_map", "//mediapipe/framework/port:logging", + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:gpu_buffer_format", ] + select({ "//conditions:default": [ - "//mediapipe/gpu:gpu_buffer", - "//mediapipe/gpu:gpu_buffer_format", "//mediapipe/gpu:gl_texture_buffer", ], "//mediapipe:ios": [ - "//mediapipe/gpu:gpu_buffer", - "//mediapipe/gpu:gpu_buffer_format", ], "//mediapipe/gpu:disable_gpu": [], }) + select({ @@ -430,7 +428,10 @@ cc_test( cc_library( name = "tensor", - srcs = ["tensor.cc"], + srcs = + [ + "tensor.cc", + ], hdrs = ["tensor.h"], copts = select({ "//mediapipe:apple": [ diff --git a/mediapipe/framework/formats/image.cc b/mediapipe/framework/formats/image.cc index 2b5900daa..65345f98d 100644 --- a/mediapipe/framework/formats/image.cc +++ b/mediapipe/framework/formats/image.cc @@ -16,48 +16,15 @@ #include "mediapipe/framework/type_map.h" +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gl_texture_view.h" +#endif // !MEDIAPIPE_DISABLE_GPU + namespace mediapipe { // TODO Refactor common code from GpuBufferToImageFrameCalculator bool Image::ConvertToCpu() const { - if (!use_gpu_) return true; // Already on CPU. -#if !MEDIAPIPE_DISABLE_GPU -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - image_frame_ = CreateImageFrameForCVPixelBuffer(GetCVPixelBufferRef()); -#else - auto gl_texture = gpu_buffer_.GetGlTextureBufferSharedPtr(); - if (!gl_texture->GetProducerContext()) return false; - gl_texture->GetProducerContext()->Run([this, &gl_texture]() { - gl_texture->WaitOnGpu(); - const auto gpu_buf = mediapipe::GpuBuffer(GetGlTextureBufferSharedPtr()); -#ifdef __ANDROID__ - glBindFramebuffer(GL_FRAMEBUFFER, 0); // b/32091368 -#endif - GLuint fb = 0; - glDisable(GL_DEPTH_TEST); - // TODO Re-use a shared framebuffer. - glGenFramebuffers(1, &fb); - glBindFramebuffer(GL_FRAMEBUFFER, fb); - glViewport(0, 0, gpu_buf.width(), gpu_buf.height()); - glActiveTexture(GL_TEXTURE0); - glBindTexture(gl_texture->target(), gl_texture->name()); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, - gl_texture->target(), gl_texture->name(), 0); - auto frame = std::make_shared( - mediapipe::ImageFormatForGpuBufferFormat(gpu_buf.format()), - gpu_buf.width(), gpu_buf.height(), - ImageFrame::kGlDefaultAlignmentBoundary); - const auto info = GlTextureInfoForGpuBufferFormat( - gpu_buf.format(), 0, gl_texture->GetProducerContext()->GetGlVersion()); - glReadPixels(0, 0, gpu_buf.width(), gpu_buf.height(), info.gl_format, - info.gl_type, frame->MutablePixelData()); - glDeleteFramebuffers(1, &fb); - // Cleanup - gl_texture->DidRead(gl_texture->GetProducerContext()->CreateSyncToken()); - image_frame_ = frame; - }); -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -#endif // !MEDIAPIPE_DISABLE_GPU + auto view = gpu_buffer_.GetReadView(); use_gpu_ = false; return true; } @@ -67,19 +34,7 @@ bool Image::ConvertToGpu() const { #if MEDIAPIPE_DISABLE_GPU return false; #else - if (use_gpu_) return true; // Already on GPU. -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - auto packet = PointToForeign(image_frame_.get()); - CFHolder buffer; - auto status = CreateCVPixelBufferForImageFramePacket(packet, true, &buffer); - CHECK_OK(status); - gpu_buffer_ = mediapipe::GpuBuffer(std::move(buffer)); -#else - // GlCalculatorHelperImpl::MakeGlTextureBuffer (CreateSourceTexture) - auto buffer = mediapipe::GlTextureBuffer::Create(*image_frame_); - glFlush(); - gpu_buffer_ = mediapipe::GpuBuffer(std::move(buffer)); -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + auto view = gpu_buffer_.GetReadView(0); use_gpu_ = true; return true; #endif // MEDIAPIPE_DISABLE_GPU diff --git a/mediapipe/framework/formats/image.h b/mediapipe/framework/formats/image.h index 441a79122..44578a7d9 100644 --- a/mediapipe/framework/formats/image.h +++ b/mediapipe/framework/formats/image.h @@ -21,20 +21,18 @@ #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/port/logging.h" - -#if !MEDIAPIPE_DISABLE_GPU - #include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_buffer_format.h" +#include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" +#include "mediapipe/gpu/image_frame_view.h" + +#if !MEDIAPIPE_DISABLE_GPU #if defined(__APPLE__) #include #include "mediapipe/objc/CFHolder.h" #include "mediapipe/objc/util.h" -#if !TARGET_OS_OSX // iOS, use CVPixelBuffer. -#define MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER 1 -#endif // TARGET_OS_OSX #endif // defined(__APPLE__) #if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER // OSX, use GL textures. @@ -73,15 +71,15 @@ class Image { // Creates an Image representing the same image content as the ImageFrame // the input shared pointer points to, and retaining shared ownership. explicit Image(ImageFrameSharedPtr image_frame) - : image_frame_(std::move(image_frame)) { + : gpu_buffer_(std::make_shared( + std::move(image_frame))) { use_gpu_ = false; - pixel_mutex_ = std::make_shared(); } // CPU getters. - const ImageFrameSharedPtr& GetImageFrameSharedPtr() const { - if (use_gpu_ == true) ConvertToCpu(); - return image_frame_; + ImageFrameSharedPtr GetImageFrameSharedPtr() const { + // Write view currently because the return type does not point to const IF. + return gpu_buffer_.GetWriteView(); } // Creates an Image representing the same image content as the input GPU @@ -99,19 +97,18 @@ class Image { explicit Image(mediapipe::GpuBuffer gpu_buffer) { use_gpu_ = true; gpu_buffer_ = gpu_buffer; - pixel_mutex_ = std::make_shared(); } // GPU getters. #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER CVPixelBufferRef GetCVPixelBufferRef() const { if (use_gpu_ == false) ConvertToGpu(); - return gpu_buffer_.GetCVPixelBufferRef(); + return mediapipe::GetCVPixelBufferRef(gpu_buffer_); } #else mediapipe::GlTextureBufferSharedPtr GetGlTextureBufferSharedPtr() const { if (use_gpu_ == false) ConvertToGpu(); - return gpu_buffer_.GetGlTextureBufferSharedPtr(); + return gpu_buffer_.internal_storage(); } #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER // Get a GPU view. Automatically uploads from CPU if needed. @@ -128,9 +125,7 @@ class Image { int step() const; // Row size in bytes. bool UsesGpu() const { return use_gpu_; } ImageFormat::Format image_format() const; -#if !MEDIAPIPE_DISABLE_GPU mediapipe::GpuBufferFormat format() const; -#endif // !MEDIAPIPE_DISABLE_GPU // Converts to true iff valid. explicit operator bool() const { return operator!=(nullptr); } @@ -147,8 +142,8 @@ class Image { // Lock/Unlock pixel data. // Should be used exclusively by the PixelLock helper class. - void LockPixels() const ABSL_EXCLUSIVE_LOCK_FUNCTION(pixel_mutex_); - void UnlockPixels() const ABSL_UNLOCK_FUNCTION(pixel_mutex_); + void LockPixels() const ABSL_EXCLUSIVE_LOCK_FUNCTION(); + void UnlockPixels() const ABSL_UNLOCK_FUNCTION(); // Helper utility for GPU->CPU data transfer. bool ConvertToCpu() const; @@ -157,75 +152,32 @@ class Image { bool ConvertToGpu() const; private: -#if !MEDIAPIPE_DISABLE_GPU mutable mediapipe::GpuBuffer gpu_buffer_; -#endif // !MEDIAPIPE_DISABLE_GPU - mutable ImageFrameSharedPtr image_frame_; mutable bool use_gpu_ = false; - mutable std::shared_ptr pixel_mutex_; // ImageFrame only. }; -inline int Image::width() const { -#if !MEDIAPIPE_DISABLE_GPU - if (use_gpu_) - return gpu_buffer_.width(); - else -#endif // !MEDIAPIPE_DISABLE_GPU - return image_frame_->Width(); -} +inline int Image::width() const { return gpu_buffer_.width(); } -inline int Image::height() const { -#if !MEDIAPIPE_DISABLE_GPU - if (use_gpu_) - return gpu_buffer_.height(); - else -#endif // !MEDIAPIPE_DISABLE_GPU - return image_frame_->Height(); -} +inline int Image::height() const { return gpu_buffer_.height(); } inline ImageFormat::Format Image::image_format() const { -#if !MEDIAPIPE_DISABLE_GPU - if (use_gpu_) - return mediapipe::ImageFormatForGpuBufferFormat(gpu_buffer_.format()); - else -#endif // !MEDIAPIPE_DISABLE_GPU - return image_frame_->Format(); + return mediapipe::ImageFormatForGpuBufferFormat(gpu_buffer_.format()); } -#if !MEDIAPIPE_DISABLE_GPU inline mediapipe::GpuBufferFormat Image::format() const { - if (use_gpu_) - return gpu_buffer_.format(); - else - return mediapipe::GpuBufferFormatForImageFormat(image_frame_->Format()); + return gpu_buffer_.format(); } -#endif // !MEDIAPIPE_DISABLE_GPU inline bool Image::operator==(std::nullptr_t other) const { -#if !MEDIAPIPE_DISABLE_GPU - if (use_gpu_) - return gpu_buffer_ == other; - else -#endif // !MEDIAPIPE_DISABLE_GPU - return image_frame_ == other; + return gpu_buffer_ == other; } inline bool Image::operator==(const Image& other) const { -#if !MEDIAPIPE_DISABLE_GPU - if (use_gpu_) - return gpu_buffer_ == other.gpu_buffer_; - else -#endif // !MEDIAPIPE_DISABLE_GPU - return image_frame_ == other.image_frame_; + return gpu_buffer_ == other.gpu_buffer_; } inline Image& Image::operator=(std::nullptr_t other) { -#if !MEDIAPIPE_DISABLE_GPU - if (use_gpu_) - gpu_buffer_ = other; - else -#endif // !MEDIAPIPE_DISABLE_GPU - image_frame_ = other; + gpu_buffer_ = other; return *this; } @@ -234,19 +186,14 @@ inline int Image::channels() const { } inline int Image::step() const { - if (use_gpu_) - return width() * channels() * - ImageFrame::ByteDepthForFormat(image_format()); - else - return image_frame_->WidthStep(); + return gpu_buffer_.GetReadView()->WidthStep(); } inline void Image::LockPixels() const { - pixel_mutex_->Lock(); ConvertToCpu(); // Download data if necessary. } -inline void Image::UnlockPixels() const { pixel_mutex_->Unlock(); } +inline void Image::UnlockPixels() const {} // Helper class for getting access to Image CPU data, // and handles automatically locking/unlocking CPU data access. @@ -268,7 +215,10 @@ class PixelReadLock { public: explicit PixelReadLock(const Image& image) { buffer_ = ℑ - if (buffer_) buffer_->LockPixels(); + if (buffer_) { + buffer_->LockPixels(); + frame_ = buffer_->GetImageFrameSharedPtr(); + } } ~PixelReadLock() { if (buffer_) buffer_->UnlockPixels(); @@ -276,10 +226,7 @@ class PixelReadLock { PixelReadLock(const PixelReadLock&) = delete; const uint8* Pixels() const { - if (buffer_ && !buffer_->UsesGpu()) { - ImageFrame* frame = buffer_->GetImageFrameSharedPtr().get(); - if (frame) return frame->PixelData(); - } + if (frame_) return frame_->PixelData(); return nullptr; } @@ -287,13 +234,17 @@ class PixelReadLock { private: const Image* buffer_ = nullptr; + std::shared_ptr frame_; }; class PixelWriteLock { public: explicit PixelWriteLock(Image* image) { buffer_ = image; - if (buffer_) buffer_->LockPixels(); + if (buffer_) { + buffer_->LockPixels(); + frame_ = buffer_->GetImageFrameSharedPtr(); + } } ~PixelWriteLock() { if (buffer_) buffer_->UnlockPixels(); @@ -301,10 +252,7 @@ class PixelWriteLock { PixelWriteLock(const PixelWriteLock&) = delete; uint8* Pixels() { - if (buffer_ && !buffer_->UsesGpu()) { - ImageFrame* frame = buffer_->GetImageFrameSharedPtr().get(); - if (frame) return frame->MutablePixelData(); - } + if (frame_) return frame_->MutablePixelData(); return nullptr; } @@ -312,6 +260,7 @@ class PixelWriteLock { private: const Image* buffer_ = nullptr; + std::shared_ptr frame_; }; } // namespace mediapipe diff --git a/mediapipe/framework/formats/image_opencv.cc b/mediapipe/framework/formats/image_opencv.cc index 3c1ec548f..3debbe421 100644 --- a/mediapipe/framework/formats/image_opencv.cc +++ b/mediapipe/framework/formats/image_opencv.cc @@ -77,7 +77,16 @@ int GetMatType(const mediapipe::ImageFormat::Format format) { namespace mediapipe { namespace formats { -cv::Mat MatView(const mediapipe::Image* image) { +std::shared_ptr MatView(const mediapipe::Image* image) { + // Used to hold the lock through the Mat's lifetime. + struct MatWithPixelLock { + // Constructor needed because you cannot use aggregate initialization with + // std::make_shared. + MatWithPixelLock(mediapipe::Image* image) : lock(image) {} + mediapipe::PixelWriteLock lock; + cv::Mat mat; + }; + const int dims = 2; const int sizes[] = {image->height(), image->width()}; const int type = @@ -85,18 +94,22 @@ cv::Mat MatView(const mediapipe::Image* image) { const size_t steps[] = {static_cast(image->step()), static_cast(ImageFrame::ByteDepthForFormat( image->image_format()))}; - mediapipe::PixelWriteLock dst_lock(const_cast(image)); - uint8* data_ptr = dst_lock.Pixels(); + auto owner = + std::make_shared(const_cast(image)); + uint8* data_ptr = owner->lock.Pixels(); CHECK(data_ptr != nullptr); // Use Image to initialize in-place. Image still owns memory. if (steps[0] == sizes[1] * image->channels() * ImageFrame::ByteDepthForFormat(image->image_format())) { // Contiguous memory optimization. See b/78570764 - return cv::Mat(dims, sizes, type, data_ptr); + owner->mat = cv::Mat(dims, sizes, type, data_ptr); } else { // Custom width step. - return cv::Mat(dims, sizes, type, data_ptr, steps); + owner->mat = cv::Mat(dims, sizes, type, data_ptr, steps); } + // Aliasing constructor makes a shared_ptr which keeps the whole + // MatWithPixelLock alive. + return std::shared_ptr(owner, &owner->mat); } } // namespace formats } // namespace mediapipe diff --git a/mediapipe/framework/formats/image_opencv.h b/mediapipe/framework/formats/image_opencv.h index 48824a4dd..b1bc4954d 100644 --- a/mediapipe/framework/formats/image_opencv.h +++ b/mediapipe/framework/formats/image_opencv.h @@ -29,7 +29,9 @@ namespace formats { // the const modifier is lost. The caller must be careful // not to use the returned object to modify the data in a const Image, // even though the returned data is mutable. -cv::Mat MatView(const mediapipe::Image* image); +// Note: this returns a shared_ptr so it can keep the CPU memory referenced +// by the Mat alive. +std::shared_ptr MatView(const mediapipe::Image* image); } // namespace formats } // namespace mediapipe diff --git a/mediapipe/framework/formats/matrix.h b/mediapipe/framework/formats/matrix.h index 5f7c76a83..99729aedc 100644 --- a/mediapipe/framework/formats/matrix.h +++ b/mediapipe/framework/formats/matrix.h @@ -39,7 +39,7 @@ void MatrixDataProtoFromMatrix(const Matrix& matrix, MatrixData* matrix_data); void MatrixFromMatrixDataProto(const MatrixData& matrix_data, Matrix* matrix); #if !defined(MEDIAPIPE_MOBILE) && !defined(MEDIAPIPE_LITE) -// Produce a Text format MatrixData std::string. Mainly useful for test code. +// Produce a Text format MatrixData string. Mainly useful for test code. std::string MatrixAsTextProto(const Matrix& matrix); // Produce a Matrix from a text format MatrixData proto representation. void MatrixFromTextProto(const std::string& text_proto, Matrix* matrix); diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 9277dc7a2..453f2c659 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -76,7 +76,7 @@ class Tensor { public: // No resources are allocated here. - enum class ElementType { kNone, kFloat16, kFloat32 }; + enum class ElementType { kNone, kFloat16, kFloat32, kUInt8 }; struct Shape { Shape() = default; Shape(std::initializer_list dimensions) : dims(dimensions) {} @@ -215,6 +215,8 @@ class Tensor { return 2; case ElementType::kFloat32: return sizeof(float); + case ElementType::kUInt8: + return 1; } } int bytes() const { return shape_.num_elements() * element_size(); } @@ -278,6 +280,11 @@ class Tensor { #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 }; +int BhwcBatchFromShape(const Tensor::Shape& shape); +int BhwcHeightFromShape(const Tensor::Shape& shape); +int BhwcWidthFromShape(const Tensor::Shape& shape); +int BhwcDepthFromShape(const Tensor::Shape& shape); + } // namespace mediapipe #endif // MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_H_ diff --git a/mediapipe/framework/formats/tensor_internal.h b/mediapipe/framework/formats/tensor_internal.h index 86e366a06..1231a991c 100644 --- a/mediapipe/framework/formats/tensor_internal.h +++ b/mediapipe/framework/formats/tensor_internal.h @@ -16,14 +16,22 @@ #define MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_INTERNAL_H_ #include +#include + +#include "mediapipe/framework/tool/type_util.h" namespace mediapipe { // Generates unique view id at compile-time using FILE and LINE. -#define TENSOR_UNIQUE_VIEW_TYPE_ID() \ - static constexpr uint64_t kId = tensor_internal::FnvHash64( \ +#define TENSOR_UNIQUE_VIEW_TYPE_ID() \ + static inline uint64_t kId = tensor_internal::FnvHash64( \ __FILE__, tensor_internal::FnvHash64(TENSOR_INT_TO_STRING(__LINE__))) +// Generates unique view id at compile-time using FILE and LINE and Type of the +// template view's argument. +#define TENSOR_UNIQUE_VIEW_TYPE_ID_T(T) \ + static inline uint64_t kId = tool::GetTypeHash(); + namespace tensor_internal { #define TENSOR_INT_TO_STRING2(x) #x @@ -36,6 +44,21 @@ constexpr uint64_t kFnvOffsetBias = 0xcbf29ce484222325; constexpr uint64_t FnvHash64(const char* str, uint64_t hash = kFnvOffsetBias) { return (str[0] == 0) ? hash : FnvHash64(str + 1, (hash ^ str[0]) * kFnvPrime); } + +template +struct TypeList { + static constexpr std::size_t size{sizeof...(Ts)}; +}; +template +struct TypeInList {}; +template +struct TypeInList> + : std::integral_constant {}; +template +struct TypeInList> + : std::integral_constant>::value> {}; + } // namespace tensor_internal } // namespace mediapipe diff --git a/mediapipe/framework/graph_validation_test.cc b/mediapipe/framework/graph_validation_test.cc index b2d38fc0e..c98983838 100644 --- a/mediapipe/framework/graph_validation_test.cc +++ b/mediapipe/framework/graph_validation_test.cc @@ -355,8 +355,8 @@ TEST(GraphValidationTest, OptionalInputNotProvidedForSubgraphCalculator) { output_stream: "OUTPUT:output_0" node { calculator: "OptionalSideInputTestCalculator" - input_side_packet: "SIDEINPUT:input_0" # std::string - output_stream: "OUTPUT:output_0" # std::string + input_side_packet: "SIDEINPUT:input_0" # string + output_stream: "OUTPUT:output_0" # string } )pb"); @@ -366,7 +366,7 @@ TEST(GraphValidationTest, OptionalInputNotProvidedForSubgraphCalculator) { output_stream: "OUTPUT:foo_out" node { calculator: "PassThroughGraph" - output_stream: "OUTPUT:foo_out" # std::string + output_stream: "OUTPUT:foo_out" # string } )pb"); @@ -406,10 +406,10 @@ TEST(GraphValidationTest, MultipleOptionalInputsForSubgraph) { output_stream: "OUTPUT:output_0" node { calculator: "OptionalSideInputTestCalculator" - input_side_packet: "SIDEINPUT:input_0" # std::string + input_side_packet: "SIDEINPUT:input_0" # string input_stream: "SELECT:select" input_stream: "ENABLE:enable" - output_stream: "OUTPUT:output_0" # std::string + output_stream: "OUTPUT:output_0" # string } )pb"); @@ -421,7 +421,7 @@ TEST(GraphValidationTest, MultipleOptionalInputsForSubgraph) { node { calculator: "PassThroughGraph" input_stream: "SELECT:foo_select" - output_stream: "OUTPUT:foo_out" # std::string + output_stream: "OUTPUT:foo_out" # string } )pb"); diff --git a/mediapipe/framework/input_stream_handler.h b/mediapipe/framework/input_stream_handler.h index 1aa319438..798f89f36 100644 --- a/mediapipe/framework/input_stream_handler.h +++ b/mediapipe/framework/input_stream_handler.h @@ -147,8 +147,7 @@ class InputStreamHandler { void Close(); - // Returns a std::string that concatenates the stream names of all managed - // streams. + // Returns a string that concatenates the stream names of all managed streams. std::string DebugStreamNames() const; // Keeps scheduling new invocations until 1) the node is not ready or 2) the diff --git a/mediapipe/framework/input_stream_shard.h b/mediapipe/framework/input_stream_shard.h index 8e5951b14..375ca83eb 100644 --- a/mediapipe/framework/input_stream_shard.h +++ b/mediapipe/framework/input_stream_shard.h @@ -51,7 +51,7 @@ class InputStreamShard : public InputStream { return !packet_queue_.empty() ? packet_queue_.front() : empty_packet_; } - // Returns a reference to the name std::string of the InputStreamManager. + // Returns a reference to the name string of the InputStreamManager. const std::string& Name() const { return *name_; } bool IsDone() const override { return is_done_; } @@ -75,7 +75,7 @@ class InputStreamShard : public InputStream { std::queue packet_queue_; Packet empty_packet_; - // Pointer to the name std::string of the InputStreamManager. + // Pointer to the name string of the InputStreamManager. const std::string* name_; bool is_done_; diff --git a/mediapipe/framework/lifetime_tracker.h b/mediapipe/framework/lifetime_tracker.h index 4a90d470b..b5dabada2 100644 --- a/mediapipe/framework/lifetime_tracker.h +++ b/mediapipe/framework/lifetime_tracker.h @@ -18,6 +18,7 @@ #include #include "absl/memory/memory.h" +#include "absl/synchronization/mutex.h" namespace mediapipe { @@ -33,9 +34,13 @@ class LifetimeTracker { class Object { public: explicit Object(LifetimeTracker* tracker) : tracker_(tracker) { + absl::MutexLock lock(&tracker_->mutex_); ++tracker_->live_count_; } - ~Object() { --tracker_->live_count_; } + ~Object() { + absl::MutexLock lock(&tracker_->mutex_); + --tracker_->live_count_; + } private: LifetimeTracker* const tracker_; @@ -47,10 +52,26 @@ class LifetimeTracker { } // Returns the number of tracked objects currently alive. - int live_count() { return live_count_; } + int live_count() { + absl::MutexLock lock(&mutex_); + return live_count_; + } + + // Waits for all instances of Object to be destroyed / live_count to reach + // zero. Returns true if this occurred within the timeout, false otherwise. + bool WaitForAllObjectsToDie( + absl::Duration timeout = absl::InfiniteDuration()) { + // Condition takes a function pointer. Prefixing the lambda with a + + // resolves it to a pointer. + absl::Condition check_count( + +[](int* value) { return *value == 0; }, &live_count_); + absl::MutexLock lock(&mutex_); + return mutex_.AwaitWithTimeout(check_count, timeout); + } private: - std::atomic live_count_ = ATOMIC_VAR_INIT(0); + absl::Mutex mutex_; + int live_count_ ABSL_GUARDED_BY(mutex_) = 0; }; } // namespace mediapipe diff --git a/mediapipe/framework/mediapipe_cc_test.bzl b/mediapipe/framework/mediapipe_cc_test.bzl index 15e691440..c953004d9 100644 --- a/mediapipe/framework/mediapipe_cc_test.bzl +++ b/mediapipe/framework/mediapipe_cc_test.bzl @@ -10,10 +10,17 @@ def mediapipe_cc_test( size = None, tags = [], timeout = None, + args = [], additional_deps = DEFAULT_ADDITIONAL_TEST_DEPS, + # ios_unit_test arguments + ios_minimum_os_version = "9.0", + # android_cc_test arguments + open_gl_driver = None, + emulator_mini_boot = True, + requires_full_emulation = True, + # wasm_web_test arguments + browsers = None, **kwargs): - # Note: additional_deps are MediaPipe-specific test support deps added by default. - # They are provided as a default argument so they can be disabled if desired. native.cc_library( name = name + "_lib", testonly = 1, diff --git a/mediapipe/framework/more_selects.bzl b/mediapipe/framework/more_selects.bzl new file mode 100644 index 000000000..0b321bb81 --- /dev/null +++ b/mediapipe/framework/more_selects.bzl @@ -0,0 +1,76 @@ +"""More utilities to help with selects.""" + +load("@bazel_skylib//lib:selects.bzl", "selects") + +# From selects.bzl, but it's not public there. +def _config_setting_always_true(name, visibility): + """Returns a config_setting with the given name that's always true. + + This is achieved by constructing a two-entry OR chain where each + config_setting takes opposite values of a boolean flag. + """ + name_on = name + "_stamp_binary_on_check" + name_off = name + "_stamp_binary_off_check" + native.config_setting( + name = name_on, + values = {"stamp": "1"}, + ) + native.config_setting( + name = name_off, + values = {"stamp": "0"}, + ) + return selects.config_setting_group( + name = name, + visibility = visibility, + match_any = [ + ":" + name_on, + ":" + name_off, + ], + ) + +def _config_setting_always_false(name, visibility): + """Returns a config_setting with the given name that's always false. + + This is achieved by constructing a two-entry AND chain where each + config_setting takes opposite values of a boolean flag. + """ + name_on = name + "_stamp_binary_on_check" + name_off = name + "_stamp_binary_off_check" + native.config_setting( + name = name_on, + values = {"stamp": "1"}, + ) + native.config_setting( + name = name_off, + values = {"stamp": "0"}, + ) + return selects.config_setting_group( + name = name, + visibility = visibility, + match_all = [ + ":" + name_on, + ":" + name_off, + ], + ) + +def _config_setting_negation(name, negate, visibility = None): + _config_setting_always_true( + name = name + "_true", + visibility = visibility, + ) + _config_setting_always_false( + name = name + "_false", + visibility = visibility, + ) + native.alias( + name = name, + actual = select({ + "//conditions:default": ":%s_true" % name, + negate: ":%s_false" % name, + }), + visibility = visibility, + ) + +more_selects = struct( + config_setting_negation = _config_setting_negation, +) diff --git a/mediapipe/framework/packet.cc b/mediapipe/framework/packet.cc index 878192a24..1fbd55e97 100644 --- a/mediapipe/framework/packet.cc +++ b/mediapipe/framework/packet.cc @@ -113,7 +113,8 @@ absl::Status Packet::ValidateAsType(const tool::TypeInfo& type_info) const { MediaPipeTypeStringOrDemangled(type_info), ", but received an empty Packet.")); } - bool holder_is_right_type = holder_->GetTypeId() == type_info.hash_code(); + bool holder_is_right_type = + holder_->GetTypeInfo().hash_code() == type_info.hash_code(); if (ABSL_PREDICT_FALSE(!holder_is_right_type)) { return absl::InvalidArgumentError(absl::StrCat( "The Packet stores \"", holder_->DebugTypeName(), "\", but \"", diff --git a/mediapipe/framework/packet.h b/mediapipe/framework/packet.h index c0de3a03d..4b0e48fbc 100644 --- a/mediapipe/framework/packet.h +++ b/mediapipe/framework/packet.h @@ -189,7 +189,11 @@ class Packet { // Get the type id for the underlying type stored in the Packet. // Crashes if IsEmpty() == true. - size_t GetTypeId() const; + size_t GetTypeId() const { return GetTypeInfo().hash_code(); } + + // Get the type info for the underlying type stored in the Packet. + // Crashes if IsEmpty() == true. + const tool::TypeInfo& GetTypeInfo() const; // Returns the timestamp. class Timestamp Timestamp() const; @@ -201,9 +205,9 @@ class Packet { // Returns the type name. If the packet is empty or the type is not // registered (with MEDIAPIPE_REGISTER_TYPE or companion macros) then - // the empty std::string is returned. + // the empty string is returned. std::string RegisteredTypeName() const; - // Returns a std::string with the best guess at the type name. + // Returns a string with the best guess at the type name. std::string DebugTypeName() const; private: @@ -220,6 +224,7 @@ class Packet { friend std::shared_ptr packet_internal::GetHolderShared(Packet&& packet); + friend class PacketType; absl::Status ValidateAsType(const tool::TypeInfo& type_info) const; std::shared_ptr holder_; @@ -364,15 +369,15 @@ class HolderBase { virtual ~HolderBase(); template bool PayloadIsOfType() const { - return GetTypeId() == tool::GetTypeHash(); + return GetTypeInfo().hash_code() == tool::GetTypeHash(); } - // Returns a printable std::string identifying the type stored in the holder. + // Returns a printable string identifying the type stored in the holder. virtual const std::string DebugTypeName() const = 0; // Returns the registered type name if it's available, otherwise the - // empty std::string. + // empty string. virtual const std::string RegisteredTypeName() const = 0; // Get the type id of the underlying data type. - virtual size_t GetTypeId() const = 0; + virtual const tool::TypeInfo& GetTypeInfo() const = 0; // Downcasts this to Holder. Returns nullptr if deserialization // failed or if the requested type is not what is stored. template @@ -440,7 +445,7 @@ ConvertToVectorOfProtoMessageLitePtrs(const T* data, } // This registry is used to create Holders of the right concrete C++ type given -// a proto type std::string (which is used as the registration key). +// a proto type string (which is used as the registration key). class MessageHolderRegistry : public GlobalFactoryRegistry> {}; @@ -505,7 +510,7 @@ class Holder : public HolderBase { HolderSupport::EnsureStaticInit(); return *ptr_; } - size_t GetTypeId() const final { return tool::GetTypeHash(); } + const tool::TypeInfo& GetTypeInfo() const final { return tool::TypeId(); } // Releases the underlying data pointer and transfers the ownership to a // unique pointer. // This method is dangerous and is only used by Packet::Consume() if the @@ -741,9 +746,9 @@ inline Packet& Packet::operator=(Packet&& packet) { inline bool Packet::IsEmpty() const { return holder_ == nullptr; } -inline size_t Packet::GetTypeId() const { +inline const tool::TypeInfo& Packet::GetTypeInfo() const { CHECK(holder_); - return holder_->GetTypeId(); + return holder_->GetTypeInfo(); } template diff --git a/mediapipe/framework/packet_type.cc b/mediapipe/framework/packet_type.cc index bbcd84d80..c633d17a8 100644 --- a/mediapipe/framework/packet_type.cc +++ b/mediapipe/framework/packet_type.cc @@ -19,57 +19,55 @@ #include #include +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/map_util.h" #include "mediapipe/framework/port/source_location.h" #include "mediapipe/framework/port/status_builder.h" #include "mediapipe/framework/tool/status_util.h" +#include "mediapipe/framework/tool/type_util.h" #include "mediapipe/framework/tool/validate_name.h" +#include "mediapipe/framework/type_map.h" namespace mediapipe { -PacketType::PacketType() - : initialized_(false), - no_packets_allowed_(true), - validate_method_(nullptr), - type_name_("[Undefined Type]"), - same_as_(nullptr) {} +absl::Status PacketType::AcceptAny(const TypeSpec& type) { + return absl::OkStatus(); +} + +absl::Status PacketType::AcceptNone(const TypeSpec& type) { + auto* special = absl::get_if(&type); + if (special && + (special->accept_fn_ == AcceptNone || special->accept_fn_ == AcceptAny)) + return absl::OkStatus(); + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "No packets are allowed for type: [No Type]"; +} PacketType& PacketType::SetAny() { - no_packets_allowed_ = false; - validate_method_ = nullptr; - same_as_ = nullptr; - type_name_ = "[Any Type]"; - initialized_ = true; + type_spec_ = SpecialType{"[Any Type]", &AcceptAny}; return *this; } PacketType& PacketType::SetNone() { - no_packets_allowed_ = true; - validate_method_ = nullptr; - same_as_ = nullptr; - type_name_ = "[No Type]"; - initialized_ = true; + type_spec_ = SpecialType{"[No Type]", &AcceptNone}; return *this; } PacketType& PacketType::SetSameAs(const PacketType* type) { // TODO Union sets together when SetSameAs is called multiple times. - no_packets_allowed_ = false; - validate_method_ = nullptr; - same_as_ = type->GetSameAs(); - type_name_ = ""; - - if (same_as_ == this) { + auto same_as = type->GetSameAs(); + if (same_as == this) { // We're the root of the union-find tree. There's a cycle, which // means we might as well be an "Any" type. - same_as_ = nullptr; + return SetAny(); } - - initialized_ = true; + type_spec_ = SameAs{same_as}; return *this; } @@ -78,10 +76,19 @@ PacketType& PacketType::Optional() { return *this; } -bool PacketType::IsInitialized() const { return initialized_; } +bool PacketType::IsInitialized() const { + return !absl::holds_alternative(type_spec_); +} + +const PacketType* PacketType::SameAsPtr() const { + auto* same_as = absl::get_if(&type_spec_); + if (same_as) return same_as->other; + return nullptr; +} PacketType* PacketType::GetSameAs() { - if (!same_as_) { + auto* same_as = SameAsPtr(); + if (!same_as) { return this; } // Don't optimize the union-find algorithm, since updating the pointer @@ -91,89 +98,174 @@ PacketType* PacketType::GetSameAs() { // make the current set point to the root of the other tree. // TODO Remove const_cast by making SetSameAs take a non-const // PacketType*. - return const_cast(same_as_->GetSameAs()); + return const_cast(same_as->GetSameAs()); } const PacketType* PacketType::GetSameAs() const { - if (!same_as_) { + auto* same_as = SameAsPtr(); + if (!same_as) { return this; } // See comments in non-const variant. - return same_as_->GetSameAs(); + return same_as->GetSameAs(); } bool PacketType::IsAny() const { - return !no_packets_allowed_ && validate_method_ == nullptr && - same_as_ == nullptr; + auto* special = absl::get_if(&type_spec_); + return special && special->accept_fn_ == AcceptAny; } -bool PacketType::IsNone() const { return no_packets_allowed_; } +bool PacketType::IsNone() const { + auto* special = absl::get_if(&type_spec_); + // The tests currently require that an uninitialized PacketType return true + // for IsNone. TODO: change it? + return !IsInitialized() || (special && special->accept_fn_ == AcceptNone); +} + +bool PacketType::IsOneOf() const { + return absl::holds_alternative(type_spec_); +} + +bool PacketType::IsExactType() const { + return absl::holds_alternative(type_spec_); +} const std::string* PacketType::RegisteredTypeName() const { - if (same_as_) { - return GetSameAs()->RegisteredTypeName(); - } - return registered_type_name_ptr_; + if (auto* same_as = SameAsPtr()) return same_as->RegisteredTypeName(); + if (auto* type_info = absl::get_if(&type_spec_)) + return MediaPipeTypeStringFromTypeId((**type_info).hash_code()); + if (auto* multi_type = absl::get_if(&type_spec_)) + return multi_type->registered_type_name; + return nullptr; } -const std::string PacketType::DebugTypeName() const { - if (same_as_) { +namespace internal { + +struct TypeInfoFormatter { + void operator()(std::string* out, const tool::TypeInfo& t) const { + absl::StrAppend(out, MediaPipeTypeStringOrDemangled(t)); + } +}; + +template +class QuoteFormatter { + public: + explicit QuoteFormatter(Formatter&& f) : f_(std::forward(f)) {} + + template + void operator()(std::string* out, const T& t) const { + absl::StrAppend(out, "\""); + f_(out, t); + absl::StrAppend(out, "\""); + } + + private: + Formatter f_; +}; +template +explicit QuoteFormatter(Formatter f) -> QuoteFormatter; + +} // namespace internal + +std::string PacketType::TypeNameForOneOf(TypeInfoSpan types) { + return absl::StrCat( + "OneOf<", + absl::StrJoin(types, ", ", + absl::DereferenceFormatter(internal::TypeInfoFormatter())), + ">"); +} + +std::string PacketType::DebugTypeName() const { + if (auto* same_as = absl::get_if(&type_spec_)) { // Construct a name based on the current chain of same_as_ links // (which may change when the framework expands out Any-type). - return absl::StrCat("[Same Type As ", GetSameAs()->DebugTypeName(), "]"); + return absl::StrCat("[Same Type As ", + same_as->other->GetSameAs()->DebugTypeName(), "]"); } - return type_name_; + if (auto* special = absl::get_if(&type_spec_)) { + return special->name_; + } + if (auto* type_info = absl::get_if(&type_spec_)) { + return MediaPipeTypeStringOrDemangled(**type_info); + } + if (auto* multi_type = absl::get_if(&type_spec_)) { + return TypeNameForOneOf(multi_type->types); + } + return "[Undefined Type]"; +} + +static bool HaveCommonType(absl::Span types1, + absl::Span types2) { + for (const auto& first : types1) { + for (const auto& second : types2) { + if (first->hash_code() == second->hash_code()) { + return true; + } + } + } + return false; } absl::Status PacketType::Validate(const Packet& packet) const { - if (!initialized_) { + if (!IsInitialized()) { return absl::InvalidArgumentError( "Uninitialized PacketType was used for validation."); } - if (same_as_) { + if (SameAsPtr()) { // Cycles are impossible at this stage due to being checked for // in SetSameAs(). return GetSameAs()->Validate(packet); } - if (no_packets_allowed_) { - return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) - << "No packets are allowed for type: " << type_name_; + if (auto* type_info = absl::get_if(&type_spec_)) { + return packet.ValidateAsType(**type_info); } - if (validate_method_ != nullptr) { - return (packet.*validate_method_)(); - } - // The PacketType is the Any Type. if (packet.IsEmpty()) { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) - << "Empty packets are not allowed for type: " << type_name_; + << "Empty packets are not allowed for type: " << DebugTypeName(); + } + if (auto* multi_type = absl::get_if(&type_spec_)) { + auto* packet_type = &packet.GetTypeInfo(); + if (HaveCommonType(multi_type->types, absl::MakeSpan(&packet_type, 1))) { + return absl::OkStatus(); + } else { + return absl::InvalidArgumentError(absl::StrCat( + "The Packet stores \"", packet.DebugTypeName(), "\", but one of ", + absl::StrJoin(multi_type->types, ", ", + absl::DereferenceFormatter(internal::QuoteFormatter( + internal::TypeInfoFormatter()))), + " was requested.")); + } + } + if (auto* special = absl::get_if(&type_spec_)) { + return special->accept_fn_(&packet.GetTypeInfo()); } return absl::OkStatus(); } +PacketType::TypeInfoSpan PacketType::GetTypeSpan(const TypeSpec& type_spec) { + if (auto* type_info = absl::get_if(&type_spec)) + return absl::MakeSpan(type_info, 1); + if (auto* multi_type = absl::get_if(&type_spec)) + return multi_type->types; + return {}; +} + bool PacketType::IsConsistentWith(const PacketType& other) const { const PacketType* type1 = GetSameAs(); const PacketType* type2 = other.GetSameAs(); - if (type1->validate_method_ == nullptr || - type2->validate_method_ == nullptr) { - // type1 or type2 either accepts anything or nothing. - if (type1->validate_method_ == nullptr && !type1->no_packets_allowed_) { - // type1 accepts anything. - return true; - } - if (type2->validate_method_ == nullptr && !type2->no_packets_allowed_) { - // type2 accepts anything. - return true; - } - if (type1->no_packets_allowed_ && type2->no_packets_allowed_) { - // type1 and type2 both accept nothing. - return true; - } - // The only special case left is that only one of "type1" or "type2" - // accepts nothing, which means there is no match. - return false; + TypeInfoSpan types1 = GetTypeSpan(type1->type_spec_); + TypeInfoSpan types2 = GetTypeSpan(type2->type_spec_); + if (!types1.empty() && !types2.empty()) { + return HaveCommonType(types1, types2); } - return type1->validate_method_ == type2->validate_method_; + if (auto* special1 = absl::get_if(&type1->type_spec_)) { + return special1->accept_fn_(type2->type_spec_).ok(); + } + if (auto* special2 = absl::get_if(&type2->type_spec_)) { + return special2->accept_fn_(type1->type_spec_).ok(); + } + return false; } absl::Status ValidatePacketTypeSet(const PacketTypeSet& packet_type_set) { diff --git a/mediapipe/framework/packet_type.h b/mediapipe/framework/packet_type.h index 676119f28..738141a29 100644 --- a/mediapipe/framework/packet_type.h +++ b/mediapipe/framework/packet_type.h @@ -23,12 +23,16 @@ #include #include "absl/base/macros.h" +#include "absl/status/status.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "mediapipe/framework/collection.h" +#include "mediapipe/framework/deps/no_destructor.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_set.h" #include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/tool/type_util.h" #include "mediapipe/framework/tool/validate_name.h" #include "mediapipe/framework/type_map.h" @@ -41,7 +45,7 @@ namespace mediapipe { class PacketType { public: // Creates an uninitialized PacketType. - PacketType(); + PacketType() = default; // PacketType can be passed by value. PacketType(const PacketType&) = default; @@ -63,6 +67,9 @@ class PacketType { // Specifically, using SetAny() still means that the stream has a type // but this particular calculator just doesn't care what it is. PacketType& SetAny(); + // Sets the packet type to accept any of the provided types. + template + PacketType& SetOneOf(); // Sets the packet type to not accept any packets. PacketType& SetNone(); // Sets the PacketType to be the same as type. This actually stores @@ -80,6 +87,11 @@ class PacketType { bool IsAny() const; // Returns true if this PacketType allows nothing. bool IsNone() const; + // Returns true if this PacketType allows a set of types. + bool IsOneOf() const; + // Returns true if this PacketType allows one specific type. + bool IsExactType() const; + // Returns true if this port has been marked as optional. bool IsOptional() const { return optional_; } // Returns true iff this and other are consistent, meaning they do @@ -101,26 +113,38 @@ class PacketType { const std::string* RegisteredTypeName() const; // Returns the type name. Do not use this for validation, use // Validate() instead. - const std::string DebugTypeName() const; + std::string DebugTypeName() const; private: - // Typedef for the ValidateAsType() method in Packet that is used for - // type validation and identification. - typedef absl::Status (Packet::*ValidateMethodType)() const; + struct SameAs { + // This PacketType is the same as other. + // We don't do union-find optimizations in order to avoid a mutex. + const PacketType* other; + }; + using TypeInfoSpan = absl::Span; + struct MultiType { + TypeInfoSpan types; + // TODO: refactor RegisteredTypeName, remove. + const std::string* registered_type_name; + }; + struct SpecialType; + using TypeSpec = absl::variant; + typedef absl::Status (*AcceptsTypeFn)(const TypeSpec& type); + struct SpecialType { + std::string name_; + AcceptsTypeFn accept_fn_; + }; + + static absl::Status AcceptAny(const TypeSpec& type); + static absl::Status AcceptNone(const TypeSpec& type); + + const PacketType* SameAsPtr() const; + static TypeInfoSpan GetTypeSpan(const TypeSpec& type_spec); + static std::string TypeNameForOneOf(TypeInfoSpan types); + + TypeSpec type_spec_; - // Records whether the packet type was set in any way. - bool initialized_; - // Don't allow any packets through. - bool no_packets_allowed_; - // Pointer to Packet::ValidateAsType. - ValidateMethodType validate_method_; - // Type name as std::string. - std::string type_name_; - // The Registered type name or nullptr if the type isn't registered. - const std::string* registered_type_name_ptr_ = nullptr; - // If this is non-null then this PacketType is the same as same_as_. - // We don't do union-find optimizations in order to avoid a mutex. - const PacketType* same_as_; // Whether the corresponding port is optional. bool optional_ = false; }; @@ -164,7 +188,7 @@ class PacketTypeSetErrorHandler { for (const auto& entry : missing_->entries) { // Optional entries that were missing are not considered errors. if (!entry.second.IsOptional()) { - // Split them to keep the error std::string unchanged. + // Split them to keep the error string unchanged. std::pair tag_idx = absl::StrSplit(entry.first, ':'); missing_->errors.push_back(absl::StrCat("Failed to get tag \"", @@ -235,12 +259,16 @@ absl::Status ValidatePacketTypeSet(const PacketTypeSet& packet_type_set); template PacketType& PacketType::Set() { - initialized_ = true; - no_packets_allowed_ = false; - validate_method_ = &Packet::ValidateAsType; - type_name_ = MediaPipeTypeStringOrDemangled(); - registered_type_name_ptr_ = MediaPipeTypeString(); - same_as_ = nullptr; + type_spec_ = &tool::TypeId(); + return *this; +} + +template +PacketType& PacketType::SetOneOf() { + static const NoDestructor> types{ + {&tool::TypeId()...}}; + static const NoDestructor name{TypeNameForOneOf(*types)}; + type_spec_ = MultiType{*types, &*name}; return *this; } diff --git a/mediapipe/framework/port.h b/mediapipe/framework/port.h index fc034ce26..e8fde0f39 100644 --- a/mediapipe/framework/port.h +++ b/mediapipe/framework/port.h @@ -60,18 +60,23 @@ #define MEDIAPIPE_OPENGL_ES_30 300 #define MEDIAPIPE_OPENGL_ES_31 310 +// NOTE: MEDIAPIPE_OPENGL_ES_VERSION macro represents the maximum OpenGL ES +// version to build for. Runtime availability is _not_ guaranteed; in +// particular, uses of OpenGL ES 3.1 should be guarded by a runtime check. +// TODO: identify and fix code where macro is used incorrectly. #if MEDIAPIPE_DISABLE_GPU #define MEDIAPIPE_OPENGL_ES_VERSION 0 #define MEDIAPIPE_METAL_ENABLED 0 #else #if defined(MEDIAPIPE_ANDROID) #if defined(MEDIAPIPE_DISABLE_GL_COMPUTE) -#define MEDIAPIPE_OPENGL_ES_VERSION MEDIAPIPE_OPENGL_ES_20 +#define MEDIAPIPE_OPENGL_ES_VERSION MEDIAPIPE_OPENGL_ES_30 #else #define MEDIAPIPE_OPENGL_ES_VERSION MEDIAPIPE_OPENGL_ES_31 #endif #define MEDIAPIPE_METAL_ENABLED 0 #elif defined(MEDIAPIPE_IOS) +// TODO: use MEDIAPIPE_OPENGL_ES_30 for iOS as max version. #define MEDIAPIPE_OPENGL_ES_VERSION MEDIAPIPE_OPENGL_ES_20 #define MEDIAPIPE_METAL_ENABLED 1 #elif defined(MEDIAPIPE_OSX) diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index 4d928bb42..0dc492975 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -89,9 +89,11 @@ cc_library( cc_library( name = "graph_profiler_real", srcs = [ - "gl_context_profiler.cc", "graph_profiler.cc", - ], + ] + select({ + "//conditions:default": ["gl_context_profiler.cc"], + "//mediapipe/gpu:disable_gpu": [], + }), hdrs = [ "graph_profiler.h", ], @@ -100,31 +102,37 @@ cc_library( ], visibility = ["//visibility:private"], deps = [ - ":graph_tracer", ":profiler_resource_util", - ":sharded_map", + ":graph_tracer", ":trace_buffer", + ":sharded_map", "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_profile_cc_proto", + "//mediapipe/framework/port:integral_types", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "//mediapipe/framework/deps:clock", + "//mediapipe/framework:calculator_context", "//mediapipe/framework:executor", "//mediapipe/framework:validated_graph_config", - "//mediapipe/framework/deps:clock", - "//mediapipe/framework/port:advanced_proto_lite", - "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/tool:tag_map", + "//mediapipe/framework/tool:validate_name", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:re2", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/framework/port:advanced_proto_lite", "//mediapipe/framework/tool:name_util", - "//mediapipe/framework/tool:tag_map", - "//mediapipe/framework/tool:validate_name", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:optional", - ], + ] + select({ + "//conditions:default": [], + }) + select({ + "//conditions:default": [ + ], + "//mediapipe/gpu:disable_gpu": [], + }), ) cc_library( @@ -270,6 +278,7 @@ cc_library( name = "profiler_resource_util", srcs = ["profiler_resource_util_common.cc"] + select({ "//conditions:default": ["profiler_resource_util.cc"], + "//mediapipe/framework:android_no_jni": ["profiler_resource_util_android_hal.cc"], "//mediapipe:android": ["profiler_resource_util_android.cc"], "//mediapipe:ios": ["profiler_resource_util_ios.cc"], }), @@ -295,6 +304,9 @@ cc_library( "//conditions:default": [ "//mediapipe/framework/port:file_helpers", ], + "//mediapipe/framework:android_no_jni": [ + "//mediapipe/framework/port:file_helpers", + ], "//mediapipe:android": [ "//mediapipe/java/com/google/mediapipe/framework/jni:jni_util", "//mediapipe/framework/port:file_helpers", diff --git a/mediapipe/framework/profiler/graph_profiler.cc b/mediapipe/framework/profiler/graph_profiler.cc index 519c720b4..a5c3254b3 100644 --- a/mediapipe/framework/profiler/graph_profiler.cc +++ b/mediapipe/framework/profiler/graph_profiler.cc @@ -600,12 +600,17 @@ void AssignNodeNames(GraphProfile* profile) { if (graph_trace) { graph_trace->clear_calculator_name(); } + std::vector canonical_names; + canonical_names.reserve(graph_config->node().size()); for (int i = 0; i < graph_config->node().size(); ++i) { - std::string node_name = CanonicalNodeName(*graph_config, i); - graph_config->mutable_node(i)->set_name(node_name); - if (graph_trace) { - graph_trace->add_calculator_name(node_name); - } + canonical_names.push_back(CanonicalNodeName(*graph_config, i)); + } + for (int i = 0; i < graph_config->node().size(); ++i) { + graph_config->mutable_node(i)->set_name(canonical_names[i]); + } + if (graph_trace) { + graph_trace->mutable_calculator_name()->Assign(canonical_names.begin(), + canonical_names.end()); } } @@ -646,7 +651,8 @@ absl::StatusOr GraphProfiler::GetTraceLogPath() { } } -absl::Status GraphProfiler::CaptureProfile(GraphProfile* result) { +absl::Status GraphProfiler::CaptureProfile( + GraphProfile* result, PopulateGraphConfig populate_config) { // Record the GraphTrace events since the previous WriteProfile. // The end_time is chosen to be trace_log_margin_usec in the past, // providing time for events to be appended to the TraceBuffer. @@ -674,6 +680,10 @@ absl::Status GraphProfiler::CaptureProfile(GraphProfile* result) { } this->Reset(); CleanCalculatorProfiles(result); + if (populate_config == PopulateGraphConfig::kFull) { + *result->mutable_config() = validated_graph_->Config(); + AssignNodeNames(result); + } return status; } @@ -686,7 +696,7 @@ absl::Status GraphProfiler::WriteProfile() { int log_interval_count = GetLogIntervalCount(profiler_config_); int log_file_count = GetLogFileCount(profiler_config_); GraphProfile profile; - MP_RETURN_IF_ERROR(CaptureProfile(&profile)); + MP_RETURN_IF_ERROR(CaptureProfile(&profile, PopulateGraphConfig::kNo)); // If there are no trace events, skip log writing. const GraphTrace& trace = *profile.graph_trace().rbegin(); diff --git a/mediapipe/framework/profiler/graph_profiler.h b/mediapipe/framework/profiler/graph_profiler.h index f13831908..29969af2e 100644 --- a/mediapipe/framework/profiler/graph_profiler.h +++ b/mediapipe/framework/profiler/graph_profiler.h @@ -71,6 +71,9 @@ struct PacketInfo { // For testing class GraphProfilerTestPeer; +// GraphProfiler::CaptureProfile option, see the method for details. +enum class PopulateGraphConfig { kNo, kFull }; + // GraphProfiler keeps track of the following in microseconds based on the // profiler clock, for each calculator // - Open(), Process(), and Close() runtime. @@ -145,7 +148,14 @@ class GraphProfiler : public std::enable_shared_from_this { // Records recent profiling and tracing data. Includes events since the // previous call to CaptureProfile. - absl::Status CaptureProfile(GraphProfile* result); + // + // If `populate_config` is `kFull`, `config` field of the resulting profile + // will contain canonicalized config of the profiled graph, and + // `graph_trace.calculator_name` will contain node names referring to that + // config. Both fields are left empty if the option is set to `kNo`. + absl::Status CaptureProfile( + GraphProfile* result, + PopulateGraphConfig populate_config = PopulateGraphConfig::kNo); // Writes recent profiling and tracing data to a file specified in the // ProfilerConfig. Includes events since the previous call to WriteProfile. @@ -356,6 +366,85 @@ class ProfilingContext : public GraphProfiler { // For now, OSS always uses GlContextProfilerStub. // TODO: Switch to GlContextProfiler when GlContext is moved to OSS. +#define MEDIAPIPE_DISABLE_GPU_PROFILER 1 + +// GlContextProfiler keeps track of all timestamp queries within a specific +// GlContext object. When created, the GlContextProfiler must be initialized +// before marking timestamps. Finally, when GlContext is no longer interested +// in marking timestamps or is about to be destroyed, Finish() must be called +// to complete all pending time queries and detach the timer from the GlContext. +// Note that the GlContextProfiler must be created and initialized within a +// valid GlContext object. +#if !MEDIAPIPE_DISABLE_GPU_PROFILER +class GlContextProfiler { + public: + explicit GlContextProfiler( + std::shared_ptr profiling_context) + : profiling_context_(profiling_context) {} + + // Not copyable or movable. + GlContextProfiler(const GlContextProfiler&) = delete; + GlContextProfiler& operator=(const GlContextProfiler&) = delete; + + // Add a GlTimingInfo object to the collection of pending timestamp queries + // associated with a specific graph node_id, packet input_timestamp and mark + // if it is a start or stop event. When a stop event is marked, this function + // blocks on the corresponding start event to complete. + void MarkTimestamp(int node_id, Timestamp input_timestamp, bool is_finish); + + // Complete all pending timing queries and detach the timer from the + // GlContext. + void LogAllTimestamps(); + + private: + // Store GlTimeQuery and the corresponding TraceEvent object that should be + // populated when the query completes together. + struct GlTimingInfo { + GlTimeQuery time_query; + TraceEvent trace_event; + }; + + // Setup the timer for marking GPU timestamps. If successful in setup, return + // true otherwise return false to indicate that timing measurment is not + // supported. + bool Initialize(); + + absl::Time TimeNow(); + + // Calibrate the GPU timer w.r.t. the CPU clock. If calibration is fails, + // timing_measurement_supported_ is set to false. + void CalibrateTimer(bool recalibrate); + + // Log a TraceEvent object to represent if the GPU calibration period has + // started or just ended. + void LogCalibrationEvent(bool started, absl::Time time); + + // Log TraceEvent objects for completed time queries. If the parameter wait is + // set to true, wait for all time queries to complete before returning. + void RetireReadyGlTimings(bool wait = false); + + // Get the TraceEvent object containing the timestamp recorded by the GPU if + // the provided query was fulfilled. If it is still pending and wait is false, + // return absl::nullopt. + absl::optional GetTimeFromQuery( + std::unique_ptr& query, bool wait); + + std::shared_ptr profiling_context_; + GlSimpleTimer gl_timer_; + bool checked_timing_supported_ = false; + bool timing_measurement_supported_ = false; + std::deque> pending_gl_times_; + std::unique_ptr gl_start_query_; +}; + +// The API class used to access the preferred GlContext profiler, such as +// GlContextProfiler or GlContextProfilerStub. GlProfilingHelper is defined as +// a class rather than a typedef in order to support clients that refer +// to it only as a forward declaration. +class GlProfilingHelper : public GlContextProfiler { + using GlContextProfiler::GlContextProfiler; +}; +#else // MEDIAPIPE_DISABLE_GPU_PROFILER class GlContextProfilerStub { public: explicit GlContextProfilerStub( @@ -370,7 +459,8 @@ class GlContextProfilerStub { class GlProfilingHelper : public GlContextProfilerStub { using GlContextProfilerStub::GlContextProfilerStub; }; - +#endif // !MEDIAPIPE_DISABLE_GPU_PROFILER +#undef MEDIAPIPE_DISABLE_GPU_PROFILER } // namespace mediapipe #endif // MEDIAPIPE_FRAMEWORK_PROFILER_GRAPH_PROFILER_H_ diff --git a/mediapipe/framework/profiler/graph_profiler_stub.h b/mediapipe/framework/profiler/graph_profiler_stub.h index 6621c0192..12a024fe8 100644 --- a/mediapipe/framework/profiler/graph_profiler_stub.h +++ b/mediapipe/framework/profiler/graph_profiler_stub.h @@ -74,6 +74,9 @@ class TraceEvent { inline TraceEvent& set_event_data(int64 data) { return *this; } }; +// GraphProfiler::CaptureProfile option, see the method for details. +enum class PopulateGraphConfig { kNo, kFull }; + // Empty implementation of ProfilingContext to be used in place of the // GraphProfiler when the main implementation is disabled. class GraphProfilerStub { @@ -85,6 +88,11 @@ class GraphProfilerStub { std::vector*) const { return absl::OkStatus(); } + absl::Status CaptureProfile( + GraphProfile* result, + PopulateGraphConfig populate_config = PopulateGraphConfig::kNo) { + return absl::OkStatus(); + } inline void Pause() {} inline void Resume() {} inline void Reset() {} diff --git a/mediapipe/framework/profiler/graph_profiler_test.cc b/mediapipe/framework/profiler/graph_profiler_test.cc index c4c12beb9..81ba90cda 100644 --- a/mediapipe/framework/profiler/graph_profiler_test.cc +++ b/mediapipe/framework/profiler/graph_profiler_test.cc @@ -1267,5 +1267,49 @@ TEST(GraphProfilerTest, CalculatorProfileFilter) { EXPECT_EQ(GetCalculatorNames(config), expected_names); } +TEST(GraphProfilerTest, CaptureProfilePopulateConfig) { + CalculatorGraphConfig config; + QCHECK(proto2::TextFormat::ParseFromString(R"( + profiler_config { + enable_profiler: true + trace_enabled: true + } + input_stream: "input_stream" + node { + calculator: "DummyTestCalculator" + input_stream: "input_stream" + } + node { + calculator: "DummyTestCalculator" + input_stream: "input_stream" + } + )", + &config)); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + GraphProfile profile; + MP_ASSERT_OK( + graph.profiler()->CaptureProfile(&profile, PopulateGraphConfig::kFull)); + EXPECT_THAT(profile.config(), Partially(EqualsProto(R"pb( + input_stream: "input_stream" + node { + name: "DummyTestCalculator_1" + calculator: "DummyTestCalculator" + input_stream: "input_stream" + } + node { + name: "DummyTestCalculator_2" + calculator: "DummyTestCalculator" + input_stream: "input_stream" + } + )pb"))); + EXPECT_THAT(profile.graph_trace(), + ElementsAre(Partially(EqualsProto( + R"pb( + calculator_name: "DummyTestCalculator_1" + calculator_name: "DummyTestCalculator_2" + )pb")))); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/profiler/graph_tracer_test.cc b/mediapipe/framework/profiler/graph_tracer_test.cc index 80af064aa..c1cc819c1 100644 --- a/mediapipe/framework/profiler/graph_tracer_test.cc +++ b/mediapipe/framework/profiler/graph_tracer_test.cc @@ -1053,8 +1053,8 @@ TEST_F(GraphTracerE2ETest, DemuxGraphLogFiles) { calculator_name: "LambdaCalculator_1" calculator_name: "FlowLimiterCalculator" calculator_name: "RoundRobinDemuxCalculator" - calculator_name: "LambdaCalculator_1" - calculator_name: "LambdaCalculator" + calculator_name: "LambdaCalculator_2" + calculator_name: "LambdaCalculator_3" calculator_name: "ImmediateMuxCalculator" stream_name: "" stream_name: "input_packets_0" @@ -1198,14 +1198,14 @@ TEST_F(GraphTracerE2ETest, DemuxGraphLogFiles) { output_stream: "OUTPUT:1:input_1" } node { - name: "LambdaCalculator_1" + name: "LambdaCalculator_2" calculator: "LambdaCalculator" input_stream: "input_0" output_stream: "output_0" input_side_packet: "callback_0" } node { - name: "LambdaCalculator" + name: "LambdaCalculator_3" calculator: "LambdaCalculator" input_stream: "input_1" output_stream: "output_1" diff --git a/mediapipe/framework/profiler/profiler_resource_util_android_hal.cc b/mediapipe/framework/profiler/profiler_resource_util_android_hal.cc new file mode 100644 index 000000000..27aaf2641 --- /dev/null +++ b/mediapipe/framework/profiler/profiler_resource_util_android_hal.cc @@ -0,0 +1,9 @@ +#include "mediapipe/framework/port/statusor.h" + +namespace mediapipe { + +StatusOr GetDefaultTraceLogDirectory() { + return "/data/local/tmp"; +} + +} // namespace mediapipe diff --git a/mediapipe/framework/profiler/reporter/print_profile.cc b/mediapipe/framework/profiler/reporter/print_profile.cc index dd6e0846d..629b8f608 100644 --- a/mediapipe/framework/profiler/reporter/print_profile.cc +++ b/mediapipe/framework/profiler/reporter/print_profile.cc @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// This program takes one input file and encodes its contents as a C++ -// std::string, which can be included in a C++ source file. It is similar to -// filewrapper (and borrows some of its code), but simpler. +// This program takes one input file and encodes its contents as a C++ string, +// which can be included in a C++ source file. It is similar to filewrapper +// (and borrows some of its code), but simpler. #include #include diff --git a/mediapipe/framework/profiler/reporter/reporter.cc b/mediapipe/framework/profiler/reporter/reporter.cc index b61afa363..25628a63b 100644 --- a/mediapipe/framework/profiler/reporter/reporter.cc +++ b/mediapipe/framework/profiler/reporter/reporter.cc @@ -215,7 +215,7 @@ void CompleteCalculatorData( } void Reporter::Accumulate(const mediapipe::GraphProfile& profile) { - // Cache nodeID to its std::string name. + // Cache nodeID to its string name. NameLookup name_lookup; CacheNodeNameLookup(profile, &name_lookup); @@ -363,8 +363,8 @@ class ReportImpl : public Report { // Values for each calculator, corresponding to the label in headers(). std::vector> lines_impl; - // The longest std::string of any value in a given column (including the - // header for that column). Used for formatting the output. + // The longest string of any value in a given column (including the header + // for that column). Used for formatting the output. std::vector char_counts_impl; bool compact_flag = false; @@ -377,7 +377,7 @@ void ReportImpl::Print(std::ostream& output) { // fill space up to char_counts[column] + 1. The strings in the output // are mutable to support padding, hence no const in the for loops. int column_number = 0; - // Make a copy of the column std::string because we might be adding spaces. + // Make a copy of the column string because we might be adding spaces. for (auto column : headers_impl) { int padding_needed = char_counts_impl[column_number] + 1 - column.length(); if (compact_flag) { diff --git a/mediapipe/framework/profiler/trace_builder.cc b/mediapipe/framework/profiler/trace_builder.cc index 7381072e2..10ce879ff 100644 --- a/mediapipe/framework/profiler/trace_builder.cc +++ b/mediapipe/framework/profiler/trace_builder.cc @@ -90,11 +90,11 @@ void BasicTraceEventTypes(TraceEventRegistry* result) { } } -// A map defining int32 identifiers for std::string object pointers. -// Lookup is fast when the same std::string object is used frequently. +// A map defining int32 identifiers for string object pointers. +// Lookup is fast when the same string object is used frequently. class StringIdMap { public: - // Returns the int32 identifier for a std::string object pointer. + // Returns the int32 identifier for a string object pointer. int32 operator[](const std::string* id) { if (id == nullptr) { return 0; diff --git a/mediapipe/framework/stream_handler/sync_set_input_stream_handler_test.cc b/mediapipe/framework/stream_handler/sync_set_input_stream_handler_test.cc index e61322834..e93f806be 100644 --- a/mediapipe/framework/stream_handler/sync_set_input_stream_handler_test.cc +++ b/mediapipe/framework/stream_handler/sync_set_input_stream_handler_test.cc @@ -47,8 +47,8 @@ std::tuple> CommandTuple( return std::make_tuple(stream, timestamp, expected); } -// Function to take the inputs and produce a diagnostic output std::string -// and output a packet with a diagnostic output std::string which includes +// Function to take the inputs and produce a diagnostic output string +// and output a packet with a diagnostic output string which includes // the input timestamp and the ids of each input which is present. absl::Status InputsToDebugString(const InputStreamShardSet& inputs, OutputStreamShardSet* outputs) { diff --git a/mediapipe/framework/test_calculators.cc b/mediapipe/framework/test_calculators.cc index 543d53125..3b6602978 100644 --- a/mediapipe/framework/test_calculators.cc +++ b/mediapipe/framework/test_calculators.cc @@ -279,7 +279,7 @@ class StdDevCalculator : public CalculatorBase { REGISTER_CALCULATOR(StdDevCalculator); // A calculator that receives some number of input streams carrying ints. -// Outputs, for each input timestamp, a space separated std::string containing +// Outputs, for each input timestamp, a space separated string containing // the timestamp and all the inputs for that timestamp (Empty inputs // will be denoted with "empty"). Sets the header to be a space-separated // concatenation of the input stream headers. @@ -368,7 +368,7 @@ REGISTER_CALCULATOR(SaverCalculator); #ifndef MEDIAPIPE_MOBILE // Source Calculator that produces matrices on the output stream with -// each coefficient from a normal gaussian. A std::string seed must be given +// each coefficient from a normal gaussian. A string seed must be given // as an input side packet. class RandomMatrixCalculator : public CalculatorBase { public: diff --git a/mediapipe/framework/timestamp.h b/mediapipe/framework/timestamp.h index 03f41597f..b8c3a69a2 100644 --- a/mediapipe/framework/timestamp.h +++ b/mediapipe/framework/timestamp.h @@ -86,7 +86,7 @@ class Timestamp { // in microseconds, but this function should be preferred over Value() in case // the underlying representation changes. int64 Microseconds() const { return Value(); } - // This provides a human readable std::string for the special values. + // This provides a human readable string for the special values. std::string DebugString() const; // For use by framework. Clients or Calculator implementations should not call diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index afc24fecc..1d8b6a88c 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -43,6 +43,7 @@ bzl_library( "//mediapipe/framework:transitive_protos_bzl", "//mediapipe/framework/deps:descriptor_set_bzl", "//mediapipe/framework/deps:expand_template_bzl", + "@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite_bzl", ], ) @@ -52,6 +53,7 @@ bzl_library( "build_defs.bzl", ], visibility = [ + "//mediapipe/app/xeno/catalog:__subpackages__", "//mediapipe/framework:__subpackages__", ], ) @@ -286,6 +288,7 @@ cc_library( mediapipe_cc_test( name = "options_util_test", size = "small", + timeout = "moderate", srcs = ["options_util_test.cc"], # A non-empty "data" param is needed to build the "_test_wasm" target. data = [":node_chain_subgraph.proto"], @@ -858,11 +861,9 @@ cc_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:mediapipe_options_cc_proto", - "//mediapipe/framework:stream_handler_cc_proto", "//mediapipe/framework:subgraph", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto", "//mediapipe/framework/tool:switch_container_cc_proto", ], alwayslink = 1, diff --git a/mediapipe/framework/tool/encode_as_c_string.cc b/mediapipe/framework/tool/encode_as_c_string.cc index a202deb09..48f6c72ed 100644 --- a/mediapipe/framework/tool/encode_as_c_string.cc +++ b/mediapipe/framework/tool/encode_as_c_string.cc @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// This program takes one input file and encodes its contents as a C++ -// std::string, which can be included in a C++ source file. It is similar to -// filewrapper (and borrows some of its code), but simpler. +// This program takes one input file and encodes its contents as a C++ string, +// which can be included in a C++ source file. It is similar to filewrapper +// (and borrows some of its code), but simpler. #include #include diff --git a/mediapipe/framework/tool/mediapipe_graph.bzl b/mediapipe/framework/tool/mediapipe_graph.bzl index e3fe2b825..45d98b1eb 100644 --- a/mediapipe/framework/tool/mediapipe_graph.bzl +++ b/mediapipe/framework/tool/mediapipe_graph.bzl @@ -20,6 +20,7 @@ load("//mediapipe/framework:transitive_protos.bzl", "transitive_protos") load("//mediapipe/framework/deps:expand_template.bzl", "expand_template") load("//mediapipe/framework/tool:build_defs.bzl", "clean_dep") load("//mediapipe/framework/deps:descriptor_set.bzl", "direct_descriptor_set", "transitive_descriptor_set") +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") def mediapipe_binary_graph(name, graph = None, output_name = None, deps = [], testonly = False, **kwargs): """Converts a graph from text format to binary format.""" @@ -98,6 +99,7 @@ def mediapipe_simple_subgraph( register_as, graph, deps = [], + tflite_deps = None, visibility = None, testonly = None, **kwargs): @@ -109,6 +111,7 @@ def mediapipe_simple_subgraph( CamelCase. graph: the BUILD label of a text-format MediaPipe graph. deps: any calculators or subgraphs used by this graph. + tflite_deps: any calculators or subgraphs used by this graph that may use different TFLite implementation. visibility: The list of packages the subgraph should be visible to. testonly: pass 1 if the graph is to be used only for tests. **kwargs: Remaining keyword args, forwarded to cc_library. @@ -138,21 +141,39 @@ def mediapipe_simple_subgraph( }, testonly = testonly, ) - native.cc_library( - name = name, - srcs = [ - name + "_linked.cc", - graph_base_name + ".inc", - ], - deps = [ - clean_dep("//mediapipe/framework:calculator_framework"), - clean_dep("//mediapipe/framework:subgraph"), - ] + deps, - alwayslink = 1, - visibility = visibility, - testonly = testonly, - **kwargs - ) + if not tflite_deps: + native.cc_library( + name = name, + srcs = [ + name + "_linked.cc", + graph_base_name + ".inc", + ], + deps = [ + clean_dep("//mediapipe/framework:calculator_framework"), + clean_dep("//mediapipe/framework:subgraph"), + ] + deps, + alwayslink = 1, + visibility = visibility, + testonly = testonly, + **kwargs + ) + else: + cc_library_with_tflite( + name = name, + srcs = [ + name + "_linked.cc", + graph_base_name + ".inc", + ], + tflite_deps = tflite_deps, + deps = [ + clean_dep("//mediapipe/framework:calculator_framework"), + clean_dep("//mediapipe/framework:subgraph"), + ] + deps, + alwayslink = 1, + visibility = visibility, + testonly = testonly, + **kwargs + ) def mediapipe_reexport_library( name, diff --git a/mediapipe/framework/tool/name_util.h b/mediapipe/framework/tool/name_util.h index 207fe162c..b3a9338c1 100644 --- a/mediapipe/framework/tool/name_util.h +++ b/mediapipe/framework/tool/name_util.h @@ -85,7 +85,7 @@ std::pair ParseTagIndexFromStream(const std::string& stream); // Formats to "tag:index". std::string CatTag(const std::string& tag, int index); -// Concatenates "tag:index:name" into a single std::string. +// Concatenates "tag:index:name" into a single string. std::string CatStream(const std::pair& tag_index, const std::string& name); diff --git a/mediapipe/framework/tool/options_syntax_util.h b/mediapipe/framework/tool/options_syntax_util.h index 09c3fbb78..e33699569 100644 --- a/mediapipe/framework/tool/options_syntax_util.h +++ b/mediapipe/framework/tool/options_syntax_util.h @@ -34,7 +34,7 @@ class OptionsSyntaxUtil { FieldPath OptionFieldPath(absl::string_view tag, const Descriptor* descriptor); - // Splits a std::string into "tag" and "name" delimited by a single colon. + // Splits a string into "tag" and "name" delimited by a single colon. std::vector StrSplitTags(absl::string_view tag_and_name); private: diff --git a/mediapipe/framework/tool/proto_util_lite.cc b/mediapipe/framework/tool/proto_util_lite.cc index dddde09d1..b9649ce5b 100644 --- a/mediapipe/framework/tool/proto_util_lite.cc +++ b/mediapipe/framework/tool/proto_util_lite.cc @@ -228,7 +228,7 @@ absl::Status SyntaxStatus(bool ok, const std::string& text, T* result) { " for type: ", MediaPipeTypeStringOrDemangled(), ".")); } -// Templated parsing of a std::string value. +// Templated parsing of a string value. template absl::Status ParseValue(const std::string& text, T* result) { return SyntaxStatus(absl::SimpleAtoi(text, result), text, result); diff --git a/mediapipe/framework/tool/simple_subgraph_template.cc b/mediapipe/framework/tool/simple_subgraph_template.cc index 606694632..57cd36cd3 100644 --- a/mediapipe/framework/tool/simple_subgraph_template.cc +++ b/mediapipe/framework/tool/simple_subgraph_template.cc @@ -31,7 +31,7 @@ class {{SUBGRAPH_CLASS_NAME}} : public Subgraph { const SubgraphOptions& /*options*/) { CalculatorGraphConfig config; // Note: this is a binary protobuf serialization, and may include NUL - // bytes. The trailing NUL added to the std::string literal should be excluded. + // bytes. The trailing NUL added to the string literal should be excluded. if (config.ParseFromArray(binary_graph, sizeof(binary_graph) - 1)) { return config; } else { diff --git a/mediapipe/framework/tool/status_util.h b/mediapipe/framework/tool/status_util.h index 039f55609..8b4bc02d2 100644 --- a/mediapipe/framework/tool/status_util.h +++ b/mediapipe/framework/tool/status_util.h @@ -40,7 +40,7 @@ absl::Status StatusInvalid(const std::string& error_message); ABSL_DEPRECATED("Use absl::UnknownError(error_message) instead.") absl::Status StatusFail(const std::string& error_message); -// Prefixes the given std::string to the error message in status. +// Prefixes the given string to the error message in status. // This function should be considered internal to the framework. // TODO Replace usage of AddStatusPrefix with util::Annotate(). absl::Status AddStatusPrefix(const std::string& prefix, diff --git a/mediapipe/framework/tool/switch_container.cc b/mediapipe/framework/tool/switch_container.cc index 164016ab9..5470f33c6 100644 --- a/mediapipe/framework/tool/switch_container.cc +++ b/mediapipe/framework/tool/switch_container.cc @@ -22,8 +22,6 @@ #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" -#include "mediapipe/framework/stream_handler.pb.h" -#include "mediapipe/framework/stream_handler/sync_set_input_stream_handler.pb.h" #include "mediapipe/framework/tool/container_util.h" #include "mediapipe/framework/tool/name_util.h" #include "mediapipe/framework/tool/subgraph_expansion.h" @@ -88,12 +86,6 @@ CalculatorGraphConfig::Node* BuildDemuxNode( CalculatorGraphConfig* config) { CalculatorGraphConfig::Node* result = config->add_node(); *result->mutable_calculator() = "SwitchDemuxCalculator"; - *result->mutable_input_stream_handler()->mutable_input_stream_handler() = - "ImmediateInputStreamHandler"; - if (container_node.has_input_stream_handler()) { - *result->mutable_input_stream_handler() = - container_node.input_stream_handler(); - } return result; } @@ -103,8 +95,6 @@ CalculatorGraphConfig::Node* BuildMuxNode( CalculatorGraphConfig* config) { CalculatorGraphConfig::Node* result = config->add_node(); *result->mutable_calculator() = "SwitchMuxCalculator"; - *result->mutable_input_stream_handler()->mutable_input_stream_handler() = - "ImmediateInputStreamHandler"; return result; } diff --git a/mediapipe/framework/tool/switch_container.proto b/mediapipe/framework/tool/switch_container.proto index ac3995006..a9c2d9094 100644 --- a/mediapipe/framework/tool/switch_container.proto +++ b/mediapipe/framework/tool/switch_container.proto @@ -24,4 +24,7 @@ message SwitchContainerOptions { // Activates channel 1 for enable = true, channel 0 otherwise. optional bool enable = 4; + + // Use DefaultInputStreamHandler for muxing & demuxing. + optional bool synchronize_io = 5; } diff --git a/mediapipe/framework/tool/switch_container_test.cc b/mediapipe/framework/tool/switch_container_test.cc index 8f91d878c..5abf9fb03 100644 --- a/mediapipe/framework/tool/switch_container_test.cc +++ b/mediapipe/framework/tool/switch_container_test.cc @@ -66,8 +66,7 @@ REGISTER_CALCULATOR(TripleIntCalculator); // A testing example of a SwitchContainer containing two subnodes. // Note that the input and output tags supplied to the container node, // must match the input and output tags required by the subnodes. -CalculatorGraphConfig SubnodeContainerExample( - const std::string& input_stream_handler = "") { +CalculatorGraphConfig SubnodeContainerExample(const std::string& options = "") { std::string config = R"pb( input_stream: "foo" input_stream: "enable" @@ -80,9 +79,9 @@ CalculatorGraphConfig SubnodeContainerExample( options { [mediapipe.SwitchContainerOptions.ext] { contained_node: { calculator: "TripleIntCalculator" } - contained_node: { calculator: "PassThroughCalculator" } + contained_node: { calculator: "PassThroughCalculator" } $options } - } $input_stream_handler + } } node { calculator: "PassThroughCalculator" @@ -94,8 +93,7 @@ CalculatorGraphConfig SubnodeContainerExample( )pb"; return mediapipe::ParseTextProtoOrDie( - absl::StrReplaceAll(config, - {{"$input_stream_handler", input_stream_handler}})); + absl::StrReplaceAll(config, {{"$options", options}})); } // A testing example of a SwitchContainer containing two subnodes. @@ -248,9 +246,6 @@ TEST(SwitchContainerTest, ApplyToSubnodes) { options { [mediapipe.SwitchContainerOptions.ext] {} } - input_stream_handler { - input_stream_handler: "ImmediateInputStreamHandler" - } } node { name: "switchcontainer__TripleIntCalculator" @@ -274,9 +269,6 @@ TEST(SwitchContainerTest, ApplyToSubnodes) { options { [mediapipe.SwitchContainerOptions.ext] {} } - input_stream_handler { - input_stream_handler: "ImmediateInputStreamHandler" - } } node { calculator: "PassThroughCalculator" @@ -322,9 +314,7 @@ TEST(SwitchContainerTest, ValidateInputStreamHandler) { options { [mediapipe.SwitchContainerOptions.ext] {} } - input_stream_handler { - input_stream_handler: "ImmediateInputStreamHandler" - } + input_stream_handler { input_stream_handler: "DefaultInputStreamHandler" } } node { name: "switchcontainer__TripleIntCalculator" @@ -350,9 +340,7 @@ TEST(SwitchContainerTest, ValidateInputStreamHandler) { options { [mediapipe.SwitchContainerOptions.ext] {} } - input_stream_handler { - input_stream_handler: "ImmediateInputStreamHandler" - } + input_stream_handler { input_stream_handler: "DefaultInputStreamHandler" } } node { calculator: "PassThroughCalculator" @@ -371,83 +359,12 @@ TEST(SwitchContainerTest, ValidateInputStreamHandler) { EXPECT_THAT(graph.Config(), mediapipe::EqualsProto(expected_graph)); } -// Expands the SwitchContainer with a node-level input_stream_handler. -TEST(SwitchContainerTest, OverrideInputStreamHandler) { - EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer")); - CalculatorGraph graph; - CalculatorGraphConfig supergraph = SubnodeContainerExample( - R"pb(input_stream_handler { - input_stream_handler: "DefaultInputStreamHandler" - })pb"); - *supergraph.mutable_node(0) - ->mutable_input_stream_handler() - ->mutable_input_stream_handler() = "DefaultInputStreamHandler"; - MP_ASSERT_OK(graph.Initialize(supergraph, {})); - CalculatorGraphConfig expected_graph = - mediapipe::ParseTextProtoOrDie(R"pb( - node { - name: "switchcontainer__SwitchDemuxCalculator" - calculator: "SwitchDemuxCalculator" - input_stream: "ENABLE:enable" - input_stream: "foo" - output_stream: "C0__:switchcontainer__c0__foo" - output_stream: "C1__:switchcontainer__c1__foo" - options { - [mediapipe.SwitchContainerOptions.ext] {} - } - input_stream_handler { - input_stream_handler: "DefaultInputStreamHandler" - } - } - node { - name: "switchcontainer__TripleIntCalculator" - calculator: "TripleIntCalculator" - input_stream: "switchcontainer__c0__foo" - output_stream: "switchcontainer__c0__bar" - } - node { - name: "switchcontainer__PassThroughCalculator" - calculator: "PassThroughCalculator" - input_stream: "switchcontainer__c1__foo" - output_stream: "switchcontainer__c1__bar" - } - node { - name: "switchcontainer__SwitchMuxCalculator" - calculator: "SwitchMuxCalculator" - input_stream: "ENABLE:enable" - input_stream: "C0__:switchcontainer__c0__bar" - input_stream: "C1__:switchcontainer__c1__bar" - output_stream: "bar" - options { - [mediapipe.SwitchContainerOptions.ext] {} - } - input_stream_handler { - input_stream_handler: "ImmediateInputStreamHandler" - } - } - node { - calculator: "PassThroughCalculator" - input_stream: "foo" - input_stream: "bar" - output_stream: "output_foo" - output_stream: "output_bar" - } - input_stream: "foo" - input_stream: "enable" - executor {} - input_side_packet: "timezone" - )pb"); - EXPECT_THAT(graph.Config(), mediapipe::EqualsProto(expected_graph)); -} - -// Runs the SwitchContainer with a node-level input_stream_handler. TEST(SwitchContainerTest, RunsWithInputStreamHandler) { EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer")); - CalculatorGraphConfig supergraph = SubnodeContainerExample( - R"pb(input_stream_handler { - input_stream_handler: "DefaultInputStreamHandler" - })pb"); + CalculatorGraphConfig supergraph = + SubnodeContainerExample(R"pb(synchronize_io: true)pb"); MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph)); + LOG(INFO) << supergraph.DebugString(); RunTestContainer(supergraph, true); } @@ -470,9 +387,6 @@ TEST(SwitchContainerTest, ApplyToSideSubnodes) { options { [mediapipe.SwitchContainerOptions.ext] {} } - input_stream_handler { - input_stream_handler: "ImmediateInputStreamHandler" - } } node { name: "switchcontainer__TripleIntCalculator" @@ -496,9 +410,6 @@ TEST(SwitchContainerTest, ApplyToSideSubnodes) { options { [mediapipe.SwitchContainerOptions.ext] {} } - input_stream_handler { - input_stream_handler: "ImmediateInputStreamHandler" - } } node { calculator: "PassThroughCalculator" diff --git a/mediapipe/framework/tool/switch_demux_calculator.cc b/mediapipe/framework/tool/switch_demux_calculator.cc index 46e6c358e..c4352c871 100644 --- a/mediapipe/framework/tool/switch_demux_calculator.cc +++ b/mediapipe/framework/tool/switch_demux_calculator.cc @@ -26,6 +26,7 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/tool/container_util.h" +#include "mediapipe/framework/tool/switch_container.pb.h" namespace mediapipe { @@ -113,7 +114,10 @@ absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) { } } } - cc->SetInputStreamHandler("ImmediateInputStreamHandler"); + auto& options = cc->Options(); + if (!options.synchronize_io()) { + cc->SetInputStreamHandler("ImmediateInputStreamHandler"); + } cc->SetProcessTimestampBounds(true); return absl::OkStatus(); } diff --git a/mediapipe/framework/tool/switch_mux_calculator.cc b/mediapipe/framework/tool/switch_mux_calculator.cc index ffa611239..9982ae4f6 100644 --- a/mediapipe/framework/tool/switch_mux_calculator.cc +++ b/mediapipe/framework/tool/switch_mux_calculator.cc @@ -28,6 +28,7 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/tool/container_util.h" +#include "mediapipe/framework/tool/switch_container.pb.h" namespace mediapipe { @@ -68,6 +69,17 @@ class SwitchMuxCalculator : public CalculatorBase { private: int channel_index_; std::set channel_tags_; + mediapipe::SwitchContainerOptions options_; + // This is used to keep around packets that we've received but not + // relayed yet (because we may not know which channel we should yet be using + // when synchronized_io flag is set). + std::map> packet_history_; + // Historical channel index values for timestamps where we don't have all + // packets available yet (when synchronized_io flag is set). + std::map channel_history_; + // Number of output steams that we already processed for the current output + // timestamp. + int current_processed_stream_count_ = 0; }; REGISTER_CALCULATOR(SwitchMuxCalculator); @@ -122,6 +134,7 @@ absl::Status SwitchMuxCalculator::GetContract(CalculatorContract* cc) { } absl::Status SwitchMuxCalculator::Open(CalculatorContext* cc) { + options_ = cc->Options(); channel_index_ = tool::GetChannelIndex(*cc, channel_index_); channel_tags_ = ChannelTags(cc->Inputs().TagMap()); @@ -141,13 +154,79 @@ absl::Status SwitchMuxCalculator::Process(CalculatorContext* cc) { // Update the input channel index if specified. channel_index_ = tool::GetChannelIndex(*cc, channel_index_); - // Relay packets and timestamps only from channel_index_. - for (const std::string& tag : channel_tags_) { - for (int index = 0; index < cc->Outputs().NumEntries(tag); ++index) { - auto& output = cc->Outputs().Get(tag, index); - std::string input_tag = tool::ChannelTag(tag, channel_index_); - auto& input = cc->Inputs().Get(input_tag, index); - tool::Relay(input, &output); + if (options_.synchronize_io()) { + // Start with adding input signals into channel_history_ and packet_history_ + if (cc->Inputs().HasTag("ENABLE") && + !cc->Inputs().Tag("ENABLE").IsEmpty()) { + channel_history_[cc->Inputs().Tag("ENABLE").Value().Timestamp()] = + channel_index_; + } + if (cc->Inputs().HasTag("SELECT") && + !cc->Inputs().Tag("SELECT").IsEmpty()) { + channel_history_[cc->Inputs().Tag("SELECT").Value().Timestamp()] = + channel_index_; + } + for (auto input_id = cc->Inputs().BeginId(); + input_id < cc->Inputs().EndId(); ++input_id) { + auto& entry = cc->Inputs().Get(input_id); + if (entry.IsEmpty()) { + continue; + } + packet_history_[entry.Value().Timestamp()][input_id] = entry.Value(); + } + // Now check if we have enough information to produce any outputs. + while (!channel_history_.empty()) { + // Look at the oldest unprocessed timestamp. + auto it = channel_history_.begin(); + auto& packets = packet_history_[it->first]; + int total_streams = 0; + // Loop over all outputs to see if we have anything new that we can relay. + for (const std::string& tag : channel_tags_) { + for (int index = 0; index < cc->Outputs().NumEntries(tag); ++index) { + ++total_streams; + auto input_id = + cc->Inputs().GetId(tool::ChannelTag(tag, it->second), index); + auto packet_it = packets.find(input_id); + if (packet_it != packets.end()) { + cc->Outputs().Get(tag, index).AddPacket(packet_it->second); + ++current_processed_stream_count_; + } else if (it->first < + cc->Inputs().Get(input_id).Value().Timestamp()) { + // Getting here means that input stream that corresponds to this + // output at the timestamp we're trying to process right now has + // already advanced beyond this timestamp. This means that we will + // shouldn't expect a packet for this timestamp anymore, and we can + // safely advance timestamp on the output. + cc->Outputs() + .Get(tag, index) + .SetNextTimestampBound(it->first.NextAllowedInStream()); + ++current_processed_stream_count_; + } + } + } + if (current_processed_stream_count_ == total_streams) { + // There's nothing else to wait for at the current timestamp, do the + // cleanup and move on to the next one. + packet_history_.erase(it->first); + channel_history_.erase(it); + current_processed_stream_count_ = 0; + } else { + // We're still missing some packets for the current timestamp. Clean up + // those that we just relayed and let the rest wait until the next + // Process() call. + packets.clear(); + break; + } + } + } else { + // Relay packets and timestamps only from channel_index_. + for (const std::string& tag : channel_tags_) { + for (int index = 0; index < cc->Outputs().NumEntries(tag); ++index) { + auto& output = cc->Outputs().Get(tag, index); + std::string input_tag = tool::ChannelTag(tag, channel_index_); + auto& input = cc->Inputs().Get(input_tag, index); + tool::Relay(input, &output); + } } } return absl::OkStatus(); diff --git a/mediapipe/framework/tool/tag_map.h b/mediapipe/framework/tool/tag_map.h index c7f3134d9..6ac23df08 100644 --- a/mediapipe/framework/tool/tag_map.h +++ b/mediapipe/framework/tool/tag_map.h @@ -51,8 +51,8 @@ class TagMap { int count; }; - // Create a TagMap from a repeated std::string proto field of - // TAG::name. This is the most common usage: + // Create a TagMap from a repeated string proto field of TAG::name. + // This is the most common usage: // ASSIGN_OR_RETURN(std::shared_ptr tag_map, // tool::TagMap::Create(node.input_streams())); static absl::StatusOr> Create( @@ -87,7 +87,7 @@ class TagMap { // Returns canonicalized strings describing the TagMap. proto_ns::RepeatedPtrField CanonicalEntries() const; - // Returns a std::string description for debug purposes. + // Returns a string description for debug purposes. std::string DebugString() const; // Returns a shorter description for debug purposes (doesn't include // stream/side packet names). diff --git a/mediapipe/framework/tool/tag_map_test.cc b/mediapipe/framework/tool/tag_map_test.cc index 39b2e1921..20a9be966 100644 --- a/mediapipe/framework/tool/tag_map_test.cc +++ b/mediapipe/framework/tool/tag_map_test.cc @@ -318,8 +318,8 @@ TEST(TagMapTest, SameAs) { } } -// A helper function to test that a TagMap's debug std::string and short -// debug std::string each satisfy a matcher. +// A helper function to test that a TagMap's debug string and short +// debug string each satisfy a matcher. template void TestDebugString( const absl::StatusOr>& statusor_tag_map, diff --git a/mediapipe/framework/tool/template_expander.cc b/mediapipe/framework/tool/template_expander.cc index bd5cd97a0..150f252fb 100644 --- a/mediapipe/framework/tool/template_expander.cc +++ b/mediapipe/framework/tool/template_expander.cc @@ -433,7 +433,7 @@ class TemplateExpanderImpl { return result; } - // Converts a TemplateArgument to std::string. + // Converts a TemplateArgument to string. std::string AsString(const TemplateArgument& value) { std::string result; if (value.has_num()) { diff --git a/mediapipe/framework/tool/template_parser.cc b/mediapipe/framework/tool/template_parser.cc index 90034091d..1d81e7a78 100644 --- a/mediapipe/framework/tool/template_parser.cc +++ b/mediapipe/framework/tool/template_parser.cc @@ -260,7 +260,7 @@ class TemplateParser::Parser::ParserImpl { typedef proto_ns::TextFormat::ParseLocation ParseLocation; // Determines if repeated values for non-repeated fields and - // oneofs are permitted, e.g., the std::string "foo: 1 foo: 2" for a + // oneofs are permitted, e.g., the string "foo: 1 foo: 2" for a // required/optional field named "foo", or "baz: 1 qux: 2" // where "baz" and "qux" are members of the same oneof. enum SingularOverwritePolicy { @@ -401,7 +401,7 @@ class TemplateParser::Parser::ParserImpl { } #ifndef PROTO2_OPENSOURCE - // Consumes a std::string value and parses it as a packed repeated field into + // Consumes a string value and parses it as a packed repeated field into // the given field of the given message. bool ConsumePackedFieldAsString(const std::string& field_name, const FieldDescriptor* field, @@ -409,7 +409,7 @@ class TemplateParser::Parser::ParserImpl { std::string packed; DO(ConsumeString(&packed)); - // Prepend field tag and varint-encoded std::string length to turn into + // Prepend field tag and varint-encoded string length to turn into // encoded message. std::string tagged; { @@ -428,7 +428,7 @@ class TemplateParser::Parser::ParserImpl { io::CodedInputStream coded_input(&array_input); if (!message->MergePartialFromCodedStream(&coded_input)) { ReportError("Could not parse packed field \"" + field_name + - "\" as wire-encoded std::string."); + "\" as wire-encoded string."); return false; } @@ -607,7 +607,7 @@ class TemplateParser::Parser::ParserImpl { bool consumed_semicolon = TryConsume(":"); if (consumed_semicolon && field->options().weak() && LookingAtType(io::Tokenizer::TYPE_STRING)) { - // we are getting a bytes std::string for a weak field. + // we are getting a bytes string for a weak field. std::string tmp; DO(ConsumeString(&tmp)); reflection->MutableMessage(message, field)->ParseFromString(tmp); @@ -640,8 +640,8 @@ class TemplateParser::Parser::ParserImpl { #ifndef PROTO2_OPENSOURCE } else if (field->is_packable() && LookingAtType(io::Tokenizer::TYPE_STRING)) { - // Packable field printed as wire-formatted std::string: "foo: "abc\123"". - // Fields of type std::string cannot be packed themselves, so this is + // Packable field printed as wire-formatted string: "foo: "abc\123"". + // Fields of type string cannot be packed themselves, so this is // unambiguous. DO(ConsumePackedFieldAsString(field_name, field, message)); #endif // !PROTO2_OPENSOURCE @@ -908,7 +908,7 @@ class TemplateParser::Parser::ParserImpl { } return true; } - // Possible field values other than std::string: + // Possible field values other than string: // 12345 => TYPE_INTEGER // -12345 => TYPE_SYMBOL + TYPE_INTEGER // 1.2345 => TYPE_FLOAT @@ -992,7 +992,7 @@ class TemplateParser::Parser::ParserImpl { return false; } - // Consume a std::string of form ".....". + // Consume a string of form ".....". bool ConsumeFullTypeName(std::string* name) { DO(ConsumeIdentifier(name)); while (TryConsume(".")) { @@ -1013,11 +1013,11 @@ class TemplateParser::Parser::ParserImpl { return true; } - // Consumes a std::string and saves its value in the text parameter. + // Consumes a string and saves its value in the text parameter. // Returns false if the token is not of type STRING. bool ConsumeString(std::string* text) { if (!LookingAtType(io::Tokenizer::TYPE_STRING)) { - ReportError("Expected std::string, got: " + tokenizer_.current().text); + ReportError("Expected string, got: " + tokenizer_.current().text); return false; } @@ -1391,7 +1391,7 @@ void StowFieldValue(Message* message, TemplateExpression* expression) { } } -// Strips first and last quotes from a std::string. +// Strips first and last quotes from a string. static void StripQuotes(std::string* str) { // Strip off the leading and trailing quotation marks from the value, if // there are any. @@ -1585,7 +1585,7 @@ class TemplateParser::Parser::MediaPipeParserImpl return true; } - // Parses a numeric or a std::string literal. + // Parses a numeric or a string literal. bool ConsumeLiteral(TemplateExpression* result) { std::string token = tokenizer_.current().text; StripQuotes(&token); diff --git a/mediapipe/framework/tool/test_util.cc b/mediapipe/framework/tool/test_util.cc index 4a7ac1570..77e2d6fbd 100644 --- a/mediapipe/framework/tool/test_util.cc +++ b/mediapipe/framework/tool/test_util.cc @@ -215,6 +215,11 @@ std::string GetTestRootDir() { std::string GetTestOutputsDir() { const char* output_dir = getenv("TEST_UNDECLARED_OUTPUTS_DIR"); if (!output_dir) { +#ifdef __APPLE__ + char path[PATH_MAX]; + size_t n = confstr(_CS_DARWIN_USER_TEMP_DIR, path, sizeof(path)); + if (n > 0 && n < sizeof(path)) return path; +#endif // __APPLE__ output_dir = "/tmp"; } return output_dir; diff --git a/mediapipe/framework/tool/validate_name.h b/mediapipe/framework/tool/validate_name.h index 8a21f1fbb..274d06721 100644 --- a/mediapipe/framework/tool/validate_name.h +++ b/mediapipe/framework/tool/validate_name.h @@ -66,7 +66,7 @@ absl::Status SetFromTagAndNameInfo( const TagAndNameInfo& info, proto_ns::RepeatedPtrField* tags_and_names); -// The std::string is a valid name for an input stream, output stream, +// The string is a valid name for an input stream, output stream, // side packet, and input collection. Names use only lower case letters, // numbers, and underscores. // @@ -77,18 +77,18 @@ absl::Status SetFromTagAndNameInfo( // (3) Because input side packet names end up in model directory names, // where lower case naming is the norm. absl::Status ValidateName(const std::string& name); -// The std::string is a valid tag name. Tags use only upper case letters, +// The string is a valid tag name. Tags use only upper case letters, // numbers, and underscores. absl::Status ValidateTag(const std::string& tag); -// Parse a "Tag and Name" std::string into a tag and a name. +// Parse a "Tag and Name" string into a tag and a name. // The format is an optional tag and colon, followed by a name. // Example 1: "VIDEO:frames2" -> tag: "VIDEO", name: "frames2" // Example 2: "video_frames_1" -> tag: "", name: "video_frames_1" absl::Status ParseTagAndName(const std::string& tag_and_name, std::string* tag, std::string* name); -// Parse a generic TAG:index:name std::string. The format is a tag, then an +// Parse a generic TAG:index:name string. The format is a tag, then an // index, then a name. The tag and index are optional. If the index // is included, then the tag must be included. If no tag is used then // index is set to -1 (and should be assigned by argument position). @@ -99,7 +99,7 @@ absl::Status ParseTagAndName(const std::string& tag_and_name, std::string* tag, absl::Status ParseTagIndexName(const std::string& tag_and_name, std::string* tag, int* index, std::string* name); -// Parse a generic TAG:index std::string. The format is a tag, then an index +// Parse a generic TAG:index string. The format is a tag, then an index // with both being optional. If the tag is missing it is assumed to be // "" and if the index is missing then it is assumed to be 0. If the // index is provided then a colon (':') must be used. diff --git a/mediapipe/framework/type_map.h b/mediapipe/framework/type_map.h index 693262c2b..d07ad1024 100644 --- a/mediapipe/framework/type_map.h +++ b/mediapipe/framework/type_map.h @@ -13,8 +13,8 @@ // limitations under the License. // This header defines static maps to store the mappings from type hash id and -// name std::string to MediaPipeTypeData. It also provides code to inspect -// types of packets and access registered serialize and deserialize functions. +// name string to MediaPipeTypeData. It also provides code to inspect types of +// packets and access registered serialize and deserialize functions. // Calculators can use this to infer types of packets and adjust accordingly. // // Register a type: @@ -242,7 +242,7 @@ class StaticMap { class MapName : public type_map_internal::StaticMap {}; // Defines a map from unique typeid number to MediaPipeTypeData. DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeIdToMediaPipeTypeData, size_t); -// Defines a map from unique type std::string to MediaPipeTypeData. +// Defines a map from unique type string to MediaPipeTypeData. DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string); // MEDIAPIPE_REGISTER_TYPE can be used to register a type. @@ -267,7 +267,7 @@ DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string); // #undef MY_MAP_TYPE // // MEDIAPIPE_REGISTER_TYPE( -// std::string, "string", StringSerializeFn, StringDeserializeFn); +// std::string, "std::string", StringSerializeFn, StringDeserializeFn); // #define MEDIAPIPE_REGISTER_TYPE(type, type_name, serialize_fn, deserialize_fn) \ SET_MEDIAPIPE_TYPE_MAP_VALUE( \ @@ -293,7 +293,7 @@ DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string); // typedef is used, the name should be prefixed with the namespace(s), // seperated by double colons. // -// Example 1: register type with non-std::string proxy. +// Example 1: register type with non-string proxy. // absl::Status ToProxyFn( // const ClassType& obj, ProxyType* proxy) // { @@ -315,15 +315,15 @@ DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string); // ::mediapipe::DeserializeUsingGenericFn, ToProxyFn, FromProxyFn); // -// Example 2: register type with std::string proxy. -// absl::Status ToProxyFn(const ClassType& obj, std::string* encoding) +// Example 2: register type with string proxy. +// absl::Status ToProxyFn(const ClassType& obj, string* encoding) // { // ... // return absl::OkStatus(); // } // // absl::Status FromProxyFn( -// const ProxyType& proxy, std::string* encoding) { +// const ProxyType& proxy, string* encoding) { // ... // return absl::OkStatus(); // } @@ -367,7 +367,7 @@ inline const std::string* MediaPipeTypeStringFromTypeId(const size_t type_id) { return (value) ? &value->type_string : nullptr; } -// Returns std::string identifier of type or NULL if not registered. +// Returns string identifier of type or NULL if not registered. template inline const std::string* MediaPipeTypeString() { return MediaPipeTypeStringFromTypeId(tool::GetTypeHash()); diff --git a/mediapipe/framework/validated_graph_config.cc b/mediapipe/framework/validated_graph_config.cc index 2a9566fbe..8057acec6 100644 --- a/mediapipe/framework/validated_graph_config.cc +++ b/mediapipe/framework/validated_graph_config.cc @@ -49,7 +49,7 @@ namespace mediapipe { namespace { -// Create a debug std::string name for a set of edge. An edge can be either +// Create a debug string name for a set of edge. An edge can be either // a stream or a side packet. std::string DebugEdgeNames( const std::string& edge_type, @@ -413,8 +413,11 @@ absl::Status ValidatedGraphConfig::Initialize( // Set Any types based on what they connect to. MP_RETURN_IF_ERROR(ResolveAnyTypes(&input_streams_, &output_streams_)); + MP_RETURN_IF_ERROR(ResolveOneOfTypes(&input_streams_, &output_streams_)); MP_RETURN_IF_ERROR( ResolveAnyTypes(&input_side_packets_, &output_side_packets_)); + MP_RETURN_IF_ERROR( + ResolveOneOfTypes(&input_side_packets_, &output_side_packets_)); // Validate consistency of side packets and streams. MP_RETURN_IF_ERROR(ValidateSidePacketTypes()); @@ -908,6 +911,29 @@ absl::Status ValidatedGraphConfig::ResolveAnyTypes( return absl::OkStatus(); } +absl::Status ValidatedGraphConfig::ResolveOneOfTypes( + std::vector* input_edges, std::vector* output_edges) { + for (EdgeInfo& input_edge : *input_edges) { + if (input_edge.upstream == -1) { + continue; + } + EdgeInfo& output_edge = (*output_edges)[input_edge.upstream]; + PacketType* input_root = input_edge.packet_type->GetSameAs(); + PacketType* output_root = output_edge.packet_type->GetSameAs(); + if (!input_root->IsConsistentWith(*output_root)) continue; + // We narrow down OneOf types here if the other side is a single type. + // We do not currently intersect multiple OneOf types. + // Note that this is sensitive to the order edges are examined. + // TODO: we should be more thorough. + if (input_root->IsOneOf() && output_root->IsExactType()) { + input_root->SetSameAs(output_edge.packet_type); + } else if (output_root->IsOneOf() && input_root->IsExactType()) { + output_root->SetSameAs(input_edge.packet_type); + } + } + return absl::OkStatus(); +} + absl::Status ValidatedGraphConfig::ValidateStreamTypes() { for (const EdgeInfo& stream : input_streams_) { RET_CHECK_NE(stream.upstream, -1); diff --git a/mediapipe/framework/validated_graph_config.h b/mediapipe/framework/validated_graph_config.h index c5ffa45b5..aee605f98 100644 --- a/mediapipe/framework/validated_graph_config.h +++ b/mediapipe/framework/validated_graph_config.h @@ -133,12 +133,11 @@ class NodeTypeInfo { // This function is only valid for a NodeTypeInfo of NodeType CALCULATOR. bool AddSource(int index) { return ancestor_sources_.insert(index).second; } - // Convert the NodeType enum into a std::string (generally for error - // messaging). + // Convert the NodeType enum into a string (generally for error messaging). static std::string NodeTypeToString(NodeType node_type); - // Returns the name of the specified InputStreamHandler, or empty std::string - // if none set. + // Returns the name of the specified InputStreamHandler, or empty string if + // none set. std::string GetInputStreamHandler() const { return contract_.GetInputStreamHandler(); } @@ -383,6 +382,9 @@ class ValidatedGraphConfig { // Infer the type of types set to "Any" by what they are connected to. absl::Status ResolveAnyTypes(std::vector* input_edges, std::vector* output_edges); + // Narrow down OneOf types if they other end is a single type. + absl::Status ResolveOneOfTypes(std::vector* input_edges, + std::vector* output_edges); // Returns an error if the generator graph does not have consistent // type specifications for side packets. diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 2d163b4b6..3782e1eee 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@bazel_skylib//lib:selects.bzl", "selects") load("//mediapipe/gpu:metal.bzl", "metal_library") load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library", "mediapipe_proto_library") load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test") +load("//mediapipe/framework:more_selects.bzl", "more_selects") licenses(["notice"]) @@ -168,6 +170,8 @@ cc_library( ":gl_base", ":gl_thread_collector", ":gpu_buffer_format", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "//mediapipe/framework:executor", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", @@ -185,6 +189,11 @@ cc_library( "//mediapipe:apple": [ "//mediapipe/objc:CFHolder", ], + }) + select({ + "//conditions:default": [ + ], + "//mediapipe:ios": [], + "//mediapipe:macos": [], }), ) @@ -199,6 +208,7 @@ cc_library( ":gl_texture_view", ":gpu_buffer_format", ":gpu_buffer_storage", + ":gpu_buffer_storage_image_frame", # TODO: remove this dependency. Some other teams' tests # depend on having an indirect image_frame dependency, need to be # fixed first. @@ -215,6 +225,29 @@ cc_library( deps = [ ":gl_base", ":gl_context", + ":gpu_buffer_storage", + ], +) + +# Workaround for "Multiple matches are not allowed unless one is unambiguously more specialized". +more_selects.config_setting_negation( + name = "not_disable_gpu", + negate = ":disable_gpu", +) + +selects.config_setting_group( + name = "platform_ios_with_gpu", + match_all = [ + ":not_disable_gpu", + "//mediapipe:ios", + ], +) + +selects.config_setting_group( + name = "platform_macos_with_gpu", + match_all = [ + ":not_disable_gpu", + "//mediapipe:macos", ], ) @@ -224,25 +257,28 @@ cc_library( hdrs = ["gpu_buffer.h"], visibility = ["//visibility:public"], deps = [ - ":gl_base", - ":gl_context", ":gpu_buffer_format", ":gpu_buffer_storage", - ":gl_texture_view", "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/port:logging", + ":gpu_buffer_storage_image_frame", ] + select({ "//conditions:default": [ + ":gl_texture_view", ":gl_texture_buffer", ], - "//mediapipe:ios": [ + ":platform_ios_with_gpu": [ + ":gl_texture_view", ":gpu_buffer_storage_cv_pixel_buffer", "//mediapipe/objc:util", "//mediapipe/objc:CFHolder", ], - "//mediapipe:macos": [ + ":platform_macos_with_gpu": [ "//mediapipe/objc:CFHolder", + ":gl_texture_view", ":gl_texture_buffer", ], + ":disable_gpu": [], }), ) @@ -252,24 +288,28 @@ cc_library( hdrs = ["gpu_buffer_format.h"], visibility = ["//visibility:public"], deps = [ - ":gl_base", - "//mediapipe/framework/deps:no_destructor", "//mediapipe/framework/formats:image_format_cc_proto", - "//mediapipe/framework/port:logging", "@com_google_absl//absl/container:flat_hash_map", - ], + "//mediapipe/framework/deps:no_destructor", + "//mediapipe/framework/port:logging", + ] + select({ + "//conditions:default": [ + ":gl_base", + ], + "//mediapipe/gpu:disable_gpu": [], + }), ) cc_library( name = "gpu_buffer_storage", + srcs = ["gpu_buffer_storage.cc"], hdrs = ["gpu_buffer_storage.h"], visibility = ["//visibility:public"], deps = [ - ":gl_base", ":gpu_buffer_format", "//mediapipe/framework/deps:no_destructor", "//mediapipe/framework/formats:image_frame", - "//mediapipe/framework/port:logging", + "//mediapipe/framework/tool:type_util", "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -284,11 +324,36 @@ cc_library( ":gl_context", ":gl_texture_view", ":gpu_buffer_storage", + ":gpu_buffer_storage_image_frame", + ":image_frame_view", "//mediapipe/objc:CFHolder", "//mediapipe/objc:util", ], ) +cc_library( + name = "gpu_buffer_storage_image_frame", + hdrs = ["gpu_buffer_storage_image_frame.h"], + visibility = ["//visibility:public"], + deps = [ + ":gpu_buffer_format", + ":gpu_buffer_storage", + ":image_frame_view", + "//mediapipe/framework/formats:image_frame", + ], +) + +cc_library( + name = "image_frame_view", + hdrs = ["image_frame_view.h"], + visibility = ["//visibility:public"], + deps = [ + ":gpu_buffer_format", + ":gpu_buffer_storage", + "//mediapipe/framework/formats:image_frame", + ], +) + mediapipe_proto_library( name = "gpu_origin_proto", srcs = ["gpu_origin.proto"], @@ -326,6 +391,7 @@ objc_library( "-Wno-shorten-64-to-32", "-std=c++17", ], + features = ["-layering_check"], sdk_frameworks = [ "CoreVideo", "Metal", @@ -418,9 +484,6 @@ alias( cc_library( name = "gpu_shared_data_internal_stub", - hdrs = [ - "gpu_shared_data_internal.h", - ], visibility = ["//visibility:private"], deps = [ ":graph_support", @@ -587,12 +650,14 @@ cc_library( deps = [ ":gl_base", ":gl_context", + ":gl_texture_buffer_pool", ":gpu_buffer", ":gpu_buffer_format", ":gpu_buffer_multi_pool", ":gpu_shared_data_internal", ":gpu_service", ":graph_support", + ":image_frame_view", ":shader_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_cc_proto", @@ -644,6 +709,7 @@ objc_library( "-Wno-shorten-64-to-32", "-std=c++17", ], + features = ["-layering_check"], sdk_frameworks = [ "CoreVideo", "Metal", @@ -735,6 +801,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":gl_calculator_helper", + ":gpu_buffer_storage_image_frame", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:status", @@ -861,6 +928,7 @@ objc_library( name = "metal_copy_calculator", srcs = ["MetalCopyCalculator.mm"], copts = ["-std=c++17"], + features = ["-layering_check"], sdk_frameworks = [ "CoreVideo", "Metal", @@ -879,6 +947,7 @@ objc_library( name = "metal_rgb_weight_calculator", srcs = ["MetalRgbWeightCalculator.mm"], copts = ["-std=c++17"], + features = ["-layering_check"], sdk_frameworks = [ "CoreVideo", "Metal", @@ -896,6 +965,7 @@ objc_library( name = "metal_sobel_calculator", srcs = ["MetalSobelCalculator.mm"], copts = ["-std=c++17"], + features = ["-layering_check"], sdk_frameworks = [ "CoreVideo", "Metal", @@ -913,6 +983,7 @@ objc_library( name = "metal_sobel_compute_calculator", srcs = ["MetalSobelComputeCalculator.mm"], copts = ["-std=c++17"], + features = ["-layering_check"], sdk_frameworks = [ "CoreVideo", "Metal", @@ -930,6 +1001,7 @@ objc_library( name = "mps_sobel_calculator", srcs = ["MPSSobelCalculator.mm"], copts = ["-std=c++17"], + features = ["-layering_check"], sdk_frameworks = [ "CoreVideo", "Metal", @@ -986,6 +1058,7 @@ objc_library( data = [ "//mediapipe/objc:testdata/googlelogo_color_272x92dp.png", ], + features = ["-layering_check"], deps = [ ":MPPGraphGPUData", ":gl_scaler_calculator", diff --git a/mediapipe/gpu/MPPGraphGPUData.mm b/mediapipe/gpu/MPPGraphGPUData.mm index 4bae8cdc7..001d4e888 100644 --- a/mediapipe/gpu/MPPGraphGPUData.mm +++ b/mediapipe/gpu/MPPGraphGPUData.mm @@ -40,8 +40,8 @@ typedef CVOpenGLTextureCacheRef CVTextureCacheType; typedef CVOpenGLESTextureCacheRef CVTextureCacheType; #endif // TARGET_OS_OSX -- (instancetype)initWithContext:(mediapipe::GlContext*)context - multiPool:(mediapipe::GpuBufferMultiPool*)pool { +- (instancetype)initWithContext:(mediapipe::GlContext *)context + multiPool:(mediapipe::GpuBufferMultiPool *)pool { self = [super init]; if (self) { _gpuBufferPool = pool; diff --git a/mediapipe/gpu/MPPMetalHelper.mm b/mediapipe/gpu/MPPMetalHelper.mm index aeb6cd58c..ce6620972 100644 --- a/mediapipe/gpu/MPPMetalHelper.mm +++ b/mediapipe/gpu/MPPMetalHelper.mm @@ -122,8 +122,7 @@ class MetalHelperLegacySupport { - (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer plane:(size_t)plane { - - CVPixelBufferRef pixel_buffer = gpuBuffer.GetCVPixelBufferRef(); + CVPixelBufferRef pixel_buffer = mediapipe::GetCVPixelBufferRef(gpuBuffer); OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); MTLPixelFormat metalPixelFormat = MTLPixelFormatInvalid; @@ -170,7 +169,7 @@ class MetalHelperLegacySupport { CVMetalTextureRef texture; CVReturn err = CVMetalTextureCacheCreateTextureFromImage( - NULL, _gpuShared.mtlTextureCache, gpuBuffer.GetCVPixelBufferRef(), NULL, + NULL, _gpuShared.mtlTextureCache, mediapipe::GetCVPixelBufferRef(gpuBuffer), NULL, metalPixelFormat, width, height, plane, &texture); CHECK_EQ(err, kCVReturnSuccess); return texture; diff --git a/mediapipe/gpu/gl_calculator_helper.cc b/mediapipe/gpu/gl_calculator_helper.cc index d015c59e7..ba1423977 100644 --- a/mediapipe/gpu/gl_calculator_helper.cc +++ b/mediapipe/gpu/gl_calculator_helper.cc @@ -35,9 +35,13 @@ GlCalculatorHelper::~GlCalculatorHelper() {} absl::Status GlCalculatorHelper::Open(CalculatorContext* cc) { CHECK(cc); + auto gpu_service = cc->Service(kGpuService); + RET_CHECK(gpu_service.IsAvailable()) + << "GPU service not available. Did you forget to call " + "GlCalculatorHelper::UpdateContract?"; // TODO return error from impl_ (needs two-stage init) - impl_ = absl::make_unique( - cc, &cc->Service(kGpuService).GetObject()); + impl_ = + absl::make_unique(cc, &gpu_service.GetObject()); return absl::OkStatus(); } @@ -114,6 +118,16 @@ GlTexture GlCalculatorHelper::CreateSourceTexture(const GpuBuffer& pixel_buffer, return impl_->CreateSourceTexture(pixel_buffer, plane); } +GpuBuffer GlCalculatorHelper::GpuBufferWithImageFrame( + std::shared_ptr image_frame) { + return impl_->GpuBufferWithImageFrame(std::move(image_frame)); +} + +GpuBuffer GlCalculatorHelper::GpuBufferCopyingImageFrame( + const ImageFrame& image_frame) { + return impl_->GpuBufferCopyingImageFrame(image_frame); +} + void GlCalculatorHelper::GetGpuBufferDimensions(const GpuBuffer& pixel_buffer, int* width, int* height) { CHECK(width); diff --git a/mediapipe/gpu/gl_calculator_helper.h b/mediapipe/gpu/gl_calculator_helper.h index 5ac0ab1bb..e44523202 100644 --- a/mediapipe/gpu/gl_calculator_helper.h +++ b/mediapipe/gpu/gl_calculator_helper.h @@ -15,6 +15,8 @@ #ifndef MEDIAPIPE_GPU_GL_CALCULATOR_HELPER_H_ #define MEDIAPIPE_GPU_GL_CALCULATOR_HELPER_H_ +#include + #include "absl/memory/memory.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_contract.h" @@ -128,8 +130,19 @@ class GlCalculatorHelper { // Convenience function for converting an ImageFrame to GpuBuffer and then // accessing it as a texture. + // This is deprecated because: 1) it encourages the use of GlTexture as a + // long-lived object; 2) it requires copying the ImageFrame's contents, + // which may not always be necessary. + ABSL_DEPRECATED("Use `GpuBufferWithImageFrame`.") GlTexture CreateSourceTexture(const ImageFrame& image_frame); + // Creates a GpuBuffer sharing ownership of image_frame. The contents of + // image_frame should not be modified after calling this. + GpuBuffer GpuBufferWithImageFrame(std::shared_ptr image_frame); + + // Creates a GpuBuffer copying the contents of image_frame. + GpuBuffer GpuBufferCopyingImageFrame(const ImageFrame& image_frame); + // Extracts GpuBuffer dimensions without creating a texture. ABSL_DEPRECATED("Use width and height methods on GpuBuffer instead") void GetGpuBufferDimensions(const GpuBuffer& pixel_buffer, int* width, @@ -170,13 +183,13 @@ class GlCalculatorHelper { // memory. class GlTexture { public: - GlTexture() {} - ~GlTexture() { Release(); } + GlTexture() : view_(std::make_shared()) {} + ~GlTexture() = default; - int width() const { return view_.width(); } - int height() const { return view_.height(); } - GLenum target() const { return view_.target(); } - GLuint name() const { return view_.name(); } + int width() const { return view_->width(); } + int height() const { return view_->height(); } + GLenum target() const { return view_->target(); } + GLuint name() const { return view_->name(); } // Returns a buffer that can be sent to another calculator. // & manages sync token @@ -185,12 +198,13 @@ class GlTexture { std::unique_ptr GetFrame() const; // Releases texture memory & manages sync token - void Release() { view_.Release(); } + void Release() { view_ = std::make_shared(); } private: - explicit GlTexture(GlTextureView view) : view_(std::move(view)) {} + explicit GlTexture(GlTextureView view) + : view_(std::make_shared(std::move(view))) {} friend class GlCalculatorHelperImpl; - GlTextureView view_; + std::shared_ptr view_; }; // Returns the entry with the given tag if the collection uses tags, with the diff --git a/mediapipe/gpu/gl_calculator_helper_impl.h b/mediapipe/gpu/gl_calculator_helper_impl.h index c17c724ad..72b3265fe 100644 --- a/mediapipe/gpu/gl_calculator_helper_impl.h +++ b/mediapipe/gpu/gl_calculator_helper_impl.h @@ -50,6 +50,9 @@ class GlCalculatorHelperImpl { GlTexture CreateDestinationTexture(int output_width, int output_height, GpuBufferFormat format); + GpuBuffer GpuBufferWithImageFrame(std::shared_ptr image_frame); + GpuBuffer GpuBufferCopyingImageFrame(const ImageFrame& image_frame); + GLuint framebuffer() const { return framebuffer_; } void BindFramebuffer(const GlTexture& dst); diff --git a/mediapipe/gpu/gl_calculator_helper_impl_common.cc b/mediapipe/gpu/gl_calculator_helper_impl_common.cc index 0bcf089b0..8dd03bfde 100644 --- a/mediapipe/gpu/gl_calculator_helper_impl_common.cc +++ b/mediapipe/gpu/gl_calculator_helper_impl_common.cc @@ -14,10 +14,12 @@ #include +#include "absl/memory/memory.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/gpu/gl_calculator_helper_impl.h" #include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/gpu_shared_data_internal.h" +#include "mediapipe/gpu/image_frame_view.h" namespace mediapipe { @@ -117,18 +119,42 @@ GlTexture GlCalculatorHelperImpl::CreateSourceTexture( GlTexture GlCalculatorHelperImpl::CreateSourceTexture( const ImageFrame& image_frame) { - auto gpu_buffer = GpuBuffer::CopyingImageFrame(image_frame); + auto gpu_buffer = GpuBufferCopyingImageFrame(image_frame); return MapGpuBuffer(gpu_buffer, gpu_buffer.GetReadView(0)); } +GpuBuffer GlCalculatorHelperImpl::GpuBufferWithImageFrame( + std::shared_ptr image_frame) { + return GpuBuffer( + std::make_shared(std::move(image_frame))); +} + +GpuBuffer GlCalculatorHelperImpl::GpuBufferCopyingImageFrame( + const ImageFrame& image_frame) { +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + auto maybe_buffer = CreateCVPixelBufferCopyingImageFrame(image_frame); + // Converts absl::StatusOr to absl::Status since CHECK_OK() currently only + // deals with absl::Status in MediaPipe OSS. + CHECK_OK(maybe_buffer.status()); + return GpuBuffer(std::move(maybe_buffer).value()); +#else + return GpuBuffer(GlTextureBuffer::Create(image_frame)); +#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +} + template <> std::unique_ptr GlTexture::GetFrame() const { - return view_.gpu_buffer().AsImageFrame(); + view_->DoneWriting(); + std::shared_ptr view = + view_->gpu_buffer().GetReadView(); + auto copy = absl::make_unique(); + copy->CopyFrom(*view, ImageFrame::kDefaultAlignmentBoundary); + return copy; } template <> std::unique_ptr GlTexture::GetFrame() const { - auto gpu_buffer = view_.gpu_buffer(); + auto gpu_buffer = view_->gpu_buffer(); #ifdef __EMSCRIPTEN__ // When WebGL is used, the GL context may be spontaneously lost which can // cause GpuBuffer allocations to fail. In that case, return a dummy buffer @@ -137,7 +163,7 @@ std::unique_ptr GlTexture::GetFrame() const { return std::make_unique(); } #endif // __EMSCRIPTEN__ - view_.DoneWriting(); + view_->DoneWriting(); return absl::make_unique(gpu_buffer); } diff --git a/mediapipe/gpu/gl_context.cc b/mediapipe/gpu/gl_context.cc index c27a8b44e..25e969d2f 100644 --- a/mediapipe/gpu/gl_context.cc +++ b/mediapipe/gpu/gl_context.cc @@ -304,7 +304,7 @@ absl::Status GlContext::FinishInitialization(bool create_thread) { glGetIntegerv(GL_MINOR_VERSION, &gl_minor_version_); } else { // GL_MAJOR_VERSION is not supported on GL versions below 3. We have to - // parse the version std::string. + // parse the version string. if (!ParseGlVersion(version_string, &gl_major_version_, &gl_minor_version_)) { LOG(WARNING) << "invalid GL_VERSION format: '" << version_string @@ -344,7 +344,8 @@ absl::Status GlContext::FinishInitialization(bool create_thread) { #if GL_ES_VERSION_2_0 // This actually means "is GLES available". // No linear float filtering by default, check extensions. can_linear_filter_float_textures_ = - HasGlExtension("OES_texture_float_linear"); + HasGlExtension("OES_texture_float_linear") || + HasGlExtension("GL_OES_texture_float_linear"); #else // Desktop GL should always allow linear filtering. can_linear_filter_float_textures_ = true; @@ -548,7 +549,11 @@ class GlFenceSyncPoint : public GlSyncPoint { : GlSyncPoint(gl_context) { gl_context_->Run([this] { sync_ = glFenceSync(GL_SYNC_GPU_COMMANDS_COMPLETE, 0); + // Defer the flush for WebGL until the glClientWaitSync call as it's a + // costly IPC call in Chrome's WebGL implementation. +#ifndef __EMSCRIPTEN__ glFlush(); +#endif }); } @@ -565,8 +570,17 @@ class GlFenceSyncPoint : public GlSyncPoint { void Wait() override { if (!sync_) return; gl_context_->Run([this] { - GLenum result = - glClientWaitSync(sync_, 0, std::numeric_limits::max()); + GLuint flags = 0; + uint64_t timeout = std::numeric_limits::max(); +#ifdef __EMSCRIPTEN__ + // Setting GL_SYNC_FLUSH_COMMANDS_BIT ensures flush happens before we wait + // on the fence. This is necessary since we defer the flush on WebGL. + flags = GL_SYNC_FLUSH_COMMANDS_BIT; + // WebGL only supports small implementation dependent timeout values. In + // particular, Chrome only supports a timeout of 0. + timeout = 0; +#endif + GLenum result = glClientWaitSync(sync_, flags, timeout); if (result == GL_ALREADY_SIGNALED || result == GL_CONDITION_SATISFIED) { glDeleteSync(sync_); sync_ = nullptr; @@ -592,7 +606,13 @@ class GlFenceSyncPoint : public GlSyncPoint { bool ready = false; // TODO: we should not block on the original context if possible. gl_context_->Run([this, &ready] { - GLenum result = glClientWaitSync(sync_, 0, 0); + GLuint flags = 0; +#ifdef __EMSCRIPTEN__ + // Setting GL_SYNC_FLUSH_COMMANDS_BIT ensures flush happens before we wait + // on the fence. This is necessary since we defer the flush on WebGL. + flags = GL_SYNC_FLUSH_COMMANDS_BIT; +#endif + GLenum result = glClientWaitSync(sync_, flags, 0); if (result == GL_ALREADY_SIGNALED || result == GL_CONDITION_SATISFIED) { glDeleteSync(sync_); sync_ = nullptr; diff --git a/mediapipe/gpu/gl_context.h b/mediapipe/gpu/gl_context.h index 6cab706a5..9e798f98a 100644 --- a/mediapipe/gpu/gl_context.h +++ b/mediapipe/gpu/gl_context.h @@ -228,6 +228,12 @@ class GlContext : public std::enable_shared_from_this { CVOpenGLTextureCacheRef cv_texture_cache() const { return *texture_cache_; } #endif // HAS_EGL + // Returns whatever the current platform's native context handle is. + // Prefer the explicit *_context methods above, unless you're going to use + // this in a context that you are sure will work with whatever definition of + // PlatformGlContext is in use. + PlatformGlContext native_context() const { return context_; } + // Check if the context is current on this thread. Mainly for test purposes. bool IsCurrent() const; @@ -432,4 +438,5 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, int plane); } // namespace mediapipe + #endif // MEDIAPIPE_GPU_GL_CONTEXT_H_ diff --git a/mediapipe/gpu/gl_context_egl.cc b/mediapipe/gpu/gl_context_egl.cc index 9386f2ce2..44ddd9314 100644 --- a/mediapipe/gpu/gl_context_egl.cc +++ b/mediapipe/gpu/gl_context_egl.cc @@ -15,6 +15,8 @@ #include #include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" @@ -30,6 +32,8 @@ namespace mediapipe { +namespace { + static pthread_key_t egl_release_thread_key; static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT; @@ -67,6 +71,29 @@ static void EnsureEglThreadRelease() { reinterpret_cast(0xDEADBEEF)); } +static absl::StatusOr GetInitializedDefaultEglDisplay() { + EGLDisplay display = eglGetDisplay(EGL_DEFAULT_DISPLAY); + RET_CHECK(display != EGL_NO_DISPLAY) + << "eglGetDisplay() returned error " << std::showbase << std::hex + << eglGetError(); + + EGLint major = 0; + EGLint minor = 0; + EGLBoolean egl_initialized = eglInitialize(display, &major, &minor); + RET_CHECK(egl_initialized) << "Unable to initialize EGL"; + LOG(INFO) << "Successfully initialized EGL. Major : " << major + << " Minor: " << minor; + + return display; +} + +static absl::StatusOr GetInitializedEglDisplay() { + auto status_or_display = GetInitializedDefaultEglDisplay(); + return status_or_display; +} + +} // namespace + GlContext::StatusOrGlContext GlContext::Create(std::nullptr_t nullp, bool create_thread) { return Create(EGL_NO_CONTEXT, create_thread); @@ -149,18 +176,7 @@ absl::Status GlContext::CreateContextInternal(EGLContext share_context, } absl::Status GlContext::CreateContext(EGLContext share_context) { - EGLint major = 0; - EGLint minor = 0; - - display_ = eglGetDisplay(EGL_DEFAULT_DISPLAY); - RET_CHECK(display_ != EGL_NO_DISPLAY) - << "eglGetDisplay() returned error " << std::showbase << std::hex - << eglGetError(); - - EGLBoolean success = eglInitialize(display_, &major, &minor); - RET_CHECK(success) << "Unable to initialize EGL"; - LOG(INFO) << "Successfully initialized EGL. Major : " << major - << " Minor: " << minor; + ASSIGN_OR_RETURN(display_, GetInitializedEglDisplay()); auto status = CreateContextInternal(share_context, 3); if (!status.ok()) { diff --git a/mediapipe/gpu/gl_quad_renderer.h b/mediapipe/gpu/gl_quad_renderer.h index 7e2c44f1c..4ef6c7669 100644 --- a/mediapipe/gpu/gl_quad_renderer.h +++ b/mediapipe/gpu/gl_quad_renderer.h @@ -81,7 +81,7 @@ class QuadRenderer { GLuint program_ = 0; GLint scale_unif_ = -1; std::vector frame_unifs_; - GLuint vao_; // vertex array object + GLuint vao_ = 0; // vertex array object GLuint vbo_[2] = {0, 0}; // for vertex buffer storage }; diff --git a/mediapipe/gpu/gl_simple_shaders.cc b/mediapipe/gpu/gl_simple_shaders.cc index 5cb718dc3..72e3cfc18 100644 --- a/mediapipe/gpu/gl_simple_shaders.cc +++ b/mediapipe/gpu/gl_simple_shaders.cc @@ -16,7 +16,7 @@ namespace mediapipe { -// This macro converts everything between its parentheses to a std::string. +// This macro converts everything between its parentheses to a string. // Using this instead of R"()" preserves C-like syntax coloring in most // editors, which is desirable for shaders. #if !defined(_STRINGIFY) @@ -28,15 +28,15 @@ namespace mediapipe { // for a type. The macro strips out the precision declaration on desktop GL, // where it's not supported. // -// Note: this does not use a raw std::string because some compilers don't handle -// raw strings inside macros correctly. It uses a macro because we want to be -// able to concatenate strings by juxtaposition. We want to concatenate strings -// by juxtaposition so we can export const char* static data containing the +// Note: this does not use a raw string because some compilers don't handle raw +// strings inside macros correctly. It uses a macro because we want to be able +// to concatenate strings by juxtaposition. We want to concatenate strings by +// juxtaposition so we can export const char* static data containing the // pre-expanded strings. // // TODO: this was written before we could rely on C++11 support. -// Consider replacing it with constexpr std::string concatenation, or replacing -// the static variables with functions. +// Consider replacing it with constexpr string concatenation, or replacing the +// static variables with functions. #define PRECISION_COMPAT \ GLES_VERSION_COMPAT \ "#ifdef GL_ES \n" \ diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index eab0f59fe..d48b35a05 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -16,6 +16,7 @@ #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/gpu/gl_texture_view.h" +#include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" namespace mediapipe { @@ -228,14 +229,14 @@ void GlTextureBuffer::WaitForConsumersOnGpu() { } GlTextureView GlTextureBuffer::GetReadView( - mediapipe::internal::types, - std::shared_ptr gpu_buffer, int plane) const { + internal::types, std::shared_ptr gpu_buffer, + int plane) const { auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); CHECK_EQ(plane, 0); // Insert wait call to sync with the producer. WaitOnGpu(); - GlTextureView::DetachFn detach = [this](mediapipe::GlTextureView& texture) { + GlTextureView::DetachFn detach = [this](GlTextureView& texture) { // Inform the GlTextureBuffer that we have finished accessing its // contents, and create a consumer sync point. DidRead(texture.gl_context()->CreateSyncToken()); @@ -246,8 +247,8 @@ GlTextureView GlTextureBuffer::GetReadView( } GlTextureView GlTextureBuffer::GetWriteView( - mediapipe::internal::types, - std::shared_ptr gpu_buffer, int plane) { + internal::types, std::shared_ptr gpu_buffer, + int plane) { auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); CHECK_EQ(plane, 0); @@ -256,9 +257,7 @@ GlTextureView GlTextureBuffer::GetWriteView( Reuse(); // TODO: the producer wait should probably be part of Reuse in the // case when there are no consumers. GlTextureView::DoneWritingFn done_writing = - [this](const mediapipe::GlTextureView& texture) { - ViewDoneWriting(texture); - }; + [this](const GlTextureView& texture) { ViewDoneWriting(texture); }; return GlTextureView(gl_context.get(), target(), name(), width(), height(), std::move(gpu_buffer), plane, nullptr, std::move(done_writing)); @@ -311,46 +310,52 @@ static void ReadTexture(const GlTextureView& view, GpuBufferFormat format, GlTextureInfo info = GlTextureInfoForGpuBufferFormat( format, view.plane(), view.gl_context()->GetGlVersion()); - GLint current_fbo; - glGetIntegerv(GL_FRAMEBUFFER_BINDING, ¤t_fbo); - CHECK_NE(current_fbo, 0); + GLint previous_fbo; + glGetIntegerv(GL_FRAMEBUFFER_BINDING, &previous_fbo); - GLint color_attachment_name; - glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, - GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME, - &color_attachment_name); - if (color_attachment_name != view.name()) { - // Save the viewport. Note that we assume that the color attachment is a - // GL_TEXTURE_2D texture. - GLint viewport[4]; - glGetIntegerv(GL_VIEWPORT, viewport); - - // Set the data from GLTextureView object. - glViewport(0, 0, view.width(), view.height()); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), - view.name(), 0); - glReadPixels(0, 0, view.width(), view.height(), info.gl_format, - info.gl_type, output); - - // Restore from the saved viewport and color attachment name. - glViewport(viewport[0], viewport[1], viewport[2], viewport[3]); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, - color_attachment_name, 0); - } else { - glReadPixels(0, 0, view.width(), view.height(), info.gl_format, - info.gl_type, output); - } + // We use a temp fbo to avoid depending on the app having an existing one. + // TODO: keep a utility fbo around in the context? + GLuint fbo = 0; + glGenFramebuffers(1, &fbo); + glBindFramebuffer(GL_FRAMEBUFFER, fbo); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), + view.name(), 0); + glReadPixels(0, 0, view.width(), view.height(), info.gl_format, info.gl_type, + output); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, 0, + 0); + // TODO: just set the binding to 0 to avoid the get call? + glBindFramebuffer(GL_FRAMEBUFFER, previous_fbo); + glDeleteFramebuffers(1, &fbo); } -std::unique_ptr GlTextureBuffer::AsImageFrame() const { - ImageFormat::Format image_format = ImageFormatForGpuBufferFormat(format()); - auto output = absl::make_unique( - image_format, width(), height(), ImageFrame::kGlDefaultAlignmentBoundary); - auto view = - GetReadView(mediapipe::internal::types{}, nullptr, 0); - ReadTexture(view, format(), output->MutablePixelData(), - output->PixelDataSize()); - return output; +static std::shared_ptr ConvertToImageFrame( + std::shared_ptr buf) { + ImageFormat::Format image_format = + ImageFormatForGpuBufferFormat(buf->format()); + auto output = + absl::make_unique(image_format, buf->width(), buf->height(), + ImageFrame::kGlDefaultAlignmentBoundary); + buf->GetProducerContext()->Run([buf, &output] { + auto view = buf->GetReadView(internal::types{}, nullptr, 0); + ReadTexture(view, buf->format(), output->MutablePixelData(), + output->PixelDataSize()); + }); + return std::make_shared(std::move(output)); } +static std::shared_ptr ConvertFromImageFrame( + std::shared_ptr frame) { + return GlTextureBuffer::Create(*frame->image_frame()); +} + +static auto kConverterRegistration = + internal::GpuBufferStorageRegistry::Get() + .RegisterConverter( + ConvertToImageFrame); +static auto kConverterRegistration2 = + internal::GpuBufferStorageRegistry::Get() + .RegisterConverter( + ConvertFromImageFrame); + } // namespace mediapipe diff --git a/mediapipe/gpu/gl_texture_buffer.h b/mediapipe/gpu/gl_texture_buffer.h index 878255fc2..124a0ec2f 100644 --- a/mediapipe/gpu/gl_texture_buffer.h +++ b/mediapipe/gpu/gl_texture_buffer.h @@ -24,6 +24,7 @@ #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gl_context.h" +#include "mediapipe/gpu/gl_texture_view.h" #include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/gpu_buffer_storage.h" @@ -33,8 +34,8 @@ class GlCalculatorHelperImpl; // Implements a GPU memory buffer as an OpenGL texture. For internal use. class GlTextureBuffer - : public mediapipe::internal::GpuBufferStorageImpl< - GlTextureBuffer, mediapipe::internal::ViewProvider> { + : public internal::GpuBufferStorageImpl< + GlTextureBuffer, internal::ViewProvider> { public: // This is called when the texture buffer is deleted. It is passed a sync // token created at that time on the GlContext. If the GlTextureBuffer has @@ -88,13 +89,12 @@ class GlTextureBuffer int height() const { return height_; } GpuBufferFormat format() const { return format_; } - GlTextureView GetReadView(mediapipe::internal::types, + GlTextureView GetReadView(internal::types, std::shared_ptr gpu_buffer, int plane) const override; - GlTextureView GetWriteView(mediapipe::internal::types, + GlTextureView GetWriteView(internal::types, std::shared_ptr gpu_buffer, int plane) override; - std::unique_ptr AsImageFrame() const override; // If this texture is going to be used outside of the context that produced // it, this method should be called to ensure that its updated contents are diff --git a/mediapipe/gpu/gl_texture_view.cc b/mediapipe/gpu/gl_texture_view.cc index adb642153..5d1862ddc 100644 --- a/mediapipe/gpu/gl_texture_view.cc +++ b/mediapipe/gpu/gl_texture_view.cc @@ -3,6 +3,7 @@ namespace mediapipe { void GlTextureView::Release() { + DoneWriting(); if (detach_) detach_(*this); detach_ = nullptr; gl_context_ = nullptr; @@ -13,4 +14,11 @@ void GlTextureView::Release() { height_ = 0; } +void GlTextureView::DoneWriting() const { + if (done_writing_) { + done_writing_(*this); + done_writing_ = nullptr; + } +} + } // namespace mediapipe diff --git a/mediapipe/gpu/gl_texture_view.h b/mediapipe/gpu/gl_texture_view.h index 76116f63d..1f0a23f31 100644 --- a/mediapipe/gpu/gl_texture_view.h +++ b/mediapipe/gpu/gl_texture_view.h @@ -20,6 +20,7 @@ #include #include "mediapipe/gpu/gl_base.h" +#include "mediapipe/gpu/gpu_buffer_storage.h" namespace mediapipe { @@ -31,7 +32,23 @@ class GlTextureView { public: GlTextureView() {} ~GlTextureView() { Release(); } - // TODO: make this class move-only. + GlTextureView(const GlTextureView&) = delete; + GlTextureView(GlTextureView&& other) { *this = std::move(other); } + GlTextureView& operator=(const GlTextureView&) = delete; + GlTextureView& operator=(GlTextureView&& other) { + DoneWriting(); + if (detach_) detach_(*this); + gl_context_ = other.gl_context_; + target_ = other.target_; + name_ = other.name_; + width_ = other.width_; + height_ = other.height_; + gpu_buffer_ = std::move(other.gpu_buffer_); + plane_ = other.plane_; + detach_ = std::exchange(other.detach_, nullptr); + done_writing_ = std::exchange(other.done_writing_, nullptr); + return *this; + } GlContext* gl_context() const { return gl_context_; } int width() const { return width_; } @@ -63,11 +80,11 @@ class GlTextureView { // TODO: remove this friend declaration. friend class GlTexture; + void Release(); + // TODO: make this non-const. - void DoneWriting() const { - if (done_writing_) done_writing_(*this); - } + void DoneWriting() const; GlContext* gl_context_ = nullptr; GLenum target_ = GL_TEXTURE_2D; @@ -78,9 +95,31 @@ class GlTextureView { std::shared_ptr gpu_buffer_; // using shared_ptr temporarily int plane_ = 0; DetachFn detach_; - DoneWritingFn done_writing_; + mutable DoneWritingFn done_writing_; }; +namespace internal { + +template <> +class ViewProvider { + public: + virtual ~ViewProvider() = default; + // Note that the view type is encoded in an argument to allow overloading, + // so a storage class can implement GetRead/WriteView for multiple view types. + // We cannot use a template function because it cannot be virtual; we want to + // have a virtual function here to enforce that different storages supporting + // the same view implement the same signature. + // Note that we allow different views to have custom signatures, providing + // additional view-specific arguments that may be needed. + virtual GlTextureView GetReadView(types, + std::shared_ptr gpu_buffer, + int plane) const = 0; + virtual GlTextureView GetWriteView(types, + std::shared_ptr gpu_buffer, + int plane) = 0; +}; + +} // namespace internal } // namespace mediapipe #endif // MEDIAPIPE_GPU_GL_TEXTURE_VIEW_H_ diff --git a/mediapipe/gpu/gpu_buffer.cc b/mediapipe/gpu/gpu_buffer.cc index a9b68f82f..bb215dbbd 100644 --- a/mediapipe/gpu/gpu_buffer.cc +++ b/mediapipe/gpu/gpu_buffer.cc @@ -1,6 +1,8 @@ #include "mediapipe/gpu/gpu_buffer.h" -#include "mediapipe/gpu/gl_context.h" +#include + +#include "mediapipe/framework/port/logging.h" #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #include "mediapipe/objc/util.h" @@ -8,22 +10,61 @@ namespace mediapipe { -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +internal::GpuBufferStorage& GpuBuffer::GetStorageForView( + TypeRef view_provider_type, bool for_writing) const { + const std::shared_ptr* chosen_storage = nullptr; -GpuBuffer GpuBuffer::CopyingImageFrame(const ImageFrame& image_frame) { - auto maybe_buffer = CreateCVPixelBufferCopyingImageFrame(image_frame); - // Converts absl::StatusOr to absl::Status since CHECK_OK() currently only - // deals with absl::Status in MediaPipe OSS. - CHECK_OK(maybe_buffer.status()); - return GpuBuffer(std::move(maybe_buffer).value()); + // First see if any current storage supports the view. + for (const auto& s : storages_) { + if (s->can_down_cast_to(view_provider_type)) { + chosen_storage = &s; + break; + } + } + + // Then try to convert existing storages to one that does. + // TODO: choose best conversion. + if (!chosen_storage) { + for (const auto& s : storages_) { + auto converter = internal::GpuBufferStorageRegistry::Get() + .StorageConverterForViewProvider(view_provider_type, + s->storage_type()); + if (converter) { + storages_.push_back(converter(s)); + chosen_storage = &storages_.back(); + } + } + } + + if (for_writing) { + if (!chosen_storage) { + // Allocate a new storage supporting the requested view. + auto factory = internal::GpuBufferStorageRegistry::Get() + .StorageFactoryForViewProvider(view_provider_type); + if (factory) { + storages_ = {factory(width(), height(), format())}; + chosen_storage = &storages_.back(); + } + } else { + // Discard all other storages. + storages_ = {*chosen_storage}; + chosen_storage = &storages_.back(); + } + } + + CHECK(chosen_storage) << "no view provider found"; + DCHECK((*chosen_storage)->can_down_cast_to(view_provider_type)); + return **chosen_storage; +} + +#if !MEDIAPIPE_DISABLE_GPU +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +CVPixelBufferRef GetCVPixelBufferRef(const GpuBuffer& buffer) { + auto p = buffer.internal_storage(); + if (p) return **p; + return nullptr; } #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - -#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -GpuBuffer GpuBuffer::CopyingImageFrame(const ImageFrame& image_frame) { - return GpuBuffer(GlTextureBuffer::Create(image_frame)); -} - -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#endif // !MEDIAPIPE_DISABLE_GPU } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer.h b/mediapipe/gpu/gpu_buffer.h index dfeebdb32..352ffe1e9 100644 --- a/mediapipe/gpu/gpu_buffer.h +++ b/mediapipe/gpu/gpu_buffer.h @@ -19,11 +19,11 @@ #include #include "mediapipe/framework/formats/image_frame.h" -#include "mediapipe/gpu/gl_base.h" -#include "mediapipe/gpu/gl_texture_view.h" #include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/gpu_buffer_storage.h" +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gl_texture_view.h" // Note: these headers are needed for the legacy storage APIs. Do not add more // storage-specific headers here. See WebGpuTextureBuffer/View for an example // of adding a new storage and view. @@ -39,19 +39,26 @@ #else #include "mediapipe/gpu/gl_texture_buffer.h" #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#endif // MEDIAPIPE_DISABLE_GPU namespace mediapipe { -class GlContext; - // This class wraps a platform-specific buffer of GPU data. // An instance of GpuBuffer acts as an opaque reference to the underlying // data object. class GpuBuffer { public: + using Format = GpuBufferFormat; + // Default constructor creates invalid object. GpuBuffer() = default; + // Creates an empty buffer of a given size and format. It will be allocated + // when a view is requested. + GpuBuffer(int width, int height, Format format) + : GpuBuffer(std::make_shared(width, height, + format)) {} + // Copy and move constructors and assignment operators are supported. GpuBuffer(const GpuBuffer& other) = default; GpuBuffer(GpuBuffer&& other) = default; @@ -63,30 +70,17 @@ class GpuBuffer { // are not portable. Applications and calculators should normally obtain // GpuBuffers in a portable way from the framework, e.g. using // GpuBufferMultiPool. - explicit GpuBuffer( - std::shared_ptr storage) - : storage_(std::move(storage)) {} - - // Note: these constructors and accessors for specific storage types exist - // for backwards compatibility reasons. Do not add new ones. -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - explicit GpuBuffer(CFHolder pixel_buffer) - : storage_(std::make_shared( - std::move(pixel_buffer))) {} - explicit GpuBuffer(CVPixelBufferRef pixel_buffer) - : storage_( - std::make_shared(pixel_buffer)) {} - - CVPixelBufferRef GetCVPixelBufferRef() const { - auto p = storage_->down_cast(); - if (p) return **p; - return nullptr; + explicit GpuBuffer(std::shared_ptr storage) { + storages_.push_back(std::move(storage)); } -#else - GlTextureBufferSharedPtr GetGlTextureBufferSharedPtr() const { - return internal_storage(); - } -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + + // This is used to support backward-compatible construction of GpuBuffer from + // some platform-specific types without having to make those types visible in + // this header. + template ()))>> + explicit GpuBuffer(T&& storage_convertible) + : GpuBuffer(internal::AsGpuBufferStorage(storage_convertible)) {} int width() const { return current_storage().width(); } int height() const { return current_storage().height(); } @@ -108,35 +102,19 @@ class GpuBuffer { // Gets a read view of the specified type. The arguments depend on the // specific view type; see the corresponding ViewProvider. template - auto GetReadView(Args... args) const { - return current_storage() - .down_cast>() - ->GetReadView(mediapipe::internal::types{}, - std::make_shared(*this), - std::forward(args)...); + decltype(auto) GetReadView(Args... args) const { + return GetViewProvider(false)->GetReadView( + internal::types{}, std::make_shared(*this), + std::forward(args)...); } // Gets a write view of the specified type. The arguments depend on the // specific view type; see the corresponding ViewProvider. template - auto GetWriteView(Args... args) { - return current_storage() - .down_cast>() - ->GetWriteView(mediapipe::internal::types{}, - std::make_shared(*this), - std::forward(args)...); - } - - // Make a GpuBuffer copying the data from an ImageFrame. - static GpuBuffer CopyingImageFrame(const ImageFrame& image_frame); - - // Make an ImageFrame, possibly sharing the same data. The data is shared if - // the GpuBuffer's storage supports memory sharing; otherwise, it is copied. - // In order to work correctly across platforms, callers should always treat - // the returned ImageFrame as if it shares memory with the GpuBuffer, i.e. - // treat it as immutable if the GpuBuffer must not be modified. - std::unique_ptr AsImageFrame() const { - return current_storage().AsImageFrame(); + decltype(auto) GetWriteView(Args... args) { + return GetViewProvider(true)->GetWriteView( + internal::types{}, std::make_shared(*this), + std::forward(args)...); } // Attempts to access an underlying storage object of the specified type. @@ -144,55 +122,79 @@ class GpuBuffer { // using views. template std::shared_ptr internal_storage() const { - if (storage_->down_cast()) return std::static_pointer_cast(storage_); + for (const auto& s : storages_) + if (s->down_cast()) return std::static_pointer_cast(s); return nullptr; } private: + using TypeRef = internal::TypeRef; + class PlaceholderGpuBufferStorage - : public mediapipe::internal::GpuBufferStorageImpl< - PlaceholderGpuBufferStorage> { + : public internal::GpuBufferStorageImpl { public: - int width() const override { return 0; } - int height() const override { return 0; } - virtual GpuBufferFormat format() const override { - return GpuBufferFormat::kUnknown; - } - std::unique_ptr AsImageFrame() const override { - return nullptr; - } + PlaceholderGpuBufferStorage(int width, int height, Format format) + : width_(width), height_(height), format_(format) {} + int width() const override { return width_; } + int height() const override { return height_; } + GpuBufferFormat format() const override { return format_; } + + private: + int width_ = 0; + int height_ = 0; + GpuBufferFormat format_ = GpuBufferFormat::kUnknown; }; - std::shared_ptr& no_storage() const { + internal::GpuBufferStorage& GetStorageForView(TypeRef view_provider_type, + bool for_writing) const; + + template + internal::ViewProvider* GetViewProvider(bool for_writing) const { + using VP = internal::ViewProvider; + return GetStorageForView(TypeRef::Get(), for_writing) + .template down_cast(); + } + + std::shared_ptr& no_storage() const { static auto placeholder = - std::static_pointer_cast( - std::make_shared()); + std::static_pointer_cast( + std::make_shared( + 0, 0, GpuBufferFormat::kUnknown)); return placeholder; } - const mediapipe::internal::GpuBufferStorage& current_storage() const { - return *storage_; + const internal::GpuBufferStorage& current_storage() const { + return storages_.empty() ? *no_storage() : *storages_[0]; } - mediapipe::internal::GpuBufferStorage& current_storage() { return *storage_; } + internal::GpuBufferStorage& current_storage() { + return storages_.empty() ? *no_storage() : *storages_[0]; + } - std::shared_ptr storage_ = - no_storage(); + // This is mutable because view methods that do not change the contents may + // still need to allocate new storages. + mutable std::vector> storages_; }; inline bool GpuBuffer::operator==(std::nullptr_t other) const { - return storage_ == no_storage(); + return storages_.empty(); } inline bool GpuBuffer::operator==(const GpuBuffer& other) const { - return storage_ == other.storage_; + return storages_ == other.storages_; } inline GpuBuffer& GpuBuffer::operator=(std::nullptr_t other) { - storage_ = no_storage(); + storages_.clear(); return *this; } +// Note: these constructors and accessors for specific storage types exist +// for backwards compatibility reasons. Do not add new ones. +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +CVPixelBufferRef GetCVPixelBufferRef(const GpuBuffer& buffer); +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + } // namespace mediapipe #endif // MEDIAPIPE_GPU_GPU_BUFFER_H_ diff --git a/mediapipe/gpu/gpu_buffer_format.cc b/mediapipe/gpu/gpu_buffer_format.cc index 278ec444e..41c98ba43 100644 --- a/mediapipe/gpu/gpu_buffer_format.cc +++ b/mediapipe/gpu/gpu_buffer_format.cc @@ -28,6 +28,7 @@ namespace mediapipe { #define GL_HALF_FLOAT 0x140B #endif // GL_HALF_FLOAT +#if !MEDIAPIPE_DISABLE_GPU #ifdef GL_ES_VERSION_2_0 static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) { switch (info->gl_internal_format) { @@ -184,6 +185,7 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, CHECK_LT(plane, planes.size()) << "invalid plane number"; return planes[plane]; } +#endif // MEDIAPIPE_DISABLE_GPU ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) { switch (format) { @@ -202,6 +204,8 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) { return ImageFormat::SRGB; case GpuBufferFormat::kTwoComponentFloat32: return ImageFormat::VEC32F2; + case GpuBufferFormat::kRGBA32: + // TODO: this likely maps to ImageFormat::SRGBA case GpuBufferFormat::kGrayHalf16: case GpuBufferFormat::kOneComponent8Red: case GpuBufferFormat::kTwoComponent8: diff --git a/mediapipe/gpu/gpu_buffer_format.h b/mediapipe/gpu/gpu_buffer_format.h index 66999f755..45f054d31 100644 --- a/mediapipe/gpu/gpu_buffer_format.h +++ b/mediapipe/gpu/gpu_buffer_format.h @@ -23,7 +23,9 @@ #endif // defined(__APPLE__) #include "mediapipe/framework/formats/image_format.pb.h" +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_base.h" +#endif // !MEDIAPIPE_DISABLE_GPU // The behavior of multi-char constants is implementation-defined, so out of an // excess of caution we define them in this portable way. @@ -32,9 +34,12 @@ namespace mediapipe { +using mediapipe::ImageFormat; + enum class GpuBufferFormat : uint32_t { kUnknown = 0, kBGRA32 = MEDIAPIPE_FOURCC('B', 'G', 'R', 'A'), + kRGBA32 = MEDIAPIPE_FOURCC('R', 'G', 'B', 'A'), kGrayFloat32 = MEDIAPIPE_FOURCC('L', '0', '0', 'f'), kGrayHalf16 = MEDIAPIPE_FOURCC('L', '0', '0', 'h'), kOneComponent8 = MEDIAPIPE_FOURCC('L', '0', '0', '8'), @@ -49,6 +54,7 @@ enum class GpuBufferFormat : uint32_t { kRGBAFloat128 = MEDIAPIPE_FOURCC('R', 'G', 'f', 'A'), }; +#if !MEDIAPIPE_DISABLE_GPU // TODO: make this more generally applicable. enum class GlVersion { kGL = 1, @@ -68,6 +74,7 @@ struct GlTextureInfo { const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, int plane, GlVersion gl_version); +#endif // !MEDIAPIPE_DISABLE_GPU ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format); GpuBufferFormat GpuBufferFormatForImageFormat(ImageFormat::Format format); @@ -78,6 +85,8 @@ inline OSType CVPixelFormatForGpuBufferFormat(GpuBufferFormat format) { switch (format) { case GpuBufferFormat::kBGRA32: return kCVPixelFormatType_32BGRA; + case GpuBufferFormat::kRGBA32: + return kCVPixelFormatType_32RGBA; case GpuBufferFormat::kGrayHalf16: return kCVPixelFormatType_OneComponent16Half; case GpuBufferFormat::kGrayFloat32: @@ -112,6 +121,8 @@ inline GpuBufferFormat GpuBufferFormatForCVPixelFormat(OSType format) { switch (format) { case kCVPixelFormatType_32BGRA: return GpuBufferFormat::kBGRA32; + case kCVPixelFormatType_32RGBA: + return GpuBufferFormat::kRGBA32; case kCVPixelFormatType_DepthFloat32: return GpuBufferFormat::kGrayFloat32; case kCVPixelFormatType_OneComponent16Half: diff --git a/mediapipe/gpu/gpu_buffer_storage.cc b/mediapipe/gpu/gpu_buffer_storage.cc new file mode 100644 index 000000000..2f0687653 --- /dev/null +++ b/mediapipe/gpu/gpu_buffer_storage.cc @@ -0,0 +1,47 @@ +#include "mediapipe/gpu/gpu_buffer_storage.h" + +namespace mediapipe { +namespace internal { + +using StorageFactory = GpuBufferStorageRegistry::StorageFactory; +using StorageConverter = GpuBufferStorageRegistry::StorageConverter; +using RegistryToken = GpuBufferStorageRegistry::RegistryToken; + +StorageFactory GpuBufferStorageRegistry::StorageFactoryForViewProvider( + TypeRef view_provider_type) { + auto it = factory_for_view_provider_.find(view_provider_type); + if (it == factory_for_view_provider_.end()) return nullptr; + return it->second; +} + +StorageConverter GpuBufferStorageRegistry::StorageConverterForViewProvider( + TypeRef view_provider_type, TypeRef existing_storage_type) { + auto it = converter_for_view_provider_and_existing_storage_.find( + {view_provider_type, existing_storage_type}); + if (it == converter_for_view_provider_and_existing_storage_.end()) + return nullptr; + return it->second; +} + +RegistryToken GpuBufferStorageRegistry::Register( + StorageFactory factory, std::vector provider_hashes) { + // TODO: choose between multiple factories for same provider type. + for (const auto p : provider_hashes) { + factory_for_view_provider_[p] = factory; + } + return {}; +} + +RegistryToken GpuBufferStorageRegistry::Register( + StorageConverter converter, std::vector provider_hashes, + TypeRef source_storage) { + // TODO: choose between multiple converters for same provider type. + for (const auto p : provider_hashes) { + converter_for_view_provider_and_existing_storage_[{p, source_storage}] = + converter; + } + return {}; +} + +} // namespace internal +} // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_storage.h b/mediapipe/gpu/gpu_buffer_storage.h index 03b2442c3..e8ad3f367 100644 --- a/mediapipe/gpu/gpu_buffer_storage.h +++ b/mediapipe/gpu/gpu_buffer_storage.h @@ -1,15 +1,17 @@ #ifndef MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_H_ #define MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_H_ -#include "mediapipe/framework/formats/image_frame.h" +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "mediapipe/framework/deps/no_destructor.h" +#include "mediapipe/framework/tool/type_util.h" #include "mediapipe/gpu/gpu_buffer_format.h" namespace mediapipe { -class GlTextureView; class GpuBuffer; -} // namespace mediapipe - -namespace mediapipe { namespace internal { template @@ -18,69 +20,181 @@ struct types {}; template class ViewProvider; -// Note: this specialization temporarily lives here for backwards compatibility -// reasons. New specializations should be put in the same file as their view. -template <> -class ViewProvider { +// An identifier for a type. We have often used size_t holding a hash for this +// purpose in MediaPipe, but a non-primitive type makes the code more readable. +// Ideally we should clean up the various ways this is handled throughout the +// framework and consolidate the utilities in type_util. When that is done, this +// type can be replaced. +class TypeRef { public: - virtual ~ViewProvider() = default; - // Note that the view type is encoded in an argument to allow overloading, - // so a storage class can implement GetRead/WriteView for multiple view types. - // We cannot use a template function because it cannot be virtual; we want to - // have a virtual function here to enforce that different storages supporting - // the same view implement the same signature. - // Note that we allow different views to have custom signatures, providing - // additional view-specific arguments that may be needed. - virtual GlTextureView GetReadView(types, - std::shared_ptr gpu_buffer, - int plane) const = 0; - virtual GlTextureView GetWriteView(types, - std::shared_ptr gpu_buffer, - int plane) = 0; + template + static TypeRef Get() { + return TypeRef{tool::GetTypeHash()}; + } + + bool operator==(const TypeRef& other) const { return hash_ == other.hash_; } + + template + friend H AbslHashValue(H h, const TypeRef& r) { + return H::combine(std::move(h), r.hash_); + } + + private: + explicit TypeRef(size_t hash) : hash_(hash) {} + + size_t hash_; }; +// Interface for a backing storage for GpuBuffer. class GpuBufferStorage { public: virtual ~GpuBufferStorage() = default; virtual int width() const = 0; virtual int height() const = 0; virtual GpuBufferFormat format() const = 0; - virtual std::unique_ptr AsImageFrame() const = 0; // We can't use dynamic_cast since we want to support building without RTTI. // The public methods delegate to the type-erased private virtual method. template T* down_cast() { - return static_cast( - const_cast(down_cast(tool::GetTypeHash()))); + return static_cast(const_cast(down_cast(TypeRef::Get()))); } template const T* down_cast() const { - return static_cast(down_cast(tool::GetTypeHash())); + return static_cast(down_cast(TypeRef::Get())); } + bool can_down_cast_to(TypeRef to) const { return down_cast(to) != nullptr; } + virtual TypeRef storage_type() const = 0; + private: - virtual const void* down_cast(size_t type_hash) const = 0; - virtual size_t storage_type_hash() const = 0; + virtual const void* down_cast(TypeRef to) const = 0; }; +// Used to disambiguate between overloads by manually specifying their priority. +// Higher Ns will be picked first. The caller should pass overload_priority +// where M is >= the largest N used in overloads (e.g. 10). +template +struct overload_priority : public overload_priority {}; +template <> +struct overload_priority<0> {}; + +// Manages the available GpuBufferStorage implementations. The list of available +// implementations is built at runtime using a registration mechanism, so that +// it can be determined by the program's dependencies. +class GpuBufferStorageRegistry { + public: + struct RegistryToken {}; + + using StorageFactory = std::function( + int, int, GpuBufferFormat)>; + using StorageConverter = std::function( + std::shared_ptr)>; + + static GpuBufferStorageRegistry& Get() { + static NoDestructor registry; + return *registry; + } + + template + RegistryToken Register() { + return Register( + [](int width, int height, + GpuBufferFormat format) -> std::shared_ptr { + return CreateStorage(overload_priority<10>{}, width, height, + format); + }, + Storage::GetProviderTypes()); + } + + template + RegistryToken RegisterConverter(F&& converter) { + return Register( + [converter](std::shared_ptr source) + -> std::shared_ptr { + return converter(std::static_pointer_cast(source)); + }, + StorageTo::GetProviderTypes(), TypeRef::Get()); + } + + // Returns a factory function for a storage that implements + // view_provider_type. + StorageFactory StorageFactoryForViewProvider(TypeRef view_provider_type); + + // Returns a conversion function that, given a storage of + // existing_storage_type, converts its contents to a new storage that + // implements view_provider_type. + StorageConverter StorageConverterForViewProvider( + TypeRef view_provider_type, TypeRef existing_storage_type); + + private: + template + static auto CreateStorage(overload_priority<1>, Args... args) + -> decltype(Storage::Create(args...)) { + return Storage::Create(args...); + } + + template + static auto CreateStorage(overload_priority<0>, Args... args) { + return std::make_shared(args...); + } + + RegistryToken Register(StorageFactory factory, + std::vector provider_hashes); + RegistryToken Register(StorageConverter converter, + std::vector provider_hashes, + TypeRef source_storage); + + absl::flat_hash_map factory_for_view_provider_; + absl::flat_hash_map, StorageConverter> + converter_for_view_provider_and_existing_storage_; +}; + +// Defining a member of this type causes P to be ODR-used, which forces its +// instantiation if it's a static member of a template. +template +struct ForceStaticInstantiation { +#ifdef _MSC_VER + // Just having it as the template argument does not count as a use for + // MSVC. + static constexpr bool Use() { return P != nullptr; } + char force_static[Use()]; +#endif // _MSC_VER +}; + +// T: storage type +// U...: ViewProvider template class GpuBufferStorageImpl : public GpuBufferStorage, public U... { - private: - virtual const void* down_cast(size_t type_hash) const override { - return down_cast_impl(type_hash, types{}); + public: + static const std::vector& GetProviderTypes() { + static std::vector kHashes{TypeRef::Get()...}; + return kHashes; } - size_t storage_type_hash() const override { return tool::GetTypeHash(); } - const void* down_cast_impl(size_t type_hash, types<>) const { - return nullptr; + private: + virtual const void* down_cast(TypeRef to) const override { + return down_cast_impl(to, types{}); } + TypeRef storage_type() const override { return TypeRef::Get(); } + + const void* down_cast_impl(TypeRef to, types<>) const { return nullptr; } template - const void* down_cast_impl(size_t type_hash, types) const { - if (type_hash == tool::GetTypeHash()) return static_cast(this); - return down_cast_impl(type_hash, types{}); + const void* down_cast_impl(TypeRef to, types) const { + if (to == TypeRef::Get()) return static_cast(this); + return down_cast_impl(to, types{}); } + + inline static auto registration = + GpuBufferStorageRegistry::Get().Register(); + using RequireStatics = ForceStaticInstantiation<®istration>; }; +// This function can be overridden to enable construction of a GpuBuffer from +// platform-specific types without having to expose that type in the GpuBuffer +// definition. It is only needed for backward compatibility reasons; do not add +// overrides for new types. +std::shared_ptr AsGpuBufferStorage(); + } // namespace internal } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc index 9370acbc1..d68ac0db0 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc @@ -1,6 +1,9 @@ #include "mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h" +#include + #include "mediapipe/gpu/gl_context.h" +#include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" #include "mediapipe/objc/util.h" namespace mediapipe { @@ -11,9 +14,20 @@ typedef CVOpenGLTextureRef CVTextureType; typedef CVOpenGLESTextureRef CVTextureType; #endif // TARGET_OS_OSX -GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView( - mediapipe::internal::types, - std::shared_ptr gpu_buffer, int plane) const { +GpuBufferStorageCvPixelBuffer::GpuBufferStorageCvPixelBuffer( + int width, int height, GpuBufferFormat format) { + OSType cv_format = CVPixelFormatForGpuBufferFormat(format); + CHECK_NE(cv_format, -1) << "unsupported pixel format"; + CVPixelBufferRef buffer; + CVReturn err = + CreateCVPixelBufferWithoutPool(width, height, cv_format, &buffer); + CHECK(!err) << "Error creating pixel buffer: " << err; + adopt(buffer); +} + +GlTextureView GpuBufferStorageCvPixelBuffer::GetTexture( + std::shared_ptr gpu_buffer, int plane, + GlTextureView::DoneWritingFn done_writing) const { CVReturn err; auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); @@ -29,8 +43,8 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView( return GlTextureView( gl_context.get(), CVOpenGLTextureGetTarget(*cv_texture), CVOpenGLTextureGetName(*cv_texture), width(), height(), *this, plane, - [cv_texture]( - mediapipe::GlTextureView&) { /* only retains cv_texture */ }); + [cv_texture](mediapipe::GlTextureView&) { /* only retains cv_texture */ }, + done_writing); #else const GlTextureInfo info = GlTextureInfoForGpuBufferFormat( format(), plane, gl_context->GetGlVersion()); @@ -49,23 +63,31 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView( CVOpenGLESTextureGetName(*cv_texture), width(), height(), std::move(gpu_buffer), plane, [cv_texture](mediapipe::GlTextureView&) { /* only retains cv_texture */ }, - // TODO: make GetGlTextureView for write view non-const, remove cast - // Note: we have to copy *this here because this storage is currently - // stored in GpuBuffer by value, and so the this pointer becomes invalid - // if the GpuBuffer is moved/copied. TODO: fix this. - [me = *this](const mediapipe::GlTextureView& view) { - const_cast(&me)->ViewDoneWriting(view); - }); + done_writing); #endif // TARGET_OS_OSX } +GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView( + internal::types, std::shared_ptr gpu_buffer, + int plane) const { + return GetTexture(std::move(gpu_buffer), plane, nullptr); +} + GlTextureView GpuBufferStorageCvPixelBuffer::GetWriteView( - mediapipe::internal::types, - std::shared_ptr gpu_buffer, int plane) { - // For this storage there is currently no difference between read and write - // views, so we delegate to the read method. - return GetReadView(mediapipe::internal::types{}, - std::move(gpu_buffer), plane); + internal::types, std::shared_ptr gpu_buffer, + int plane) { + return GetTexture( + std::move(gpu_buffer), plane, + [this](const mediapipe::GlTextureView& view) { ViewDoneWriting(view); }); +} + +std::shared_ptr GpuBufferStorageCvPixelBuffer::GetReadView( + internal::types, std::shared_ptr gpu_buffer) const { + return CreateImageFrameForCVPixelBuffer(**this); +} +std::shared_ptr GpuBufferStorageCvPixelBuffer::GetWriteView( + internal::types, std::shared_ptr gpu_buffer) { + return CreateImageFrameForCVPixelBuffer(**this); } void GpuBufferStorageCvPixelBuffer::ViewDoneWriting(const GlTextureView& view) { @@ -111,9 +133,32 @@ void GpuBufferStorageCvPixelBuffer::ViewDoneWriting(const GlTextureView& view) { #endif } -std::unique_ptr GpuBufferStorageCvPixelBuffer::AsImageFrame() - const { - return CreateImageFrameForCVPixelBuffer(**this); +static std::shared_ptr ConvertFromImageFrame( + std::shared_ptr frame) { + auto status_or_buffer = + CreateCVPixelBufferForImageFrame(frame->image_frame()); + CHECK(status_or_buffer.ok()); + return std::make_shared( + std::move(status_or_buffer).value()); } +static auto kConverterFromImageFrameRegistration = + internal::GpuBufferStorageRegistry::Get() + .RegisterConverter( + ConvertFromImageFrame); + +namespace internal { +std::shared_ptr AsGpuBufferStorage( + CFHolder pixel_buffer) { + return std::make_shared( + std::move(pixel_buffer)); +} + +std::shared_ptr AsGpuBufferStorage( + CVPixelBufferRef pixel_buffer) { + return std::make_shared(pixel_buffer); +} +} // namespace internal + } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h index 237138f77..017771dc7 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h @@ -5,6 +5,7 @@ #include "mediapipe/gpu/gl_texture_view.h" #include "mediapipe/gpu/gpu_buffer_storage.h" +#include "mediapipe/gpu/image_frame_view.h" #include "mediapipe/objc/CFHolder.h" namespace mediapipe { @@ -12,12 +13,13 @@ namespace mediapipe { class GlContext; class GpuBufferStorageCvPixelBuffer - : public mediapipe::internal::GpuBufferStorageImpl< - GpuBufferStorageCvPixelBuffer, - mediapipe::internal::ViewProvider>, + : public internal::GpuBufferStorageImpl< + GpuBufferStorageCvPixelBuffer, internal::ViewProvider, + internal::ViewProvider>, public CFHolder { public: using CFHolder::CFHolder; + GpuBufferStorageCvPixelBuffer(int width, int height, GpuBufferFormat format); GpuBufferStorageCvPixelBuffer(const CFHolder& other) : CFHolder(other) {} GpuBufferStorageCvPixelBuffer(CFHolder&& other) @@ -30,18 +32,35 @@ class GpuBufferStorageCvPixelBuffer return GpuBufferFormatForCVPixelFormat( CVPixelBufferGetPixelFormatType(**this)); } - GlTextureView GetReadView(mediapipe::internal::types, + GlTextureView GetReadView(internal::types, std::shared_ptr gpu_buffer, int plane) const override; - GlTextureView GetWriteView(mediapipe::internal::types, + GlTextureView GetWriteView(internal::types, std::shared_ptr gpu_buffer, int plane) override; - std::unique_ptr AsImageFrame() const override; + std::shared_ptr GetReadView( + internal::types, + std::shared_ptr gpu_buffer) const override; + std::shared_ptr GetWriteView( + internal::types, + std::shared_ptr gpu_buffer) override; private: + GlTextureView GetTexture(std::shared_ptr gpu_buffer, int plane, + GlTextureView::DoneWritingFn done_writing) const; void ViewDoneWriting(const GlTextureView& view); }; +namespace internal { +// These functions enable backward-compatible construction of a GpuBuffer from +// CVPixelBufferRef without having to expose that type in the main GpuBuffer +// header. +std::shared_ptr AsGpuBufferStorage( + CFHolder pixel_buffer); +std::shared_ptr AsGpuBufferStorage( + CVPixelBufferRef pixel_buffer); +} // namespace internal + } // namespace mediapipe #endif // MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_CV_PIXEL_BUFFER_H_ diff --git a/mediapipe/gpu/gpu_buffer_storage_image_frame.h b/mediapipe/gpu/gpu_buffer_storage_image_frame.h new file mode 100644 index 000000000..2cea3445e --- /dev/null +++ b/mediapipe/gpu/gpu_buffer_storage_image_frame.h @@ -0,0 +1,48 @@ +#ifndef MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_IMAGE_FRAME_H_ +#define MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_IMAGE_FRAME_H_ + +#include + +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/gpu/gpu_buffer_format.h" +#include "mediapipe/gpu/gpu_buffer_storage.h" +#include "mediapipe/gpu/image_frame_view.h" + +namespace mediapipe { + +// Implements support for ImageFrame as a backing storage of GpuBuffer. +class GpuBufferStorageImageFrame + : public internal::GpuBufferStorageImpl< + GpuBufferStorageImageFrame, internal::ViewProvider> { + public: + explicit GpuBufferStorageImageFrame(std::shared_ptr image_frame) + : image_frame_(image_frame) {} + GpuBufferStorageImageFrame(int width, int height, GpuBufferFormat format) { + image_frame_ = std::make_shared( + ImageFormatForGpuBufferFormat(format), width, height); + } + int width() const override { return image_frame_->Width(); } + int height() const override { return image_frame_->Height(); } + GpuBufferFormat format() const override { + return GpuBufferFormatForImageFormat(image_frame_->Format()); + } + std::shared_ptr image_frame() const { return image_frame_; } + std::shared_ptr image_frame() { return image_frame_; } + std::shared_ptr GetReadView( + internal::types, + std::shared_ptr gpu_buffer) const override { + return image_frame_; + } + std::shared_ptr GetWriteView( + internal::types, + std::shared_ptr gpu_buffer) override { + return image_frame_; + } + + private: + std::shared_ptr image_frame_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_IMAGE_FRAME_H_ diff --git a/mediapipe/gpu/gpu_buffer_test.cc b/mediapipe/gpu/gpu_buffer_test.cc index a4bd93a7a..daf64d9c5 100644 --- a/mediapipe/gpu/gpu_buffer_test.cc +++ b/mediapipe/gpu/gpu_buffer_test.cc @@ -14,13 +14,89 @@ #include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" #include "mediapipe/gpu/gpu_test_base.h" +#include "stb_image.h" +#include "stb_image_write.h" namespace mediapipe { namespace { +// Write an ImageFrame as PNG to the test undeclared outputs directory. +// The image's name will contain the given prefix and a timestamp. +// Returns the path to the output if successful. +std::string SavePngImage(const mediapipe::ImageFrame& image, + absl::string_view prefix) { + std::string output_dir = mediapipe::GetTestOutputsDir(); + std::string now_string = absl::FormatTime(absl::Now()); + std::string out_file_path = + absl::StrCat(output_dir, "/", prefix, "_", now_string, ".png"); + EXPECT_TRUE(stbi_write_png(out_file_path.c_str(), image.Width(), + image.Height(), image.NumberOfChannels(), + image.PixelData(), image.WidthStep())) + << " path: " << out_file_path; + return out_file_path; +} + +void FillImageFrameRGBA(ImageFrame& image, uint8 r, uint8 g, uint8 b, uint8 a) { + auto* data = image.MutablePixelData(); + for (int y = 0; y < image.Height(); ++y) { + auto* row = data + image.WidthStep() * y; + for (int x = 0; x < image.Width(); ++x) { + auto* pixel = row + x * image.NumberOfChannels(); + pixel[0] = r; + pixel[1] = g; + pixel[2] = b; + pixel[3] = a; + } + } +} + +// Assumes a framebuffer is already set up +void CopyGlTexture(const GlTextureView& src, GlTextureView& dst) { + glViewport(0, 0, src.width(), src.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, src.target(), + src.name(), 0); + + glActiveTexture(GL_TEXTURE0); + glBindTexture(dst.target(), dst.name()); + glCopyTexSubImage2D(dst.target(), 0, 0, 0, 0, 0, dst.width(), dst.height()); + + glBindTexture(dst.target(), 0); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, src.target(), 0, + 0); +} + +void FillGlTextureRgba(GlTextureView& view, float r, float g, float b, + float a) { + glViewport(0, 0, view.width(), view.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), + view.name(), 0); + glClearColor(r, g, b, a); + glClear(GL_COLOR_BUFFER_BIT); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), 0, + 0); +} + +class TempGlFramebuffer { + public: + TempGlFramebuffer() { + glGenFramebuffers(1, &framebuffer_); + glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); + } + ~TempGlFramebuffer() { + glBindFramebuffer(GL_FRAMEBUFFER, 0); + glDeleteFramebuffers(1, &framebuffer_); + } + + private: + GLuint framebuffer_; +}; + class GpuBufferTest : public GpuTestBase {}; TEST_F(GpuBufferTest, BasicTest) { @@ -46,5 +122,144 @@ TEST_F(GpuBufferTest, BasicTest) { }); } +TEST_F(GpuBufferTest, GlTextureView) { + GpuBuffer buffer(300, 200, GpuBufferFormat::kBGRA32); + EXPECT_EQ(buffer.width(), 300); + EXPECT_EQ(buffer.height(), 200); + EXPECT_TRUE(buffer); + EXPECT_FALSE(buffer == nullptr); + + RunInGlContext([&buffer] { + TempGlFramebuffer fb; + auto view = buffer.GetWriteView(0); + FillGlTextureRgba(view, 1.0, 0.0, 0.0, 1.0); + glFlush(); + }); + std::shared_ptr view = buffer.GetReadView(); + EXPECT_EQ(view->Width(), 300); + EXPECT_EQ(view->Height(), 200); + + ImageFrame red(ImageFormat::SRGBA, 300, 200); + FillImageFrameRGBA(red, 255, 0, 0, 255); + + EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0)); + SavePngImage(red, "gltv_red_gold"); + SavePngImage(*view, "gltv_red_view"); +} + +TEST_F(GpuBufferTest, ImageFrame) { + GpuBuffer buffer(300, 200, GpuBufferFormat::kBGRA32); + EXPECT_EQ(buffer.width(), 300); + EXPECT_EQ(buffer.height(), 200); + EXPECT_TRUE(buffer); + EXPECT_FALSE(buffer == nullptr); + + { + std::shared_ptr view = buffer.GetWriteView(); + EXPECT_EQ(view->Width(), 300); + EXPECT_EQ(view->Height(), 200); + FillImageFrameRGBA(*view, 255, 0, 0, 255); + } + + GpuBuffer buffer2(300, 200, GpuBufferFormat::kBGRA32); + RunInGlContext([&buffer, &buffer2] { + TempGlFramebuffer fb; + auto src = buffer.GetReadView(0); + auto dst = buffer2.GetWriteView(0); + CopyGlTexture(src, dst); + glFlush(); + }); + { + std::shared_ptr view = buffer2.GetReadView(); + EXPECT_EQ(view->Width(), 300); + EXPECT_EQ(view->Height(), 200); + + ImageFrame red(ImageFormat::SRGBA, 300, 200); + FillImageFrameRGBA(red, 255, 0, 0, 255); + + EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0)); + SavePngImage(red, "if_red_gold"); + SavePngImage(*view, "if_red_view"); + } +} + +TEST_F(GpuBufferTest, Overwrite) { + GpuBuffer buffer(300, 200, GpuBufferFormat::kBGRA32); + EXPECT_EQ(buffer.width(), 300); + EXPECT_EQ(buffer.height(), 200); + EXPECT_TRUE(buffer); + EXPECT_FALSE(buffer == nullptr); + + { + std::shared_ptr view = buffer.GetWriteView(); + EXPECT_EQ(view->Width(), 300); + EXPECT_EQ(view->Height(), 200); + FillImageFrameRGBA(*view, 255, 0, 0, 255); + } + + GpuBuffer red_copy(300, 200, GpuBufferFormat::kBGRA32); + RunInGlContext([&buffer, &red_copy] { + TempGlFramebuffer fb; + auto src = buffer.GetReadView(0); + auto dst = red_copy.GetWriteView(0); + CopyGlTexture(src, dst); + glFlush(); + }); + + { + std::shared_ptr view = red_copy.GetReadView(); + ImageFrame red(ImageFormat::SRGBA, 300, 200); + FillImageFrameRGBA(red, 255, 0, 0, 255); + + EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0)); + SavePngImage(red, "ow_red_gold"); + SavePngImage(*view, "ow_red_view"); + } + + { + std::shared_ptr view = buffer.GetWriteView(); + EXPECT_EQ(view->Width(), 300); + EXPECT_EQ(view->Height(), 200); + FillImageFrameRGBA(*view, 0, 255, 0, 255); + } + + GpuBuffer green_copy(300, 200, GpuBufferFormat::kBGRA32); + RunInGlContext([&buffer, &green_copy] { + TempGlFramebuffer fb; + auto src = buffer.GetReadView(0); + auto dst = green_copy.GetWriteView(0); + CopyGlTexture(src, dst); + glFlush(); + }); + + RunInGlContext([&buffer] { + TempGlFramebuffer fb; + auto view = buffer.GetWriteView(0); + FillGlTextureRgba(view, 0.0, 0.0, 1.0, 1.0); + glFlush(); + }); + + { + std::shared_ptr view = + green_copy.GetReadView(); + ImageFrame green(ImageFormat::SRGBA, 300, 200); + FillImageFrameRGBA(green, 0, 255, 0, 255); + + EXPECT_TRUE(mediapipe::CompareImageFrames(*view, green, 0.0, 0.0)); + SavePngImage(green, "ow_green_gold"); + SavePngImage(*view, "ow_green_view"); + } + + { + std::shared_ptr view = buffer.GetReadView(); + ImageFrame blue(ImageFormat::SRGBA, 300, 200); + FillImageFrameRGBA(blue, 0, 0, 255, 255); + + EXPECT_TRUE(mediapipe::CompareImageFrames(*view, blue, 0.0, 0.0)); + SavePngImage(blue, "ow_blue_gold"); + SavePngImage(*view, "ow_blue_view"); + } +} + } // anonymous namespace } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_to_image_frame_calculator.cc b/mediapipe/gpu/gpu_buffer_to_image_frame_calculator.cc index 8bca0a27d..c9527880a 100644 --- a/mediapipe/gpu/gpu_buffer_to_image_frame_calculator.cc +++ b/mediapipe/gpu/gpu_buffer_to_image_frame_calculator.cc @@ -76,7 +76,7 @@ absl::Status GpuBufferToImageFrameCalculator::Process(CalculatorContext* cc) { const auto& input = cc->Inputs().Index(0).Get(); #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER std::unique_ptr frame = - CreateImageFrameForCVPixelBuffer(input.GetCVPixelBufferRef()); + CreateImageFrameForCVPixelBuffer(GetCVPixelBufferRef(input)); cc->Outputs().Index(0).Add(frame.release(), cc->InputTimestamp()); #else helper_.RunInGlContext([this, &input, &cc]() { diff --git a/mediapipe/gpu/gpu_service.cc b/mediapipe/gpu/gpu_service.cc index ed9c8f62f..f280a58d0 100644 --- a/mediapipe/gpu/gpu_service.cc +++ b/mediapipe/gpu/gpu_service.cc @@ -16,6 +16,6 @@ namespace mediapipe { -const GraphService<::mediapipe::GpuResources> kGpuService("kGpuService"); +const GraphService kGpuService("kGpuService"); } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_service.h b/mediapipe/gpu/gpu_service.h index 52be9d720..a610a275f 100644 --- a/mediapipe/gpu/gpu_service.h +++ b/mediapipe/gpu/gpu_service.h @@ -18,12 +18,9 @@ #include "mediapipe/framework/graph_service.h" namespace mediapipe { + class GpuResources; -} // namespace mediapipe - -namespace mediapipe { - -extern const GraphService<::mediapipe::GpuResources> kGpuService; +extern const GraphService kGpuService; } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index 10edb601a..7aa622d24 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -117,7 +117,8 @@ absl::Status GpuResources::PrepareGpuNode(CalculatorNode* node) { (node_type == "GpuBufferToImageFrameCalculator") || (node_type == "GlSurfaceSinkCalculator"); - const auto& options = node->GetCalculatorState().Options(); + const auto& options = + node->GetCalculatorState().Options(); if (options.has_gl_context_name() && !options.gl_context_name().empty()) { context_key = absl::StrCat("user:", options.gl_context_name()); } else if (gets_own_context) { diff --git a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc index 2a8331db8..11f3ae58c 100644 --- a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc +++ b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc @@ -16,10 +16,7 @@ #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/gpu/gl_calculator_helper.h" - -#ifdef __APPLE__ -#include "mediapipe/objc/util.h" -#endif +#include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" namespace mediapipe { @@ -34,9 +31,7 @@ class ImageFrameToGpuBufferCalculator : public CalculatorBase { absl::Status Process(CalculatorContext* cc) override; private: -#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER GlCalculatorHelper helper_; -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER }; REGISTER_CALCULATOR(ImageFrameToGpuBufferCalculator); @@ -56,28 +51,25 @@ absl::Status ImageFrameToGpuBufferCalculator::Open(CalculatorContext* cc) { // Inform the framework that we always output at the same timestamp // as we receive a packet at. cc->SetOffset(TimestampDiff(0)); -#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER MP_RETURN_IF_ERROR(helper_.Open(cc)); -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER return absl::OkStatus(); } absl::Status ImageFrameToGpuBufferCalculator::Process(CalculatorContext* cc) { -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - CFHolder buffer; - MP_RETURN_IF_ERROR(CreateCVPixelBufferForImageFramePacket( - cc->Inputs().Index(0).Value(), &buffer)); - cc->Outputs().Index(0).Add(new GpuBuffer(buffer), cc->InputTimestamp()); -#else - const auto& input = cc->Inputs().Index(0).Get(); - helper_.RunInGlContext([this, &input, &cc]() { - auto src = helper_.CreateSourceTexture(input); - auto output = src.GetFrame(); - glFlush(); - cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - src.Release(); + auto image_frame = std::const_pointer_cast( + mediapipe::SharedPtrWithPacket( + cc->Inputs().Index(0).Value())); + auto gpu_buffer = MakePacket( + std::make_shared( + std::move(image_frame))) + .At(cc->InputTimestamp()); + // Request GPU access to ensure the data is available to the GPU. + // TODO: have a better way to do this, or defer until later. + helper_.RunInGlContext([&gpu_buffer] { + auto view = gpu_buffer.Get().GetReadView(0); }); -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + cc->Outputs().Index(0).AddPacket(std::move(gpu_buffer)); + return absl::OkStatus(); } diff --git a/mediapipe/gpu/image_frame_view.h b/mediapipe/gpu/image_frame_view.h new file mode 100644 index 000000000..2fc6f2495 --- /dev/null +++ b/mediapipe/gpu/image_frame_view.h @@ -0,0 +1,23 @@ +#ifndef MEDIAPIPE_GPU_IMAGE_FRAME_VIEW_H_ +#define MEDIAPIPE_GPU_IMAGE_FRAME_VIEW_H_ + +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/gpu/gpu_buffer_storage.h" + +namespace mediapipe { +namespace internal { + +template <> +class ViewProvider { + public: + virtual ~ViewProvider() = default; + virtual std::shared_ptr GetReadView( + types, std::shared_ptr gpu_buffer) const = 0; + virtual std::shared_ptr GetWriteView( + types, std::shared_ptr gpu_buffer) = 0; +}; + +} // namespace internal +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_IMAGE_FRAME_VIEW_H_ diff --git a/mediapipe/gpu/metal.bzl b/mediapipe/gpu/metal.bzl index d623f4c3e..dcfb268fc 100644 --- a/mediapipe/gpu/metal.bzl +++ b/mediapipe/gpu/metal.bzl @@ -134,16 +134,8 @@ def _metal_library_impl(ctx): ), ) - # This circumlocution is needed because new_objc_provider rejects - # an empty depset, with the error: - # "Value for key header must be a set of File, instead found set of unknown." - # It also rejects an explicit "None". - additional_params = {} - if ctx.files.hdrs: - additional_params["header"] = depset([f for f in ctx.files.hdrs]) objc_provider = apple_common.new_objc_provider( providers = [x[apple_common.Objc] for x in ctx.attr.deps if apple_common.Objc in x], - **additional_params ) cc_infos = [dep[CcInfo] for dep in ctx.attr.deps if CcInfo in dep] diff --git a/mediapipe/graphs/face_mesh/subgraphs/BUILD b/mediapipe/graphs/face_mesh/subgraphs/BUILD index fbb946dfb..872ebba13 100644 --- a/mediapipe/graphs/face_mesh/subgraphs/BUILD +++ b/mediapipe/graphs/face_mesh/subgraphs/BUILD @@ -24,7 +24,7 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "renderer_calculators", deps = [ - "//mediapipe/calculators/core:split_landmarks_calculator", + "//mediapipe/calculators/core:split_proto_list_calculator", "//mediapipe/calculators/util:annotation_overlay_calculator", "//mediapipe/calculators/util:detections_to_render_data_calculator", "//mediapipe/calculators/util:landmarks_to_render_data_calculator", diff --git a/mediapipe/graphs/holistic_tracking/BUILD b/mediapipe/graphs/holistic_tracking/BUILD index 986cf9f3a..dec521de3 100644 --- a/mediapipe/graphs/holistic_tracking/BUILD +++ b/mediapipe/graphs/holistic_tracking/BUILD @@ -27,10 +27,10 @@ mediapipe_simple_subgraph( graph = "holistic_tracking_to_render_data.pbtxt", register_as = "HolisticTrackingToRenderData", deps = [ - "//mediapipe/calculators/core:concatenate_normalized_landmark_list_calculator", + "//mediapipe/calculators/core:concatenate_proto_list_calculator", "//mediapipe/calculators/core:concatenate_vector_calculator", "//mediapipe/calculators/core:merge_calculator", - "//mediapipe/calculators/core:split_landmarks_calculator", + "//mediapipe/calculators/core:split_proto_list_calculator", "//mediapipe/calculators/core:split_vector_calculator", "//mediapipe/calculators/util:detections_to_render_data_calculator", "//mediapipe/calculators/util:landmarks_to_render_data_calculator", diff --git a/mediapipe/graphs/iris_tracking/subgraphs/BUILD b/mediapipe/graphs/iris_tracking/subgraphs/BUILD index d37c55095..8374fb79c 100644 --- a/mediapipe/graphs/iris_tracking/subgraphs/BUILD +++ b/mediapipe/graphs/iris_tracking/subgraphs/BUILD @@ -24,9 +24,9 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "renderer_calculators", deps = [ - "//mediapipe/calculators/core:concatenate_normalized_landmark_list_calculator", + "//mediapipe/calculators/core:concatenate_proto_list_calculator", "//mediapipe/calculators/core:concatenate_vector_calculator", - "//mediapipe/calculators/core:split_landmarks_calculator", + "//mediapipe/calculators/core:split_proto_list_calculator", "//mediapipe/calculators/util:annotation_overlay_calculator", "//mediapipe/calculators/util:detection_label_id_to_text_calculator", "//mediapipe/calculators/util:detections_to_render_data_calculator", diff --git a/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc b/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc index 9bc43ba03..3a77e983a 100644 --- a/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc +++ b/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc @@ -663,7 +663,7 @@ absl::Status GlAnimationOverlayCalculator::Process(CalculatorContext *cc) { if (result.ok()) { input_frame = std::move(result).value(); #if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - input_frame->GetGlTextureBufferSharedPtr()->Reuse(); + input_frame->internal_storage()->Reuse(); #endif width = input_frame->width(); height = input_frame->height(); diff --git a/mediapipe/graphs/pose_tracking/subgraphs/BUILD b/mediapipe/graphs/pose_tracking/subgraphs/BUILD index fa3464062..8831692cb 100644 --- a/mediapipe/graphs/pose_tracking/subgraphs/BUILD +++ b/mediapipe/graphs/pose_tracking/subgraphs/BUILD @@ -26,7 +26,7 @@ mediapipe_simple_subgraph( graph = "pose_renderer_gpu.pbtxt", register_as = "PoseRendererGpu", deps = [ - "//mediapipe/calculators/core:split_landmarks_calculator", + "//mediapipe/calculators/core:split_proto_list_calculator", "//mediapipe/calculators/image:recolor_calculator", "//mediapipe/calculators/util:annotation_overlay_calculator", "//mediapipe/calculators/util:detections_to_render_data_calculator", @@ -41,7 +41,7 @@ mediapipe_simple_subgraph( graph = "pose_renderer_cpu.pbtxt", register_as = "PoseRendererCpu", deps = [ - "//mediapipe/calculators/core:split_landmarks_calculator", + "//mediapipe/calculators/core:split_proto_list_calculator", "//mediapipe/calculators/image:recolor_calculator", "//mediapipe/calculators/util:annotation_overlay_calculator", "//mediapipe/calculators/util:detections_to_render_data_calculator", diff --git a/mediapipe/java/com/google/mediapipe/BUILD b/mediapipe/java/com/google/mediapipe/BUILD index 6995a7636..1c320b161 100644 --- a/mediapipe/java/com/google/mediapipe/BUILD +++ b/mediapipe/java/com/google/mediapipe/BUILD @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http:#www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java b/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java index 6910d4d7f..f4aa330dd 100644 --- a/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java +++ b/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java @@ -103,6 +103,24 @@ public class ExternalTextureConverter implements TextureFrameProducer { } } + /** + * Sets the new buffer pool size. This is safe to set at any time. + * + * This doesn't adjust the buffer pool right way. Instead, it behaves as follows: + * + * If the new size is smaller: Excess frames in pool are not de-allocated, but rather when frames + * are released, they wouldn't be added back to the pool until size restriction is met. + * + * If the new size is greater: New frames won't be created immediately. ETC anyway creates new + * frames when all frames in the pool are in-use, but they are only added back to the pool upon + * release if the size allows so. + * + * @param bufferPoolSize the number of camera frames that can enter processing simultaneously. + */ + public void setBufferPoolSize(int bufferPoolSize) { + thread.setBufferPoolSize(bufferPoolSize); + } + /** * Sets vertical flipping of the texture, useful for conversion between coordinate systems with * top-left v.s. bottom-left origins. This should be called before {@link @@ -242,7 +260,7 @@ public class ExternalTextureConverter implements TextureFrameProducer { private final Queue framesAvailable = new ArrayDeque<>(); private int framesInUse = 0; - private final int framesToKeep; + private int framesToKeep; private ExternalTextureRenderer renderer = null; private long nextFrameTimestampOffset = 0; @@ -278,6 +296,10 @@ public class ExternalTextureConverter implements TextureFrameProducer { consumers = new ArrayList<>(); } + public void setBufferPoolSize(int bufferPoolSize) { + this.framesToKeep = bufferPoolSize; + } + public void setFlipY(boolean flip) { renderer.setFlipY(flip); } @@ -408,6 +430,9 @@ public class ExternalTextureConverter implements TextureFrameProducer { // TODO: Switch to ref-counted single copy instead of making additional // copies blitting to separate textures each time. updateOutputFrame(outputFrame); + // Release immediately as this is not sent to a consumer so no release() would be + // called otherwise. + outputFrame.release(); } } } finally { diff --git a/mediapipe/java/com/google/mediapipe/framework/BUILD b/mediapipe/java/com/google/mediapipe/framework/BUILD index ed1a42c20..7b1a89166 100644 --- a/mediapipe/java/com/google/mediapipe/framework/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/BUILD @@ -49,6 +49,7 @@ android_library( "MediaPipeRunner.java", ], visibility = [ + "//java/com/google/android/libraries/camera/effects:__subpackages__", "//mediapipe/java/com/google/mediapipe:__subpackages__", ], exports = [ diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc index 29211b224..1e4d74fab 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc @@ -29,7 +29,6 @@ #include "mediapipe/framework/port/threadpool.h" #include "mediapipe/framework/tool/executor_util.h" #include "mediapipe/framework/tool/name_util.h" -#include "mediapipe/gpu/gpu_shared_data_internal.h" #include "mediapipe/gpu/graph_support.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/class_registry.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h" @@ -41,6 +40,7 @@ #endif // __ANDROID__ #if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/egl_surface_holder.h" +#include "mediapipe/gpu/gpu_shared_data_internal.h" #endif // !MEDIAPIPE_DISABLE_GPU namespace mediapipe { @@ -572,19 +572,21 @@ void Graph::SetGraphInputStreamAddMode( graph_input_stream_add_mode_ = mode; } +#if !MEDIAPIPE_DISABLE_GPU mediapipe::GpuResources* Graph::GetGpuResources() const { return gpu_resources_.get(); } +#endif // !MEDIAPIPE_DISABLE_GPU absl::Status Graph::SetParentGlContext(int64 java_gl_context) { +#if MEDIAPIPE_DISABLE_GPU + LOG(FATAL) << "GPU support has been disabled in this build!"; +#else if (gpu_resources_) { return absl::AlreadyExistsError( "trying to set the parent GL context, but the gpu shared " "data has already been set up."); } -#if MEDIAPIPE_DISABLE_GPU - LOG(FATAL) << "GPU support has been disabled in this build!"; -#else ASSIGN_OR_RETURN(gpu_resources_, mediapipe::GpuResources::Create( reinterpret_cast(java_gl_context))); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph.h b/mediapipe/java/com/google/mediapipe/framework/jni/graph.h index 87ac516bd..a5a635e29 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph.h @@ -27,9 +27,9 @@ #include "mediapipe/framework/calculator_framework.h" #if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gpu_shared_data_internal.h" #endif // !MEDIAPIPE_DISABLE_GPU #include "absl/synchronization/mutex.h" -#include "mediapipe/gpu/gpu_shared_data_internal.h" namespace mediapipe { namespace android { @@ -119,11 +119,11 @@ class Graph { // Puts a mediapipe packet into the context for management. // Returns the handle to the internal PacketWithContext object. int64_t WrapPacketIntoContext(const Packet& packet); - +#if !MEDIAPIPE_DISABLE_GPU // Gets the shared mediapipe::GpuResources. Only valid once the graph is // running. mediapipe::GpuResources* GetGpuResources() const; - +#endif // !MEDIAPIPE_DISABLE_GPU // Adds a surface output for a given stream name. // Multiple outputs can be attached to the same stream. // Returns a native packet handle for the mediapipe::EglSurfaceHolder, or 0 in @@ -212,12 +212,13 @@ class Graph { // All callback handlers managed by the context. std::vector> callback_handlers_; +#if !MEDIAPIPE_DISABLE_GPU // mediapipe::GpuResources used by the graph. // Note: this class does not create a CalculatorGraph until StartRunningGraph // is called, and we may have to create the mediapipe::GpuResources before // that time, e.g. before a SurfaceOutput is associated with a Surface. std::shared_ptr gpu_resources_; - +#endif // !MEDIAPIPE_DISABLE_GPU // Maps surface output names to the side packet used for the associated // surface. std::unordered_map output_surface_side_packets_; diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc index 2b761bf60..41646a12c 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc @@ -96,7 +96,7 @@ JNIEXPORT void JNICALL GRAPH_METHOD(nativeLoadBinaryGraph)(JNIEnv* env, mediapipe::android::Graph* mediapipe_graph = reinterpret_cast(context); const char* path_ref = env->GetStringUTFChars(path, nullptr); - // Make a copy of the std::string and release the jni reference. + // Make a copy of the string and release the jni reference. std::string path_to_graph(path_ref); env->ReleaseStringUTFChars(path, path_ref); ThrowIfError(env, mediapipe_graph->LoadBinaryGraph(path_to_graph)); @@ -133,7 +133,7 @@ JNIEXPORT void JNICALL GRAPH_METHOD(nativeSetGraphType)(JNIEnv* env, mediapipe::android::Graph* mediapipe_graph = reinterpret_cast(context); const char* graph_type_ref = env->GetStringUTFChars(graph_type, nullptr); - // Make a copy of the std::string and release the jni reference. + // Make a copy of the string and release the jni reference. std::string graph_type_string(graph_type_ref); env->ReleaseStringUTFChars(graph_type, graph_type_ref); ThrowIfError(env, mediapipe_graph->SetGraphType(graph_type_string)); @@ -200,7 +200,7 @@ JNIEXPORT void JNICALL GRAPH_METHOD(nativeAddMultiStreamCallback)( if (s.empty()) { ThrowIfError(env, absl::InternalError("streamNames is not correctly parsed or " - "it contains empty std::string.")); + "it contains empty string.")); return; } } diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc index c9c8553fd..2701c7a5e 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc @@ -27,12 +27,12 @@ #include "mediapipe/framework/formats/video_stream_header.h" #include "mediapipe/framework/port/core_proto_inc.h" #include "mediapipe/framework/port/logging.h" -#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/graph.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h" #if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gpu_buffer.h" #endif // !MEDIAPIPE_DISABLE_GPU namespace { @@ -57,14 +57,15 @@ int64_t CreatePacketWithContext(jlong context, } #if !MEDIAPIPE_DISABLE_GPU -mediapipe::GpuBuffer CreateGpuBuffer(JNIEnv* env, jobject thiz, jlong context, - jint name, jint width, jint height, - jobject texture_release_callback) { +absl::StatusOr CreateGpuBuffer( + JNIEnv* env, jobject thiz, jlong context, jint name, jint width, + jint height, jobject texture_release_callback) { mediapipe::android::Graph* mediapipe_graph = reinterpret_cast(context); auto* gpu_resources = mediapipe_graph->GetGpuResources(); - CHECK(gpu_resources) << "Cannot create a mediapipe::GpuBuffer packet on a " - "graph without GPU support"; + RET_CHECK(gpu_resources) + << "Cannot create a mediapipe::GpuBuffer packet on a " + "graph without GPU support"; mediapipe::GlTextureBuffer::DeletionCallback cc_callback; if (texture_release_callback) { @@ -78,7 +79,7 @@ mediapipe::GpuBuffer CreateGpuBuffer(JNIEnv* env, jobject thiz, jlong context, "(JL" "com/google/mediapipe/framework/TextureReleaseCallback" ";)V"); - CHECK(release_method); + RET_CHECK(release_method); env->DeleteLocalRef(my_class); jobject java_callback = env->NewGlobalRef(texture_release_callback); @@ -400,18 +401,22 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGpuImage)( JNIEnv* env, jobject thiz, jlong context, jint name, jint width, jint height, jobject texture_release_callback) { - mediapipe::Packet image_packet = - mediapipe::MakePacket(CreateGpuBuffer( - env, thiz, context, name, width, height, texture_release_callback)); - return CreatePacketWithContext(context, image_packet); + auto buffer_or = CreateGpuBuffer(env, thiz, context, name, width, height, + texture_release_callback); + if (ThrowIfError(env, buffer_or.status())) return 0L; + mediapipe::Packet packet = + mediapipe::MakePacket(std::move(buffer_or).value()); + return CreatePacketWithContext(context, packet); } JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGpuBuffer)( JNIEnv* env, jobject thiz, jlong context, jint name, jint width, jint height, jobject texture_release_callback) { + auto buffer_or = CreateGpuBuffer(env, thiz, context, name, width, height, + texture_release_callback); + if (ThrowIfError(env, buffer_or.status())) return 0L; mediapipe::Packet packet = - mediapipe::MakePacket(CreateGpuBuffer( - env, thiz, context, name, width, height, texture_release_callback)); + mediapipe::MakePacket(std::move(buffer_or).value()); return CreatePacketWithContext(context, packet); } diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc index 664b19b76..0aca74b51 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc @@ -15,18 +15,19 @@ #include "mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h" #include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/time_series_header.pb.h" #include "mediapipe/framework/formats/video_stream_header.h" #include "mediapipe/framework/port/core_proto_inc.h" #include "mediapipe/framework/port/proto_ns.h" -#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/graph.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h" #if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gpu_buffer.h" #endif // !MEDIAPIPE_DISABLE_GPU namespace { @@ -437,7 +438,8 @@ JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetGpuBufferName)( // gpu_buffer.name() returns a GLuint. Make sure the cast to jint is safe. static_assert(sizeof(GLuint) <= sizeof(jint), "The cast to jint may truncate GLuint"); - return static_cast(gpu_buffer.GetGlTextureBufferSharedPtr()->name()); + return static_cast( + gpu_buffer.internal_storage()->name()); } JNIEXPORT jlong JNICALL PACKET_GETTER_METHOD(nativeGetGpuBuffer)( @@ -459,7 +461,7 @@ JNIEXPORT jlong JNICALL PACKET_GETTER_METHOD(nativeGetGpuBuffer)( } else { const mediapipe::GpuBuffer& buffer = mediapipe_packet.Get(); - ptr = buffer.GetGlTextureBufferSharedPtr(); + ptr = buffer.internal_storage(); } if (wait_on_cpu) { ptr->WaitUntilComplete(); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc b/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc index 2ddc42337..8ea37d9c5 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc @@ -130,12 +130,17 @@ void RegisterGraphNatives(JNIEnv *env) { AddJNINativeMethod(&graph_methods, graph, "nativeCloseAllPacketSources", "(J)V", (void *)&GRAPH_METHOD(nativeCloseAllPacketSources)); + AddJNINativeMethod(&graph_methods, graph, "nativeWaitUntilGraphIdle", "(J)V", + (void *)&GRAPH_METHOD(nativeWaitUntilGraphIdle)); AddJNINativeMethod(&graph_methods, graph, "nativeWaitUntilGraphDone", "(J)V", (void *)&GRAPH_METHOD(nativeWaitUntilGraphDone)); AddJNINativeMethod(&graph_methods, graph, "nativeReleaseGraph", "(J)V", (void *)&GRAPH_METHOD(nativeReleaseGraph)); AddJNINativeMethod(&graph_methods, graph, "nativeGetProfiler", "(J)J", (void *)&GRAPH_METHOD(nativeGetProfiler)); + AddJNINativeMethod(&graph_methods, graph, "nativeAddPacketToInputStream", + "(JLjava/lang/String;JJ)V", + (void *)&GRAPH_METHOD(nativeAddPacketToInputStream)); RegisterNativesVector(env, graph_class, graph_methods); env->DeleteLocalRef(graph_class); } @@ -229,6 +234,10 @@ void RegisterPacketCreatorNatives(JNIEnv *env) { AddJNINativeMethod(&packet_creator_methods, packet_creator, "nativeCreateString", "(JLjava/lang/String;)J", (void *)&PACKET_CREATOR_METHOD(nativeCreateString)); + AddJNINativeMethod( + &packet_creator_methods, packet_creator, + "nativeCreateStringFromByteArray", "(J[B)J", + (void *)&PACKET_CREATOR_METHOD(nativeCreateStringFromByteArray)); std::string serialized_message_name = class_registry.GetClassName( mediapipe::android::ClassRegistry::kProtoUtilSerializedMessageClassName); AddJNINativeMethod(&packet_creator_methods, packet_creator, diff --git a/mediapipe/models/object_detection_saved_model/README.md b/mediapipe/models/object_detection_saved_model/README.md index 15e206f95..6acac0a1b 100644 --- a/mediapipe/models/object_detection_saved_model/README.md +++ b/mediapipe/models/object_detection_saved_model/README.md @@ -14,7 +14,7 @@ The TFLite model is converted from the TensorFlow above. The steps needed to con * `model.ckpt.data-00000-of-00001` * `pipeline.config` -Make sure you have installed these [python libraries](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md). Then to get the frozen graph, run the `export_tflite_ssd_graph.py` script from the `models/research` directory with this command: +Make sure you have installed these [python libraries](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1.md). Then to get the frozen graph, run the `export_tflite_ssd_graph.py` script from the `models/research` directory with this command: ```bash $ PATH_TO_MODEL=path/to/the/model @@ -44,7 +44,7 @@ You should be able to see the input image size of the model is 320x320 and the o * `raw_outputs/box_encodings` * `raw_outputs/class_predictions` -The last step is to convert the model to TFLite. You can look at [this guide](https://www.tensorflow.org/lite/convert/cmdline_examples) for more detail. For this example, you just need to run: +The last step is to convert the model to TFLite. You can look at [this guide](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md) for more detail. For this example, you just need to run: ```bash $ tflite_convert -- \ diff --git a/mediapipe/modules/face_detection/face_detection_short_range_by_roi_gpu.pbtxt b/mediapipe/modules/face_detection/face_detection_short_range_by_roi_gpu.pbtxt index 1bd08e932..c35331e0e 100644 --- a/mediapipe/modules/face_detection/face_detection_short_range_by_roi_gpu.pbtxt +++ b/mediapipe/modules/face_detection/face_detection_short_range_by_roi_gpu.pbtxt @@ -1,5 +1,5 @@ -# MediaPipe graph to detect faces. (CPU input, and inference is executed on -# CPU.) +# MediaPipe graph to detect faces. (GPU input, and inference is executed on +# GPU.) # # It is required that "face_detection_short_range.tflite" is available at # "mediapipe/modules/face_detection/face_detection_short_range.tflite" diff --git a/mediapipe/modules/face_detection/face_detection_short_range_gpu.pbtxt b/mediapipe/modules/face_detection/face_detection_short_range_gpu.pbtxt index d30644b19..ce0d25b13 100644 --- a/mediapipe/modules/face_detection/face_detection_short_range_gpu.pbtxt +++ b/mediapipe/modules/face_detection/face_detection_short_range_gpu.pbtxt @@ -1,5 +1,5 @@ -# MediaPipe graph to detect faces. (CPU input, and inference is executed on -# CPU.) +# MediaPipe graph to detect faces. (GPU input, and inference is executed on +# GPU.) # # It is required that "face_detection_short_range.tflite" is available at # "mediapipe/modules/face_detection/face_detection_short_range.tflite" diff --git a/mediapipe/modules/hand_landmark/BUILD b/mediapipe/modules/hand_landmark/BUILD index b28dc785e..6e5c49390 100644 --- a/mediapipe/modules/hand_landmark/BUILD +++ b/mediapipe/modules/hand_landmark/BUILD @@ -164,7 +164,6 @@ mediapipe_simple_subgraph( graph = "hand_landmark_landmarks_to_roi.pbtxt", register_as = "HandLandmarkLandmarksToRoi", deps = [ - "//mediapipe/calculators/core:split_landmarks_calculator", "//mediapipe/calculators/util:rect_transformation_calculator", "//mediapipe/modules/hand_landmark/calculators:hand_landmarks_to_rect_calculator", ], diff --git a/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc b/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc index 3e3f5c8fa..6f2c49d64 100644 --- a/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc +++ b/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc @@ -24,13 +24,21 @@ namespace mediapipe { namespace { +// NORM_LANDMARKS is either the full set of landmarks for the hand, or +// a subset of the hand landmarks (indices 0, 1, 2, 3, 5, 6, 9, 10, 13, 14, +// 17 and 18). The latter is the legacy behavior, please just pass in +// the full set of hand landmarks. +// +// TODO: update clients to just pass all the landmarks in. constexpr char kNormalizedLandmarksTag[] = "NORM_LANDMARKS"; constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +// Indices within the partial landmarks. constexpr int kWristJoint = 0; constexpr int kMiddleFingerPIPJoint = 6; constexpr int kIndexFingerPIPJoint = 4; constexpr int kRingFingerPIPJoint = 8; +constexpr int kNumLandmarks = 21; constexpr float kTargetAngle = M_PI * 0.5f; inline float NormalizeRadians(float angle) { @@ -150,8 +158,7 @@ class HandLandmarksToRectCalculator : public CalculatorBase { std::pair image_size = cc->Inputs().Tag(kImageSizeTag).Get>(); - const auto& landmarks = - cc->Inputs().Tag(kNormalizedLandmarksTag).Get(); + const auto landmarks = GetPartialLandmarks(cc); auto output_rect = absl::make_unique(); MP_RETURN_IF_ERROR( NormalizedLandmarkListToRect(landmarks, image_size, output_rect.get())); @@ -161,6 +168,25 @@ class HandLandmarksToRectCalculator : public CalculatorBase { return absl::OkStatus(); } + + private: + NormalizedLandmarkList GetPartialLandmarks(CalculatorContext* cc) { + const auto& landmarks = + cc->Inputs().Tag(kNormalizedLandmarksTag).Get(); + if (landmarks.landmark_size() == kNumLandmarks) { + static constexpr int kPartialLandmarkIndices[]{0, 1, 2, 3, 5, 6, + 9, 10, 13, 14, 17, 18}; + NormalizedLandmarkList partial_landmarks; + for (int i : kPartialLandmarkIndices) { + *partial_landmarks.add_landmark() = landmarks.landmark(i); + } + return partial_landmarks; + } else { + // Assume the calculator is receiving the partial landmarks directly. + // This is the legacy behavior. + return landmarks; + } + } }; REGISTER_CALCULATOR(HandLandmarksToRectCalculator); diff --git a/mediapipe/modules/hand_landmark/hand_landmark_landmarks_to_roi.pbtxt b/mediapipe/modules/hand_landmark/hand_landmark_landmarks_to_roi.pbtxt index 1d82d7672..46000c193 100644 --- a/mediapipe/modules/hand_landmark/hand_landmark_landmarks_to_roi.pbtxt +++ b/mediapipe/modules/hand_landmark/hand_landmark_landmarks_to_roi.pbtxt @@ -11,35 +11,11 @@ input_stream: "IMAGE_SIZE:image_size" # ROI according to landmarks. (NormalizedRect) output_stream: "ROI:roi" -# Extracts a subset of the hand landmarks that are relatively more stable across -# frames (e.g. comparing to finger tips) for computing the bounding box. The box -# will later be expanded to contain the entire hand. In this approach, it is -# more robust to drastically changing hand size. -# The landmarks extracted are: wrist, MCP/PIP of five fingers. -node { - calculator: "SplitNormalizedLandmarkListCalculator" - input_stream: "landmarks" - output_stream: "partial_landmarks" - options: { - [mediapipe.SplitVectorCalculatorOptions.ext] { - ranges: { begin: 0 end: 4 } - ranges: { begin: 5 end: 7 } - ranges: { begin: 9 end: 11 } - ranges: { begin: 13 end: 15 } - ranges: { begin: 17 end: 19 } - combine_outputs: true - } - } -} - # Converts the hand landmarks into a rectangle (normalized by image size) -# that encloses the hand. The calculator uses a subset of all hand landmarks -# extracted from SplitNormalizedLandmarkListCalculator above to -# calculate the bounding box and the rotation of the output rectangle. Please -# see the comments in the calculator for more detail. +# that encloses the hand. node { calculator: "HandLandmarksToRectCalculator" - input_stream: "NORM_LANDMARKS:partial_landmarks" + input_stream: "NORM_LANDMARKS:landmarks" input_stream: "IMAGE_SIZE:image_size" output_stream: "NORM_RECT:hand_rect_from_landmarks" } diff --git a/mediapipe/modules/holistic_landmark/BUILD b/mediapipe/modules/holistic_landmark/BUILD index 44854c0d9..6c09eb0d4 100644 --- a/mediapipe/modules/holistic_landmark/BUILD +++ b/mediapipe/modules/holistic_landmark/BUILD @@ -31,7 +31,7 @@ mediapipe_simple_subgraph( ":face_detection_front_detections_to_roi", ":face_landmarks_from_pose_to_recrop_roi", ":face_tracking", - "//mediapipe/calculators/core:split_landmarks_calculator", + "//mediapipe/calculators/core:split_proto_list_calculator", "//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/modules/face_detection:face_detection_short_range_by_roi_gpu", "//mediapipe/modules/face_landmark:face_landmark_gpu", @@ -46,7 +46,7 @@ mediapipe_simple_subgraph( ":face_detection_front_detections_to_roi", ":face_landmarks_from_pose_to_recrop_roi", ":face_tracking", - "//mediapipe/calculators/core:split_landmarks_calculator", + "//mediapipe/calculators/core:split_proto_list_calculator", "//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/modules/face_detection:face_detection_short_range_by_roi_cpu", "//mediapipe/modules/face_landmark:face_landmark_cpu", @@ -131,7 +131,6 @@ mediapipe_simple_subgraph( graph = "hand_landmarks_to_roi.pbtxt", register_as = "HandLandmarksToRoi", deps = [ - "//mediapipe/calculators/core:split_landmarks_calculator", "//mediapipe/calculators/util:rect_transformation_calculator", "//mediapipe/modules/hand_landmark/calculators:hand_landmarks_to_rect_calculator", ], @@ -191,7 +190,7 @@ mediapipe_simple_subgraph( deps = [ "//mediapipe/calculators/core:constant_side_packet_calculator", "//mediapipe/calculators/core:side_packet_to_stream_calculator", - "//mediapipe/calculators/core:split_landmarks_calculator", + "//mediapipe/calculators/core:split_proto_list_calculator", "//mediapipe/calculators/util:set_landmark_visibility_calculator", ], ) @@ -202,7 +201,7 @@ mediapipe_simple_subgraph( register_as = "HandLandmarksLeftAndRightGpu", deps = [ ":hand_landmarks_from_pose_gpu", - "//mediapipe/calculators/core:split_landmarks_calculator", + "//mediapipe/calculators/core:split_proto_list_calculator", ], ) @@ -212,7 +211,7 @@ mediapipe_simple_subgraph( register_as = "HandLandmarksLeftAndRightCpu", deps = [ ":hand_landmarks_from_pose_cpu", - "//mediapipe/calculators/core:split_landmarks_calculator", + "//mediapipe/calculators/core:split_proto_list_calculator", ], ) @@ -232,7 +231,7 @@ mediapipe_simple_subgraph( graph = "hand_visibility_from_hand_landmarks_from_pose.pbtxt", register_as = "HandVisibilityFromHandLandmarksFromPose", deps = [ - "//mediapipe/calculators/core:split_landmarks_calculator", + "//mediapipe/calculators/core:split_proto_list_calculator", "//mediapipe/calculators/util:landmark_visibility_calculator", "//mediapipe/calculators/util:thresholding_calculator", ], diff --git a/mediapipe/modules/holistic_landmark/hand_landmarks_to_roi.pbtxt b/mediapipe/modules/holistic_landmark/hand_landmarks_to_roi.pbtxt index b874c1d40..71c272a9d 100644 --- a/mediapipe/modules/holistic_landmark/hand_landmarks_to_roi.pbtxt +++ b/mediapipe/modules/holistic_landmark/hand_landmarks_to_roi.pbtxt @@ -10,31 +10,11 @@ input_stream: "IMAGE_SIZE:image_size" # ROI according to the hand landmarks. (NormalizedRect) output_stream: "ROI:roi" -# Gets hand palm landmarks. -node { - calculator: "SplitNormalizedLandmarkListCalculator" - input_stream: "hand_landmarks" - output_stream: "palm_landmarks" - options: { - [mediapipe.SplitVectorCalculatorOptions.ext] { - ranges: { begin: 0 end: 4 } - ranges: { begin: 5 end: 7 } - ranges: { begin: 9 end: 11 } - ranges: { begin: 13 end: 15 } - ranges: { begin: 17 end: 19 } - combine_outputs: true - } - } -} - # Converts the hand landmarks into a rectangle (normalized by image size) -# that encloses the hand. The calculator uses a subset of all hand landmarks -# extracted from SplitNormalizedLandmarkListCalculator above to -# calculate the bounding box and the rotation of the output rectangle. Please -# see the comments in the calculator for more detail. +# that encloses the hand. node { calculator: "HandLandmarksToRectCalculator" - input_stream: "NORM_LANDMARKS:palm_landmarks" + input_stream: "NORM_LANDMARKS:hand_landmarks" input_stream: "IMAGE_SIZE:image_size" output_stream: "NORM_RECT:palm_landmarks_rect" } diff --git a/mediapipe/modules/objectron/calculators/BUILD b/mediapipe/modules/objectron/calculators/BUILD index fb75eb3a7..eeeaee5f4 100644 --- a/mediapipe/modules/objectron/calculators/BUILD +++ b/mediapipe/modules/objectron/calculators/BUILD @@ -382,6 +382,23 @@ cc_library( alwayslink = 1, ) +cc_test( + name = "filter_detection_calculator_test", + srcs = ["filter_detection_calculator_test.cc"], + deps = [ + ":filter_detection_calculator", # build_cleaner: keep + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:packet", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/util:packet_test_util", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "box_util_test", srcs = ["box_util_test.cc"], diff --git a/mediapipe/modules/objectron/calculators/filter_detection_calculator.cc b/mediapipe/modules/objectron/calculators/filter_detection_calculator.cc index 0f29f9ca8..db0f27484 100644 --- a/mediapipe/modules/objectron/calculators/filter_detection_calculator.cc +++ b/mediapipe/modules/objectron/calculators/filter_detection_calculator.cc @@ -42,33 +42,6 @@ using mediapipe::RE2; using Detections = std::vector; using Strings = std::vector; -} // namespace - -// Filters the entries in a Detection to only those with valid scores -// for the specified allowed labels. Allowed labels are provided as a -// vector in an optional input side packet. Allowed labels can -// contain simple strings or regular expressions. The valid score range -// can be set in the options.The allowed labels can be provided as -// vector (LABELS) or CSV std::string (LABELS_CSV) containing class -// names of allowed labels. Note: Providing an empty vector in the input side -// packet Packet causes this calculator to act as a sink if -// empty_allowed_labels_means_allow_everything is set to false (default value). -// To allow all labels, use the calculator with no input side packet stream, or -// set empty_allowed_labels_means_allow_everything to true. -// -// Example config: -// node { -// calculator: "FilterDetectionCalculator" -// input_stream: "DETECTIONS:detections" -// output_stream: "DETECTIONS:filtered_detections" -// input_side_packet: "LABELS:allowed_labels" -// options: { -// [mediapipe.FilterDetectionCalculatorOptions.ext]: { -// min_score: 0.5 -// } -// } -// } - struct FirstGreaterComparator { bool operator()(const std::pair& a, const std::pair& b) const { @@ -112,6 +85,33 @@ absl::Status SortLabelsByDecreasingScore(const Detection& detection, return absl::OkStatus(); } +} // namespace + +// Filters the entries in a Detection to only those with valid scores +// for the specified allowed labels. Allowed labels are provided as a +// std::vector in an optional input side packet. Allowed labels can +// contain simple strings or regular expressions. The valid score range +// can be set in the options.The allowed labels can be provided as +// std::vector (LABELS) or CSV string (LABELS_CSV) containing class +// names of allowed labels. Note: Providing an empty vector in the input side +// packet Packet causes this calculator to act as a sink if +// empty_allowed_labels_means_allow_everything is set to false (default value). +// To allow all labels, use the calculator with no input side packet stream, or +// set empty_allowed_labels_means_allow_everything to true. +// +// Example config: +// node { +// calculator: "FilterDetectionCalculator" +// input_stream: "DETECTIONS:detections" +// output_stream: "DETECTIONS:filtered_detections" +// input_side_packet: "LABELS:allowed_labels" +// options: { +// [mediapipe.FilterDetectionCalculatorOptions.ext]: { +// min_score: 0.5 +// } +// } +// } + class FilterDetectionCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc); @@ -196,7 +196,7 @@ absl::Status FilterDetectionCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().HasTag(kDetectionsTag)) { detections = cc->Inputs().Tag(kDetectionsTag).Get(); } else if (cc->Inputs().HasTag(kDetectionTag)) { - detections.emplace_back(cc->Inputs().Tag(kDetectionsTag).Get()); + detections.emplace_back(cc->Inputs().Tag(kDetectionTag).Get()); } std::unique_ptr outputs(new Detections); for (const auto& input : detections) { @@ -229,7 +229,7 @@ absl::Status FilterDetectionCalculator::Process(CalculatorContext* cc) { .Add(outputs.release(), cc->InputTimestamp()); } else if (!outputs->empty()) { cc->Outputs() - .Tag(kDetectionsTag) + .Tag(kDetectionTag) .Add(new Detection((*outputs)[0]), cc->InputTimestamp()); } return absl::OkStatus(); diff --git a/mediapipe/modules/objectron/calculators/filter_detection_calculator_test.cc b/mediapipe/modules/objectron/calculators/filter_detection_calculator_test.cc new file mode 100644 index 000000000..958fe4c54 --- /dev/null +++ b/mediapipe/modules/objectron/calculators/filter_detection_calculator_test.cc @@ -0,0 +1,71 @@ +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_options.pb.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/util/packet_test_util.h" +#include "testing/base/public/gmock.h" +#include "testing/base/public/gunit.h" + +namespace mediapipe { +namespace { + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::EqualsProto; + +TEST(FilterDetectionCalculatorTest, DetectionFilterTest) { + auto runner = std::make_unique( + ParseTextProtoOrDie(R"pb( + calculator: "FilterDetectionCalculator" + input_stream: "DETECTION:input" + output_stream: "DETECTION:output" + options { + [mediapipe.FilterDetectionCalculatorOptions.ext]: { min_score: 0.6 } + } + )pb")); + + runner->MutableInputs()->Tag("DETECTION").packets = { + MakePacket(ParseTextProtoOrDie(R"pb( + label: "a" + label: "b" + label: "c" + score: 1 + score: 0.8 + score: 0.3 + )pb")) + .At(Timestamp(20)), + MakePacket(ParseTextProtoOrDie(R"pb( + label: "a" + label: "b" + label: "c" + score: 0.6 + score: 0.4 + score: 0.2 + )pb")) + .At(Timestamp(40)), + }; + + // Run graph. + MP_ASSERT_OK(runner->Run()); + + // Check output. + EXPECT_THAT( + runner->Outputs().Tag("DETECTION").packets, + ElementsAre(PacketContainsTimestampAndPayload( + Eq(Timestamp(20)), + EqualsProto(R"pb( + label: "a" label: "b" score: 1 score: 0.8 + )pb")), // Packet 1 at timestamp 20. + PacketContainsTimestampAndPayload( + Eq(Timestamp(40)), + EqualsProto(R"pb( + label: "a" score: 0.6 + )pb")) // Packet 2 at timestamp 40. + )); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/modules/pose_landmark/BUILD b/mediapipe/modules/pose_landmark/BUILD index 787f0e2a1..424579a46 100644 --- a/mediapipe/modules/pose_landmark/BUILD +++ b/mediapipe/modules/pose_landmark/BUILD @@ -67,7 +67,7 @@ mediapipe_simple_subgraph( register_as = "TensorsToPoseLandmarksAndSegmentation", deps = [ "//mediapipe/calculators/core:gate_calculator", - "//mediapipe/calculators/core:split_landmarks_calculator", + "//mediapipe/calculators/core:split_proto_list_calculator", "//mediapipe/calculators/core:split_vector_calculator", "//mediapipe/calculators/tensor:tensors_to_floats_calculator", "//mediapipe/calculators/tensor:tensors_to_landmarks_calculator", diff --git a/mediapipe/modules/selfie_segmentation/selfie_segmentation_model_loader.pbtxt b/mediapipe/modules/selfie_segmentation/selfie_segmentation_model_loader.pbtxt index 39495f80d..e4b4e7cc2 100644 --- a/mediapipe/modules/selfie_segmentation/selfie_segmentation_model_loader.pbtxt +++ b/mediapipe/modules/selfie_segmentation/selfie_segmentation_model_loader.pbtxt @@ -2,9 +2,12 @@ type: "SelfieSegmentationModelLoader" -# An integer 0 or 1. Use 0 to select a general-purpose model (operating on a -# 256x256 tensor), and 1 to select a model (operating on a 256x144 tensor) more -# optimized for landscape images. If unspecified, functions as set to 0. (int) +# model_selection is an integer. +# Use 0 to select a general-purpose model (operating on a 256x256 tensor). +# Use 1 to select a model (operating on a 256x144 tensor) more optimized for +# landscape images. +# +# If unspecified, 0 is selected by default. input_side_packet: "MODEL_SELECTION:model_selection" # TF Lite model represented as a FlatBuffer. @@ -39,6 +42,7 @@ node { } } } + # } } } diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index f5a3b8e1a..75d74f06f 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -11,6 +11,7 @@ cc_library( "//mediapipe:apple": ["CFHolder.h"], "//conditions:default": [], }), + features = ["layering_check"], visibility = ["//mediapipe/framework:mediapipe_internal"], ) @@ -18,7 +19,10 @@ cc_library( name = "util", srcs = ["util.cc"], hdrs = ["util.h"], - features = ["-parse_headers"], + features = [ + "-parse_headers", + "layering_check", + ], linkopts = [ "-framework Accelerate", "-framework CoreFoundation", @@ -123,6 +127,13 @@ objc_library( "CoreVideo", ], visibility = ["//mediapipe/framework:mediapipe_internal"], + deps = [ + ] + select({ + "//mediapipe:ios": [ + ], + "//mediapipe:macos": [ + ], + }), ) objc_library( diff --git a/mediapipe/objc/MPPGraph.h b/mediapipe/objc/MPPGraph.h index c9c06cd36..31381c4b5 100644 --- a/mediapipe/objc/MPPGraph.h +++ b/mediapipe/objc/MPPGraph.h @@ -24,10 +24,6 @@ @class MPPGraph; -namespace mediapipe { -struct GpuSharedData; -} // namespace mediapipe - /// A delegate that can receive frames from a MediaPipe graph. @protocol MPPGraphDelegate diff --git a/mediapipe/objc/MPPGraph.mm b/mediapipe/objc/MPPGraph.mm index bc9eff69f..67d71720e 100644 --- a/mediapipe/objc/MPPGraph.mm +++ b/mediapipe/objc/MPPGraph.mm @@ -169,7 +169,7 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, wrapper->_framesInFlight--; CVPixelBufferRef pixelBuffer; if (packetType == MPPPacketTypePixelBuffer) - pixelBuffer = packet.Get().GetCVPixelBufferRef(); + pixelBuffer = mediapipe::GetCVPixelBufferRef(packet.Get()); else pixelBuffer = packet.Get().GetCVPixelBufferRef(); if ([wrapper.delegate diff --git a/mediapipe/objc/util.cc b/mediapipe/objc/util.cc index 15d84cd07..895463060 100644 --- a/mediapipe/objc/util.cc +++ b/mediapipe/objc/util.cc @@ -279,9 +279,12 @@ absl::StatusOr> CreateCVPixelBufferWithoutPool( return MakeCFHolderAdopting(buffer); } -void ReleaseMediaPipePacket(void* refcon, const void* base_address) { - auto packet = (mediapipe::Packet*)refcon; - delete packet; +/// When storing a shared_ptr in a CVPixelBuffer's refcon, this can be +/// used as a CVPixelBufferReleaseBytesCallback. This keeps the data +/// alive while the CVPixelBuffer is in use. +static void ReleaseSharedPtr(void* refcon, const void* base_address) { + auto ptr = (std::shared_ptr*)refcon; + delete ptr; } CVPixelBufferRef CreateCVPixelBufferForImageFramePacket( @@ -307,10 +310,18 @@ absl::Status CreateCVPixelBufferForImageFramePacket( return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "out_buffer cannot be NULL"; } - CFHolder pixel_buffer; + auto image_frame = std::const_pointer_cast( + mediapipe::SharedPtrWithPacket( + image_frame_packet)); + ASSIGN_OR_RETURN(*out_buffer, CreateCVPixelBufferForImageFrame( + image_frame, can_overwrite)); + return absl::OkStatus(); +} - auto packet_copy = absl::make_unique(image_frame_packet); - const auto& frame = packet_copy->Get(); +absl::StatusOr> CreateCVPixelBufferForImageFrame( + std::shared_ptr image_frame, bool can_overwrite) { + CFHolder pixel_buffer; + const auto& frame = *image_frame; void* frame_data = const_cast(reinterpret_cast(frame.PixelData())); @@ -366,17 +377,18 @@ absl::Status CreateCVPixelBufferForImageFramePacket( << "CVPixelBufferUnlockBaseAddress failed: " << status; } else { CVPixelBufferRef pixel_buffer_temp; + auto holder = absl::make_unique>(image_frame); status = CVPixelBufferCreateWithBytes( NULL, frame.Width(), frame.Height(), pixel_format, frame_data, - frame.WidthStep(), ReleaseMediaPipePacket, packet_copy.release(), + frame.WidthStep(), ReleaseSharedPtr, holder.get(), GetCVPixelBufferAttributesForGlCompatibility(), &pixel_buffer_temp); RET_CHECK(status == kCVReturnSuccess) << "failed to create pixel buffer: " << status; + holder.release(); // will be deleted by ReleaseSharedPtr pixel_buffer.adopt(pixel_buffer_temp); } - *out_buffer = pixel_buffer; - return absl::OkStatus(); + return pixel_buffer; } absl::StatusOr> CreateCVPixelBufferCopyingImageFrame( diff --git a/mediapipe/objc/util.h b/mediapipe/objc/util.h index 55d60da58..446b57c8c 100644 --- a/mediapipe/objc/util.h +++ b/mediapipe/objc/util.h @@ -59,11 +59,6 @@ vImage_Error vImageRGBAToGray(const vImage_Buffer* src, vImage_Buffer* dst); vImage_Error vImageConvertCVPixelBuffers(CVPixelBufferRef src, CVPixelBufferRef dst); -/// When storing a mediapipe::Packet* in a CVPixelBuffer's refcon, this can be -/// used as a CVPixelBufferReleaseBytesCallback. This keeps the packet's data -/// alive while the CVPixelBuffer is in use. -void ReleaseMediaPipePacket(void* refcon, const void* base_address); - // Create a CVPixelBuffer without using a pool. See pixel_buffer_pool_util.h // for creation functions that use pools. CVReturn CreateCVPixelBufferWithoutPool(int width, int height, OSType cv_format, @@ -88,6 +83,9 @@ absl::Status CreateCVPixelBufferForImageFramePacket( CFHolder* out_buffer); absl::StatusOr> CreateCVPixelBufferCopyingImageFrame( const mediapipe::ImageFrame& image_frame); +absl::StatusOr> CreateCVPixelBufferForImageFrame( + std::shared_ptr image_frame, + bool can_overwrite = false); /// Creates a CVPixelBuffer with a copy of the contents of the CGImage. absl::Status CreateCVPixelBufferFromCGImage( diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index 42ce07f63..b1b96c31f 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -59,7 +59,7 @@ cc_library( "//mediapipe/calculators/core:gate_calculator", "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/calculators/core:side_packet_to_stream_calculator", - "//mediapipe/calculators/core:split_landmarks_calculator", + "//mediapipe/calculators/core:split_proto_list_calculator", "//mediapipe/calculators/core:string_to_int_calculator", "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/util:detection_unique_id_calculator", diff --git a/mediapipe/python/pybind/calculator_graph.cc b/mediapipe/python/pybind/calculator_graph.cc index 431c477cf..b017ca38b 100644 --- a/mediapipe/python/pybind/calculator_graph.cc +++ b/mediapipe/python/pybind/calculator_graph.cc @@ -392,7 +392,7 @@ void CalculatorGraphSubmodule(pybind11::module* module) { } return std::string(); }, - R"doc(Combines error messages as a single std::string. + R"doc(Combines error messages as a single string. Examples: if graph.has_error(): diff --git a/mediapipe/python/pybind/packet_creator.cc b/mediapipe/python/pybind/packet_creator.cc index 5cc66a310..ef7b70194 100644 --- a/mediapipe/python/pybind/packet_creator.cc +++ b/mediapipe/python/pybind/packet_creator.cc @@ -80,13 +80,13 @@ void PublicPacketCreators(pybind11::module* m) { m->def( "create_string", [](const std::string& data) { return MakePacket(data); }, - R"doc(Create a MediaPipe std::string Packet from a str. + R"doc(Create a MediaPipe string Packet from a str. Args: data: A str. Returns: - A MediaPipe std::string Packet. + A MediaPipe string Packet. Raises: TypeError: If the input is not a str. @@ -100,13 +100,13 @@ void PublicPacketCreators(pybind11::module* m) { m->def( "create_string", [](const py::bytes& data) { return MakePacket(data); }, - R"doc(Create a MediaPipe std::string Packet from a bytes object. + R"doc(Create a MediaPipe string Packet from a bytes object. Args: data: A bytes object. Returns: - A MediaPipe std::string Packet. + A MediaPipe string Packet. Raises: TypeError: If the input is not a bytes object. @@ -498,13 +498,13 @@ void PublicPacketCreators(pybind11::module* m) { [](const std::vector& data) { return MakePacket>(data); }, - R"doc(Create a MediaPipe std::string vector Packet from a list of str. + R"doc(Create a MediaPipe string vector Packet from a list of str. Args: data: A list of str. Returns: - A MediaPipe std::string vector Packet. + A MediaPipe string vector Packet. Raises: TypeError: If the input is not a list of str. @@ -546,7 +546,7 @@ void PublicPacketCreators(pybind11::module* m) { [](const std::map& data) { return MakePacket>(data); }, - R"doc(Create a MediaPipe std::string to packet map Packet from a dictionary. + R"doc(Create a MediaPipe string to packet map Packet from a dictionary. Args: data: A dictionary that has (str, Packet) pairs. @@ -561,7 +561,7 @@ void PublicPacketCreators(pybind11::module* m) { dict_packet = mp.packet_creator.create_string_to_packet_map({ 'float': mp.packet_creator.create_float(0.1), 'int': mp.packet_creator.create_int(1), - 'std::string': mp.packet_creator.create_string('1') + 'string': mp.packet_creator.create_string('1') data = mp.packet_getter.get_str_to_packet_dict(dict_packet) )doc", py::arg().noconvert(), py::return_value_policy::move); diff --git a/mediapipe/python/pybind/packet_getter.cc b/mediapipe/python/pybind/packet_getter.cc index 8a4c98d64..0abe928a5 100644 --- a/mediapipe/python/pybind/packet_getter.cc +++ b/mediapipe/python/pybind/packet_getter.cc @@ -42,16 +42,16 @@ namespace py = pybind11; void PublicPacketGetters(pybind11::module* m) { m->def("get_str", &GetContent, - R"doc(Get the content of a MediaPipe std::string Packet as a str. + R"doc(Get the content of a MediaPipe string Packet as a str. Args: - packet: A MediaPipe std::string Packet. + packet: A MediaPipe string Packet. Returns: A str. Raises: - ValueError: If the Packet doesn't contain std::string data. + ValueError: If the Packet doesn't contain string data. Examples: packet = mp.packet_creator.create_string('abc') @@ -63,16 +63,16 @@ void PublicPacketGetters(pybind11::module* m) { [](const Packet& packet) { return py::bytes(GetContent(packet)); }, - R"doc(Get the content of a MediaPipe std::string Packet as a bytes object. + R"doc(Get the content of a MediaPipe string Packet as a bytes object. Args: - packet: A MediaPipe std::string Packet. + packet: A MediaPipe string Packet. Returns: A bytes object. Raises: - ValueError: If the Packet doesn't contain std::string data. + ValueError: If the Packet doesn't contain string data. Examples: packet = mp.packet_creator.create_string(b'\xd0\xd0\xd0') @@ -266,7 +266,7 @@ void PublicPacketGetters(pybind11::module* m) { m->def( "get_str_list", &GetContent>, - R"doc(Get the content of a MediaPipe std::string vector Packet as a str list. + R"doc(Get the content of a MediaPipe string vector Packet as a str list. Args: packet: A MediaPipe Packet that holds std:vector. @@ -322,7 +322,7 @@ void PublicPacketGetters(pybind11::module* m) { dict_packet = mp.packet_creator.create_string_to_packet_map({ 'float': packet_creator.create_float(0.1), 'int': packet_creator.create_int(1), - 'std::string': packet_creator.create_string('1') + 'string': packet_creator.create_string('1') data = mp.packet_getter.get_str_to_packet_dict(dict_packet) )doc"); @@ -418,11 +418,11 @@ void InternalPacketGetters(pybind11::module* m) { m->def( "_get_serialized_proto", [](const Packet& packet) { - // By default, py::bytes is an extra copy of the original std::string - // object: https://github.com/pybind/pybind11/issues/1236 However, when - // Pybind11 performs the C++ to Python transition, it only increases the - // py::bytes object's ref count. See the implmentation at line 1583 in - // "pybind11/cast.h". + // By default, py::bytes is an extra copy of the original string object: + // https://github.com/pybind/pybind11/issues/1236 + // However, when Pybind11 performs the C++ to Python transition, it + // only increases the py::bytes object's ref count. See the + // implmentation at line 1583 in "pybind11/cast.h". return py::bytes(packet.GetProtoMessageLite().SerializeAsString()); }, py::return_value_policy::move); diff --git a/mediapipe/python/solution_base.py b/mediapipe/python/solution_base.py index b46a13209..d4b9a943a 100644 --- a/mediapipe/python/solution_base.py +++ b/mediapipe/python/solution_base.py @@ -79,9 +79,16 @@ CALCULATOR_TO_OPTIONS = { } +def type_names_from_oneof(oneof_type_name: str) -> Optional[List[str]]: + if oneof_type_name.startswith('OneOf<') and oneof_type_name.endswith('>'): + comma_separated_types = oneof_type_name[len('OneOf<'):-len('>')] + return [n.strip() for n in comma_separated_types.split(',')] + return None + + # TODO: Support more packet data types, such as "Any" type. @enum.unique -class _PacketDataType(enum.Enum): +class PacketDataType(enum.Enum): """The packet data types supported by the SolutionBase class.""" STRING = 'string' BOOL = 'bool' @@ -96,79 +103,86 @@ class _PacketDataType(enum.Enum): PROTO_LIST = 'proto_list' @staticmethod - def from_registered_name(registered_name: str) -> '_PacketDataType': - return NAME_TO_TYPE[registered_name] + def from_registered_name(registered_name: str) -> 'PacketDataType': + try: + return NAME_TO_TYPE[registered_name] + except KeyError as e: + names = type_names_from_oneof(registered_name) + if names: + for n in names: + if n in NAME_TO_TYPE.keys(): + return NAME_TO_TYPE[n] + raise e - -NAME_TO_TYPE: Mapping[str, '_PacketDataType'] = { +NAME_TO_TYPE: Mapping[str, 'PacketDataType'] = { 'string': - _PacketDataType.STRING, + PacketDataType.STRING, 'bool': - _PacketDataType.BOOL, + PacketDataType.BOOL, '::std::vector': - _PacketDataType.BOOL_LIST, + PacketDataType.BOOL_LIST, 'int': - _PacketDataType.INT, + PacketDataType.INT, 'float': - _PacketDataType.FLOAT, + PacketDataType.FLOAT, '::std::vector': - _PacketDataType.FLOAT_LIST, + PacketDataType.FLOAT_LIST, '::mediapipe::Matrix': - _PacketDataType.AUDIO, + PacketDataType.AUDIO, '::mediapipe::ImageFrame': - _PacketDataType.IMAGE_FRAME, + PacketDataType.IMAGE_FRAME, '::mediapipe::Classification': - _PacketDataType.PROTO, + PacketDataType.PROTO, '::mediapipe::ClassificationList': - _PacketDataType.PROTO, + PacketDataType.PROTO, '::mediapipe::ClassificationListCollection': - _PacketDataType.PROTO, + PacketDataType.PROTO, '::mediapipe::Detection': - _PacketDataType.PROTO, + PacketDataType.PROTO, '::mediapipe::DetectionList': - _PacketDataType.PROTO, + PacketDataType.PROTO, '::mediapipe::Landmark': - _PacketDataType.PROTO, + PacketDataType.PROTO, '::mediapipe::LandmarkList': - _PacketDataType.PROTO, + PacketDataType.PROTO, '::mediapipe::LandmarkListCollection': - _PacketDataType.PROTO, + PacketDataType.PROTO, '::mediapipe::NormalizedLandmark': - _PacketDataType.PROTO, + PacketDataType.PROTO, '::mediapipe::FrameAnnotation': - _PacketDataType.PROTO, + PacketDataType.PROTO, '::mediapipe::Trigger': - _PacketDataType.PROTO, + PacketDataType.PROTO, '::mediapipe::Rect': - _PacketDataType.PROTO, + PacketDataType.PROTO, '::mediapipe::NormalizedRect': - _PacketDataType.PROTO, + PacketDataType.PROTO, '::mediapipe::NormalizedLandmarkList': - _PacketDataType.PROTO, + PacketDataType.PROTO, '::mediapipe::NormalizedLandmarkListCollection': - _PacketDataType.PROTO, + PacketDataType.PROTO, '::mediapipe::Image': - _PacketDataType.IMAGE, + PacketDataType.IMAGE, '::std::vector<::mediapipe::Classification>': - _PacketDataType.PROTO_LIST, + PacketDataType.PROTO_LIST, '::std::vector<::mediapipe::ClassificationList>': - _PacketDataType.PROTO_LIST, + PacketDataType.PROTO_LIST, '::std::vector<::mediapipe::Detection>': - _PacketDataType.PROTO_LIST, + PacketDataType.PROTO_LIST, '::std::vector<::mediapipe::DetectionList>': - _PacketDataType.PROTO_LIST, + PacketDataType.PROTO_LIST, '::std::vector<::mediapipe::Landmark>': - _PacketDataType.PROTO_LIST, + PacketDataType.PROTO_LIST, '::std::vector<::mediapipe::LandmarkList>': - _PacketDataType.PROTO_LIST, + PacketDataType.PROTO_LIST, '::std::vector<::mediapipe::NormalizedLandmark>': - _PacketDataType.PROTO_LIST, + PacketDataType.PROTO_LIST, '::std::vector<::mediapipe::NormalizedLandmarkList>': - _PacketDataType.PROTO_LIST, + PacketDataType.PROTO_LIST, '::std::vector<::mediapipe::Rect>': - _PacketDataType.PROTO_LIST, + PacketDataType.PROTO_LIST, '::std::vector<::mediapipe::NormalizedRect>': - _PacketDataType.PROTO_LIST, + PacketDataType.PROTO_LIST, } @@ -196,7 +210,8 @@ class SolutionBase: graph_config: Optional[calculator_pb2.CalculatorGraphConfig] = None, calculator_params: Optional[Mapping[str, Any]] = None, side_inputs: Optional[Mapping[str, Any]] = None, - outputs: Optional[List[str]] = None): + outputs: Optional[List[str]] = None, + stream_type_hints: Optional[Mapping[str, PacketDataType]] = None): """Initializes the SolutionBase object. Args: @@ -209,6 +224,7 @@ class SolutionBase: outputs: A list of the graph output stream names to observe. If the list is empty, all the output streams listed in the graph config will be automatically observed by default. + stream_type_hints: A mapping from the stream name to its packet type hint. Raises: FileNotFoundError: If the binary graph file can't be found. @@ -240,7 +256,7 @@ class SolutionBase: validated_graph.initialize(graph_config=graph_config) canonical_graph_config_proto = self._initialize_graph_interface( - validated_graph, side_inputs, outputs) + validated_graph, side_inputs, outputs, stream_type_hints) if calculator_params: self._modify_calculator_options(canonical_graph_config_proto, calculator_params) @@ -310,15 +326,15 @@ class SolutionBase: self._simulated_timestamp += 33333 for stream_name, data in input_dict.items(): input_stream_type = self._input_stream_type_info[stream_name] - if (input_stream_type == _PacketDataType.PROTO_LIST or - input_stream_type == _PacketDataType.AUDIO): + if (input_stream_type == PacketDataType.PROTO_LIST or + input_stream_type == PacketDataType.AUDIO): # TODO: Support audio data. raise NotImplementedError( f'SolutionBase can only process non-audio and non-proto-list data. ' f'{self._input_stream_type_info[stream_name].name} ' f'type is not supported yet.') - elif (input_stream_type == _PacketDataType.IMAGE_FRAME or - input_stream_type == _PacketDataType.IMAGE): + elif (input_stream_type == PacketDataType.IMAGE_FRAME or + input_stream_type == PacketDataType.IMAGE): if data.shape[2] != RGB_CHANNELS: raise ValueError('Input image must contain three channel rgb data.') self._graph.add_packet_to_input_stream( @@ -364,7 +380,8 @@ class SolutionBase: self, validated_graph: validated_graph_config.ValidatedGraphConfig, side_inputs: Optional[Mapping[str, Any]] = None, - outputs: Optional[List[str]] = None): + outputs: Optional[List[str]] = None, + stream_type_hints: Optional[Mapping[str, PacketDataType]] = None): """Gets graph interface type information and returns the canonical graph config proto.""" canonical_graph_config_proto = calculator_pb2.CalculatorGraphConfig() @@ -375,13 +392,16 @@ class SolutionBase: return tag_index_name.split(':')[-1] # Gets the packet type information of the input streams and output streams - # from the validated calculator graph. The mappings from the stream names to - # the packet data types is for deciding which packet creator and getter - # methods to call in the process() method. + # from the user provided stream_type_hints field or validated calculator + # graph. The mappings from the stream names to the packet data types is + # for deciding which packet creator and getter methods to call in the + # process() method. def get_stream_packet_type(packet_tag_index_name): - return _PacketDataType.from_registered_name( - validated_graph.registered_stream_type_name( - get_name(packet_tag_index_name))) + stream_name = get_name(packet_tag_index_name) + if stream_type_hints and stream_name in stream_type_hints.keys(): + return stream_type_hints[stream_name] + return PacketDataType.from_registered_name( + validated_graph.registered_stream_type_name(stream_name)) self._input_stream_type_info = { get_name(tag_index_name): get_stream_packet_type(tag_index_name) @@ -402,7 +422,7 @@ class SolutionBase: # packet data types is for making the input_side_packets dict for graph # start_run(). def get_side_packet_type(packet_tag_index_name): - return _PacketDataType.from_registered_name( + return PacketDataType.from_registered_name( validated_graph.registered_side_packet_type_name( get_name(packet_tag_index_name))) @@ -503,16 +523,16 @@ class SolutionBase: if num_modified < len(nested_calculator_params): raise ValueError('Not all calculator params are valid.') - def _make_packet(self, packet_data_type: _PacketDataType, + def _make_packet(self, packet_data_type: PacketDataType, data: Any) -> packet.Packet: - if (packet_data_type == _PacketDataType.IMAGE_FRAME or - packet_data_type == _PacketDataType.IMAGE): + if (packet_data_type == PacketDataType.IMAGE_FRAME or + packet_data_type == PacketDataType.IMAGE): return getattr(packet_creator, 'create_' + packet_data_type.value)( data, image_format=image_frame.ImageFormat.SRGB) else: return getattr(packet_creator, 'create_' + packet_data_type.value)(data) - def _get_packet_content(self, packet_data_type: _PacketDataType, + def _get_packet_content(self, packet_data_type: PacketDataType, output_packet: packet.Packet) -> Any: """Gets packet content from a packet by type. @@ -527,10 +547,10 @@ class SolutionBase: if output_packet.is_empty(): return None - if packet_data_type == _PacketDataType.STRING: + if packet_data_type == PacketDataType.STRING: return packet_getter.get_str(output_packet) - elif (packet_data_type == _PacketDataType.IMAGE_FRAME or - packet_data_type == _PacketDataType.IMAGE): + elif (packet_data_type == PacketDataType.IMAGE_FRAME or + packet_data_type == PacketDataType.IMAGE): return getattr(packet_getter, 'get_' + packet_data_type.value)(output_packet).numpy_view() else: diff --git a/mediapipe/python/solution_base_test.py b/mediapipe/python/solution_base_test.py index 0acd25cb0..6d04d94d7 100644 --- a/mediapipe/python/solution_base_test.py +++ b/mediapipe/python/solution_base_test.py @@ -22,6 +22,7 @@ from google.protobuf import text_format from mediapipe.framework import calculator_pb2 from mediapipe.framework.formats import detection_pb2 from mediapipe.python import solution_base +from mediapipe.python.solution_base import PacketDataType CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG = """ input_stream: 'image_in' @@ -348,6 +349,34 @@ class SolutionBaseTest(parameterized.TestCase): self.assertTrue(np.array_equal(input_image, outputs.image_out)) solution.reset() + def test_solution_stream_type_hints(self): + text_config = """ + input_stream: 'union_type_image_in' + output_stream: 'image_type_out' + node { + calculator: 'ToImageCalculator' + input_stream: 'IMAGE:union_type_image_in' + output_stream: 'IMAGE:image_type_out' + } + """ + config_proto = text_format.Parse(text_config, + calculator_pb2.CalculatorGraphConfig()) + input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3) + with solution_base.SolutionBase( + graph_config=config_proto, + stream_type_hints={'union_type_image_in': PacketDataType.IMAGE + }) as solution: + for _ in range(20): + outputs = solution.process(input_image) + self.assertTrue(np.array_equal(input_image, outputs.image_type_out)) + with solution_base.SolutionBase( + graph_config=config_proto, + stream_type_hints={'union_type_image_in': PacketDataType.IMAGE_FRAME + }) as solution2: + for _ in range(20): + outputs = solution2.process(input_image) + self.assertTrue(np.array_equal(input_image, outputs.image_type_out)) + def _process_and_verify(self, config_proto, side_inputs=None, diff --git a/mediapipe/util/BUILD b/mediapipe/util/BUILD index a1b179a3a..b5f52450f 100644 --- a/mediapipe/util/BUILD +++ b/mediapipe/util/BUILD @@ -158,6 +158,7 @@ cc_library( ] + select({ "//conditions:default": ["resource_util_default.cc"], "//mediapipe:android": ["resource_util_android.cc"], + "//mediapipe/framework:android_no_jni": ["resource_util_loonix.cc"], "//mediapipe:ios": ["resource_util_apple.cc"], "//mediapipe:macos": ["resource_util_default.cc"], }), @@ -192,6 +193,7 @@ cc_library( "//mediapipe/util/android:asset_manager_util", "//mediapipe/util/android/file/base", ], + "//mediapipe/framework:android_no_jni": [], "//mediapipe:ios": [], "//mediapipe:macos": [ "@com_google_absl//absl/flags:flag", diff --git a/mediapipe/util/android/file/base/file.cc b/mediapipe/util/android/file/base/file.cc index 6a216124e..83a34f15f 100644 --- a/mediapipe/util/android/file/base/file.cc +++ b/mediapipe/util/android/file/base/file.cc @@ -109,8 +109,7 @@ void StringReplace(absl::string_view s, absl::string_view oldsub, absl::string_view newsub, bool replace_all, std::string* res) { if (oldsub.empty()) { - res->append(s.data(), - s.length()); // If empty, append the given std::string. + res->append(s.data(), s.length()); // If empty, append the given string. return; } diff --git a/mediapipe/util/android/file/base/helpers.cc b/mediapipe/util/android/file/base/helpers.cc index add4c9b36..3de6b9b49 100644 --- a/mediapipe/util/android/file/base/helpers.cc +++ b/mediapipe/util/android/file/base/helpers.cc @@ -42,7 +42,7 @@ class FdCloser { } // namespace -// Read contents of a file to a std::string. +// Read contents of a file to a string. absl::Status GetContents(int fd, std::string* output) { // Determine the length of the file. struct stat buf; @@ -69,7 +69,7 @@ absl::Status GetContents(int fd, std::string* output) { return absl::OkStatus(); } -// Read contents of a file to a std::string. +// Read contents of a file to a string. absl::Status GetContents(absl::string_view file_name, std::string* output, const file::Options& /*options*/) { int fd = open(std::string(file_name).c_str(), O_RDONLY); diff --git a/mediapipe/util/android/file/base/helpers.h b/mediapipe/util/android/file/base/helpers.h index df61d423e..26d9f0f0e 100644 --- a/mediapipe/util/android/file/base/helpers.h +++ b/mediapipe/util/android/file/base/helpers.h @@ -24,21 +24,21 @@ namespace mediapipe { namespace file { -// Read contents of a file to a std::string. +// Read contents of a file to a string. absl::Status GetContents(absl::string_view file_name, std::string* output, const file::Options& options); -// Read contents of a file to a std::string with default file options. +// Read contents of a file to a string with default file options. absl::Status GetContents(absl::string_view file_name, std::string* output); -// Read contents of a file to a std::string from an open file descriptor. +// Read contents of a file to a string from an open file descriptor. absl::Status GetContents(int fd, std::string* output); -// Write std::string to file. +// Write string to file. absl::Status SetContents(absl::string_view file_name, absl::string_view content, const file::Options& options); -// Write std::string to file with default file options. +// Write string to file with default file options. absl::Status SetContents(absl::string_view file_name, absl::string_view content); diff --git a/mediapipe/util/annotation_renderer.cc b/mediapipe/util/annotation_renderer.cc index d464995e2..19fbbc14d 100644 --- a/mediapipe/util/annotation_renderer.cc +++ b/mediapipe/util/annotation_renderer.cc @@ -174,6 +174,18 @@ void AnnotationRenderer::DrawRectangle(const RenderAnnotation& annotation) { cv::Rect rect(left, top, right - left, bottom - top); cv::rectangle(mat_image_, rect, color, thickness); } + if (rectangle.has_top_left_thickness()) { + const auto& rect = RectangleToOpenCVRotatedRect(left, top, right, bottom, + rectangle.rotation()); + const int kNumVertices = 4; + cv::Point2f vertices[kNumVertices]; + rect.points(vertices); + const int top_left_thickness = + ClampThickness(round(rectangle.top_left_thickness() * scale_factor_)); + cv::ellipse(mat_image_, vertices[1], + cv::Size(top_left_thickness, top_left_thickness), 0.0, 0, 360, + color, -1); + } } void AnnotationRenderer::DrawFilledRectangle( diff --git a/mediapipe/util/render_data.proto b/mediapipe/util/render_data.proto index 54a2d65b8..0ff6b3409 100644 --- a/mediapipe/util/render_data.proto +++ b/mediapipe/util/render_data.proto @@ -46,6 +46,8 @@ message RenderAnnotation { optional double bottom = 4; optional bool normalized = 5 [default = false]; optional double rotation = 6; // Rotation in radians. + // Radius of top left corner circle. + optional double top_left_thickness = 7; } message FilledRectangle { diff --git a/mediapipe/util/resource_util_apple.cc b/mediapipe/util/resource_util_apple.cc index 5fb66c24c..c812dcb57 100644 --- a/mediapipe/util/resource_util_apple.cc +++ b/mediapipe/util/resource_util_apple.cc @@ -49,12 +49,7 @@ absl::Status DefaultGetResourceContents(const std::string& path, LOG(WARNING) << "Setting \"read_as_binary\" to false is a no-op on ios."; } ASSIGN_OR_RETURN(std::string full_path, PathToResourceAsFile(path)); - - std::ifstream input_file(full_path); - std::stringstream buffer; - buffer << input_file.rdbuf(); - buffer.str().swap(*output); - return absl::OkStatus(); + return file::GetContents(full_path, output, read_as_binary); } } // namespace internal diff --git a/mediapipe/util/sequence/BUILD b/mediapipe/util/sequence/BUILD index 18aedf8e6..ac7c2ba51 100644 --- a/mediapipe/util/sequence/BUILD +++ b/mediapipe/util/sequence/BUILD @@ -21,6 +21,7 @@ cc_library( name = "media_sequence_util", hdrs = ["media_sequence_util.h"], visibility = [ + "//home/interaction:__subpackages__", "//mediapipe:__subpackages__", ], deps = [ diff --git a/mediapipe/util/sequence/media_sequence.h b/mediapipe/util/sequence/media_sequence.h index c8be33e37..2f71024b4 100644 --- a/mediapipe/util/sequence/media_sequence.h +++ b/mediapipe/util/sequence/media_sequence.h @@ -37,8 +37,8 @@ // The clip label group adds labels that apply to the entire media clip. To // annotate that a video clip has a particular label, set the clip metadata // above and also set the SetClipLabelIndex and SetClipLabelString. Most -// training pipelines will only use the label index or std::string, but we -// recommend storing both to improve readability while maintaining ease of use. +// training pipelines will only use the label index or string, but we recommend +// storing both to improve readability while maintaining ease of use. // Example: // SetClipLabelString({"run", "jump"}, &sequence); // SetClipLabelIndex({35, 47}, &sequence); @@ -49,10 +49,11 @@ // are called segments. To annotate that a video clip has time spans with labels // set the clip metadata above and use the functions SetSegmentStartTimestamp, // SetSegmentEndTimestamp, SetSegmentLabelIndex, and SetSegmentLabelString. Most -// training pipelines will only use the label index or std::string, but we -// recommend storing both to improve readability while maintaining ease of use. -// By listing segments as times, the frame rate or other properties can change -// without affecting the labels. Example: +// training pipelines will only use the label index or string, but we recommend +// storing both to improve readability while maintaining ease of use. By listing +// segments as times, the frame rate or other properties can change without +// affecting the labels. +// Example: // SetSegmentStartTimestamp({500000, 1000000}, &sequence); // in microseconds // SetSegmentEndTimestamp({2000000, 6000000}, &sequence); // SetSegmentLabelIndex({35, 47}, &sequence); @@ -63,7 +64,7 @@ // needed can vary by task, but to annotate a video clip for detection set the // clip metadata above and use repeatedly call AddBBox, AddBBoxTimestamp, // AddBBoxLabelIndex, and AddBBoxLabelString. Most training pipelines will only -// use the label index or std::string, but we recommend storing both to improve +// use the label index or string, but we recommend storing both to improve // readability while maintaining ease of use. Because bounding boxes are // assigned to timepoints in a video, changing the image frame rate can can // change the alignment. The ReconcileMetadata function can align bounding boxes @@ -100,7 +101,7 @@ // tensorflow::SequenceExample example; // SetDataPath("data_path", &example); // if (HasDataPath(example)) { -// std::string data_path = GetDataPath(example); +// string data_path = GetDataPath(example); // ClearDataPath(&example); // } // @@ -146,13 +147,12 @@ // } // // As described in media_sequence_util.h, each of these functions can take an -// additional std::string prefix argument as their first argument. The prefix -// can be fixed with a new NAME by calling a FIXED_PREFIX_... macro. Prefixes -// are used to identify common storage patterns (e.g. storing an image along -// with the height and width) under different names (e.g. storing a left and -// right image in a stereo pair.) An example creating functions such as -// AddLeftImageEncoded that adds a std::string under the key -// "LEFT/image/encoded": +// additional string prefix argument as their first argument. The prefix can +// be fixed with a new NAME by calling a FIXED_PREFIX_... macro. Prefixes are +// used to identify common storage patterns (e.g. storing an image along with +// the height and width) under different names (e.g. storing a left and right +// image in a stereo pair.) An example creating functions such as +// AddLeftImageEncoded that adds a string under the key "LEFT/image/encoded": // FIXED_PREFIX_STRING_FEATURE_LIST("LEFT", LeftImageEncoded, "image/encoded"); #ifndef MEDIAPIPE_TENSORFLOW_SEQUENCE_MEDIA_SEQUENCE_H_ @@ -230,7 +230,7 @@ const char kSegmentEndIndexKey[] = "segment/end/index"; // A list with the label index for each segment. // Multiple labels for the same segment are encoded as repeated segments. const char kSegmentLabelIndexKey[] = "segment/label/index"; -// A list with the label std::string for each segment. +// A list with the label string for each segment. // Multiple labels for the same segment are encoded as repeated segments. const char kSegmentLabelStringKey[] = "segment/label/string"; // A list with the label confidence for each segment. @@ -301,7 +301,7 @@ const char kRegionTimestampKey[] = "region/timestamp"; // An embedding for each region. The length of each list must be the product of // the number of regions and the product of the embedding dimensions. const char kRegionEmbeddingFloatKey[] = "region/embedding/float"; -// A std::string encoded embedding for each region. +// A string encoded embedding for each region. const char kRegionEmbeddingEncodedKey[] = "region/embedding/encoded"; // The confidence of the embedding. const char kRegionEmbeddingConfidenceKey[] = "region/embedding/confidence"; diff --git a/mediapipe/util/sequence/media_sequence_test.cc b/mediapipe/util/sequence/media_sequence_test.cc index c4e482600..ca3021ea4 100644 --- a/mediapipe/util/sequence/media_sequence_test.cc +++ b/mediapipe/util/sequence/media_sequence_test.cc @@ -15,6 +15,7 @@ #include "mediapipe/util/sequence/media_sequence.h" #include +#include #include "mediapipe/framework/formats/location.h" #include "mediapipe/framework/port/gmock.h" diff --git a/mediapipe/util/sequence/media_sequence_util.h b/mediapipe/util/sequence/media_sequence_util.h index 531605aba..1737f91a0 100644 --- a/mediapipe/util/sequence/media_sequence_util.h +++ b/mediapipe/util/sequence/media_sequence_util.h @@ -29,14 +29,14 @@ // the most basic function prototypes for the name='MyFeature' are similar to: // // {BYTES,INT64,FLOAT}_CONTEXT_FEATURE: -// std::string GetMyFeatureKey(sequence) +// string GetMyFeatureKey(sequence) // bool HasMyFeature(sequence) // void ClearMyFeature(*sequence) // void SetMyFeature(value, *sequence) // TYPE GetMyFeature(sequence) // // VECTOR_{BYTES,INT64,FLOAT}_CONTEXT_FEATURE: -// std::string GetMyFeatureKey(sequence) +// string GetMyFeatureKey(sequence) // bool HasMyFeature(sequence) // void ClearMyFeature(*sequence) // void SetMyFeature(repeated_value, *sequence) @@ -46,7 +46,7 @@ // TYPE GetMyFeatureAt(sequence) // // {BYTES,INT64,FLOAT}_FEATURE_LIST: -// std::string GetMyFeatureKey(sequence) +// string GetMyFeatureKey(sequence) // bool HasMyFeature(sequence) // void ClearMyFeature(*sequence) // void AddMyFeature(value, *sequence) @@ -54,7 +54,7 @@ // TYPE GetMyFeatureAt(sequence) // // VECTOR_{BYTES,INT64,FLOAT}_FEATURE_LIST: -// std::string GetMyFeatureKey(sequence) +// string GetMyFeatureKey(sequence) // bool HasMyFeature(sequence) // void ClearMyFeature(*sequence) // void AddMyFeature(repeated_value, *sequence) @@ -236,8 +236,8 @@ inline const proto_ns::RepeatedField& GetInt64sAt( return fl.feature().Get(index).int64_list().value(); } -// Returns a refrerence to the std::string values for the feature list indicated -// by key at the provided sequence index. +// Returns a refrerence to the string values for the feature list indicated by +// key at the provided sequence index. inline const proto_ns::RepeatedPtrField& GetBytesAt( const tensorflow::SequenceExample& sequence, const std::string& key, const int index) { @@ -279,17 +279,17 @@ void AddBytesContainer(const std::string& key, const TContainer& bytes_list, // The macros provided below are useful for creating getters and setters for // keys and values in a tf::SequenceExample. You only need to specify the C++ -// name to use in the functions and the std::string key used in the -// SequenceExample proto maps. Macro versions exist for {strings, int64s, and -// floats} for creating singular or repeated context features and singular or -// repeated feature_list features. +// name to use in the functions and the string key used in the SequenceExample +// proto maps. Macro versions exist for {strings, int64s, and floats} for +// creating singular or repeated context features and singular or repeated +// feature_list features. // Helpers to create functions names in the macros below. #define CONCAT_STR2(a, b) a##b #define CONCAT_STR3(a, b, c) a##b##c // This macro creates functions for HasX, GetX, ClearX, and SetX where X is a -// name and the value stored is a std::string in the context. +// name and the value stored is a string in the context. #define PREFIXED_BYTES_CONTEXT_FEATURE(name, key) \ inline const bool CONCAT_STR2(Has, name)( \ const std::string& prefix, \ @@ -766,7 +766,7 @@ void AddBytesContainer(const std::string& key, const TContainer& bytes_list, FIXED_PREFIX_VECTOR_FLOAT_CONTEXT_FEATURE(name, key, ""); // This macro creates functions for HasX, GetXSize, GetXAt, ClearX, and AddX -// where X is a name and the value stored is a std::string in a feature_list. +// where X is a name and the value stored is a string in a feature_list. #define PREFIXED_BYTES_FEATURE_LIST(name, key) \ inline const bool CONCAT_STR2(Has, name)( \ const std::string& prefix, \ diff --git a/mediapipe/util/tensor_to_detection.cc b/mediapipe/util/tensor_to_detection.cc index 4326067bc..91fc31696 100644 --- a/mediapipe/util/tensor_to_detection.cc +++ b/mediapipe/util/tensor_to_detection.cc @@ -35,7 +35,7 @@ Detection TensorToDetection( detection.add_score(score); // According to mediapipe/framework/formats/detection.proto - // "Either std::string or integer labels must be used but not both at the + // "Either string or integer labels must be used but not both at the // same time." if (absl::holds_alternative(class_label)) { detection.add_label_id(absl::get(class_label)); diff --git a/mediapipe/util/tflite/BUILD b/mediapipe/util/tflite/BUILD index 3c0f1f35c..a4852b804 100644 --- a/mediapipe/util/tflite/BUILD +++ b/mediapipe/util/tflite/BUILD @@ -98,7 +98,10 @@ cc_library( "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model_builder", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2", ], - }) + ["@org_tensorflow//tensorflow/lite/core/api"], + }) + [ + "@com_google_absl//absl/status", + "@org_tensorflow//tensorflow/lite/core/api", + ], ) cc_library( diff --git a/mediapipe/util/tflite/tflite_gpu_runner.cc b/mediapipe/util/tflite/tflite_gpu_runner.cc index 14a12db86..4c422835a 100644 --- a/mediapipe/util/tflite/tflite_gpu_runner.cc +++ b/mediapipe/util/tflite/tflite_gpu_runner.cc @@ -252,8 +252,10 @@ absl::Status TFLiteGPURunner::InitializeOpenCL( MP_RETURN_IF_ERROR(cl_environment_->NewInferenceBuilder( cl_options, std::move(graph_cl), builder)); -#endif // __ANDROID__ return absl::OkStatus(); +#else + return mediapipe::UnimplementedError("Currently only Android is supported"); +#endif // __ANDROID__ } #ifdef __ANDROID__ diff --git a/mediapipe/util/time_series_test_util.h b/mediapipe/util/time_series_test_util.h index 81c6d61a5..7e31aeff5 100644 --- a/mediapipe/util/time_series_test_util.h +++ b/mediapipe/util/time_series_test_util.h @@ -152,7 +152,7 @@ class TimeSeriesCalculatorTest : public ::testing::Test { } // Makes the CalculatorGraphConfig used to initialize CalculatorRunner - // runner_. If no options are needed, pass the empty std::string for options. + // runner_. If no options are needed, pass the empty string for options. CalculatorGraphConfig::Node MakeNodeConfig(const std::string& calculator_name, const int num_side_packets, const CalculatorOptions& options) { diff --git a/mediapipe/util/tracking/box_detector.cc b/mediapipe/util/tracking/box_detector.cc index 02ac6321a..e3a0eb476 100644 --- a/mediapipe/util/tracking/box_detector.cc +++ b/mediapipe/util/tracking/box_detector.cc @@ -154,7 +154,7 @@ void BoxDetectorInterface::DetectAndAddBoxFromFeatures( } for (int idx = 0; idx < size_before_add; ++idx) { - if ((options_.has_detect_every_n_frame() > 0 && + if ((options_.detect_every_n_frame() > 0 && cnt_detect_called_ % options_.detect_every_n_frame() == 0) || !tracked[idx] || (options_.detect_out_of_fov() && has_been_out_of_fov_[idx])) { diff --git a/mediapipe/util/tracking/camera_motion.h b/mediapipe/util/tracking/camera_motion.h index 98b28abbf..cadee78cb 100644 --- a/mediapipe/util/tracking/camera_motion.h +++ b/mediapipe/util/tracking/camera_motion.h @@ -93,11 +93,11 @@ float ForegroundMotion(const CameraMotion& camera_motion, void InitCameraMotionFromFeatureList(const RegionFlowFeatureList& feature_list, CameraMotion* camera_motion); -// Converts Camera motion flag to std::string. +// Converts Camera motion flag to string. std::string CameraMotionFlagToString(const CameraMotion& motion); -// Converts Camera motion type to std::string. Used instead of builtin proto -// function for mobile support. +// Converts Camera motion type to string. Used instead of builtin proto function +// for mobile support. std::string CameraMotionTypeToString(const CameraMotion& motion); // Returns inlier coverage either based on mixture (if present, in this case diff --git a/mediapipe/util/tracking/flow_packager.h b/mediapipe/util/tracking/flow_packager.h index 03b67facf..b0b6c9c15 100644 --- a/mediapipe/util/tracking/flow_packager.h +++ b/mediapipe/util/tracking/flow_packager.h @@ -64,7 +64,7 @@ namespace mediapipe { // flow_packager.FinalizeTrackingContainerFormat(&container); // flow_packager.FinalizeTrackingProto(&proto); // -// // Convert to binary std::string to stream out. +// // Convert to binary string to stream out. // std::string output; // flow_packager.TrackingContainerFormatToBinary(container, &output); // // OR: @@ -143,9 +143,8 @@ class FlowPackager { void SortRegionFlowFeatureList(float scale_x, float scale_y, RegionFlowFeatureList* feature_list) const; - // Removes binary encoded container from std::string and parses it to - // container. Returns header std::string of the parsed container. Useful for - // random seek. + // Removes binary encoded container from string and parses it to container. + // Returns header string of the parsed container. Useful for random seek. std::string SplitContainerFromString(absl::string_view* binary_data, TrackingContainer* container); @@ -155,7 +154,7 @@ class FlowPackager { const std::vector& data_sizes, MetaData* meta_data) const; - // Serializes container to binary std::string and adds it to binary_data. + // Serializes container to binary string and adds it to binary_data. void AddContainerToString(const TrackingContainer& container, std::string* binary_data); diff --git a/mediapipe/util/tracking/streaming_buffer.h b/mediapipe/util/tracking/streaming_buffer.h index c2be024f4..2d7945222 100644 --- a/mediapipe/util/tracking/streaming_buffer.h +++ b/mediapipe/util/tracking/streaming_buffer.h @@ -120,7 +120,7 @@ namespace mediapipe { // Stores pair (tag, TypeId of type). typedef std::pair TaggedType; -// Returns TaggedType for type T* tagged with passed std::string. +// Returns TaggedType for type T* tagged with passed string. template TaggedType TaggedPointerType(const std::string& tag); diff --git a/third_party/org_tensorflow_objc_cxx17.diff b/third_party/org_tensorflow_objc_cxx17.diff deleted file mode 100644 index a9da53fdf..000000000 --- a/third_party/org_tensorflow_objc_cxx17.diff +++ /dev/null @@ -1,24 +0,0 @@ -diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD -index 069230ebcf6..3924d7cced7 100644 ---- a/tensorflow/lite/delegates/gpu/BUILD -+++ b/tensorflow/lite/delegates/gpu/BUILD -@@ -83,6 +83,7 @@ objc_library( - hdrs = ["metal_delegate.h"], - module_name = "TensorFlowLiteCMetal", - sdk_frameworks = ["Metal"], -+ copts = ["-std=c++17"], - deps = [ - "//tensorflow/lite:kernel_api", - "//tensorflow/lite:minimal_logging", -diff --git a/tensorflow/lite/delegates/gpu/metal/BUILD b/tensorflow/lite/delegates/gpu/metal/BUILD -index 6dcde34a62f..1adfc28aad9 100644 ---- a/tensorflow/lite/delegates/gpu/metal/BUILD -+++ b/tensorflow/lite/delegates/gpu/metal/BUILD -@@ -17,6 +17,7 @@ package( - - DEFAULT_COPTS = [ - "-Wno-shorten-64-to-32", -+ "-std=c++17", - ] - - objc_library(