Project import generated by Copybara.

GitOrigin-RevId: bb059a0721c92e8154d33ce8057b3915a25b3d7d
This commit is contained in:
MediaPipe Team 2021-12-10 14:03:51 -08:00 committed by jqtang
parent cf101e62a9
commit e6c19885c6
96 changed files with 554 additions and 486 deletions

View File

@ -58,6 +58,7 @@ build:android_arm64 --fat_apk_cpu=arm64-v8a
# iOS configs. # iOS configs.
build:ios --apple_platform_type=ios build:ios --apple_platform_type=ios
build:ios --copt=-fno-aligned-allocation
build:ios_i386 --config=ios build:ios_i386 --config=ios
build:ios_i386 --cpu=ios_i386 build:ios_i386 --cpu=ios_i386

View File

@ -1 +1 @@
3.7.2 4.2.1

View File

@ -5,7 +5,7 @@
* Bug fixes * Bug fixes
* Documentation fixes * Documentation fixes
For new feature additions (e.g., new graphs and calculators), we are currently not planning to accept new feature pull requests into the MediaPipe repository. Instead, we like to get contributors to create their own repositories of the new feature and list it at [Awesome MediaPipe](https://mediapipe.org). This will allow contributors to more quickly get their code out to the community. For new feature additions (e.g., new graphs and calculators), we are currently not planning to accept new feature pull requests into the MediaPipe repository. Instead, we like to get contributors to create their own repositories of the new feature and list it at [Awesome MediaPipe](https://mediapipe.page.link/awesome-mediapipe). This will allow contributors to more quickly get their code out to the community.
Before sending your pull requests, make sure you followed this list. Before sending your pull requests, make sure you followed this list.

View File

@ -56,7 +56,7 @@ RUN pip3 install tf_slim
RUN ln -s /usr/bin/python3 /usr/bin/python RUN ln -s /usr/bin/python3 /usr/bin/python
# Install bazel # Install bazel
ARG BAZEL_VERSION=3.7.2 ARG BAZEL_VERSION=4.2.1
RUN mkdir /bazel && \ RUN mkdir /bazel && \
wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/b\ wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/b\
azel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ azel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \

View File

@ -122,16 +122,16 @@ http_archive(
# ...but the Java download is currently broken, so we use the "source" download. # ...but the Java download is currently broken, so we use the "source" download.
http_archive( http_archive(
name = "com_google_protobuf_javalite", name = "com_google_protobuf_javalite",
sha256 = "a79d19dcdf9139fa4b81206e318e33d245c4c9da1ffed21c87288ed4380426f9", sha256 = "87407cd28e7a9c95d9f61a098a53cf031109d451a7763e7dd1253abf8b4df422",
strip_prefix = "protobuf-3.11.4", strip_prefix = "protobuf-3.19.1",
urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.11.4.tar.gz"], urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.19.1.tar.gz"],
) )
http_archive( http_archive(
name = "com_google_protobuf", name = "com_google_protobuf",
sha256 = "a79d19dcdf9139fa4b81206e318e33d245c4c9da1ffed21c87288ed4380426f9", sha256 = "87407cd28e7a9c95d9f61a098a53cf031109d451a7763e7dd1253abf8b4df422",
strip_prefix = "protobuf-3.11.4", strip_prefix = "protobuf-3.19.1",
urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.11.4.tar.gz"], urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.19.1.tar.gz"],
patches = [ patches = [
"@//third_party:com_google_protobuf_fixes.diff" "@//third_party:com_google_protobuf_fixes.diff"
], ],
@ -154,28 +154,29 @@ http_archive(
sha256 = "75922da3a1bdb417d820398eb03d4e9bd067c4905a4246d35a44c01d62154d91", sha256 = "75922da3a1bdb417d820398eb03d4e9bd067c4905a4246d35a44c01d62154d91",
) )
# Point to the commit that deprecates the usage of Eigen::MappedSparseMatrix.
http_archive( http_archive(
name = "pybind11", name = "pybind11",
urls = [ urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.7.1.tar.gz", "https://github.com/pybind/pybind11/archive/70a58c577eaf067748c2ec31bfd0b0a614cffba6.zip",
"https://github.com/pybind/pybind11/archive/v2.7.1.tar.gz",
], ],
sha256 = "616d1c42e4cf14fa27b2a4ff759d7d7b33006fdc5ad8fd603bb2c22622f27020", sha256 = "b971842fab1b5b8f3815a2302331782b7d137fef0e06502422bc4bc360f4956c",
strip_prefix = "pybind11-2.7.1", strip_prefix = "pybind11-70a58c577eaf067748c2ec31bfd0b0a614cffba6",
build_file = "@pybind11_bazel//:pybind11.BUILD", build_file = "@pybind11_bazel//:pybind11.BUILD",
) )
# Point to the commit that deprecates the usage of Eigen::MappedSparseMatrix.
http_archive( http_archive(
name = "ceres_solver", name = "ceres_solver",
url = "https://github.com/ceres-solver/ceres-solver/archive/2.0.0.zip", url = "https://github.com/ceres-solver/ceres-solver/archive/123fba61cf2611a3c8bddc9d91416db26b10b558.zip",
patches = [ patches = [
"@//third_party:ceres_solver_compatibility_fixes.diff" "@//third_party:ceres_solver_compatibility_fixes.diff"
], ],
patch_args = [ patch_args = [
"-p1", "-p1",
], ],
strip_prefix = "ceres-solver-2.0.0", strip_prefix = "ceres-solver-123fba61cf2611a3c8bddc9d91416db26b10b558",
sha256 = "db12d37b4cebb26353ae5b7746c7985e00877baa8e7b12dc4d3a1512252fff3b" sha256 = "8b7b16ceb363420e0fd499576daf73fa338adb0b1449f58bea7862766baa1ac7"
) )
http_archive( http_archive(
@ -249,21 +250,12 @@ http_archive(
], ],
) )
# You may run setup_android.sh to install Android SDK and NDK.
android_ndk_repository(
name = "androidndk",
)
android_sdk_repository(
name = "androidsdk",
)
# iOS basic build deps. # iOS basic build deps.
http_archive( http_archive(
name = "build_bazel_rules_apple", name = "build_bazel_rules_apple",
sha256 = "7a7afdd4869bb201c9352eed2daf37294d42b093579b70423490c1b4d4f6ce42", sha256 = "77e8bf6fda706f420a55874ae6ee4df0c9d95da6c7838228b26910fc82eea5a2",
url = "https://github.com/bazelbuild/rules_apple/releases/download/0.19.0/rules_apple.0.19.0.tar.gz", url = "https://github.com/bazelbuild/rules_apple/releases/download/0.32.0/rules_apple.0.32.0.tar.gz",
patches = [ patches = [
# Bypass checking ios unit test runner when building MP ios applications. # Bypass checking ios unit test runner when building MP ios applications.
"@//third_party:build_bazel_rules_apple_bypass_test_runner_check.diff" "@//third_party:build_bazel_rules_apple_bypass_test_runner_check.diff"
@ -289,10 +281,9 @@ swift_rules_dependencies()
http_archive( http_archive(
name = "build_bazel_apple_support", name = "build_bazel_apple_support",
sha256 = "122ebf7fe7d1c8e938af6aeaee0efe788a3a2449ece5a8d6a428cb18d6f88033", sha256 = "741366f79d900c11e11d8efd6cc6c66a31bfb2451178b58e0b5edc6f1db17b35",
urls = [ urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/apple_support/releases/download/0.7.1/apple_support.0.7.1.tar.gz", "https://github.com/bazelbuild/apple_support/releases/download/0.10.0/apple_support.0.10.0.tar.gz"
"https://github.com/bazelbuild/apple_support/releases/download/0.7.1/apple_support.0.7.1.tar.gz",
], ],
) )
@ -382,9 +373,9 @@ http_archive(
) )
# Tensorflow repo should always go after the other external dependencies. # Tensorflow repo should always go after the other external dependencies.
# 2021-07-29 # 2021-12-02
_TENSORFLOW_GIT_COMMIT = "52a2905cbc21034766c08041933053178c5d10e3" _TENSORFLOW_GIT_COMMIT = "18a1dc0ba806dc023808531f0373d9ec068e64bf"
_TENSORFLOW_SHA256 = "06d4691bcdb700f3275fa0971a1585221c2b9f3dffe867963be565a6643d7f56" _TENSORFLOW_SHA256 = "85b90416f7a11339327777bccd634de00ca0de2cf334f5f0727edcb11ff9289a"
http_archive( http_archive(
name = "org_tensorflow", name = "org_tensorflow",
urls = [ urls = [

View File

@ -29,8 +29,8 @@ APIs (currently in alpha) that are now available in
* Install MediaPipe following these [instructions](./install.md). * Install MediaPipe following these [instructions](./install.md).
* Setup Java Runtime. * Setup Java Runtime.
* Setup Android SDK release 28.0.3 and above. * Setup Android SDK release 30.0.0 and above.
* Setup Android NDK version between 18 and 21. * Setup Android NDK version 18 and above.
MediaPipe recommends setting up Android SDK and NDK via Android Studio (and see MediaPipe recommends setting up Android SDK and NDK via Android Studio (and see
below for Android Studio setup). However, if you prefer using MediaPipe without below for Android Studio setup). However, if you prefer using MediaPipe without
@ -47,6 +47,15 @@ export ANDROID_HOME=<path to the Android SDK>
export ANDROID_NDK_HOME=<path to the Android NDK> export ANDROID_NDK_HOME=<path to the Android NDK>
``` ```
and add android_ndk_repository() and android_sdk_repository() rules into the
[`WORKSPACE`](https://github.com/google/mediapipe/blob/master/WORKSPACE) file as
the following:
```bash
$ echo "android_sdk_repository(name = \"androidsdk\")" >> WORKSPACE
$ echo "android_ndk_repository(name = \"androidndk\")" >> WORKSPACE
```
In order to use MediaPipe on earlier Android versions, MediaPipe needs to switch In order to use MediaPipe on earlier Android versions, MediaPipe needs to switch
to a lower Android API level. You can achieve this by specifying `api_level = to a lower Android API level. You can achieve this by specifying `api_level =
$YOUR_INTENDED_API_LEVEL` in android_ndk_repository() and/or $YOUR_INTENDED_API_LEVEL` in android_ndk_repository() and/or

View File

@ -117,7 +117,7 @@ each project.
implementation 'com.google.flogger:flogger-system-backend:latest.release' implementation 'com.google.flogger:flogger-system-backend:latest.release'
implementation 'com.google.code.findbugs:jsr305:latest.release' implementation 'com.google.code.findbugs:jsr305:latest.release'
implementation 'com.google.guava:guava:27.0.1-android' implementation 'com.google.guava:guava:27.0.1-android'
implementation 'com.google.protobuf:protobuf-java:3.11.4' implementation 'com.google.protobuf:protobuf-javalite:3.19.1'
// CameraX core library // CameraX core library
def camerax_version = "1.0.0-beta10" def camerax_version = "1.0.0-beta10"
implementation "androidx.camera:camera-core:$camerax_version" implementation "androidx.camera:camera-core:$camerax_version"

View File

@ -569,7 +569,7 @@ next section.
Option 1. Follow Option 1. Follow
[the official Bazel documentation](https://docs.bazel.build/versions/master/install-windows.html) [the official Bazel documentation](https://docs.bazel.build/versions/master/install-windows.html)
to install Bazel 3.7.2 or higher. to install Bazel 4.2.1 or higher.
Option 2. Follow the official Option 2. Follow the official
[Bazel documentation](https://docs.bazel.build/versions/master/install-bazelisk.html) [Bazel documentation](https://docs.bazel.build/versions/master/install-bazelisk.html)
@ -657,7 +657,7 @@ cameras. Alternatively, you use a video file as input.
Note: Windows' and WSLs adb versions must be the same version, e.g., if WSL Note: Windows' and WSLs adb versions must be the same version, e.g., if WSL
has ADB 1.0.39, you need to download the corresponding Windows ADB from has ADB 1.0.39, you need to download the corresponding Windows ADB from
[here](https://dl.google.com/android/repository/platform-tools_r26.0.1-windows.zip). [here](https://dl.google.com/android/repository/platform-tools_r30.0.3-windows.zip).
3. Launch WSL. 3. Launch WSL.
@ -796,7 +796,7 @@ This will use a Docker image that will isolate mediapipe's installation from the
```bash ```bash
$ docker run -it --name mediapipe mediapipe:latest $ docker run -it --name mediapipe mediapipe:latest
root@bca08b91ff63:/mediapipe# GLOG_logtostderr=1 bazelisk run --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/hello_world:hello_world root@bca08b91ff63:/mediapipe# GLOG_logtostderr=1 bazel run --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/hello_world
# Should print: # Should print:
# Hello World! # Hello World!

View File

@ -249,12 +249,12 @@ three stages: initialization and setup, graph run, and graph shutdown.
graph.start_run() graph.start_run()
graph.add_packet_to_input_stream( graph.add_packet_to_input_stream(
'in_stream', mp.packet_creator.create_str('abc').at(0)) 'in_stream', mp.packet_creator.create_string('abc').at(0))
rgb_img = cv2.cvtColor(cv2.imread('/path/to/your/image.png'), cv2.COLOR_BGR2RGB) rgb_img = cv2.cvtColor(cv2.imread('/path/to/your/image.png'), cv2.COLOR_BGR2RGB)
graph.add_packet_to_input_stream( graph.add_packet_to_input_stream(
'in_stream', 'in_stream',
mp.packet_creator.create_image_frame(format=mp.ImageFormat.SRGB, mp.packet_creator.create_image_frame(image_format=mp.ImageFormat.SRGB,
data=rgb_img).at(1)) data=rgb_img).at(1))
``` ```

View File

@ -108,14 +108,14 @@ ERROR: No matching distribution found for mediapipe
after running `pip install mediapipe` usually indicates that there is no qualified MediaPipe Python for your system. after running `pip install mediapipe` usually indicates that there is no qualified MediaPipe Python for your system.
Please note that MediaPipe Python PyPI officially supports the **64-bit** Please note that MediaPipe Python PyPI officially supports the **64-bit**
version of Python 3.7 and above on the following OS: version of Python 3.7 to 3.10 on the following OS:
- x86_64 Linux - x86_64 Linux
- x86_64 macOS 10.15+ - x86_64 macOS 10.15+
- amd64 Windows - amd64 Windows
If the OS is currently supported and you still see this error, please make sure If the OS is currently supported and you still see this error, please make sure
that both the Python and pip binary are for Python 3.7 and above. Otherwise, that both the Python and pip binary are for Python 3.7 to 3.10. Otherwise,
please consider building the MediaPipe Python package locally by following the please consider building the MediaPipe Python package locally by following the
instructions [here](python.md#building-mediapipe-python-package). instructions [here](python.md#building-mediapipe-python-package).

View File

@ -200,13 +200,9 @@ magnitude of `z` uses roughly the same scale as `x`.
#### multi_hand_world_landmarks #### multi_hand_world_landmarks
Collection of detected/tracked hands, where each hand is represented as a list Collection of detected/tracked hands, where each hand is represented as a list
of 21 hand landmarks in world coordinates. Each landmark consists of the of 21 hand landmarks in world coordinates. Each landmark is composed of `x`, `y`
following: and `z`: real-world 3D coordinates in meters with the origin at the hand's
approximate geometric center.
* `x`, `y` and `z`: Real-world 3D coordinates in meters with the origin at the
hand's approximate geometric center.
* `visibility`: Identical to that defined in the corresponding
[multi_hand_landmarks](#multi_hand_landmarks).
#### multi_handedness #### multi_handedness

View File

@ -1242,7 +1242,6 @@ cc_test(
"//mediapipe/framework:calculator_profile_cc_proto", "//mediapipe/framework:calculator_profile_cc_proto",
"//mediapipe/framework:test_calculators", "//mediapipe/framework:test_calculators",
"//mediapipe/framework/deps:clock", "//mediapipe/framework/deps:clock",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",

View File

@ -23,7 +23,6 @@
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_profile.pb.h" #include "mediapipe/framework/calculator_profile.pb.h"
#include "mediapipe/framework/deps/clock.h" #include "mediapipe/framework/deps/clock.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/integral_types.h"

View File

@ -45,6 +45,9 @@ namespace mediapipe {
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
#if defined(MEDIAPIPE_IOS)
#endif // defined(MEDIAPIPE_IOS)
namespace { namespace {
constexpr char kImageFrameTag[] = "IMAGE"; constexpr char kImageFrameTag[] = "IMAGE";

View File

@ -39,8 +39,6 @@
namespace mediapipe { namespace mediapipe {
using ::tflite::Interpreter;
void DoSmokeTest(const std::string& graph_proto) { void DoSmokeTest(const std::string& graph_proto) {
const int width = 8; const int width = 8;
const int height = 8; const int height = 8;

View File

@ -26,7 +26,6 @@
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
namespace mediapipe { namespace mediapipe {
namespace { namespace {
constexpr char kTransposeOptionsString[] = constexpr char kTransposeOptionsString[] =

View File

@ -176,7 +176,6 @@ cc_test(
":filter_detections_calculator", ":filter_detections_calculator",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
@ -215,7 +214,6 @@ cc_test(
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework:timestamp", "//mediapipe/framework:timestamp",
"//mediapipe/framework/deps:clock", "//mediapipe/framework/deps:clock",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/stream_handler:immediate_input_stream_handler", "//mediapipe/framework/stream_handler:immediate_input_stream_handler",
@ -488,7 +486,6 @@ cc_test(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:location_data_cc_proto", "//mediapipe/framework/formats:location_data_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
@ -772,7 +769,6 @@ cc_test(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:location_data_cc_proto", "//mediapipe/framework/formats:location_data_cc_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
@ -869,7 +865,6 @@ cc_test(
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
@ -1067,7 +1062,6 @@ cc_test(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
@ -1217,7 +1211,6 @@ cc_test(
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework:collection_item_id", "//mediapipe/framework:collection_item_id",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:location_data_cc_proto", "//mediapipe/framework/formats:location_data_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
@ -1285,6 +1278,7 @@ cc_library(
"//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_options_cc_proto",
"//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
@ -1344,7 +1338,6 @@ cc_test(
":detection_classifications_merger_calculator", ":detection_classifications_merger_calculator",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",

View File

@ -16,7 +16,6 @@
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h" #include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"

View File

@ -14,9 +14,9 @@
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"

View File

@ -19,7 +19,6 @@
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h" #include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"

View File

@ -17,7 +17,6 @@
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h" #include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
@ -34,8 +33,6 @@ constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kRenderDataTag[] = "RENDER_DATA"; constexpr char kRenderDataTag[] = "RENDER_DATA";
constexpr char kDetectionListTag[] = "DETECTION_LIST"; constexpr char kDetectionListTag[] = "DETECTION_LIST";
using ::testing::DoubleNear;
// Error tolerance for pixels, distances, etc. // Error tolerance for pixels, distances, etc.
static constexpr double kErrorTolerance = 1e-5; static constexpr double kErrorTolerance = 1e-5;

View File

@ -16,7 +16,6 @@
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"

View File

@ -33,9 +33,17 @@ class InverseMatrixCalculatorImpl : public NodeImpl<InverseMatrixCalculator> {
kInputMatrix(cc).Get().data()); kInputMatrix(cc).Get().data());
Eigen::Matrix<float, 4, 4, Eigen::RowMajor> inverse_matrix; Eigen::Matrix<float, 4, 4, Eigen::RowMajor> inverse_matrix;
bool inverse_check; bool inverse_check = false;
matrix.computeInverseWithCheck(inverse_matrix, inverse_check); // The matrix is invertible if the absolute value of its determinant is
RET_CHECK(inverse_check) << "Inverse matrix cannot be calculated."; // greater than this threshold. Quite small threshold is selected to enable
// inverting valid matrices containing relatively small values resulting in
// a small determinant.
constexpr double kAbsDeterminantThreshold =
Eigen::NumTraits<double>::epsilon();
matrix.computeInverseWithCheck(inverse_matrix, inverse_check,
kAbsDeterminantThreshold);
RET_CHECK(inverse_check)
<< "Inverse matrix cannot be calculated for: " << matrix;
std::array<float, 16> output; std::array<float, 16> output;
Eigen::Map<Eigen::Matrix<float, 4, 4, Eigen::RowMajor>>( Eigen::Map<Eigen::Matrix<float, 4, 4, Eigen::RowMajor>>(

View File

@ -42,7 +42,11 @@ void RunTest(const std::array<float, 16>& matrix,
const auto& inverse_matrix = output_packets[0].Get<std::array<float, 16>>(); const auto& inverse_matrix = output_packets[0].Get<std::array<float, 16>>();
EXPECT_THAT(inverse_matrix, testing::Eq(expected_inverse_matrix)); EXPECT_THAT(
inverse_matrix,
testing::Pointwise(testing::FloatEq(),
absl::MakeSpan(expected_inverse_matrix.data(),
expected_inverse_matrix.size())));
// Fully close graph at end, otherwise calculator+tensors are destroyed // Fully close graph at end, otherwise calculator+tensors are destroyed
// after calling WaitUntilDone(). // after calling WaitUntilDone().
@ -122,5 +126,25 @@ TEST(InverseMatrixCalculatorTest, Rotation90) {
RunTest(matrix, expected_inverse_matrix); RunTest(matrix, expected_inverse_matrix);
} }
TEST(InverseMatrixCalculatorTest, CheckPrecision) {
// clang-format off
std::array<float, 16> matrix = {
0.00001f, 0.0f, 0.0f, 0.0f,
0.0f, 0.00001f, 0.0f, 0.0f,
0.0f, 0.0f, 1.0f, 0.0f,
0.0f, 0.0f, 0.0f, 1.0f,
};
std::array<float, 16> expected_inverse_matrix = {
100000.0f, 0.0f, 0.0f, 0.0f,
0.0f, 100000.0f, 0.0f, 0.0f,
0.0f, 0.0f, 1.0f, 0.0f,
0.0f, 0.0f, 0.0f, 1.0f,
};
// clang-format on
RunTest(matrix, expected_inverse_matrix);
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -6,9 +6,9 @@
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"

View File

@ -16,7 +16,6 @@
#include "mediapipe/calculators/util/latency.pb.h" #include "mediapipe/calculators/util/latency.pb.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/clock.h" #include "mediapipe/framework/deps/clock.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"

View File

@ -14,6 +14,7 @@
#include <memory> #include <memory>
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_options.pb.h" #include "mediapipe/framework/calculator_options.pb.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
@ -23,23 +24,23 @@
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/vector.h" #include "mediapipe/framework/port/vector.h"
#if !MEDIAPIPE_DISABLE_GPU
#include "mediapipe/gpu/gl_calculator_helper.h"
#endif // !MEDIAPIPE_DISABLE_GPU
namespace mediapipe { namespace mediapipe {
namespace api2 {
namespace { #if MEDIAPIPE_DISABLE_GPU
constexpr char kImageFrameTag[] = "IMAGE_CPU"; // Just a placeholder to not have to depend on mediapipe::GpuBuffer.
constexpr char kGpuBufferTag[] = "IMAGE_GPU"; class Nothing {};
constexpr char kImageTag[] = "IMAGE"; using GpuBuffer = Nothing;
} // namespace #else
using GpuBuffer = mediapipe::GpuBuffer;
#endif // MEDIAPIPE_DISABLE_GPU
// A calculator for converting from legacy MediaPipe datatypes into a // A calculator for converting from legacy MediaPipe datatypes into a
// unified image container. // unified image container.
// //
// Inputs: // Inputs:
// One of the following two tags: // One of the following two tags:
// IMAGE: An Image, ImageFrame, or GpuBuffer containing input image.
// IMAGE_CPU: An ImageFrame containing input image. // IMAGE_CPU: An ImageFrame containing input image.
// IMAGE_GPU: A GpuBuffer containing input image. // IMAGE_GPU: A GpuBuffer containing input image.
// //
@ -49,107 +50,44 @@ constexpr char kImageTag[] = "IMAGE";
// Note: // Note:
// No CPU/GPU conversion is done. // No CPU/GPU conversion is done.
// //
class ToImageCalculator : public CalculatorBase { class ToImageCalculator : public Node {
public: public:
ToImageCalculator() = default; ToImageCalculator() = default;
~ToImageCalculator() override = default; ~ToImageCalculator() override = default;
static absl::Status GetContract(CalculatorContract* cc); static constexpr Input<
OneOf<mediapipe::Image, mediapipe::ImageFrame, GpuBuffer>>::Optional kIn{
"IMAGE"};
static constexpr Input<mediapipe::ImageFrame>::Optional kInCpu{"IMAGE_CPU"};
static constexpr Input<GpuBuffer>::Optional kInGpu{"IMAGE_GPU"};
static constexpr Output<mediapipe::Image> kOut{"IMAGE"};
MEDIAPIPE_NODE_CONTRACT(kIn, kInCpu, kInGpu, kOut);
static absl::Status UpdateContract(CalculatorContract* cc);
// From Calculator. // From Calculator.
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override; absl::Status Process(CalculatorContext* cc) override;
absl::Status Close(CalculatorContext* cc) override; absl::Status Close(CalculatorContext* cc) override;
private: private:
absl::Status RenderGpu(CalculatorContext* cc); absl::StatusOr<Packet<Image>> GetInputImage(CalculatorContext* cc);
absl::Status RenderCpu(CalculatorContext* cc);
bool gpu_input_ = false;
bool gpu_initialized_ = false;
#if !MEDIAPIPE_DISABLE_GPU
mediapipe::GlCalculatorHelper gpu_helper_;
#endif // !MEDIAPIPE_DISABLE_GPU
}; };
REGISTER_CALCULATOR(ToImageCalculator); MEDIAPIPE_REGISTER_NODE(ToImageCalculator);
absl::Status ToImageCalculator::GetContract(CalculatorContract* cc) { absl::Status ToImageCalculator::UpdateContract(CalculatorContract* cc) {
cc->Outputs().Tag(kImageTag).Set<mediapipe::Image>(); int num_inputs = static_cast<int>(kIn(cc).IsConnected()) +
static_cast<int>(kInCpu(cc).IsConnected()) +
bool gpu_input = false; static_cast<int>(kInGpu(cc).IsConnected());
if (num_inputs != 1) {
if (cc->Inputs().HasTag(kImageFrameTag) &&
cc->Inputs().HasTag(kGpuBufferTag)) {
return absl::InternalError("Cannot have multiple inputs."); return absl::InternalError("Cannot have multiple inputs.");
} }
if (cc->Inputs().HasTag(kGpuBufferTag)) {
#if !MEDIAPIPE_DISABLE_GPU
cc->Inputs().Tag(kGpuBufferTag).Set<mediapipe::GpuBuffer>();
gpu_input = true;
#else
RET_CHECK_FAIL() << "GPU is disabled. Cannot use IMAGE_GPU stream.";
#endif // !MEDIAPIPE_DISABLE_GPU
}
if (cc->Inputs().HasTag(kImageFrameTag)) {
cc->Inputs().Tag(kImageFrameTag).Set<mediapipe::ImageFrame>();
}
if (gpu_input) {
#if !MEDIAPIPE_DISABLE_GPU
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#endif // !MEDIAPIPE_DISABLE_GPU
}
return absl::OkStatus();
}
absl::Status ToImageCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
if (cc->Inputs().HasTag(kGpuBufferTag)) {
gpu_input_ = true;
}
if (gpu_input_) {
#if !MEDIAPIPE_DISABLE_GPU
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#endif
} // !MEDIAPIPE_DISABLE_GPU
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status ToImageCalculator::Process(CalculatorContext* cc) { absl::Status ToImageCalculator::Process(CalculatorContext* cc) {
if (gpu_input_) { ASSIGN_OR_RETURN(auto output, GetInputImage(cc));
#if !MEDIAPIPE_DISABLE_GPU kOut(cc).Send(output.At(cc->InputTimestamp()));
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([&cc]() -> absl::Status {
auto& input = cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>();
// Wrap texture pointer; shallow copy.
auto output = std::make_unique<mediapipe::Image>(input);
cc->Outputs().Tag(kImageTag).Add(output.release(), cc->InputTimestamp());
return absl::OkStatus();
}));
#endif // !MEDIAPIPE_DISABLE_GPU
} else {
// The input ImageFrame.
auto& input = cc->Inputs().Tag(kImageFrameTag).Get<mediapipe::ImageFrame>();
// Make a copy of the input packet to co-own the input ImageFrame.
Packet* packet_copy_ptr =
new Packet(cc->Inputs().Tag(kImageFrameTag).Value());
// Create an output Image that (co-)owns a new ImageFrame that points to
// the same pixel data as the input ImageFrame and also owns the packet
// copy. As a result, the output Image indirectly co-owns the input
// ImageFrame. This ensures a correct life span of the shared pixel data.
std::unique_ptr<mediapipe::Image> output =
std::make_unique<mediapipe::Image>(
std::make_shared<mediapipe::ImageFrame>(
input.Format(), input.Width(), input.Height(),
input.WidthStep(), const_cast<uint8*>(input.PixelData()),
[packet_copy_ptr](uint8*) { delete packet_copy_ptr; }));
cc->Outputs().Tag(kImageTag).Add(output.release(), cc->InputTimestamp());
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -157,4 +95,43 @@ absl::Status ToImageCalculator::Close(CalculatorContext* cc) {
return absl::OkStatus(); return absl::OkStatus();
} }
// Wrap ImageFrameSharedPtr; shallow copy.
absl::StatusOr<Packet<Image>> FromImageFrame(Packet<ImageFrame> packet) {
return MakePacket<Image, std::shared_ptr<mediapipe::ImageFrame>>(
std::const_pointer_cast<mediapipe::ImageFrame>(
SharedPtrWithPacket<mediapipe::ImageFrame>(packet)));
}
// Wrap texture pointer; shallow copy.
absl::StatusOr<Packet<Image>> FromGpuBuffer(Packet<GpuBuffer> packet) {
#if !MEDIAPIPE_DISABLE_GPU
const GpuBuffer& buffer = *packet;
return MakePacket<Image, const GpuBuffer&>(buffer);
#else
return absl::UnimplementedError("GPU processing is disabled in build flags");
#endif // !MEDIAPIPE_DISABLE_GPU
}
absl::StatusOr<Packet<Image>> ToImageCalculator::GetInputImage(
CalculatorContext* cc) {
if (kIn(cc).IsConnected()) {
return kIn(cc).Visit(
[&](const mediapipe::Image&) {
return absl::StatusOr<Packet<Image>>(kIn(cc).As<Image>());
},
[&](const mediapipe::ImageFrame&) {
return FromImageFrame(kIn(cc).As<ImageFrame>());
},
[&](const GpuBuffer&) {
return FromGpuBuffer(kIn(cc).As<GpuBuffer>());
});
} else if (kInCpu(cc).IsConnected()) {
return FromImageFrame(kInCpu(cc).As<ImageFrame>());
} else if (kInGpu(cc).IsConnected()) {
return FromGpuBuffer(kInGpu(cc).As<GpuBuffer>());
}
return absl::InvalidArgumentError("No input found.");
}
} // namespace api2
} // namespace mediapipe } // namespace mediapipe

View File

@ -291,6 +291,7 @@ cc_library(
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/container:btree",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -390,6 +391,7 @@ cc_test(
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],

View File

@ -657,7 +657,10 @@ absl::Status ContentZoomingCalculator::Process(
} }
const bool camera_active = const bool camera_active =
is_animating || pan_state || tilt_state || zoom_state; is_animating || pan_state || tilt_state || zoom_state;
if (cc->Outputs().HasTag(kCameraActive)) { // Waiting for first rect before setting any value of the camera active flag
// so we avoid setting it to false during initialization.
if (cc->Outputs().HasTag(kCameraActive) &&
first_rect_timestamp_ != Timestamp::Unset()) {
cc->Outputs() cc->Outputs()
.Tag(kCameraActive) .Tag(kCameraActive)
.AddPacket(MakePacket<bool>(camera_active).At(cc->InputTimestamp())); .AddPacket(MakePacket<bool>(camera_active).At(cc->InputTimestamp()));

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/container/btree_set.h"
#include "absl/flags/flag.h" #include "absl/flags/flag.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator.pb.h" #include "mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator.pb.h"
@ -54,7 +55,8 @@ const char kConfig[] = R"(
const int kTestFrameWidth = 640; const int kTestFrameWidth = 640;
const int kTestFrameHeight = 480; const int kTestFrameHeight = 480;
void AddFrames(const int number_of_frames, const std::set<int>& skip_frames, void AddFrames(const int number_of_frames,
const absl::btree_set<int>& skip_frames,
CalculatorRunner* runner) { CalculatorRunner* runner) {
cv::Mat image = cv::Mat image =
cv::imread(file::JoinPath("./", cv::imread(file::JoinPath("./",
@ -78,7 +80,8 @@ void AddFrames(const int number_of_frames, const std::set<int>& skip_frames,
} }
} }
void CheckOutput(const int number_of_frames, const std::set<int>& shot_frames, void CheckOutput(const int number_of_frames,
const absl::btree_set<int>& shot_frames,
const std::vector<Packet>& output_packets) { const std::vector<Packet>& output_packets) {
ASSERT_EQ(number_of_frames, output_packets.size()); ASSERT_EQ(number_of_frames, output_packets.size());
for (int i = 0; i < number_of_frames; i++) { for (int i = 0; i < number_of_frames; i++) {

View File

@ -18,6 +18,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "absl/container/btree_map.h"
#include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h" #include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h"
#include "mediapipe/examples/desktop/autoflip/calculators/signal_fusing_calculator.pb.h" #include "mediapipe/examples/desktop/autoflip/calculators/signal_fusing_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
@ -178,8 +179,8 @@ absl::Status SignalFusingCalculator::Close(mediapipe::CalculatorContext* cc) {
absl::Status SignalFusingCalculator::ProcessScene( absl::Status SignalFusingCalculator::ProcessScene(
mediapipe::CalculatorContext* cc) { mediapipe::CalculatorContext* cc) {
std::map<std::string, int> detection_count; absl::btree_map<std::string, int> detection_count;
std::map<std::string, float> multiframe_score; absl::btree_map<std::string, float> multiframe_score;
// Create a unified score for all items with temporal ids. // Create a unified score for all items with temporal ids.
for (const Frame& frame : scene_frames_) { for (const Frame& frame : scene_frames_) {
for (const auto& detection : frame.input_detections) { for (const auto& detection : frame.input_detections) {

View File

@ -1187,7 +1187,6 @@ cc_test(
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/api2:node", "//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
@ -1580,7 +1579,6 @@ cc_test(
":packet", ":packet",
":packet_test_cc_proto", ":packet_test_cc_proto",
":type_map", ":type_map",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
@ -1640,7 +1638,6 @@ cc_test(
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework:status_handler_cc_proto", "//mediapipe/framework:status_handler_cc_proto",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:template_parser", "//mediapipe/framework/tool:template_parser",

View File

@ -36,7 +36,6 @@ cc_test(
":tag", ":tag",
":test_contracts", ":test_contracts",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
@ -175,7 +174,6 @@ cc_test(
":port", ":port",
":test_contracts", ":test_contracts",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:subgraph_expansion", "//mediapipe/framework/tool:subgraph_expansion",

View File

@ -9,7 +9,6 @@
#include "mediapipe/framework/api2/tag.h" #include "mediapipe/framework/api2/tag.h"
#include "mediapipe/framework/api2/test_contracts.h" #include "mediapipe/framework/api2/test_contracts.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"

View File

@ -565,6 +565,19 @@ TEST(NodeTest, ConsumeInputs) {
MP_EXPECT_OK(graph.WaitUntilDone()); MP_EXPECT_OK(graph.WaitUntilDone());
} }
// Just to test that single-port contracts work.
struct LogSinkNode : public Node {
static constexpr Input<int> kIn{"IN"};
MEDIAPIPE_NODE_CONTRACT(kIn);
absl::Status Process(CalculatorContext* cc) override {
LOG(INFO) << "LogSinkNode received: " << kIn(cc).Get();
return {};
}
};
MEDIAPIPE_REGISTER_NODE(LogSinkNode);
} // namespace test } // namespace test
} // namespace api2 } // namespace api2
} // namespace mediapipe } // namespace mediapipe

View File

@ -4,7 +4,6 @@
#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/api2/test_contracts.h" #include "mediapipe/framework/api2/test_contracts.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"

View File

@ -18,11 +18,10 @@ template <class Tuple>
using tuple_index_sequence = using tuple_index_sequence =
std::make_index_sequence<std::tuple_size_v<std::decay_t<Tuple>>>; std::make_index_sequence<std::tuple_size_v<std::decay_t<Tuple>>>;
// Concatenates two std::index_sequences. // Concatenates multiple std::index_sequences.
template <std::size_t... I, std::size_t... J> template <std::size_t... I>
constexpr auto index_sequence_cat(std::index_sequence<I...>, constexpr auto index_sequence_cat(std::index_sequence<I...>)
std::index_sequence<J...>) -> std::index_sequence<I...> {
-> std::index_sequence<I..., J...> {
return {}; return {};
} }

View File

@ -174,7 +174,7 @@ TEST(CalculatorTest, CreateByName) {
TEST(CalculatorTest, CreateByNameWhitelisted) { TEST(CalculatorTest, CreateByNameWhitelisted) {
// Reset the registration namespace whitelist. // Reset the registration namespace whitelist.
*const_cast<absl::flat_hash_set<std::string>*>( *const_cast<absl::flat_hash_set<std::string>*>(
&NamespaceWhitelist::TopNamespaces()) = absl::flat_hash_set<std::string>{ &NamespaceAllowlist::TopNamespaces()) = absl::flat_hash_set<std::string>{
"mediapipe::test_ns::whitelisted_ns", "mediapipe::test_ns::whitelisted_ns",
"mediapipe", "mediapipe",
}; };

View File

@ -236,21 +236,21 @@ class CalculatorNode {
} }
private: private:
// Sets up the output side packets from the master flat array. // Sets up the output side packets from the main flat array.
absl::Status InitializeOutputSidePackets( absl::Status InitializeOutputSidePackets(
const PacketTypeSet& output_side_packet_types, const PacketTypeSet& output_side_packet_types,
OutputSidePacketImpl* output_side_packets); OutputSidePacketImpl* output_side_packets);
// Connects the input side packets as mirrors on the output side packets. // Connects the input side packets as mirrors on the output side packets.
// Output side packets are looked up in the master flat array which is // Output side packets are looked up in the main flat array which is
// provided. // provided.
absl::Status InitializeInputSidePackets( absl::Status InitializeInputSidePackets(
OutputSidePacketImpl* output_side_packets); OutputSidePacketImpl* output_side_packets);
// Sets up the output streams from the master flat array. // Sets up the output streams from the main flat array.
absl::Status InitializeOutputStreams( absl::Status InitializeOutputStreams(
OutputStreamManager* output_stream_managers); OutputStreamManager* output_stream_managers);
// Sets up the input streams and connects them as mirrors on the // Sets up the input streams and connects them as mirrors on the
// output streams. Both input streams and output streams are looked // output streams. Both input streams and output streams are looked
// up in the master flat arrays which are provided. // up in the main flat arrays which are provided.
absl::Status InitializeInputStreams( absl::Status InitializeInputStreams(
InputStreamManager* input_stream_managers, InputStreamManager* input_stream_managers,
OutputStreamManager* output_stream_managers); OutputStreamManager* output_stream_managers);

View File

@ -26,6 +26,7 @@
#include "absl/base/macros.h" #include "absl/base/macros.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/tool/tag_map.h" #include "mediapipe/framework/tool/tag_map.h"
@ -50,7 +51,7 @@ struct CollectionErrorHandlerFatal {
// Since there isn't any state and we're not returning anything, we // Since there isn't any state and we're not returning anything, we
// get away with only one version of this function (which is const // get away with only one version of this function (which is const
// but returns a non-const reference). // but returns a non-const reference).
T& GetFallback(const std::string& tag, int index) const { T& GetFallback(const absl::string_view tag, int index) const {
LOG(FATAL) << "Failed to get tag \"" << tag << "\" index " << index; LOG(FATAL) << "Failed to get tag \"" << tag << "\" index " << index;
std::abort(); std::abort();
} }
@ -131,16 +132,16 @@ class Collection {
const value_type& Get(CollectionItemId id) const; const value_type& Get(CollectionItemId id) const;
// Convenience functions. // Convenience functions.
value_type& Get(const std::string& tag, int index); value_type& Get(absl::string_view tag, int index);
const value_type& Get(const std::string& tag, int index) const; const value_type& Get(absl::string_view tag, int index) const;
// Equivalent to Get("", index); // Equivalent to Get("", index);
value_type& Index(int index); value_type& Index(int index);
const value_type& Index(int index) const; const value_type& Index(int index) const;
// Equivalent to Get(tag, 0); // Equivalent to Get(tag, 0);
value_type& Tag(const std::string& tag); value_type& Tag(absl::string_view tag);
const value_type& Tag(const std::string& tag) const; const value_type& Tag(absl::string_view tag) const;
// These functions only exist for collections with storage == // These functions only exist for collections with storage ==
// kStorePointer. GetPtr returns the stored ptr value rather than // kStorePointer. GetPtr returns the stored ptr value rather than
@ -179,13 +180,15 @@ class Collection {
//////////////////////////////////////// ////////////////////////////////////////
// Returns true if the provided tag is available (not necessarily set yet). // Returns true if the provided tag is available (not necessarily set yet).
bool HasTag(const std::string& tag) const { return tag_map_->HasTag(tag); } bool HasTag(const absl::string_view tag) const {
return tag_map_->HasTag(tag);
}
// Returns the number of entries in this collection. // Returns the number of entries in this collection.
int NumEntries() const { return tag_map_->NumEntries(); } int NumEntries() const { return tag_map_->NumEntries(); }
// Returns the number of entries with the provided tag. // Returns the number of entries with the provided tag.
int NumEntries(const std::string& tag) const { int NumEntries(const absl::string_view tag) const {
return tag_map_->NumEntries(tag); return tag_map_->NumEntries(tag);
} }
@ -200,7 +203,7 @@ class Collection {
// However, be careful in using this fact, as it circumvents the // However, be careful in using this fact, as it circumvents the
// validity checks in GetId() (i.e. ++GetId("BLAH", 2) looks like it // validity checks in GetId() (i.e. ++GetId("BLAH", 2) looks like it
// is valid, while GetId("BLAH", 3) is not valid). // is valid, while GetId("BLAH", 3) is not valid).
CollectionItemId GetId(const std::string& tag, int index) const { CollectionItemId GetId(const absl::string_view tag, int index) const {
return tag_map_->GetId(tag, index); return tag_map_->GetId(tag, index);
} }
@ -234,10 +237,10 @@ class Collection {
// for (CollectionItemId id = collection.BeginId(tag); // for (CollectionItemId id = collection.BeginId(tag);
// id < collection.EndId(tag); ++id) { // id < collection.EndId(tag); ++id) {
// } // }
CollectionItemId BeginId(const std::string& tag) const { CollectionItemId BeginId(const absl::string_view tag) const {
return tag_map_->BeginId(tag); return tag_map_->BeginId(tag);
} }
CollectionItemId EndId(const std::string& tag) const { CollectionItemId EndId(const absl::string_view tag) const {
return tag_map_->EndId(tag); return tag_map_->EndId(tag);
} }
@ -404,7 +407,7 @@ bool Collection<T, storage, ErrorHandler>::UsesTags() const {
return false; return false;
} }
// If the one tag present is non-empty then we are using tags. // If the one tag present is non-empty then we are using tags.
return mapping.begin()->first != ""; return !mapping.begin()->first.empty();
} }
template <typename T, CollectionStorage storage, typename ErrorHandler> template <typename T, CollectionStorage storage, typename ErrorHandler>
@ -449,7 +452,8 @@ Collection<T, storage, ErrorHandler>::GetPtr(CollectionItemId id) const {
template <typename T, CollectionStorage storage, typename ErrorHandler> template <typename T, CollectionStorage storage, typename ErrorHandler>
typename Collection<T, storage, ErrorHandler>::value_type& typename Collection<T, storage, ErrorHandler>::value_type&
Collection<T, storage, ErrorHandler>::Get(const std::string& tag, int index) { Collection<T, storage, ErrorHandler>::Get(const absl::string_view tag,
int index) {
CollectionItemId id = GetId(tag, index); CollectionItemId id = GetId(tag, index);
if (!id.IsValid()) { if (!id.IsValid()) {
return error_handler_.GetFallback(tag, index); return error_handler_.GetFallback(tag, index);
@ -459,7 +463,7 @@ Collection<T, storage, ErrorHandler>::Get(const std::string& tag, int index) {
template <typename T, CollectionStorage storage, typename ErrorHandler> template <typename T, CollectionStorage storage, typename ErrorHandler>
const typename Collection<T, storage, ErrorHandler>::value_type& const typename Collection<T, storage, ErrorHandler>::value_type&
Collection<T, storage, ErrorHandler>::Get(const std::string& tag, Collection<T, storage, ErrorHandler>::Get(const absl::string_view tag,
int index) const { int index) const {
CollectionItemId id = GetId(tag, index); CollectionItemId id = GetId(tag, index);
if (!id.IsValid()) { if (!id.IsValid()) {
@ -482,13 +486,13 @@ Collection<T, storage, ErrorHandler>::Index(int index) const {
template <typename T, CollectionStorage storage, typename ErrorHandler> template <typename T, CollectionStorage storage, typename ErrorHandler>
typename Collection<T, storage, ErrorHandler>::value_type& typename Collection<T, storage, ErrorHandler>::value_type&
Collection<T, storage, ErrorHandler>::Tag(const std::string& tag) { Collection<T, storage, ErrorHandler>::Tag(const absl::string_view tag) {
return Get(tag, 0); return Get(tag, 0);
} }
template <typename T, CollectionStorage storage, typename ErrorHandler> template <typename T, CollectionStorage storage, typename ErrorHandler>
const typename Collection<T, storage, ErrorHandler>::value_type& const typename Collection<T, storage, ErrorHandler>::value_type&
Collection<T, storage, ErrorHandler>::Tag(const std::string& tag) const { Collection<T, storage, ErrorHandler>::Tag(const absl::string_view tag) const {
return Get(tag, 0); return Get(tag, 0);
} }
@ -535,21 +539,23 @@ Collection<T, storage, ErrorHandler>::end() const {
// Returns c.HasTag(tag) && !Tag(tag)->IsEmpty() (just for convenience). // Returns c.HasTag(tag) && !Tag(tag)->IsEmpty() (just for convenience).
// This version is used with Calculator. // This version is used with Calculator.
template <class S> template <class S>
bool HasTagValue(const internal::Collection<S*>& c, const std::string& tag) { bool HasTagValue(const internal::Collection<S*>& c,
const absl::string_view tag) {
return c.HasTag(tag) && !c.Tag(tag)->IsEmpty(); return c.HasTag(tag) && !c.Tag(tag)->IsEmpty();
} }
// Returns c.HasTag(tag) && !Tag(tag).IsEmpty() (just for convenience). // Returns c.HasTag(tag) && !Tag(tag).IsEmpty() (just for convenience).
// This version is used with CalculatorBase. // This version is used with CalculatorBase.
template <class S> template <class S>
bool HasTagValue(const internal::Collection<S>& c, const std::string& tag) { bool HasTagValue(const internal::Collection<S>& c,
const absl::string_view tag) {
return c.HasTag(tag) && !c.Tag(tag).IsEmpty(); return c.HasTag(tag) && !c.Tag(tag).IsEmpty();
} }
// Returns c.HasTag(tag) && !Tag(tag).IsEmpty() (just for convenience). // Returns c.HasTag(tag) && !Tag(tag).IsEmpty() (just for convenience).
// This version is used with Calculator or CalculatorBase. // This version is used with Calculator or CalculatorBase.
template <class C> template <class C>
bool HasTagValue(const C& c, const std::string& tag) { bool HasTagValue(const C& c, const absl::string_view tag) {
return HasTagValue(c->Inputs(), tag); return HasTagValue(c->Inputs(), tag);
} }

View File

@ -87,10 +87,11 @@ cc_library(
name = "message_matchers", name = "message_matchers",
testonly = True, testonly = True,
hdrs = ["message_matchers.h"], hdrs = ["message_matchers.h"],
visibility = ["//visibility:public"], # Use this library through "mediapipe/framework/port:gtest_main".
visibility = ["//mediapipe/framework/port:__pkg__"],
deps = [ deps = [
"//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:gtest_main", "@com_google_googletest//:gtest",
], ],
) )

View File

@ -17,8 +17,8 @@
#include <memory> #include <memory>
#include "gmock/gmock.h"
#include "mediapipe/framework/port/core_proto_inc.h" #include "mediapipe/framework/port/core_proto_inc.h"
#include "mediapipe/framework/port/gmock.h"
namespace mediapipe { namespace mediapipe {

View File

@ -21,7 +21,7 @@ namespace mediapipe {
namespace { namespace {
// List of namespaces that can register calculators inside the namespace // List of namespaces that can register calculators inside the namespace
// and still refer to them using an unqualified name. This whitelist // and still refer to them using an unqualified name. This allowlist
// is meant to facilitate migration from unqualified to fully qualified // is meant to facilitate migration from unqualified to fully qualified
// calculator names. // calculator names.
constexpr char const* kTopNamespaces[] = { constexpr char const* kTopNamespaces[] = {
@ -36,7 +36,7 @@ inline size_t array_size(T (&arr)[SIZE]) {
} // namespace } // namespace
/*static*/ /*static*/
const absl::flat_hash_set<std::string>& NamespaceWhitelist::TopNamespaces() { const absl::flat_hash_set<std::string>& NamespaceAllowlist::TopNamespaces() {
static absl::flat_hash_set<std::string>* result = static absl::flat_hash_set<std::string>* result =
new absl::flat_hash_set<std::string>( new absl::flat_hash_set<std::string>(
kTopNamespaces, kTopNamespaces + array_size(kTopNamespaces)); kTopNamespaces, kTopNamespaces + array_size(kTopNamespaces));

View File

@ -144,7 +144,7 @@ struct WrapStatusOr<absl::StatusOr<T>> {
}; };
} // namespace registration_internal } // namespace registration_internal
class NamespaceWhitelist { class NamespaceAllowlist {
public: public:
static const absl::flat_hash_set<std::string>& TopNamespaces(); static const absl::flat_hash_set<std::string>& TopNamespaces();
}; };
@ -289,14 +289,14 @@ class FunctionRegistry {
mutable absl::Mutex lock_; mutable absl::Mutex lock_;
std::unordered_map<std::string, Function> functions_ ABSL_GUARDED_BY(lock_); std::unordered_map<std::string, Function> functions_ ABSL_GUARDED_BY(lock_);
// For names included in NamespaceWhitelist, strips the namespace. // For names included in NamespaceAllowlist, strips the namespace.
std::string GetAdjustedName(const std::string& name) { std::string GetAdjustedName(const std::string& name) {
constexpr auto kCxxSep = registration_internal::kCxxSep; constexpr auto kCxxSep = registration_internal::kCxxSep;
std::vector<std::string> names = absl::StrSplit(name, kCxxSep); std::vector<std::string> names = absl::StrSplit(name, kCxxSep);
std::string base_name = names.back(); std::string base_name = names.back();
names.pop_back(); names.pop_back();
std::string ns = absl::StrJoin(names, kCxxSep); std::string ns = absl::StrJoin(names, kCxxSep);
if (NamespaceWhitelist::TopNamespaces().count(ns)) { if (NamespaceAllowlist::TopNamespaces().count(ns)) {
return base_name; return base_name;
} }
return name; return name;

View File

@ -242,7 +242,7 @@ bool ImageFrame::IsValidAlignmentNumber(uint32 alignment_boundary) {
// static // static
std::string ImageFrame::InvalidFormatString(ImageFormat::Format format) { std::string ImageFrame::InvalidFormatString(ImageFormat::Format format) {
#ifdef MEDIAPIPE_MOBILE #ifdef MEDIAPIPE_PROTO_LITE
return "Invalid format."; return "Invalid format.";
#else #else
const proto_ns::EnumValueDescriptor* enum_value_descriptor = const proto_ns::EnumValueDescriptor* enum_value_descriptor =

View File

@ -75,7 +75,6 @@ int GetMatType(const mediapipe::ImageFormat::Format format) {
} // namespace } // namespace
namespace mediapipe { namespace mediapipe {
namespace formats { namespace formats {
cv::Mat MatView(const ImageFrame* image) { cv::Mat MatView(const ImageFrame* image) {

View File

@ -75,7 +75,6 @@ int GetMatType(const mediapipe::ImageFormat::Format format) {
} }
} // namespace } // namespace
namespace mediapipe { namespace mediapipe {
namespace formats { namespace formats {
cv::Mat MatView(const mediapipe::Image* image) { cv::Mat MatView(const mediapipe::Image* image) {

View File

@ -1,3 +1,17 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_INTERNAL_H_ #ifndef MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_INTERNAL_H_
#define MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_INTERNAL_H_ #define MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_INTERNAL_H_
@ -6,7 +20,7 @@
namespace mediapipe { namespace mediapipe {
// Generates unique view id at compile-time using FILE and LINE. // Generates unique view id at compile-time using FILE and LINE.
#define TENSOR_UNIQUE_VIEW_ID() \ #define TENSOR_UNIQUE_VIEW_TYPE_ID() \
static constexpr uint64_t kId = tensor_internal::FnvHash64( \ static constexpr uint64_t kId = tensor_internal::FnvHash64( \
__FILE__, tensor_internal::FnvHash64(TENSOR_INT_TO_STRING(__LINE__))) __FILE__, tensor_internal::FnvHash64(TENSOR_INT_TO_STRING(__LINE__)))

View File

@ -20,7 +20,6 @@
#include <functional> #include <functional>
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"

View File

@ -21,7 +21,6 @@
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
namespace mediapipe { namespace mediapipe {
using SyncSet = InputStreamHandler::SyncSet; using SyncSet = InputStreamHandler::SyncSet;
absl::Status InputStreamHandler::InitializeInputStreamManagers( absl::Status InputStreamHandler::InitializeInputStreamManagers(

View File

@ -82,7 +82,7 @@ class InputStreamHandler {
// flat_input_stream_managers is expected to point to a contiguous // flat_input_stream_managers is expected to point to a contiguous
// flat array with InputStreamManagers corresponding to the id's in // flat array with InputStreamManagers corresponding to the id's in
// InputStreamHandler::input_stream_managers_ (meaning it should point // InputStreamHandler::input_stream_managers_ (meaning it should point
// to somewhere in the middle of the master flat array of all input // to somewhere in the middle of the main flat array of all input
// stream managers). // stream managers).
absl::Status InitializeInputStreamManagers( absl::Status InitializeInputStreamManagers(
InputStreamManager* flat_input_stream_managers); InputStreamManager* flat_input_stream_managers);

View File

@ -74,7 +74,7 @@ class OutputStreamHandler {
// flat_output_stream_managers is expected to point to a contiguous // flat_output_stream_managers is expected to point to a contiguous
// flat array with OutputStreamManagers corresponding to the id's in // flat array with OutputStreamManagers corresponding to the id's in
// OutputStreamHandler::output_stream_managers_ (meaning it should // OutputStreamHandler::output_stream_managers_ (meaning it should
// point to somewhere in the middle of the master flat array of all // point to somewhere in the middle of the main flat array of all
// output stream managers). // output stream managers).
absl::Status InitializeOutputStreamManagers( absl::Status InitializeOutputStreamManagers(
OutputStreamManager* flat_output_stream_managers); OutputStreamManager* flat_output_stream_managers);

View File

@ -363,13 +363,8 @@ class HolderBase {
HolderBase& operator=(const HolderBase&) = delete; HolderBase& operator=(const HolderBase&) = delete;
virtual ~HolderBase(); virtual ~HolderBase();
template <typename T> template <typename T>
void SetHolderTypeId() { bool PayloadIsOfType() const {
type_id_ = tool::GetTypeHash<T>(); return GetTypeId() == tool::GetTypeHash<T>();
}
size_t GetHolderTypeId() const { return type_id_; }
template <typename T>
bool HolderIsOfType() const {
return type_id_ == tool::GetTypeHash<T>();
} }
// Returns a printable std::string identifying the type stored in the holder. // Returns a printable std::string identifying the type stored in the holder.
virtual const std::string DebugTypeName() const = 0; virtual const std::string DebugTypeName() const = 0;
@ -397,8 +392,7 @@ class HolderBase {
virtual StatusOr<std::vector<const proto_ns::MessageLite*>> virtual StatusOr<std::vector<const proto_ns::MessageLite*>>
GetVectorOfProtoMessageLite() const = 0; GetVectorOfProtoMessageLite() const = 0;
private: virtual bool HasForeignOwner() const { return false; }
size_t type_id_;
}; };
// Two helper functions to get the proto base pointers. // Two helper functions to get the proto base pointers.
@ -505,7 +499,6 @@ class Holder : public HolderBase {
public: public:
explicit Holder(const T* ptr) : ptr_(ptr) { explicit Holder(const T* ptr) : ptr_(ptr) {
HolderSupport<T>::EnsureStaticInit(); HolderSupport<T>::EnsureStaticInit();
SetHolderTypeId<Holder>();
} }
~Holder() override { delete_helper(); } ~Holder() override { delete_helper(); }
const T& data() const { const T& data() const {
@ -521,9 +514,7 @@ class Holder : public HolderBase {
absl::StatusOr<std::unique_ptr<T>> Release( absl::StatusOr<std::unique_ptr<T>> Release(
typename std::enable_if<!std::is_array<U>::value || typename std::enable_if<!std::is_array<U>::value ||
std::extent<U>::value != 0>::type* = 0) { std::extent<U>::value != 0>::type* = 0) {
// Since C++ doesn't allow virtual, templated functions, check holder if (HasForeignOwner()) {
// type here to make sure it's not upcasted from a ForeignHolder.
if (!HolderIsOfType<Holder<T>>()) {
return InternalError( return InternalError(
"Foreign holder can't release data ptr without ownership."); "Foreign holder can't release data ptr without ownership.");
} }
@ -592,25 +583,19 @@ class Holder : public HolderBase {
template <typename T> template <typename T>
class ForeignHolder : public Holder<T> { class ForeignHolder : public Holder<T> {
public: public:
explicit ForeignHolder(const T* ptr) : Holder<T>(ptr) { using Holder<T>::Holder;
// Distinguishes between Holder and ForeignHolder since Consume() treats
// them differently.
this->template SetHolderTypeId<ForeignHolder>();
}
~ForeignHolder() override { ~ForeignHolder() override {
// Null out ptr_ so it doesn't get deleted by ~Holder. // Null out ptr_ so it doesn't get deleted by ~Holder.
// Note that ~Holder cannot call HasForeignOwner because the subclass's
// destructor runs first.
this->ptr_ = nullptr; this->ptr_ = nullptr;
} }
// Foreign holder can't release data pointer without ownership. bool HasForeignOwner() const final { return true; }
absl::StatusOr<std::unique_ptr<T>> Release() {
return absl::InternalError(
"Foreign holder can't release data ptr without ownership.");
}
}; };
template <typename T> template <typename T>
Holder<T>* HolderBase::As() { Holder<T>* HolderBase::As() {
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) { if (PayloadIsOfType<T>()) {
return static_cast<Holder<T>*>(this); return static_cast<Holder<T>*>(this);
} }
// Does not hold a T. // Does not hold a T.
@ -619,7 +604,7 @@ Holder<T>* HolderBase::As() {
template <typename T> template <typename T>
const Holder<T>* HolderBase::As() const { const Holder<T>* HolderBase::As() const {
if (HolderIsOfType<Holder<T>>() || HolderIsOfType<ForeignHolder<T>>()) { if (PayloadIsOfType<T>()) {
return static_cast<const Holder<T>*>(this); return static_cast<const Holder<T>*>(this);
} }
// Does not hold a T. // Does not hold a T.
@ -648,7 +633,7 @@ inline absl::StatusOr<std::unique_ptr<T>> Packet::Consume() {
MP_RETURN_IF_ERROR(ValidateAsType<T>()); MP_RETURN_IF_ERROR(ValidateAsType<T>());
// Clients who use this function are responsible for ensuring that no // Clients who use this function are responsible for ensuring that no
// other thread is doing anything with this Packet. // other thread is doing anything with this Packet.
if (holder_.unique()) { if (!holder_->HasForeignOwner() && holder_.unique()) {
VLOG(2) << "Consuming the data of " << DebugString(); VLOG(2) << "Consuming the data of " << DebugString();
absl::StatusOr<std::unique_ptr<T>> release_result = absl::StatusOr<std::unique_ptr<T>> release_result =
holder_->As<T>()->Release(); holder_->As<T>()->Release();
@ -670,8 +655,7 @@ inline absl::StatusOr<std::unique_ptr<T>> Packet::ConsumeOrCopy(
typename std::enable_if<!std::is_array<T>::value>::type*) { typename std::enable_if<!std::is_array<T>::value>::type*) {
MP_RETURN_IF_ERROR(ValidateAsType<T>()); MP_RETURN_IF_ERROR(ValidateAsType<T>());
// If holder is the sole owner of the underlying data, consumes this packet. // If holder is the sole owner of the underlying data, consumes this packet.
if (!holder_->HolderIsOfType<packet_internal::ForeignHolder<T>>() && if (!holder_->HasForeignOwner() && holder_.unique()) {
holder_.unique()) {
VLOG(2) << "Consuming the data of " << DebugString(); VLOG(2) << "Consuming the data of " << DebugString();
absl::StatusOr<std::unique_ptr<T>> release_result = absl::StatusOr<std::unique_ptr<T>> release_result =
holder_->As<T>()->Release(); holder_->As<T>()->Release();
@ -701,8 +685,7 @@ inline absl::StatusOr<std::unique_ptr<T>> Packet::ConsumeOrCopy(
std::extent<T>::value != 0>::type*) { std::extent<T>::value != 0>::type*) {
MP_RETURN_IF_ERROR(ValidateAsType<T>()); MP_RETURN_IF_ERROR(ValidateAsType<T>());
// If holder is the sole owner of the underlying data, consumes this packet. // If holder is the sole owner of the underlying data, consumes this packet.
if (!holder_->HolderIsOfType<packet_internal::ForeignHolder<T>>() && if (!holder_->HasForeignOwner() && holder_.unique()) {
holder_.unique()) {
VLOG(2) << "Consuming the data of " << DebugString(); VLOG(2) << "Consuming the data of " << DebugString();
absl::StatusOr<std::unique_ptr<T>> release_result = absl::StatusOr<std::unique_ptr<T>> release_result =
holder_->As<T>()->Release(); holder_->As<T>()->Release();

View File

@ -21,7 +21,6 @@
#include <vector> #include <vector>
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/packet_test.pb.h" #include "mediapipe/framework/packet_test.pb.h"
#include "mediapipe/framework/port/core_proto_inc.h" #include "mediapipe/framework/port/core_proto_inc.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
@ -374,9 +373,9 @@ TEST(PacketTest, TestConsumeForeignHolder) {
Packet packet = PointToForeign(data.get()); Packet packet = PointToForeign(data.get());
absl::StatusOr<std::unique_ptr<int>> result = packet.Consume<int>(); absl::StatusOr<std::unique_ptr<int>> result = packet.Consume<int>();
EXPECT_FALSE(result.ok()); EXPECT_FALSE(result.ok());
EXPECT_EQ(result.status().code(), absl::StatusCode::kInternal); EXPECT_EQ(result.status().code(), absl::StatusCode::kFailedPrecondition);
EXPECT_EQ(result.status().message(), EXPECT_EQ(result.status().message(),
"Foreign holder can't release data ptr without ownership."); "Packet isn't the sole owner of the holder.");
ASSERT_FALSE(packet.IsEmpty()); ASSERT_FALSE(packet.IsEmpty());
EXPECT_EQ(33, packet.Get<int>()); EXPECT_EQ(33, packet.Get<int>());
} }

View File

@ -24,6 +24,7 @@
#include "absl/base/macros.h" #include "absl/base/macros.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/collection.h" #include "mediapipe/framework/collection.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_set.h" #include "mediapipe/framework/packet_set.h"
@ -133,7 +134,7 @@ class PacketTypeSetErrorHandler {
// Returns a usable PacketType. A different PacketType object is // Returns a usable PacketType. A different PacketType object is
// returned for each different invalid location and the same object // returned for each different invalid location and the same object
// is returned for multiple accesses to the same invalid location. // is returned for multiple accesses to the same invalid location.
PacketType& GetFallback(const std::string& tag, int index) { PacketType& GetFallback(const absl::string_view tag, int index) {
if (!missing_) { if (!missing_) {
missing_ = absl::make_unique<Missing>(); missing_ = absl::make_unique<Missing>();
} }
@ -143,7 +144,7 @@ class PacketTypeSetErrorHandler {
} }
// In the const setting produce a FATAL error. // In the const setting produce a FATAL error.
const PacketType& GetFallback(const std::string& tag, int index) const { const PacketType& GetFallback(const absl::string_view tag, int index) const {
LOG(FATAL) << "Failed to get tag \"" << tag << "\" index " << index LOG(FATAL) << "Failed to get tag \"" << tag << "\" index " << index
<< ". Unable to defer error due to const specifier."; << ". Unable to defer error due to const specifier.";
std::abort(); std::abort();

View File

@ -176,6 +176,7 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":status_matchers", ":status_matchers",
"//mediapipe/framework/deps:message_matchers",
"@com_google_googletest//:gtest", "@com_google_googletest//:gtest",
], ],
) )
@ -192,6 +193,7 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":status_matchers", ":status_matchers",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/deps:status_matchers", "//mediapipe/framework/deps:status_matchers",
"@com_google_googletest//:gtest_main", "@com_google_googletest//:gtest_main",
], ],

View File

@ -16,5 +16,6 @@
#define MEDIAPIPE_PORT_GMOCK_H_ #define MEDIAPIPE_PORT_GMOCK_H_
#include "gmock/gmock.h" #include "gmock/gmock.h"
#include "mediapipe/framework/deps/message_matchers.h"
#endif // MEDIAPIPE_PORT_GMOCK_H_ #endif // MEDIAPIPE_PORT_GMOCK_H_

View File

@ -234,7 +234,6 @@ cc_test(
"//mediapipe/framework:calculator_profile_cc_proto", "//mediapipe/framework:calculator_profile_cc_proto",
"//mediapipe/framework:test_calculators", "//mediapipe/framework:test_calculators",
"//mediapipe/framework/deps:clock", "//mediapipe/framework/deps:clock",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:advanced_proto",
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",

View File

@ -28,7 +28,6 @@
#include "mediapipe/framework/tool/simulation_clock.h" #include "mediapipe/framework/tool/simulation_clock.h"
#include "mediapipe/framework/tool/tag_map_helper.h" #include "mediapipe/framework/tool/tag_map_helper.h"
using ::testing::EqualsProto;
using ::testing::proto::Partially; using ::testing::proto::Partially;
namespace mediapipe { namespace mediapipe {

View File

@ -27,7 +27,6 @@
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_profile.pb.h" #include "mediapipe/framework/calculator_profile.pb.h"
#include "mediapipe/framework/deps/clock.h" #include "mediapipe/framework/deps/clock.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/port/advanced_proto_inc.h" #include "mediapipe/framework/port/advanced_proto_inc.h"
#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"

View File

@ -85,7 +85,7 @@ void BasicTraceEventTypes(TraceEventRegistry* result) {
{TraceEvent::PACKET_QUEUED, "An input queue size when a packet arrives.", {TraceEvent::PACKET_QUEUED, "An input queue size when a packet arrives.",
true, true, false}, true, true, false},
}; };
for (TraceEventType t : basic_types) { for (const TraceEventType& t : basic_types) {
(*result)[t.event_type()] = t; (*result)[t.event_type()] = t;
} }
} }

View File

@ -162,7 +162,6 @@ cc_test(
visibility = ["//mediapipe/framework:mediapipe_internal"], visibility = ["//mediapipe/framework:mediapipe_internal"],
deps = [ deps = [
":executor_util", ":executor_util",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
], ],
@ -302,7 +301,6 @@ mediapipe_cc_test(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework:validated_graph_config", "//mediapipe/framework:validated_graph_config",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
@ -449,6 +447,7 @@ cc_library(
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor", "//mediapipe/framework/port:statusor",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
@ -613,7 +612,6 @@ cc_test(
deps = [ deps = [
":validate_name", ":validate_name",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
@ -736,7 +734,6 @@ cc_test(
"//mediapipe/framework:packet_type", "//mediapipe/framework:packet_type",
"//mediapipe/framework:status_handler", "//mediapipe/framework:status_handler",
"//mediapipe/framework:subgraph", "//mediapipe/framework:subgraph",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
@ -885,7 +882,6 @@ cc_test(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:subgraph", "//mediapipe/framework:subgraph",
"//mediapipe/framework:test_calculators", "//mediapipe/framework:test_calculators",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",

View File

@ -14,7 +14,6 @@
#include "mediapipe/framework/tool/executor_util.h" #include "mediapipe/framework/tool/executor_util.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"

View File

@ -17,7 +17,7 @@
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/message_matchers.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"

View File

@ -18,7 +18,6 @@
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/graph_service_manager.h" #include "mediapipe/framework/graph_service_manager.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_set.h" #include "mediapipe/framework/packet_set.h"

View File

@ -15,7 +15,6 @@
#include "absl/strings/str_replace.h" #include "absl/strings/str_replace.h"
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"

View File

@ -215,19 +215,16 @@ std::string TagMap::ShortDebugString() const {
return output; return output;
} }
bool TagMap::HasTag(const std::string& tag) const { bool TagMap::HasTag(const absl::string_view tag) const {
return mapping_.find(tag) != mapping_.end(); return mapping_.contains(tag);
} }
int TagMap::NumEntries(const std::string& tag) const { int TagMap::NumEntries(const absl::string_view tag) const {
const auto it = mapping_.find(tag); const auto it = mapping_.find(tag);
if (it == mapping_.end()) { return it != mapping_.end() ? it->second.count : 0;
return 0;
}
return it->second.count;
} }
CollectionItemId TagMap::GetId(const std::string& tag, int index) const { CollectionItemId TagMap::GetId(const absl::string_view tag, int index) const {
const auto it = mapping_.find(tag); const auto it = mapping_.find(tag);
if (it == mapping_.end()) { if (it == mapping_.end()) {
return CollectionItemId::GetInvalid(); return CollectionItemId::GetInvalid();
@ -248,11 +245,11 @@ std::pair<std::string, int> TagMap::TagAndIndexFromId(
return {"", -1}; return {"", -1};
} }
CollectionItemId TagMap::BeginId(const std::string& tag) const { CollectionItemId TagMap::BeginId(const absl::string_view tag) const {
return GetId(tag, 0); return GetId(tag, 0);
} }
CollectionItemId TagMap::EndId(const std::string& tag) const { CollectionItemId TagMap::EndId(const absl::string_view tag) const {
const auto it = mapping_.find(tag); const auto it = mapping_.find(tag);
if (it == mapping_.end()) { if (it == mapping_.end()) {
return CollectionItemId::GetInvalid(); return CollectionItemId::GetInvalid();

View File

@ -20,6 +20,8 @@
#include <vector> #include <vector>
#include "absl/base/macros.h" #include "absl/base/macros.h"
#include "absl/container/btree_map.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/core_proto_inc.h" #include "mediapipe/framework/port/core_proto_inc.h"
@ -72,7 +74,9 @@ class TagMap {
} }
// Returns a reference to the mapping from tag to tag data. // Returns a reference to the mapping from tag to tag data.
const std::map<std::string, TagData>& Mapping() const { return mapping_; } const absl::btree_map<std::string, TagData>& Mapping() const {
return mapping_;
}
// Returns the vector of names (indexed by CollectionItemId). // Returns the vector of names (indexed by CollectionItemId).
const std::vector<std::string>& Names() const { return names_; } const std::vector<std::string>& Names() const { return names_; }
@ -91,16 +95,16 @@ class TagMap {
// The following functions are directly utilized by collection.h see // The following functions are directly utilized by collection.h see
// that file for comments. // that file for comments.
bool HasTag(const std::string& tag) const; bool HasTag(absl::string_view tag) const;
int NumEntries() const { return num_entries_; } int NumEntries() const { return num_entries_; }
int NumEntries(const std::string& tag) const; int NumEntries(absl::string_view tag) const;
CollectionItemId GetId(const std::string& tag, int index) const; CollectionItemId GetId(absl::string_view tag, int index) const;
std::set<std::string> GetTags() const; std::set<std::string> GetTags() const;
std::pair<std::string, int> TagAndIndexFromId(CollectionItemId id) const; std::pair<std::string, int> TagAndIndexFromId(CollectionItemId id) const;
CollectionItemId BeginId() const { return CollectionItemId(0); } CollectionItemId BeginId() const { return CollectionItemId(0); }
CollectionItemId EndId() const { return CollectionItemId(num_entries_); } CollectionItemId EndId() const { return CollectionItemId(num_entries_); }
CollectionItemId BeginId(const std::string& tag) const; CollectionItemId BeginId(absl::string_view tag) const;
CollectionItemId EndId(const std::string& tag) const; CollectionItemId EndId(absl::string_view tag) const;
private: private:
// Use static factory function TagMap::Create(). // Use static factory function TagMap::Create().
@ -122,7 +126,7 @@ class TagMap {
// The total number of entries under all tags. // The total number of entries under all tags.
int num_entries_; int num_entries_;
// Mapping from tag to tag data. // Mapping from tag to tag data.
std::map<std::string, TagData> mapping_; absl::btree_map<std::string, TagData> mapping_;
// The names of the data (indexed by CollectionItemId). // The names of the data (indexed by CollectionItemId).
std::vector<std::string> names_; std::vector<std::string> names_;
}; };

View File

@ -37,7 +37,6 @@
#include "mediapipe/framework/tool/proto_util_lite.h" #include "mediapipe/framework/tool/proto_util_lite.h"
using mediapipe::proto_ns::Descriptor; using mediapipe::proto_ns::Descriptor;
using mediapipe::proto_ns::DescriptorPool;
using mediapipe::proto_ns::DynamicMessageFactory; using mediapipe::proto_ns::DynamicMessageFactory;
using mediapipe::proto_ns::EnumDescriptor; using mediapipe::proto_ns::EnumDescriptor;
using mediapipe::proto_ns::EnumValueDescriptor; using mediapipe::proto_ns::EnumValueDescriptor;
@ -1666,7 +1665,6 @@ TemplateParser::Parser::Parser()
allow_partial_(false), allow_partial_(false),
allow_case_insensitive_field_(false), allow_case_insensitive_field_(false),
allow_unknown_field_(false), allow_unknown_field_(false),
allow_unknown_extension_(true),
allow_unknown_enum_(false), allow_unknown_enum_(false),
allow_field_number_(false), allow_field_number_(false),
allow_relaxed_whitespace_(false), allow_relaxed_whitespace_(false),
@ -1685,10 +1683,11 @@ bool TemplateParser::Parser::Parse(io::ZeroCopyInputStream* input,
: ParserImpl::FORBID_SINGULAR_OVERWRITES; : ParserImpl::FORBID_SINGULAR_OVERWRITES;
int recursion_limit = std::numeric_limits<int>::max(); int recursion_limit = std::numeric_limits<int>::max();
bool allow_unknown_extension = false;
MediaPipeParserImpl parser( MediaPipeParserImpl parser(
output->GetDescriptor(), input, error_collector_, finder_, output->GetDescriptor(), input, error_collector_, finder_,
parse_info_tree_, overwrites_policy, allow_case_insensitive_field_, parse_info_tree_, overwrites_policy, allow_case_insensitive_field_,
allow_unknown_field_, allow_unknown_extension_, allow_unknown_enum_, allow_unknown_field_, allow_unknown_extension, allow_unknown_enum_,
allow_field_number_, allow_relaxed_whitespace_, allow_partial_, allow_field_number_, allow_relaxed_whitespace_, allow_partial_,
recursion_limit); recursion_limit);
return MergeUsingImpl(input, output, &parser); return MergeUsingImpl(input, output, &parser);
@ -1703,11 +1702,12 @@ bool TemplateParser::Parser::ParseFromString(const std::string& input,
bool TemplateParser::Parser::Merge(io::ZeroCopyInputStream* input, bool TemplateParser::Parser::Merge(io::ZeroCopyInputStream* input,
Message* output) { Message* output) {
int recursion_limit = std::numeric_limits<int>::max(); int recursion_limit = std::numeric_limits<int>::max();
bool allow_unknown_extension = false;
MediaPipeParserImpl parser( MediaPipeParserImpl parser(
output->GetDescriptor(), input, error_collector_, finder_, output->GetDescriptor(), input, error_collector_, finder_,
parse_info_tree_, ParserImpl::ALLOW_SINGULAR_OVERWRITES, parse_info_tree_, ParserImpl::ALLOW_SINGULAR_OVERWRITES,
allow_case_insensitive_field_, allow_unknown_field_, allow_case_insensitive_field_, allow_unknown_field_,
allow_unknown_extension_, allow_unknown_enum_, allow_field_number_, allow_unknown_extension, allow_unknown_enum_, allow_field_number_,
allow_relaxed_whitespace_, allow_partial_, recursion_limit); allow_relaxed_whitespace_, allow_partial_, recursion_limit);
return MergeUsingImpl(input, output, &parser); return MergeUsingImpl(input, output, &parser);
} }
@ -1737,11 +1737,12 @@ bool TemplateParser::Parser::ParseFieldValueFromString(
const std::string& input, const FieldDescriptor* field, Message* output) { const std::string& input, const FieldDescriptor* field, Message* output) {
io::ArrayInputStream input_stream(input.data(), input.size()); io::ArrayInputStream input_stream(input.data(), input.size());
int recursion_limit = std::numeric_limits<int>::max(); int recursion_limit = std::numeric_limits<int>::max();
bool allow_unknown_extension = false;
ParserImpl parser( ParserImpl parser(
output->GetDescriptor(), &input_stream, error_collector_, finder_, output->GetDescriptor(), &input_stream, error_collector_, finder_,
parse_info_tree_, ParserImpl::ALLOW_SINGULAR_OVERWRITES, parse_info_tree_, ParserImpl::ALLOW_SINGULAR_OVERWRITES,
allow_case_insensitive_field_, allow_unknown_field_, allow_case_insensitive_field_, allow_unknown_field_,
allow_unknown_extension_, allow_unknown_enum_, allow_field_number_, allow_unknown_extension, allow_unknown_enum_, allow_field_number_,
allow_relaxed_whitespace_, allow_partial_, recursion_limit); allow_relaxed_whitespace_, allow_partial_, recursion_limit);
return parser.ParseField(field, output); return parser.ParseField(field, output);
} }

View File

@ -37,10 +37,6 @@ class TemplateParser {
Parser(); Parser();
~Parser(); ~Parser();
void set_allow_unknown_extension(bool allow_unknown_extension) {
allow_unknown_extension_ = allow_unknown_extension;
}
// Like TextFormat::Parse(). // Like TextFormat::Parse().
bool Parse(proto_ns::io::ZeroCopyInputStream* input, bool Parse(proto_ns::io::ZeroCopyInputStream* input,
proto_ns::Message* output); proto_ns::Message* output);
@ -103,7 +99,6 @@ class TemplateParser {
bool allow_partial_; bool allow_partial_;
bool allow_case_insensitive_field_; bool allow_case_insensitive_field_;
bool allow_unknown_field_; bool allow_unknown_field_;
bool allow_unknown_extension_;
bool allow_unknown_enum_; bool allow_unknown_enum_;
bool allow_field_number_; bool allow_field_number_;
bool allow_relaxed_whitespace_; bool allow_relaxed_whitespace_;

View File

@ -56,7 +56,7 @@ absl::Status GetTagAndNameInfo(
} }
info->names.push_back(name); info->names.push_back(name);
} }
if (info->tags.size() > 0 && info->names.size() != info->tags.size()) { if (!info->tags.empty() && info->names.size() != info->tags.size()) {
info->tags.clear(); info->tags.clear();
info->names.clear(); info->names.clear();
return absl::InvalidArgumentError(absl::StrCat( return absl::InvalidArgumentError(absl::StrCat(

View File

@ -17,7 +17,6 @@
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"

View File

@ -96,7 +96,7 @@ class NodeTypeInfo {
// Get the input/output side packet/stream index that is the first // Get the input/output side packet/stream index that is the first
// for the PacketTypeSets. Subsequent id's in the collection are // for the PacketTypeSets. Subsequent id's in the collection are
// guaranteed to be contiguous in the master flat array. // guaranteed to be contiguous in the main flat array.
int InputSidePacketBaseIndex() const { return input_side_packet_base_index_; } int InputSidePacketBaseIndex() const { return input_side_packet_base_index_; }
int OutputSidePacketBaseIndex() const { int OutputSidePacketBaseIndex() const {
return output_side_packet_base_index_; return output_side_packet_base_index_;
@ -154,7 +154,7 @@ class NodeTypeInfo {
CalculatorContract contract_; CalculatorContract contract_;
// The base indexes of the first entry belonging to this node in // The base indexes of the first entry belonging to this node in
// the master flat arrays of ValidatedGraphConfig. Subsequent // the main flat arrays of ValidatedGraphConfig. Subsequent
// entries are guaranteed to be sequential and in the order of the // entries are guaranteed to be sequential and in the order of the
// CollectionItemIds. // CollectionItemIds.
// Example: // Example:

View File

@ -9,8 +9,8 @@
#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/graph_service.h" #include "mediapipe/framework/graph_service.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"

View File

@ -161,15 +161,18 @@ void GlTextureBuffer::Reuse() {
// sync fences; with a single-threaded executor, that means switching to // sync fences; with a single-threaded executor, that means switching to
// each of those contexts, grabbing its mutex. Let's do that after releasing // each of those contexts, grabbing its mutex. Let's do that after releasing
// our own mutex. // our own mutex.
// Likewise, if we don't have sync fences and are simulating them, WaitOnGpu
// will also require invoking the consumer context, so we should not call it
// while holding the mutex.
std::unique_ptr<GlMultiSyncPoint> old_consumer_sync; std::unique_ptr<GlMultiSyncPoint> old_consumer_sync;
{ {
absl::MutexLock lock(&consumer_sync_mutex_); absl::MutexLock lock(&consumer_sync_mutex_);
consumer_multi_sync_->WaitOnGpu();
// Reset the sync points. // Reset the sync points.
old_consumer_sync = std::move(consumer_multi_sync_); old_consumer_sync = std::move(consumer_multi_sync_);
consumer_multi_sync_ = absl::make_unique<GlMultiSyncPoint>(); consumer_multi_sync_ = absl::make_unique<GlMultiSyncPoint>();
producer_sync_ = nullptr; producer_sync_ = nullptr;
} }
old_consumer_sync->WaitOnGpu();
} }
void GlTextureBuffer::Updated(std::shared_ptr<GlSyncPoint> prod_token) { void GlTextureBuffer::Updated(std::shared_ptr<GlSyncPoint> prod_token) {

View File

@ -112,7 +112,7 @@ absl::Status GpuResources::PrepareGpuNode(CalculatorNode* node) {
#ifndef __EMSCRIPTEN__ #ifndef __EMSCRIPTEN__
// TODO Allow calculators to request a separate context. // TODO Allow calculators to request a separate context.
// For now, white-list a few calculators to run in their own context. // For now, allow a few calculators to run in their own context.
bool gets_own_context = (node_type == "ImageFrameToGpuBufferCalculator") || bool gets_own_context = (node_type == "ImageFrameToGpuBufferCalculator") ||
(node_type == "GpuBufferToImageFrameCalculator") || (node_type == "GpuBufferToImageFrameCalculator") ||
(node_type == "GlSurfaceSinkCalculator"); (node_type == "GlSurfaceSinkCalculator");

View File

@ -154,7 +154,7 @@ EOF
"//third_party:camerax_core", "//third_party:camerax_core",
"//third_party:camerax_camera2", "//third_party:camerax_camera2",
"//third_party:camerax_lifecycle", "//third_party:camerax_lifecycle",
"@com_google_protobuf//:protobuf_java", "@com_google_protobuf//:protobuf_javalite",
"@maven//:com_google_code_findbugs_jsr305", "@maven//:com_google_code_findbugs_jsr305",
"@maven//:com_google_flogger_flogger", "@maven//:com_google_flogger_flogger",
"@maven//:com_google_flogger_flogger_system_backend", "@maven//:com_google_flogger_flogger_system_backend",
@ -252,7 +252,7 @@ def _proto_java_src_generator(name, proto_src, java_lite_out, srcs = []):
native.genrule( native.genrule(
name = name + "_proto_java_src_generator", name = name + "_proto_java_src_generator",
srcs = srcs + [ srcs = srcs + [
"@com_google_protobuf//:well_known_protos", "@com_google_protobuf//:lite_well_known_protos",
], ],
outs = [java_lite_out], outs = [java_lite_out],
cmd = "$(location @com_google_protobuf//:protoc) " + cmd = "$(location @com_google_protobuf//:protoc) " +

View File

@ -38,7 +38,6 @@ constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kLabelsTag[] = "LABELS"; constexpr char kLabelsTag[] = "LABELS";
constexpr char kLabelsCsvTag[] = "LABELS_CSV"; constexpr char kLabelsCsvTag[] = "LABELS_CSV";
using mediapipe::ContainsKey;
using mediapipe::RE2; using mediapipe::RE2;
using Detections = std::vector<Detection>; using Detections = std::vector<Detection>;
using Strings = std::vector<std::string>; using Strings = std::vector<std::string>;
@ -161,24 +160,24 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) {
limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) || limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) ||
cc->InputSidePackets().HasTag(kLabelsCsvTag); cc->InputSidePackets().HasTag(kLabelsCsvTag);
if (limit_labels_) { if (limit_labels_) {
Strings whitelist_labels; Strings allowlist_labels;
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) { if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
whitelist_labels = absl::StrSplit( allowlist_labels = absl::StrSplit(
cc->InputSidePackets().Tag(kLabelsCsvTag).Get<std::string>(), ',', cc->InputSidePackets().Tag(kLabelsCsvTag).Get<std::string>(), ',',
absl::SkipWhitespace()); absl::SkipWhitespace());
for (auto& e : whitelist_labels) { for (auto& e : allowlist_labels) {
absl::StripAsciiWhitespace(&e); absl::StripAsciiWhitespace(&e);
} }
} else { } else {
whitelist_labels = cc->InputSidePackets().Tag(kLabelsTag).Get<Strings>(); allowlist_labels = cc->InputSidePackets().Tag(kLabelsTag).Get<Strings>();
} }
allowed_labels_.insert(whitelist_labels.begin(), whitelist_labels.end()); allowed_labels_.insert(allowlist_labels.begin(), allowlist_labels.end());
} }
if (limit_labels_ && allowed_labels_.empty()) { if (limit_labels_ && allowed_labels_.empty()) {
if (options_.fail_on_empty_labels()) { if (options_.fail_on_empty_labels()) {
cc->GetCounter("VideosWithEmptyLabelsWhitelist")->Increment(); cc->GetCounter("VideosWithEmptyLabelsAllowlist")->Increment();
return tool::StatusFail( return tool::StatusFail(
"FilterDetectionCalculator received empty whitelist with " "FilterDetectionCalculator received empty allowlist with "
"fail_on_empty_labels = true."); "fail_on_empty_labels = true.");
} }
if (options_.empty_allowed_labels_means_allow_everything()) { if (options_.empty_allowed_labels_means_allow_everything()) {

View File

@ -211,6 +211,33 @@ class GraphTest(absltest.TestCase):
self.assertEqual( self.assertEqual(
mp.packet_getter.get_uint(graph.get_output_side_packet('number')), 42) mp.packet_getter.get_uint(graph.get_output_side_packet('number')), 42)
def test_sequence_input(self):
text_config = """
max_queue_size: 1
input_stream: 'in'
output_stream: 'out'
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
output_stream: 'out'
}
"""
hello_world_packet = mp.packet_creator.create_string('hello world')
out = []
graph = mp.CalculatorGraph(graph_config=text_config)
graph.observe_output_stream('out', lambda _, packet: out.append(packet))
graph.start_run()
sequence_size = 1000
for i in range(sequence_size):
graph.add_packet_to_input_stream(
stream='in', packet=hello_world_packet, timestamp=i)
graph.wait_until_idle()
self.assertLen(out, sequence_size)
for i in range(sequence_size):
self.assertEqual(out[i].timestamp, i)
self.assertEqual(mp.packet_getter.get_str(out[i]), 'hello world')
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View File

@ -121,6 +121,7 @@ pybind_library(
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"@com_google_absl//absl/status:statusor",
], ],
) )

View File

@ -165,8 +165,10 @@ void CalculatorGraphSubmodule(pybind11::module* module) {
" can't be the timestamp of a Packet in a stream.") " can't be the timestamp of a Packet in a stream.")
.c_str()); .c_str());
} }
py::gil_scoped_release gil_release;
RaisePyErrorIfNotOk( RaisePyErrorIfNotOk(
self->AddPacketToInputStream(stream, packet.At(packet_timestamp))); self->AddPacketToInputStream(stream, packet.At(packet_timestamp)),
/**acquire_gil=*/true);
}, },
R"doc(Add a packet to a graph input stream. R"doc(Add a packet to a graph input stream.
@ -347,7 +349,9 @@ void CalculatorGraphSubmodule(pybind11::module* module) {
calculator_graph.def( calculator_graph.def(
"wait_for_observed_output", "wait_for_observed_output",
[](CalculatorGraph* self) { [](CalculatorGraph* self) {
RaisePyErrorIfNotOk(self->WaitForObservedOutput()); py::gil_scoped_release gil_release;
RaisePyErrorIfNotOk(self->WaitForObservedOutput(),
/**acquire_gil=*/true);
}, },
R"doc(Wait until a packet is emitted on one of the observed output streams. R"doc(Wait until a packet is emitted on one of the observed output streams.

View File

@ -14,6 +14,7 @@
#include "mediapipe/python/pybind/packet_getter.h" #include "mediapipe/python/pybind/packet_getter.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"

View File

@ -14,6 +14,7 @@
"""Tests for mediapipe.python.solutions.hands.""" """Tests for mediapipe.python.solutions.hands."""
import json
import os import os
import tempfile # pylint: disable=unused-import import tempfile # pylint: disable=unused-import
from typing import NamedTuple from typing import NamedTuple
@ -52,6 +53,21 @@ EXPECTED_HAND_COORDINATES_PREDICTION = [[[580, 34], [504, 50], [459, 94],
class HandsTest(parameterized.TestCase): class HandsTest(parameterized.TestCase):
def _get_output_path(self, name):
return os.path.join(tempfile.gettempdir(), self.id().split('.')[-1] + name)
def _landmarks_list_to_array(self, landmark_list, image_shape):
rows, cols, _ = image_shape
return np.asarray([(lmk.x * cols, lmk.y * rows, lmk.z * cols)
for lmk in landmark_list.landmark])
def _world_landmarks_list_to_array(self, landmark_list):
return np.asarray([(lmk.x, lmk.y, lmk.z)
for lmk in landmark_list.landmark])
def _assert_diff_less(self, array1, array2, threshold):
npt.assert_array_less(np.abs(array1 - array2), threshold)
def _annotate(self, frame: np.ndarray, results: NamedTuple, idx: int): def _annotate(self, frame: np.ndarray, results: NamedTuple, idx: int):
for hand_landmarks in results.multi_hand_landmarks: for hand_landmarks in results.multi_hand_landmarks:
mp_drawing.draw_landmarks( mp_drawing.draw_landmarks(
@ -112,6 +128,91 @@ class HandsTest(parameterized.TestCase):
diff_threshold = LITE_MODEL_DIFF_THRESHOLD if model_complexity == 0 else FULL_MODEL_DIFF_THRESHOLD diff_threshold = LITE_MODEL_DIFF_THRESHOLD if model_complexity == 0 else FULL_MODEL_DIFF_THRESHOLD
npt.assert_array_less(prediction_error, diff_threshold) npt.assert_array_less(prediction_error, diff_threshold)
def _process_video(self, model_complexity, video_path,
max_num_hands=1,
num_landmarks=21,
num_dimensions=3):
# Predict pose landmarks for each frame.
video_cap = cv2.VideoCapture(video_path)
landmarks_per_frame = []
w_landmarks_per_frame = []
with mp_hands.Hands(
static_image_mode=False,
max_num_hands=max_num_hands,
model_complexity=model_complexity,
min_detection_confidence=0.5) as hands:
while True:
# Get next frame of the video.
success, input_frame = video_cap.read()
if not success:
break
# Run pose tracker.
input_frame = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB)
frame_shape = input_frame.shape
result = hands.process(image=input_frame)
frame_landmarks = np.zeros([max_num_hands,
num_landmarks, num_dimensions]) * np.nan
frame_w_landmarks = np.zeros([max_num_hands,
num_landmarks, num_dimensions]) * np.nan
if result.multi_hand_landmarks:
for idx, landmarks in enumerate(result.multi_hand_landmarks):
landmarks = self._landmarks_list_to_array(landmarks, frame_shape)
frame_landmarks[idx] = landmarks
if result.multi_hand_world_landmarks:
for idx, w_landmarks in enumerate(result.multi_hand_world_landmarks):
w_landmarks = self._world_landmarks_list_to_array(w_landmarks)
frame_w_landmarks[idx] = w_landmarks
landmarks_per_frame.append(frame_landmarks)
w_landmarks_per_frame.append(frame_w_landmarks)
return (np.array(landmarks_per_frame), np.array(w_landmarks_per_frame))
@parameterized.named_parameters(
('full', 1, 'asl_hand.full.npz'))
def test_on_video(self, model_complexity, expected_name):
"""Tests hand models on a video."""
# Set threshold for comparing actual and expected predictions in pixels.
diff_threshold = 18
world_diff_threshold = 0.05
video_path = os.path.join(os.path.dirname(__file__),
'testdata/asl_hand.25fps.mp4')
expected_path = os.path.join(os.path.dirname(__file__),
'testdata/{}'.format(expected_name))
actual, actual_world = self._process_video(model_complexity, video_path)
# Dump actual .npz.
npz_path = self._get_output_path(expected_name)
np.savez(npz_path, predictions=actual, w_predictions=actual_world)
# Dump actual JSON.
json_path = self._get_output_path(expected_name.replace('.npz', '.json'))
with open(json_path, 'w') as fl:
dump_data = {
'predictions': np.around(actual, 3).tolist(),
'predictions_world': np.around(actual_world, 3).tolist()
}
fl.write(json.dumps(dump_data, indent=2, separators=(',', ': ')))
# Validate actual vs. expected landmarks.
expected = np.load(expected_path)['predictions']
assert actual.shape == expected.shape, (
'Unexpected shape of predictions: {} instead of {}'.format(
actual.shape, expected.shape))
self._assert_diff_less(
actual[..., :2], expected[..., :2], threshold=diff_threshold)
# Validate actual vs. expected world landmarks.
expected_world = np.load(expected_path)['w_predictions']
assert actual_world.shape == expected_world.shape, (
'Unexpected shape of world predictions: {} instead of {}'.format(
actual_world.shape, expected_world.shape))
self._assert_diff_less(
actual_world, expected_world, threshold=world_diff_threshold)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View File

@ -269,7 +269,6 @@ cc_library(
":time_series_util", ":time_series_util",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
@ -289,7 +288,6 @@ cc_test(
deps = [ deps = [
":time_series_util", ":time_series_util",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"@eigen_archive//:eigen3", "@eigen_archive//:eigen3",

View File

@ -27,6 +27,7 @@
#include "absl/algorithm/container.h" #include "absl/algorithm/container.h"
#include "absl/flags/flag.h" #include "absl/flags/flag.h"
#include "absl/strings/match.h"
#include "absl/strings/numbers.h" #include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
@ -45,8 +46,7 @@ namespace {
constexpr uint32 kBufferLength = 64; constexpr uint32 kBufferLength = 64;
absl::StatusOr<std::string> GetFilePath(int cpu) { absl::StatusOr<std::string> GetFilePath(int cpu) {
if (absl::GetFlag(FLAGS_system_cpu_max_freq_file).find("$0") == if (!absl::StrContains(absl::GetFlag(FLAGS_system_cpu_max_freq_file), "$0")) {
std::string::npos) {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
absl::StrCat("Invalid frequency file: ", absl::StrCat("Invalid frequency file: ",
absl::GetFlag(FLAGS_system_cpu_max_freq_file))); absl::GetFlag(FLAGS_system_cpu_max_freq_file)));

View File

@ -25,7 +25,6 @@
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h" #include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"

View File

@ -16,7 +16,6 @@
#include "Eigen/Core" #include "Eigen/Core"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/formats/time_series_header.pb.h" #include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"

View File

@ -716,7 +716,6 @@ cc_test(
deps = [ deps = [
":motion_estimation", ":motion_estimation",
":motion_models", ":motion_models",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:vector", "//mediapipe/framework/port:vector",

View File

@ -501,7 +501,7 @@ bool BoxTracker::GetTimedPosition(int id, int64 time_msec, TimedBox* result,
absl::MutexLock lock(&path_mutex_); absl::MutexLock lock(&path_mutex_);
const Path& path = paths_[id]; const Path& path = paths_[id];
if (path.size() < 1) { if (path.empty()) {
LOG(ERROR) << "Empty path!"; LOG(ERROR) << "Empty path!";
return false; return false;
} }

View File

@ -14,7 +14,6 @@
#include "mediapipe/util/tracking/motion_models.h" #include "mediapipe/util/tracking/motion_models.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"

View File

@ -4,5 +4,3 @@ matplotlib
numpy numpy
opencv-contrib-python opencv-contrib-python
protobuf>=3.11.4 protobuf>=3.11.4
six
wheel

125
setup.py
View File

@ -26,28 +26,22 @@ import sys
import setuptools import setuptools
import setuptools.command.build_ext as build_ext import setuptools.command.build_ext as build_ext
import setuptools.command.build_py as build_py
import setuptools.command.install as install import setuptools.command.install as install
# It is recommended to import setuptools prior to importing distutils to avoid
# using legacy behavior from distutils.
from distutils import spawn
import distutils.command.build as build
import distutils.command.clean as clean
__version__ = '0.8' __version__ = 'dev'
IS_WINDOWS = (platform.system() == 'Windows') IS_WINDOWS = (platform.system() == 'Windows')
MP_ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) MP_ROOT_PATH = os.path.dirname(os.path.abspath(__file__))
ROOT_INIT_PY = os.path.join(MP_ROOT_PATH, '__init__.py')
MP_DIR_INIT_PY = os.path.join(MP_ROOT_PATH, 'mediapipe/__init__.py') MP_DIR_INIT_PY = os.path.join(MP_ROOT_PATH, 'mediapipe/__init__.py')
MP_THIRD_PARTY_BUILD = os.path.join(MP_ROOT_PATH, 'third_party/BUILD') MP_THIRD_PARTY_BUILD = os.path.join(MP_ROOT_PATH, 'third_party/BUILD')
SUBDIR_INIT_PY_FILES = [ DIR_INIT_PY_FILES = [
os.path.join(MP_ROOT_PATH, '__init__.py'),
os.path.join(MP_ROOT_PATH, 'mediapipe/calculators/__init__.py'), os.path.join(MP_ROOT_PATH, 'mediapipe/calculators/__init__.py'),
os.path.join(MP_ROOT_PATH, 'mediapipe/modules/__init__.py'), os.path.join(MP_ROOT_PATH, 'mediapipe/modules/__init__.py'),
os.path.join(MP_ROOT_PATH, os.path.join(MP_ROOT_PATH,
'mediapipe/modules/holistic_landmark/__init__.py'), 'mediapipe/modules/holistic_landmark/__init__.py'),
os.path.join(MP_ROOT_PATH, 'mediapipe/modules/objectron/__init__.py') os.path.join(MP_ROOT_PATH, 'mediapipe/modules/objectron/__init__.py')
] ]
if not os.path.exists(ROOT_INIT_PY):
open(ROOT_INIT_PY, 'w').close()
def _normalize_path(path): def _normalize_path(path):
@ -79,7 +73,7 @@ def _get_long_description():
def _check_bazel(): def _check_bazel():
"""Check Bazel binary as well as its version.""" """Check Bazel binary as well as its version."""
if not spawn.find_executable('bazel'): if not shutil.which('bazel'):
sys.stderr.write('could not find bazel executable. Please install bazel to' sys.stderr.write('could not find bazel executable. Please install bazel to'
'build the MediaPipe Python package.') 'build the MediaPipe Python package.')
sys.exit(-1) sys.exit(-1)
@ -126,28 +120,6 @@ def _modify_opencv_cmake_rule(link_opencv):
build_file.close() build_file.close()
class ModifyInitFiles(setuptools.Command):
"""Modify the init files for building MediaPipe Python package."""
user_options = []
def initialize_options(self):
pass
def finalize_options(self):
pass
def run(self):
# Save the original init file.
shutil.copyfile(MP_DIR_INIT_PY, _get_backup_file(MP_DIR_INIT_PY))
mp_dir_init_file = open(MP_DIR_INIT_PY, 'a')
mp_dir_init_file.writelines(
['\n', 'from mediapipe.python import *\n',
'import mediapipe.python.solutions as solutions',
'\n'])
mp_dir_init_file.close()
class GeneratePyProtos(setuptools.Command): class GeneratePyProtos(setuptools.Command):
"""Generate MediaPipe Python protobuf files by Protocol Compiler.""" """Generate MediaPipe Python protobuf files by Protocol Compiler."""
@ -163,18 +135,14 @@ class GeneratePyProtos(setuptools.Command):
if 'PROTOC' in os.environ and os.path.exists(os.environ['PROTOC']): if 'PROTOC' in os.environ and os.path.exists(os.environ['PROTOC']):
self._protoc = os.environ['PROTOC'] self._protoc = os.environ['PROTOC']
else: else:
self._protoc = spawn.find_executable('protoc') self._protoc = shutil.which('protoc')
if self._protoc is None: if self._protoc is None:
sys.stderr.write( sys.stderr.write(
'protoc is not found. Please run \'apt install -y protobuf' 'protoc is not found. Please run \'apt install -y protobuf'
'-compiler\' (linux) or \'brew install protobuf\'(macos) to install ' '-compiler\' (linux) or \'brew install protobuf\'(macos) to install '
'protobuf compiler binary.') 'protobuf compiler binary.')
sys.exit(-1) sys.exit(-1)
# Add __init__.py to make the generated py proto files visiable. self._modify_inits()
for init_py in SUBDIR_INIT_PY_FILES:
if not os.path.exists(init_py):
sys.stderr.write('adding __init__ file: %s\n' % init_py)
open(init_py, 'w').close()
# Build framework and calculator protos. # Build framework and calculator protos.
for pattern in [ for pattern in [
'mediapipe/framework/**/*.proto', 'mediapipe/calculators/**/*.proto', 'mediapipe/framework/**/*.proto', 'mediapipe/calculators/**/*.proto',
@ -198,6 +166,21 @@ class GeneratePyProtos(setuptools.Command):
open(init_py, 'w').close() open(init_py, 'w').close()
self._generate_proto(proto_file) self._generate_proto(proto_file)
def _modify_inits(self):
# Add __init__.py to make the dirs indexable.
for init_py in DIR_INIT_PY_FILES:
if not os.path.exists(init_py):
sys.stderr.write('adding __init__ file: %s\n' % init_py)
open(init_py, 'w').close()
# Save the original init file.
shutil.copyfile(MP_DIR_INIT_PY, _get_backup_file(MP_DIR_INIT_PY))
mp_dir_init_file = open(MP_DIR_INIT_PY, 'a')
mp_dir_init_file.writelines(
['\n', 'from mediapipe.python import *\n',
'import mediapipe.python.solutions as solutions',
'\n'])
mp_dir_init_file.close()
def _generate_proto(self, source): def _generate_proto(self, source):
"""Invokes the Protocol Compiler to generate a _pb2.py.""" """Invokes the Protocol Compiler to generate a _pb2.py."""
@ -216,8 +199,20 @@ class GeneratePyProtos(setuptools.Command):
sys.exit(-1) sys.exit(-1)
class BuildBinaryGraphs(build.build): class BuildBinaryGraphs(build_ext.build_ext):
"""Build binary graphs for Python examples.""" """Build MediaPipe solution binary graphs."""
user_options = build_ext.build_ext.user_options + [
('link-opencv', None, 'if true, build opencv from source.'),
]
boolean_options = build_ext.build_ext.boolean_options + ['link-opencv']
def initialize_options(self):
self.link_opencv = False
build_ext.build_ext.initialize_options(self)
def finalize_options(self):
build_ext.build_ext.finalize_options(self)
def run(self): def run(self):
_check_bazel() _check_bazel()
@ -271,7 +266,7 @@ class BazelExtension(setuptools.Extension):
setuptools.Extension.__init__(self, ext_name, sources=[]) setuptools.Extension.__init__(self, ext_name, sources=[])
class BuildBazelExtension(build_ext.build_ext): class BuildExtension(build_ext.build_ext):
"""A command that runs Bazel to build a C/C++ extension.""" """A command that runs Bazel to build a C/C++ extension."""
user_options = build_ext.build_ext.user_options + [ user_options = build_ext.build_ext.user_options + [
@ -289,10 +284,10 @@ class BuildBazelExtension(build_ext.build_ext):
def run(self): def run(self):
_check_bazel() _check_bazel()
for ext in self.extensions: for ext in self.extensions:
self.bazel_build(ext) self._build_binary(ext)
build_ext.build_ext.run(self) build_ext.build_ext.run(self)
def bazel_build(self, ext): def _build_binary(self, ext):
if not os.path.exists(self.build_temp): if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp) os.makedirs(self.build_temp)
bazel_command = [ bazel_command = [
@ -306,7 +301,8 @@ class BuildBazelExtension(build_ext.build_ext):
] ]
if not self.link_opencv and not IS_WINDOWS: if not self.link_opencv and not IS_WINDOWS:
bazel_command.append('--define=OPENCV=source') bazel_command.append('--define=OPENCV=source')
self.spawn(bazel_command) if subprocess.call(bazel_command) != 0:
sys.exit(-1)
ext_bazel_bin_path = os.path.join('bazel-bin', ext.relpath, ext_bazel_bin_path = os.path.join('bazel-bin', ext.relpath,
ext.target_name + '.so') ext.target_name + '.so')
ext_dest_path = self.get_ext_fullpath(ext.name) ext_dest_path = self.get_ext_fullpath(ext.name)
@ -320,20 +316,20 @@ class BuildBazelExtension(build_ext.build_ext):
shutil.copy(opencv_dll, ext_dest_dir) shutil.copy(opencv_dll, ext_dest_dir)
class Build(build.build): class BuildPy(build_py.build_py):
"""Build command that builds binary graphs and extension and does a cleanup afterwards.""" """Build command that generates protos, builds binary graphs and extension, builds python source, and performs a cleanup afterwards."""
user_options = build.build.user_options + [ user_options = build_py.build_py.user_options + [
('link-opencv', None, 'if true, use the installed opencv library.'), ('link-opencv', None, 'if true, use the installed opencv library.'),
] ]
boolean_options = build.build.boolean_options + ['link-opencv'] boolean_options = build_py.build_py.boolean_options + ['link-opencv']
def initialize_options(self): def initialize_options(self):
self.link_opencv = False self.link_opencv = False
build.build.initialize_options(self) build_py.build_py.initialize_options(self)
def finalize_options(self): def finalize_options(self):
build.build.finalize_options(self) build_py.build_py.finalize_options(self)
def run(self): def run(self):
_modify_opencv_cmake_rule(self.link_opencv) _modify_opencv_cmake_rule(self.link_opencv)
@ -344,13 +340,12 @@ class Build(build.build):
build_ext_obj.link_opencv = self.link_opencv build_ext_obj.link_opencv = self.link_opencv
self.run_command('build_binary_graphs') self.run_command('build_binary_graphs')
self.run_command('build_ext') self.run_command('build_ext')
self.run_command('modify_inits') build_py.build_py.run(self)
build.build.run(self)
self.run_command('remove_generated') self.run_command('remove_generated')
class Install(install.install): class Install(install.install):
"""Install command that builds binary graphs and extension and does a cleanup afterwards.""" """Install command that generates protos, builds binary graphs and extension, builds python source, and performs a cleanup afterwards."""
user_options = install.install.user_options + [ user_options = install.install.user_options + [
('link-opencv', None, 'if true, use the installed opencv library.'), ('link-opencv', None, 'if true, use the installed opencv library.'),
@ -373,14 +368,21 @@ class Install(install.install):
build_ext_obj.link_opencv = self.link_opencv build_ext_obj.link_opencv = self.link_opencv
self.run_command('build_binary_graphs') self.run_command('build_binary_graphs')
self.run_command('build_ext') self.run_command('build_ext')
self.run_command('modify_inits')
install.install.run(self) install.install.run(self)
self.run_command('remove_generated') self.run_command('remove_generated')
class RemoveGenerated(clean.clean): class RemoveGenerated(setuptools.Command):
"""Remove the generated files.""" """Remove the generated files."""
user_options = []
def initialize_options(self):
pass
def finalize_options(self):
pass
def run(self): def run(self):
for pattern in [ for pattern in [
'mediapipe/calculators/**/*pb2.py', 'mediapipe/calculators/**/*pb2.py',
@ -409,9 +411,8 @@ class RemoveGenerated(clean.clean):
if os.path.exists(_get_backup_file(MP_THIRD_PARTY_BUILD)): if os.path.exists(_get_backup_file(MP_THIRD_PARTY_BUILD)):
os.remove(MP_THIRD_PARTY_BUILD) os.remove(MP_THIRD_PARTY_BUILD)
shutil.move(_get_backup_file(MP_THIRD_PARTY_BUILD), MP_THIRD_PARTY_BUILD) shutil.move(_get_backup_file(MP_THIRD_PARTY_BUILD), MP_THIRD_PARTY_BUILD)
for init_py in SUBDIR_INIT_PY_FILES: for init_py in DIR_INIT_PY_FILES:
os.remove(init_py) os.remove(init_py)
clean.clean.run(self)
setuptools.setup( setuptools.setup(
@ -426,11 +427,10 @@ setuptools.setup(
packages=setuptools.find_packages(exclude=['mediapipe.examples.desktop.*']), packages=setuptools.find_packages(exclude=['mediapipe.examples.desktop.*']),
install_requires=_parse_requirements('requirements.txt'), install_requires=_parse_requirements('requirements.txt'),
cmdclass={ cmdclass={
'build': Build, 'build_py': BuildPy,
'gen_protos': GeneratePyProtos, 'gen_protos': GeneratePyProtos,
'modify_inits': ModifyInitFiles,
'build_binary_graphs': BuildBinaryGraphs, 'build_binary_graphs': BuildBinaryGraphs,
'build_ext': BuildBazelExtension, 'build_ext': BuildExtension,
'install': Install, 'install': Install,
'remove_generated': RemoveGenerated, 'remove_generated': RemoveGenerated,
}, },
@ -451,6 +451,7 @@ setuptools.setup(
'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3 :: Only', 'Programming Language :: Python :: 3 :: Only',
'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Artificial Intelligence',
@ -461,5 +462,3 @@ setuptools.setup(
license='Apache 2.0', license='Apache 2.0',
keywords='mediapipe', keywords='mediapipe',
) )
os.remove(ROOT_INIT_PY)

View File

@ -17,7 +17,7 @@
# Script to setup Android SDK and NDK. # Script to setup Android SDK and NDK.
# usage: # usage:
# $ cd <mediapipe root dir> # $ cd <mediapipe root dir>
# $ bash ./setup_android_sdk_and_ndk.sh ~/Android/Sdk ~/Android/Ndk r19c # $ bash ./setup_android_sdk_and_ndk.sh ~/Android/Sdk ~/Android/Ndk r21
set -e set -e
@ -54,8 +54,8 @@ fi
if [ -z $3 ] if [ -z $3 ]
then then
echo "Warning: ndk_version (argument 3) is not specified. Fallback to r19c." echo "Warning: ndk_version (argument 3) is not specified. Fallback to r21."
ndk_version="r19c" ndk_version="r21"
fi fi
if [ -d "$android_sdk_path" ] if [ -d "$android_sdk_path" ]
@ -64,11 +64,11 @@ then
else else
rm -rf /tmp/android_sdk/ rm -rf /tmp/android_sdk/
mkdir /tmp/android_sdk/ mkdir /tmp/android_sdk/
curl https://dl.google.com/android/repository/commandlinetools-${platform_android_sdk}-6609375_latest.zip -o /tmp/android_sdk/commandline_tools.zip curl https://dl.google.com/android/repository/commandlinetools-${platform_android_sdk}-7583922_latest.zip -o /tmp/android_sdk/commandline_tools.zip
unzip /tmp/android_sdk/commandline_tools.zip -d /tmp/android_sdk/ unzip /tmp/android_sdk/commandline_tools.zip -d /tmp/android_sdk/
mkdir -p $android_sdk_path mkdir -p $android_sdk_path
/tmp/android_sdk/tools/bin/sdkmanager --update --sdk_root=${android_sdk_path} /tmp/android_sdk/cmdline-tools/bin/sdkmanager --update --sdk_root=${android_sdk_path}
/tmp/android_sdk/tools/bin/sdkmanager "build-tools;29.0.1" "platform-tools" "platforms;android-29" --sdk_root=${android_sdk_path} /tmp/android_sdk/cmdline-tools/bin/sdkmanager "build-tools;30.0.3" "platform-tools" "platforms;android-30" "extras;android;m2repository" --sdk_root=${android_sdk_path}
rm -rf /tmp/android_sdk/ rm -rf /tmp/android_sdk/
echo "Android SDK is now installed. Consider setting \$ANDROID_HOME environment variable to be ${android_sdk_path}" echo "Android SDK is now installed. Consider setting \$ANDROID_HOME environment variable to be ${android_sdk_path}"
fi fi
@ -88,22 +88,6 @@ fi
echo "Set android_ndk_repository and android_sdk_repository in WORKSPACE" echo "Set android_ndk_repository and android_sdk_repository in WORKSPACE"
workspace_file="$( cd "$(dirname "$0")" ; pwd -P )"/WORKSPACE workspace_file="$( cd "$(dirname "$0")" ; pwd -P )"/WORKSPACE
echo "android_sdk_repository(name = \"androidsdk\", path = \"${android_sdk_path}\")" >> $workspace_file
ndk_block=$(grep -n 'android_ndk_repository(' $workspace_file | awk -F ":" '{print $1}') echo "android_ndk_repository(name = \"androidndk\", path = \"${android_ndk_path}/android-ndk-${ndk_version}\")" >> $workspace_file
ndk_path_line=$((ndk_block+2))'i'
sdk_block=$(grep -n 'android_sdk_repository(' $workspace_file | awk -F ":" '{print $1}')
sdk_path_line=$((sdk_block+3))'i'
if [ $platform == "darwin" ]; then
sed -i -e "$ndk_path_line\\
\ \ \ \ path = \"${android_ndk_path}/android-ndk-${ndk_version}\",
" $workspace_file
sed -i -e "$sdk_path_line\\
\ \ \ \ path = \"${android_sdk_path}\",
" $workspace_file
elif [ $platform == "linux" ]; then
sed -i "$ndk_path_line \ path = \"${android_ndk_path}/android-ndk-${ndk_version}\"," $workspace_file
sed -i "$sdk_path_line \ path = \"${android_sdk_path}\"," $workspace_file
fi
echo "Done" echo "Done"

View File

@ -12,3 +12,15 @@ index ce170b2..bb5aa82 100644
+ "@com_github_glog_glog//:glog", + "@com_github_glog_glog//:glog",
], ],
) )
diff --git a/bazel/ceres.bzl b/bazel/ceres.bzl
index ce170b2..8dd62c5 100644
--- a/bazel/ceres.bzl
+++ b/bazel/ceres.bzl
@@ -116,7 +116,6 @@ CERES_SRCS = ["internal/ceres/" + filename for filename in [
"sparse_cholesky.cc",
"sparse_matrix.cc",
"sparse_normal_cholesky_solver.cc",
- "split.cc",
"stringprintf.cc",
"subset_preconditioner.cc",
"suitesparse.cc",

View File

@ -1,8 +1,8 @@
diff --git a/BUILD b/BUILD diff --git a/BUILD b/BUILD
index 79871d621..51b3a063f 100644 index 1690d4219..e13ca8338 100644
--- a/BUILD --- a/BUILD
+++ b/BUILD +++ b/BUILD
@@ -26,7 +26,7 @@ config_setting( @@ -19,7 +19,7 @@ exports_files(["LICENSE"])
# ZLIB configuration # ZLIB configuration
################################################################################ ################################################################################
@ -11,7 +11,7 @@ index 79871d621..51b3a063f 100644
################################################################################ ################################################################################
# Protobuf Runtime Library # Protobuf Runtime Library
@@ -157,6 +157,7 @@ cc_library( @@ -197,6 +197,7 @@ cc_library(
includes = ["src/"], includes = ["src/"],
linkopts = LINK_OPTS, linkopts = LINK_OPTS,
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
@ -19,7 +19,7 @@ index 79871d621..51b3a063f 100644
) )
PROTOBUF_DEPS = select({ PROTOBUF_DEPS = select({
@@ -230,6 +231,7 @@ cc_library( @@ -271,6 +272,7 @@ cc_library(
linkopts = LINK_OPTS, linkopts = LINK_OPTS,
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [":protobuf_lite"] + PROTOBUF_DEPS, deps = [":protobuf_lite"] + PROTOBUF_DEPS,
@ -27,63 +27,16 @@ index 79871d621..51b3a063f 100644
) )
# This provides just the header files for use in projects that need to build # This provides just the header files for use in projects that need to build
@@ -318,13 +320,13 @@ cc_proto_library(
[native_cc_proto_library(
name = proto + "_cc_proto",
- deps = [proto + "_proto"],
visibility = ["//visibility:private"],
+ deps = [proto + "_proto"],
) for proto in WELL_KNOWN_PROTO_MAP.keys()]
cc_proto_blacklist_test(
name = "cc_proto_blacklist_test",
- deps = [proto + "_cc_proto" for proto in WELL_KNOWN_PROTO_MAP.keys()]
+ deps = [proto + "_cc_proto" for proto in WELL_KNOWN_PROTO_MAP.keys()],
)
################################################################################
@@ -900,7 +902,6 @@ py_proto_library(
py_extra_srcs = glob(["python/**/__init__.py"]),
py_libs = [
":python_srcs",
- "@six//:six",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
@@ -1002,7 +1003,9 @@ cc_library(
# Note: We use `native_proto_common` here because we depend on an implementation-detail of
# `proto_lang_toolchain`, which may not be available on `proto_common`.
reject_blacklisted_files = hasattr(native_proto_common, "proto_lang_toolchain_rejects_files_do_not_use_or_we_will_break_you_without_mercy")
+
cc_toolchain_blacklisted_protos = [proto + "_proto" for proto in WELL_KNOWN_PROTO_MAP.keys()] if reject_blacklisted_files else [":well_known_protos"]
+
proto_lang_toolchain(
name = "cc_toolchain",
blacklisted_protos = cc_toolchain_blacklisted_protos,
diff --git a/protobuf.bzl b/protobuf.bzl
index 829464d44..4ac23594b 100644
--- a/protobuf.bzl
+++ b/protobuf.bzl
@@ -87,6 +87,8 @@ def _proto_gen_impl(ctx):
for dep in ctx.attr.deps:
import_flags += dep.proto.import_flags
deps += dep.proto.deps
+ import_flags = depset(import_flags).to_list()
+ deps = depset(deps).to_list()
if not ctx.attr.gen_cc and not ctx.attr.gen_py and not ctx.executable.plugin:
return struct(
diff --git a/src/google/protobuf/io/gzip_stream.h b/src/google/protobuf/io/gzip_stream.h diff --git a/src/google/protobuf/io/gzip_stream.h b/src/google/protobuf/io/gzip_stream.h
index b1ce1d36c..d5d560ea7 100644 index f0283e86f..436c6ce4b 100644
--- a/src/google/protobuf/io/gzip_stream.h --- a/src/google/protobuf/io/gzip_stream.h
+++ b/src/google/protobuf/io/gzip_stream.h +++ b/src/google/protobuf/io/gzip_stream.h
@@ -47,10 +47,12 @@ @@ -47,10 +47,13 @@
#include <google/protobuf/stubs/common.h> #include <google/protobuf/stubs/common.h>
#include <google/protobuf/io/zero_copy_stream.h> #include <google/protobuf/io/zero_copy_stream.h>
#include <google/protobuf/port.h> #include <google/protobuf/port.h>
-#include <zlib.h> -#include <zlib.h>
-
#include <google/protobuf/port_def.inc> #include <google/protobuf/port_def.inc>
+#if HAVE_ZLIB +#if HAVE_ZLIB
@ -93,7 +46,7 @@ index b1ce1d36c..d5d560ea7 100644
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace io { namespace io {
@@ -76,8 +78,10 @@ class PROTOBUF_EXPORT GzipInputStream : public ZeroCopyInputStream { @@ -76,8 +79,10 @@ class PROTOBUF_EXPORT GzipInputStream : public ZeroCopyInputStream {
virtual ~GzipInputStream(); virtual ~GzipInputStream();
// Return last error message or NULL if no error. // Return last error message or NULL if no error.
@ -103,8 +56,8 @@ index b1ce1d36c..d5d560ea7 100644
+ #endif // HAVE_ZLIB + #endif // HAVE_ZLIB
// implements ZeroCopyInputStream ---------------------------------- // implements ZeroCopyInputStream ----------------------------------
bool Next(const void** data, int* size); bool Next(const void** data, int* size) override;
@@ -90,8 +94,10 @@ class PROTOBUF_EXPORT GzipInputStream : public ZeroCopyInputStream { @@ -90,8 +95,10 @@ class PROTOBUF_EXPORT GzipInputStream : public ZeroCopyInputStream {
ZeroCopyInputStream* sub_stream_; ZeroCopyInputStream* sub_stream_;
@ -115,19 +68,18 @@ index b1ce1d36c..d5d560ea7 100644
void* output_buffer_; void* output_buffer_;
void* output_position_; void* output_position_;
@@ -142,9 +148,11 @@ class PROTOBUF_EXPORT GzipOutputStream : public ZeroCopyOutputStream { @@ -143,8 +150,10 @@ class PROTOBUF_EXPORT GzipOutputStream : public ZeroCopyOutputStream {
virtual ~GzipOutputStream(); virtual ~GzipOutputStream();
+#if HAVE_ZLIB
// Return last error message or NULL if no error. // Return last error message or NULL if no error.
+ #if HAVE_ZLIB
inline const char* ZlibErrorMessage() const { return zcontext_.msg; } inline const char* ZlibErrorMessage() const { return zcontext_.msg; }
inline int ZlibErrorCode() const { return zerror_; } inline int ZlibErrorCode() const { return zerror_; }
+ #endif // HAVE_ZLIB + #endif // HAVE_ZLIB
// Flushes data written so far to zipped data in the underlying stream. // Flushes data written so far to zipped data in the underlying stream.
// It is the caller's responsibility to flush the underlying stream if // It is the caller's responsibility to flush the underlying stream if
@@ -177,8 +185,10 @@ class PROTOBUF_EXPORT GzipOutputStream : public ZeroCopyOutputStream { @@ -177,8 +186,10 @@ class PROTOBUF_EXPORT GzipOutputStream : public ZeroCopyOutputStream {
void* sub_data_; void* sub_data_;
int sub_data_size_; int sub_data_size_;

View File

@ -24,13 +24,13 @@ index b7c22ae77ba..d0ba7b48b4b 100644
} }
} }
diff --git a/tensorflow/core/platform/test.h b/tensorflow/core/platform/test.h diff --git a/tensorflow/core/platform/test.h b/tensorflow/core/platform/test.h
index 94b4853a810..75589d04a60 100644 index b598b6ee1e4..51c013a2d62 100644
--- a/tensorflow/core/platform/test.h --- a/tensorflow/core/platform/test.h
+++ b/tensorflow/core/platform/test.h +++ b/tensorflow/core/platform/test.h
@@ -40,7 +40,6 @@ limitations under the License. @@ -40,7 +40,6 @@ limitations under the License.
// The advantages of using gmock matchers instead of self defined matchers are
// better error messages, more maintainable tests and more test coverage. // better error messages, more maintainable tests and more test coverage.
#if !defined(PLATFORM_GOOGLE) && !defined(PLATFORM_GOOGLE_ANDROID) #if !defined(PLATFORM_GOOGLE) && !defined(PLATFORM_GOOGLE_ANDROID) && \
!defined(PLATFORM_CHROMIUMOS)
-#include <gmock/gmock-generated-matchers.h> -#include <gmock/gmock-generated-matchers.h>
#include <gmock/gmock-matchers.h> #include <gmock/gmock-matchers.h>
#include <gmock/gmock-more-matchers.h> #include <gmock/gmock-more-matchers.h>