diff --git a/.bazelversion b/.bazelversion index 91ff57278..f3b5af39e 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -5.2.0 +6.1.1 diff --git a/Dockerfile b/Dockerfile index 3df22cc04..03b335823 100644 --- a/Dockerfile +++ b/Dockerfile @@ -61,7 +61,7 @@ RUN pip3 install tf_slim RUN ln -s /usr/bin/python3 /usr/bin/python # Install bazel -ARG BAZEL_VERSION=5.2.0 +ARG BAZEL_VERSION=6.1.1 RUN mkdir /bazel && \ wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/b\ azel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ diff --git a/README.md b/README.md index 012ea3a27..a82c88ab1 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,20 @@ nav_order: 1 ![MediaPipe](https://mediapipe.dev/images/mediapipe_small.png) +---- + +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +*This notice and web page will be removed on June 1, 2023.* + +---- + +









+









+









+ -------------------------------------------------------------------------------- ## Live ML anywhere @@ -21,15 +35,6 @@ ML solutions for live and streaming media. ---- -**Attention:** *Thanks for your interest in MediaPipe! We are moving to -[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) -as the primary developer documentation -site for MediaPipe starting April 3, 2023.* - -*This notice and web page will be removed on April 3, 2023.* - ----- - ## ML solutions in MediaPipe Face Detection | Face Mesh | Iris | Hands | Pose | Holistic diff --git a/WORKSPACE b/WORKSPACE index 17e96c0b2..199b6a000 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -54,6 +54,76 @@ load("@rules_foreign_cc//:workspace_definitions.bzl", "rules_foreign_cc_dependen rules_foreign_cc_dependencies() +http_archive( + name = "com_google_protobuf", + sha256 = "87407cd28e7a9c95d9f61a098a53cf031109d451a7763e7dd1253abf8b4df422", + strip_prefix = "protobuf-3.19.1", + urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.19.1.tar.gz"], + patches = [ + "@//third_party:com_google_protobuf_fixes.diff" + ], + patch_args = [ + "-p1", + ], +) + +# Load Zlib before initializing TensorFlow and the iOS build rules to guarantee +# that the target @zlib//:mini_zlib is available +http_archive( + name = "zlib", + build_file = "@//third_party:zlib.BUILD", + sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", + strip_prefix = "zlib-1.2.11", + urls = [ + "http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz", + "http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15 + ], + patches = [ + "@//third_party:zlib.diff", + ], + patch_args = [ + "-p1", + ], +) + +# iOS basic build deps. +http_archive( + name = "build_bazel_rules_apple", + sha256 = "3e2c7ae0ddd181c4053b6491dad1d01ae29011bc322ca87eea45957c76d3a0c3", + url = "https://github.com/bazelbuild/rules_apple/releases/download/2.1.0/rules_apple.2.1.0.tar.gz", + patches = [ + # Bypass checking ios unit test runner when building MP ios applications. + "@//third_party:build_bazel_rules_apple_bypass_test_runner_check.diff" + ], + patch_args = [ + "-p1", + ], +) + +load( + "@build_bazel_rules_apple//apple:repositories.bzl", + "apple_rules_dependencies", +) +apple_rules_dependencies() + +load( + "@build_bazel_rules_swift//swift:repositories.bzl", + "swift_rules_dependencies", +) +swift_rules_dependencies() + +load( + "@build_bazel_rules_swift//swift:extras.bzl", + "swift_rules_extra_dependencies", +) +swift_rules_extra_dependencies() + +load( + "@build_bazel_apple_support//lib:repositories.bzl", + "apple_support_dependencies", +) +apple_support_dependencies() + # This is used to select all contents of the archives for CMake-based packages to give CMake access to them. all_content = """filegroup(name = "all", srcs = glob(["**"]), visibility = ["//visibility:public"])""" @@ -133,19 +203,6 @@ http_archive( urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.19.1.tar.gz"], ) -http_archive( - name = "com_google_protobuf", - sha256 = "87407cd28e7a9c95d9f61a098a53cf031109d451a7763e7dd1253abf8b4df422", - strip_prefix = "protobuf-3.19.1", - urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.19.1.tar.gz"], - patches = [ - "@//third_party:com_google_protobuf_fixes.diff" - ], - patch_args = [ - "-p1", - ], -) - load("@//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") flatbuffers() @@ -319,63 +376,6 @@ http_archive( ], ) -# Load Zlib before initializing TensorFlow and the iOS build rules to guarantee -# that the target @zlib//:mini_zlib is available -http_archive( - name = "zlib", - build_file = "@//third_party:zlib.BUILD", - sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", - strip_prefix = "zlib-1.2.11", - urls = [ - "http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz", - "http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15 - ], - patches = [ - "@//third_party:zlib.diff", - ], - patch_args = [ - "-p1", - ], -) - -# iOS basic build deps. -http_archive( - name = "build_bazel_rules_apple", - sha256 = "f94e6dddf74739ef5cb30f000e13a2a613f6ebfa5e63588305a71fce8a8a9911", - url = "https://github.com/bazelbuild/rules_apple/releases/download/1.1.3/rules_apple.1.1.3.tar.gz", - patches = [ - # Bypass checking ios unit test runner when building MP ios applications. - "@//third_party:build_bazel_rules_apple_bypass_test_runner_check.diff" - ], - patch_args = [ - "-p1", - ], -) - -load( - "@build_bazel_rules_apple//apple:repositories.bzl", - "apple_rules_dependencies", -) -apple_rules_dependencies() - -load( - "@build_bazel_rules_swift//swift:repositories.bzl", - "swift_rules_dependencies", -) -swift_rules_dependencies() - -load( - "@build_bazel_rules_swift//swift:extras.bzl", - "swift_rules_extra_dependencies", -) -swift_rules_extra_dependencies() - -load( - "@build_bazel_apple_support//lib:repositories.bzl", - "apple_support_dependencies", -) -apple_support_dependencies() - # More iOS deps. http_archive( diff --git a/docs/framework_concepts/building_graphs_cpp.md b/docs/framework_concepts/building_graphs_cpp.md index 250cd89e2..26cdfe1e4 100644 --- a/docs/framework_concepts/building_graphs_cpp.md +++ b/docs/framework_concepts/building_graphs_cpp.md @@ -1,5 +1,6 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/framework/framework_concepts/graphs_cpp title: Building Graphs in C++ parent: Graphs nav_order: 1 @@ -12,6 +13,12 @@ nav_order: 1 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + C++ graph builder is a powerful tool for: * Building complex graphs diff --git a/docs/framework_concepts/calculators.md b/docs/framework_concepts/calculators.md index 5c51a3ec5..3a3661dd4 100644 --- a/docs/framework_concepts/calculators.md +++ b/docs/framework_concepts/calculators.md @@ -13,6 +13,12 @@ nav_order: 1 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + Each calculator is a node of a graph. We describe how to create a new calculator, how to initialize a calculator, how to perform its calculations, input and output streams, timestamps, and options. Each node in the graph is diff --git a/docs/framework_concepts/framework_concepts.md b/docs/framework_concepts/framework_concepts.md index 5d953480e..004c75cff 100644 --- a/docs/framework_concepts/framework_concepts.md +++ b/docs/framework_concepts/framework_concepts.md @@ -14,6 +14,12 @@ has_toc: false {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + ## The basics ### Packet diff --git a/docs/framework_concepts/gpu.md b/docs/framework_concepts/gpu.md index 3c411d55c..8900ab3b4 100644 --- a/docs/framework_concepts/gpu.md +++ b/docs/framework_concepts/gpu.md @@ -13,6 +13,12 @@ nav_order: 5 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + ## Overview MediaPipe supports calculator nodes for GPU compute and rendering, and allows combining multiple GPU nodes, as well as mixing them with CPU based calculator nodes. There exist several GPU APIs on mobile platforms (eg, OpenGL ES, Metal and Vulkan). MediaPipe does not attempt to offer a single cross-API GPU abstraction. Individual nodes can be written using different APIs, allowing them to take advantage of platform specific features when needed. diff --git a/docs/framework_concepts/graphs.md b/docs/framework_concepts/graphs.md index 0d38c75fc..5f9c68e08 100644 --- a/docs/framework_concepts/graphs.md +++ b/docs/framework_concepts/graphs.md @@ -13,6 +13,12 @@ nav_order: 2 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + ## Graph A `CalculatorGraphConfig` proto specifies the topology and functionality of a diff --git a/docs/framework_concepts/packets.md b/docs/framework_concepts/packets.md index 100bc6b01..1bfad376d 100644 --- a/docs/framework_concepts/packets.md +++ b/docs/framework_concepts/packets.md @@ -13,6 +13,12 @@ nav_order: 3 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + Calculators communicate by sending and receiving packets. Typically a single packet is sent along each input stream at each input timestamp. A packet can contain any kind of data, such as a single frame of video or a single integer diff --git a/docs/framework_concepts/realtime_streams.md b/docs/framework_concepts/realtime_streams.md index 43d147f55..60f586cb9 100644 --- a/docs/framework_concepts/realtime_streams.md +++ b/docs/framework_concepts/realtime_streams.md @@ -13,6 +13,12 @@ nav_order: 6 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + ## Real-time timestamps MediaPipe calculator graphs are often used to process streams of video or audio diff --git a/docs/framework_concepts/synchronization.md b/docs/framework_concepts/synchronization.md index e12d077a7..8a0a907c5 100644 --- a/docs/framework_concepts/synchronization.md +++ b/docs/framework_concepts/synchronization.md @@ -13,6 +13,12 @@ nav_order: 4 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + ## Scheduling mechanics Data processing in a MediaPipe graph occurs inside processing nodes defined as diff --git a/docs/getting_started/android.md b/docs/getting_started/android.md index cb99a6fef..83fbd1c93 100644 --- a/docs/getting_started/android.md +++ b/docs/getting_started/android.md @@ -15,6 +15,12 @@ nav_order: 1 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + Please follow instructions below to build Android example apps in the supported MediaPipe [solutions](../solutions/solutions.md). To learn more about these example apps, start from [Hello World! on Android](./hello_world_android.md). diff --git a/docs/getting_started/android_archive_library.md b/docs/getting_started/android_archive_library.md index 9b92ef498..7d98b32c5 100644 --- a/docs/getting_started/android_archive_library.md +++ b/docs/getting_started/android_archive_library.md @@ -14,6 +14,12 @@ nav_order: 3 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + ***Experimental Only*** The MediaPipe Android Archive (AAR) library is a convenient way to use MediaPipe diff --git a/docs/getting_started/android_solutions.md b/docs/getting_started/android_solutions.md index 0c492c1bb..159d1358d 100644 --- a/docs/getting_started/android_solutions.md +++ b/docs/getting_started/android_solutions.md @@ -1,5 +1,6 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/ title: MediaPipe Android Solutions parent: MediaPipe on Android grand_parent: Getting Started @@ -13,14 +14,9 @@ nav_order: 2 {:toc} --- -**Attention:** *Thanks for your interest in MediaPipe! We are moving to +**Attention:** *Thanks for your interest in MediaPipe! We have moved to [https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) -as the primary developer documentation -site for MediaPipe starting April 3, 2023. This content will not be moved to -the new site, but will remain available in the source code repository on an -as-is basis.* - -*This notice and web page will be removed on April 3, 2023.* +as the primary developer documentation site for MediaPipe as of April 3, 2023.* ---- diff --git a/docs/getting_started/building_examples.md b/docs/getting_started/building_examples.md index 20c30bef2..a77f6ea66 100644 --- a/docs/getting_started/building_examples.md +++ b/docs/getting_started/building_examples.md @@ -1,5 +1,6 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/ title: Building MediaPipe Examples parent: Getting Started nav_exclude: true @@ -12,14 +13,9 @@ nav_exclude: true {:toc} --- -**Attention:** *Thanks for your interest in MediaPipe! We are moving to +**Attention:** *Thanks for your interest in MediaPipe! We have moved to [https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) -as the primary developer documentation -site for MediaPipe starting April 3, 2023. This content will not be moved to -the new site, but will remain available in the source code repository on an -as-is basis.* - -*This notice and web page will be removed on April 3, 2023.* +as the primary developer documentation site for MediaPipe as of April 3, 2023.* ---- diff --git a/docs/getting_started/cpp.md b/docs/getting_started/cpp.md index 3afde767f..d708866a7 100644 --- a/docs/getting_started/cpp.md +++ b/docs/getting_started/cpp.md @@ -15,6 +15,12 @@ nav_order: 5 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + Please follow instructions below to build C++ command-line example apps in the supported MediaPipe [solutions](../solutions/solutions.md). To learn more about these example apps, start from [Hello World! in C++](./hello_world_cpp.md). diff --git a/docs/getting_started/faq.md b/docs/getting_started/faq.md index b7c24e6ec..ca50ae530 100644 --- a/docs/getting_started/faq.md +++ b/docs/getting_started/faq.md @@ -13,6 +13,12 @@ nav_order: 9 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + ### How to convert ImageFrames and GpuBuffers The Calculators [`ImageFrameToGpuBufferCalculator`] and diff --git a/docs/getting_started/getting_started.md b/docs/getting_started/getting_started.md index fea9cfa73..db605b4b4 100644 --- a/docs/getting_started/getting_started.md +++ b/docs/getting_started/getting_started.md @@ -1,5 +1,6 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/ title: Getting Started nav_order: 2 has_children: true @@ -12,13 +13,8 @@ has_children: true {:toc} --- -**Attention:** *Thanks for your interest in MediaPipe! We are moving to +**Attention:** *Thanks for your interest in MediaPipe! We have moved to [https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) -as the primary developer documentation -site for MediaPipe starting April 3, 2023. This content will not be moved to -the new site, but will remain available in the source code repository on an -as-is basis.* - -*This notice and web page will be removed on April 3, 2023.* +as the primary developer documentation site for MediaPipe as of April 3, 2023.* ---- diff --git a/docs/getting_started/gpu_support.md b/docs/getting_started/gpu_support.md index 4bd1efeb8..6c0e8be0f 100644 --- a/docs/getting_started/gpu_support.md +++ b/docs/getting_started/gpu_support.md @@ -13,6 +13,12 @@ nav_order: 7 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + ## OpenGL ES Support MediaPipe supports OpenGL ES up to version 3.2 on Android/Linux and up to ES 3.0 diff --git a/docs/getting_started/hello_world_android.md b/docs/getting_started/hello_world_android.md index 012743048..1148ff5a9 100644 --- a/docs/getting_started/hello_world_android.md +++ b/docs/getting_started/hello_world_android.md @@ -14,6 +14,12 @@ nav_order: 1 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + ## Introduction This codelab uses MediaPipe on an Android device. diff --git a/docs/getting_started/hello_world_cpp.md b/docs/getting_started/hello_world_cpp.md index 880248725..7c8f9be3e 100644 --- a/docs/getting_started/hello_world_cpp.md +++ b/docs/getting_started/hello_world_cpp.md @@ -14,6 +14,12 @@ nav_order: 1 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + 1. Ensure you have a working version of MediaPipe. See [installation instructions](./install.md). diff --git a/docs/getting_started/hello_world_ios.md b/docs/getting_started/hello_world_ios.md index 713dbc79a..4be097646 100644 --- a/docs/getting_started/hello_world_ios.md +++ b/docs/getting_started/hello_world_ios.md @@ -14,6 +14,12 @@ nav_order: 1 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + ## Introduction This codelab uses MediaPipe on an iOS device. diff --git a/docs/getting_started/help.md b/docs/getting_started/help.md index 3ba052741..2cb6b9e68 100644 --- a/docs/getting_started/help.md +++ b/docs/getting_started/help.md @@ -13,6 +13,12 @@ nav_order: 8 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + ## Technical questions For help with technical or algorithmic questions, visit diff --git a/docs/getting_started/install.md b/docs/getting_started/install.md index cc5c0241d..b30284779 100644 --- a/docs/getting_started/install.md +++ b/docs/getting_started/install.md @@ -13,6 +13,12 @@ nav_order: 6 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + Note: To interoperate with OpenCV, OpenCV 3.x to 4.1 are preferred. OpenCV 2.x currently works but interoperability support may be deprecated in the future. @@ -577,7 +583,7 @@ next section. Option 1. Follow [the official Bazel documentation](https://docs.bazel.build/versions/master/install-windows.html) - to install Bazel 5.2.0 or higher. + to install Bazel 6.1.1 or higher. Option 2. Follow the official [Bazel documentation](https://docs.bazel.build/versions/master/install-bazelisk.html) diff --git a/docs/getting_started/ios.md b/docs/getting_started/ios.md index c4b8fb99e..798017aa3 100644 --- a/docs/getting_started/ios.md +++ b/docs/getting_started/ios.md @@ -15,6 +15,12 @@ nav_order: 2 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + Please follow instructions below to build iOS example apps in the supported MediaPipe [solutions](../solutions/solutions.md). To learn more about these example apps, start from, start from diff --git a/docs/getting_started/javascript.md b/docs/getting_started/javascript.md index 79269827b..e68d40917 100644 --- a/docs/getting_started/javascript.md +++ b/docs/getting_started/javascript.md @@ -1,5 +1,6 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/ title: MediaPipe in JavaScript parent: Getting Started nav_order: 4 @@ -14,12 +15,7 @@ nav_order: 4 **Attention:** *Thanks for your interest in MediaPipe! We are moving to [https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) -as the primary developer documentation -site for MediaPipe starting April 3, 2023. This content will not be moved to -the new site, but will remain available in the source code repository on an -as-is basis.* - -*This notice and web page will be removed on April 3, 2023.* +as the primary developer documentation site for MediaPipe starting April 3, 2023.* ---- diff --git a/docs/getting_started/python.md b/docs/getting_started/python.md index 880d5c85d..43f452a50 100644 --- a/docs/getting_started/python.md +++ b/docs/getting_started/python.md @@ -1,5 +1,6 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/ title: MediaPipe in Python parent: Getting Started has_children: true @@ -14,6 +15,12 @@ nav_order: 3 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + ## Ready-to-use Python Solutions MediaPipe offers ready-to-use yet customizable Python solutions as a prebuilt diff --git a/docs/getting_started/python_framework.md b/docs/getting_started/python_framework.md index db5bc0cd4..6d4b7d450 100644 --- a/docs/getting_started/python_framework.md +++ b/docs/getting_started/python_framework.md @@ -12,6 +12,11 @@ nav_order: 1 1. TOC {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- The MediaPipe Python framework grants direct access to the core components of the MediaPipe C++ framework such as Timestamp, Packet, and CalculatorGraph, diff --git a/docs/getting_started/troubleshooting.md b/docs/getting_started/troubleshooting.md index 0da25497d..e7dff332c 100644 --- a/docs/getting_started/troubleshooting.md +++ b/docs/getting_started/troubleshooting.md @@ -13,6 +13,12 @@ nav_order: 10 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + ## Missing Python binary path The error message: diff --git a/docs/index.md b/docs/index.md index 012ea3a27..a82c88ab1 100644 --- a/docs/index.md +++ b/docs/index.md @@ -6,6 +6,20 @@ nav_order: 1 ![MediaPipe](https://mediapipe.dev/images/mediapipe_small.png) +---- + +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +*This notice and web page will be removed on June 1, 2023.* + +---- + +









+









+









+ -------------------------------------------------------------------------------- ## Live ML anywhere @@ -21,15 +35,6 @@ ML solutions for live and streaming media. ---- -**Attention:** *Thanks for your interest in MediaPipe! We are moving to -[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) -as the primary developer documentation -site for MediaPipe starting April 3, 2023.* - -*This notice and web page will be removed on April 3, 2023.* - ----- - ## ML solutions in MediaPipe Face Detection | Face Mesh | Iris | Hands | Pose | Holistic diff --git a/docs/index.rst b/docs/index.rst index 4563284bd..fc7a6f50f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,3 +1,3 @@ MediaPipe ===================================== -Please see https://docs.mediapipe.dev. +Please see https://developers.google.com/mediapipe/ diff --git a/docs/solutions/autoflip.md b/docs/solutions/autoflip.md index d0a763436..a9e1e7052 100644 --- a/docs/solutions/autoflip.md +++ b/docs/solutions/autoflip.md @@ -1,7 +1,8 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/solutions/guide#legacy title: AutoFlip (Saliency-aware Video Cropping) -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 14 --- @@ -20,12 +21,10 @@ nav_order: 14 **Attention:** *Thank you for your interest in MediaPipe Solutions. We have ended support for this MediaPipe Legacy Solution as of March 1, 2023. -For more information, see the new +For more information, see the [MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) site.* -*This notice and web page will be removed on April 3, 2023.* - ---- ## Overview diff --git a/docs/solutions/box_tracking.md b/docs/solutions/box_tracking.md index 4fecc5150..537916ac4 100644 --- a/docs/solutions/box_tracking.md +++ b/docs/solutions/box_tracking.md @@ -1,7 +1,8 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/solutions/guide#legacy title: Box Tracking -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 10 --- @@ -20,12 +21,10 @@ nav_order: 10 **Attention:** *Thank you for your interest in MediaPipe Solutions. We have ended support for this MediaPipe Legacy Solution as of March 1, 2023. -For more information, see the new +For more information, see the [MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) site.* -*This notice and web page will be removed on April 3, 2023.* - ---- ## Overview diff --git a/docs/solutions/face_detection.md b/docs/solutions/face_detection.md index 789d9b3dd..f060d062c 100644 --- a/docs/solutions/face_detection.md +++ b/docs/solutions/face_detection.md @@ -1,7 +1,8 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/solutions/vision/face_detector/ title: Face Detection -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 1 --- @@ -20,12 +21,10 @@ nav_order: 1 **Attention:** *Thank you for your interest in MediaPipe Solutions. As of March 1, 2023, this solution is planned to be upgraded to a new MediaPipe -Solution. For more information, see the new +Solution. For more information, see the [MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) site.* -*This notice and web page will be removed on April 3, 2023.* - ---- ## Overview diff --git a/docs/solutions/face_mesh.md b/docs/solutions/face_mesh.md index 84fbb22a5..ab34ba401 100644 --- a/docs/solutions/face_mesh.md +++ b/docs/solutions/face_mesh.md @@ -1,7 +1,8 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/solutions/vision/face_landmarker/ title: Face Mesh -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 2 --- @@ -20,12 +21,10 @@ nav_order: 2 **Attention:** *Thank you for your interest in MediaPipe Solutions. As of March 1, 2023, this solution is planned to be upgraded to a new MediaPipe -Solution. For more information, see the new +Solution. For more information, see the [MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) site.* -*This notice and web page will be removed on April 3, 2023.* - ---- ## Overview diff --git a/docs/solutions/hair_segmentation.md b/docs/solutions/hair_segmentation.md index 481cd0058..feb40f9c0 100644 --- a/docs/solutions/hair_segmentation.md +++ b/docs/solutions/hair_segmentation.md @@ -1,7 +1,8 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/solutions/vision/image_segmenter/ title: Hair Segmentation -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 8 --- @@ -19,13 +20,11 @@ nav_order: 8 --- **Attention:** *Thank you for your interest in MediaPipe Solutions. -As of March 1, 2023, this solution is planned to be upgraded to a new MediaPipe -Solution. For more information, see the new -[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) +As of April 4, 2023, this solution was upgraded to a new MediaPipe +Solution. For more information, see the +[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/vision/image_segmenter/) site.* -*This notice and web page will be removed on April 3, 2023.* - ---- ![hair_segmentation_android_gpu_gif](https://mediapipe.dev/images/mobile/hair_segmentation_android_gpu.gif) diff --git a/docs/solutions/hands.md b/docs/solutions/hands.md index a4cd90baa..6cf2264ed 100644 --- a/docs/solutions/hands.md +++ b/docs/solutions/hands.md @@ -1,7 +1,8 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/solutions/vision/hand_landmarker title: Hands -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 4 --- @@ -19,13 +20,11 @@ nav_order: 4 --- **Attention:** *Thank you for your interest in MediaPipe Solutions. -As of March 1, 2023, this solution is planned to be upgraded to a new MediaPipe -Solution. For more information, see the new -[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) +As of March 1, 2023, this solution was upgraded to a new MediaPipe +Solution. For more information, see the +[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/vision/hand_landmarker) site.* -*This notice and web page will be removed on April 3, 2023.* - ---- ## Overview diff --git a/docs/solutions/holistic.md b/docs/solutions/holistic.md index 876a88572..25288ab55 100644 --- a/docs/solutions/holistic.md +++ b/docs/solutions/holistic.md @@ -1,7 +1,8 @@ --- -layout: default +layout: forward +target: https://github.com/google/mediapipe/blob/master/docs/solutions/holistic.md title: Holistic -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 6 --- @@ -20,12 +21,10 @@ nav_order: 6 **Attention:** *Thank you for your interest in MediaPipe Solutions. As of March 1, 2023, this solution is planned to be upgraded to a new MediaPipe -Solution. For more information, see the new +Solution. For more information, see the [MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) site.* -*This notice and web page will be removed on April 3, 2023.* - ---- ## Overview diff --git a/docs/solutions/instant_motion_tracking.md b/docs/solutions/instant_motion_tracking.md index 1e714bdc8..361bc91ff 100644 --- a/docs/solutions/instant_motion_tracking.md +++ b/docs/solutions/instant_motion_tracking.md @@ -1,7 +1,8 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/solutions/guide#legacy title: Instant Motion Tracking -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 11 --- @@ -20,12 +21,10 @@ nav_order: 11 **Attention:** *Thank you for your interest in MediaPipe Solutions. We have ended support for this MediaPipe Legacy Solution as of March 1, 2023. -For more information, see the new +For more information, see the [MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) site.* -*This notice and web page will be removed on April 3, 2023.* - ---- ## Overview diff --git a/docs/solutions/iris.md b/docs/solutions/iris.md index b8459a0e3..eab3dedf6 100644 --- a/docs/solutions/iris.md +++ b/docs/solutions/iris.md @@ -1,7 +1,8 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/solutions/vision/face_landmarker/ title: Iris -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 3 --- @@ -20,12 +21,10 @@ nav_order: 3 **Attention:** *Thank you for your interest in MediaPipe Solutions. As of March 1, 2023, this solution is planned to be upgraded to a new MediaPipe -Solution. For more information, see the new +Solution. For more information, see the [MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) site.* -*This notice and web page will be removed on April 3, 2023.* - ---- ## Overview diff --git a/docs/solutions/knift.md b/docs/solutions/knift.md index ad5d39f22..19e04cb5e 100644 --- a/docs/solutions/knift.md +++ b/docs/solutions/knift.md @@ -1,7 +1,8 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/solutions/guide#legacy title: KNIFT (Template-based Feature Matching) -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 13 --- @@ -20,12 +21,10 @@ nav_order: 13 **Attention:** *Thank you for your interest in MediaPipe Solutions. We have ended support for this MediaPipe Legacy Solution as of March 1, 2023. -For more information, see the new +For more information, see the [MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) site.* -*This notice and web page will be removed on April 3, 2023.* - ---- ## Overview diff --git a/docs/solutions/media_sequence.md b/docs/solutions/media_sequence.md index 5c479ea4c..5224dd371 100644 --- a/docs/solutions/media_sequence.md +++ b/docs/solutions/media_sequence.md @@ -1,7 +1,8 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/solutions/guide#legacy title: Dataset Preparation with MediaSequence -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 15 --- @@ -24,8 +25,6 @@ For more information, see the new [MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) site.* -*This notice and web page will be removed on April 3, 2023.* - ---- ## Overview diff --git a/docs/solutions/models.md b/docs/solutions/models.md index c45aefa44..0af91eb48 100644 --- a/docs/solutions/models.md +++ b/docs/solutions/models.md @@ -1,7 +1,8 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/solutions/guide#legacy title: Models and Model Cards -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 30 --- @@ -22,8 +23,6 @@ MediaPipe Legacy Solutions will continue to be provided on an as-is basis. We encourage you to check out the new MediaPipe Solutions at: [https://developers.google.com/mediapipe/solutions](https://developers.google.com/mediapipe/solutions)* -*This notice and web page will be removed on April 3, 2023.* - ---- ### [Face Detection](https://google.github.io/mediapipe/solutions/face_detection) diff --git a/docs/solutions/object_detection.md b/docs/solutions/object_detection.md index ef7db8671..efa2e5266 100644 --- a/docs/solutions/object_detection.md +++ b/docs/solutions/object_detection.md @@ -1,7 +1,8 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/solutions/vision/object_detector/ title: Object Detection -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 9 --- @@ -19,13 +20,11 @@ nav_order: 9 --- **Attention:** *Thank you for your interest in MediaPipe Solutions. -As of March 1, 2023, this solution is planned to be upgraded to a new MediaPipe -Solution. For more information, see the new -[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) +As of March 1, 2023, this solution was upgraded to a new MediaPipe +Solution. For more information, see the +[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/vision/object_detector/) site.* -*This notice and web page will be removed on April 3, 2023.* - ---- ![object_detection_android_gpu.gif](https://mediapipe.dev/images/mobile/object_detection_android_gpu.gif) diff --git a/docs/solutions/object_detection_saved_model.md b/docs/solutions/object_detection_saved_model.md index 6acac0a1b..1c67bca68 100644 --- a/docs/solutions/object_detection_saved_model.md +++ b/docs/solutions/object_detection_saved_model.md @@ -1,4 +1,31 @@ -## TensorFlow/TFLite Object Detection Model +--- +layout: forward +target: https://developers.google.com/mediapipe/solutions/vision/object_detector +title: Object Detection +parent: MediaPipe Legacy Solutions +nav_order: 9 +--- + +# MediaPipe Object Detection +{: .no_toc } + +
+ + Table of contents + + {: .text-delta } +1. TOC +{:toc} +
+--- + +**Attention:** *Thank you for your interest in MediaPipe Solutions. +As of March 1, 2023, this solution was upgraded to a new MediaPipe +Solution. For more information, see the +[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/vision/object_detector) +site.* + +---- ### TensorFlow model diff --git a/docs/solutions/objectron.md b/docs/solutions/objectron.md index 4ffb27bd0..09f8028bc 100644 --- a/docs/solutions/objectron.md +++ b/docs/solutions/objectron.md @@ -1,7 +1,8 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/solutions/guide#legacy title: Objectron (3D Object Detection) -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 12 --- @@ -20,12 +21,10 @@ nav_order: 12 **Attention:** *Thank you for your interest in MediaPipe Solutions. We have ended support for this MediaPipe Legacy Solution as of March 1, 2023. -For more information, see the new +For more information, see the [MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) site.* -*This notice and web page will be removed on April 3, 2023.* - ---- ## Overview diff --git a/docs/solutions/pose.md b/docs/solutions/pose.md index ce0197ebd..3c9f14f54 100644 --- a/docs/solutions/pose.md +++ b/docs/solutions/pose.md @@ -1,7 +1,8 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/solutions/vision/pose_landmarker/ title: Pose -parent: Solutions +parent: MediaPipe Legacy Solutions has_children: true has_toc: false nav_order: 5 @@ -22,12 +23,10 @@ nav_order: 5 **Attention:** *Thank you for your interest in MediaPipe Solutions. As of March 1, 2023, this solution is planned to be upgraded to a new MediaPipe -Solution. For more information, see the new -[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) +Solution. For more information, see the +[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/vision/pose_landmarker/) site.* -*This notice and web page will be removed on April 3, 2023.* - ---- ## Overview diff --git a/docs/solutions/pose_classification.md b/docs/solutions/pose_classification.md index 24f20f727..8420e2d7c 100644 --- a/docs/solutions/pose_classification.md +++ b/docs/solutions/pose_classification.md @@ -1,8 +1,9 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/solutions/vision/pose_landmarker/ title: Pose Classification parent: Pose -grand_parent: Solutions +grand_parent: MediaPipe Legacy Solutions nav_order: 1 --- @@ -21,12 +22,10 @@ nav_order: 1 **Attention:** *Thank you for your interest in MediaPipe Solutions. As of March 1, 2023, this solution is planned to be upgraded to a new MediaPipe -Solution. For more information, see the new -[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) +Solution. For more information, see the +[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/vision/pose_landmarker/) site.* -*This notice and web page will be removed on April 3, 2023.* - ---- ## Overview diff --git a/docs/solutions/selfie_segmentation.md b/docs/solutions/selfie_segmentation.md index 5febf29f0..17e6fc252 100644 --- a/docs/solutions/selfie_segmentation.md +++ b/docs/solutions/selfie_segmentation.md @@ -1,7 +1,8 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/solutions/vision/image_segmenter/ title: Selfie Segmentation -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 7 --- @@ -19,13 +20,11 @@ nav_order: 7 --- **Attention:** *Thank you for your interest in MediaPipe Solutions. -As of March 1, 2023, this solution is planned to be upgraded to a new MediaPipe -Solution. For more information, see the new -[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) +As of April 4, 2023, this solution was upgraded to a new MediaPipe +Solution. For more information, see the +[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/vision/image_segmenter/) site.* -*This notice and web page will be removed on April 3, 2023.* - ---- ## Overview diff --git a/docs/solutions/solutions.md b/docs/solutions/solutions.md index b65390af7..7bc32d169 100644 --- a/docs/solutions/solutions.md +++ b/docs/solutions/solutions.md @@ -1,12 +1,12 @@ --- layout: default -title: Solutions +title: MediaPipe Legacy Solutions nav_order: 3 has_children: true has_toc: false --- -# Solutions +# MediaPipe Legacy Solutions {: .no_toc } 1. TOC @@ -29,6 +29,12 @@ Solutions at: ---- +









+









+









+ +---- + MediaPipe offers open source cross-platform, customizable ML solutions for live and streaming media. diff --git a/docs/solutions/youtube_8m.md b/docs/solutions/youtube_8m.md index 2e82b85d3..80fb9d9a6 100644 --- a/docs/solutions/youtube_8m.md +++ b/docs/solutions/youtube_8m.md @@ -1,7 +1,8 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/solutions/guide#legacy title: YouTube-8M Feature Extraction and Model Inference -parent: Solutions +parent: MediaPipe Legacy Solutions nav_order: 16 --- @@ -20,12 +21,10 @@ nav_order: 16 **Attention:** *Thank you for your interest in MediaPipe Solutions. We have ended support for this MediaPipe Legacy Solution as of March 1, 2023. -For more information, see the new +For more information, see the [MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) site.* -*This notice and web page will be removed on April 3, 2023.* - ---- MediaPipe is a useful and general framework for media processing that can assist diff --git a/docs/tools/performance_benchmarking.md b/docs/tools/performance_benchmarking.md index f0d334f58..fedbb6e8a 100644 --- a/docs/tools/performance_benchmarking.md +++ b/docs/tools/performance_benchmarking.md @@ -1,5 +1,6 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/ title: Performance Benchmarking parent: Tools nav_order: 3 @@ -12,6 +13,12 @@ nav_order: 3 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +--- + *Coming soon.* Future mediapipe releases will include tools for visualizing and analysing the diff --git a/docs/tools/tools.md b/docs/tools/tools.md index 568ba76a7..8e4c55db3 100644 --- a/docs/tools/tools.md +++ b/docs/tools/tools.md @@ -1,5 +1,6 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/ title: Tools nav_order: 4 has_children: true @@ -11,3 +12,9 @@ has_children: true 1. TOC {:toc} --- + +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- diff --git a/docs/tools/tracing_and_profiling.md b/docs/tools/tracing_and_profiling.md index 861564c99..0ed6f57ab 100644 --- a/docs/tools/tracing_and_profiling.md +++ b/docs/tools/tracing_and_profiling.md @@ -1,5 +1,6 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/ title: Tracing and Profiling parent: Tools nav_order: 2 @@ -12,6 +13,12 @@ nav_order: 2 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +---- + The MediaPipe framework includes a built-in tracer and profiler. The tracer records various timing events related to packet processing, including the start and end time of each Calculator::Process call. The tracer writes trace log files diff --git a/docs/tools/visualizer.md b/docs/tools/visualizer.md index 45111a36e..eb24a7fc5 100644 --- a/docs/tools/visualizer.md +++ b/docs/tools/visualizer.md @@ -13,6 +13,12 @@ nav_order: 1 {:toc} --- +**Attention:** *Thanks for your interest in MediaPipe! We have moved to +[https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) +as the primary developer documentation site for MediaPipe as of April 3, 2023.* + +--- + To help users understand the structure of their calculator graphs and to understand the overall behavior of their machine learning inference pipelines, we have built the [MediaPipe Visualizer](https://viz.mediapipe.dev/) diff --git a/mediapipe/calculators/audio/basic_time_series_calculators.cc b/mediapipe/calculators/audio/basic_time_series_calculators.cc index f7b24f6f6..5006a0b54 100644 --- a/mediapipe/calculators/audio/basic_time_series_calculators.cc +++ b/mediapipe/calculators/audio/basic_time_series_calculators.cc @@ -26,10 +26,11 @@ namespace mediapipe { namespace { static bool SafeMultiply(int x, int y, int* result) { - static_assert(sizeof(int64) >= 2 * sizeof(int), + static_assert(sizeof(int64_t) >= 2 * sizeof(int), "Unable to detect overflow after multiplication"); - const int64 big = static_cast(x) * static_cast(y); - if (big > static_cast(INT_MIN) && big < static_cast(INT_MAX)) { + const int64_t big = static_cast(x) * static_cast(y); + if (big > static_cast(INT_MIN) && + big < static_cast(INT_MAX)) { if (result != nullptr) *result = static_cast(big); return true; } else { diff --git a/mediapipe/calculators/audio/spectrogram_calculator.cc b/mediapipe/calculators/audio/spectrogram_calculator.cc index bd4d8f3bf..939e721ab 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator.cc +++ b/mediapipe/calculators/audio/spectrogram_calculator.cc @@ -182,12 +182,12 @@ class SpectrogramCalculator : public CalculatorBase { int frame_duration_samples_; int frame_overlap_samples_; // How many samples we've been passed, used for checking input time stamps. - int64 cumulative_input_samples_; + int64_t cumulative_input_samples_; // How many frames we've emitted, used for calculating output time stamps. - int64 cumulative_completed_frames_; + int64_t cumulative_completed_frames_; // How many frames were emitted last, used for estimating the timestamp on // Close when use_local_timestamp_ is true; - int64 last_completed_frames_; + int64_t last_completed_frames_; Timestamp initial_input_timestamp_; int num_input_channels_; // How many frequency bins we emit (=N_FFT/2 + 1). diff --git a/mediapipe/calculators/audio/spectrogram_calculator_test.cc b/mediapipe/calculators/audio/spectrogram_calculator_test.cc index 3c2b8435d..b35f30583 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator_test.cc +++ b/mediapipe/calculators/audio/spectrogram_calculator_test.cc @@ -92,8 +92,8 @@ class SpectrogramCalculatorTest .cos() .transpose(); } - int64 input_timestamp = round(packet_start_time_seconds * - Timestamp::kTimestampUnitsPerSecond); + int64_t input_timestamp = round(packet_start_time_seconds * + Timestamp::kTimestampUnitsPerSecond); AppendInputPacket(packet_data, input_timestamp); total_num_input_samples += packet_size_samples; } @@ -116,8 +116,8 @@ class SpectrogramCalculatorTest double packet_start_time_seconds = kInitialTimestampOffsetMicroseconds * 1e-6 + total_num_input_samples / input_sample_rate_; - int64 input_timestamp = round(packet_start_time_seconds * - Timestamp::kTimestampUnitsPerSecond); + int64_t input_timestamp = round(packet_start_time_seconds * + Timestamp::kTimestampUnitsPerSecond); std::unique_ptr impulse( new Matrix(Matrix::Zero(1, packet_sizes_samples[i]))); (*impulse)(0, impulse_offsets_samples[i]) = 1.0; @@ -157,8 +157,8 @@ class SpectrogramCalculatorTest .cos() .transpose(); } - int64 input_timestamp = round(packet_start_time_seconds * - Timestamp::kTimestampUnitsPerSecond); + int64_t input_timestamp = round(packet_start_time_seconds * + Timestamp::kTimestampUnitsPerSecond); AppendInputPacket(packet_data, input_timestamp); total_num_input_samples += packet_size_samples; } @@ -218,7 +218,7 @@ class SpectrogramCalculatorTest const double expected_timestamp_seconds = packet_timestamp_offset_seconds + cumulative_output_frames * frame_step_seconds; - const int64 expected_timestamp_ticks = + const int64_t expected_timestamp_ticks = expected_timestamp_seconds * Timestamp::kTimestampUnitsPerSecond; EXPECT_EQ(expected_timestamp_ticks, packet.Timestamp().Value()); // Accept the timestamp of the first packet as the baseline for checking diff --git a/mediapipe/calculators/audio/stabilized_log_calculator_test.cc b/mediapipe/calculators/audio/stabilized_log_calculator_test.cc index e6e0b5c6f..f04202676 100644 --- a/mediapipe/calculators/audio/stabilized_log_calculator_test.cc +++ b/mediapipe/calculators/audio/stabilized_log_calculator_test.cc @@ -54,7 +54,8 @@ TEST_F(StabilizedLogCalculatorTest, BasicOperation) { std::vector input_data_matrices; for (int input_packet = 0; input_packet < kNumPackets; ++input_packet) { - const int64 timestamp = input_packet * Timestamp::kTimestampUnitsPerSecond; + const int64_t timestamp = + input_packet * Timestamp::kTimestampUnitsPerSecond; Matrix input_data_matrix = Matrix::Random(kNumChannels, kNumSamples).array().abs(); input_data_matrices.push_back(input_data_matrix); @@ -80,7 +81,8 @@ TEST_F(StabilizedLogCalculatorTest, OutputScaleWorks) { std::vector input_data_matrices; for (int input_packet = 0; input_packet < kNumPackets; ++input_packet) { - const int64 timestamp = input_packet * Timestamp::kTimestampUnitsPerSecond; + const int64_t timestamp = + input_packet * Timestamp::kTimestampUnitsPerSecond; Matrix input_data_matrix = Matrix::Random(kNumChannels, kNumSamples).array().abs(); input_data_matrices.push_back(input_data_matrix); diff --git a/mediapipe/calculators/audio/time_series_framer_calculator.cc b/mediapipe/calculators/audio/time_series_framer_calculator.cc index fbbf34226..a200b898a 100644 --- a/mediapipe/calculators/audio/time_series_framer_calculator.cc +++ b/mediapipe/calculators/audio/time_series_framer_calculator.cc @@ -109,7 +109,7 @@ class TimeSeriesFramerCalculator : public CalculatorBase { // Returns the timestamp of a sample on a base, which is usually the time // stamp of a packet. Timestamp CurrentSampleTimestamp(const Timestamp& timestamp_base, - int64 number_of_samples) { + int64_t number_of_samples) { return timestamp_base + round(number_of_samples / sample_rate_ * Timestamp::kTimestampUnitsPerSecond); } @@ -118,10 +118,10 @@ class TimeSeriesFramerCalculator : public CalculatorBase { // emitted. int next_frame_step_samples() const { // All numbers are in input samples. - const int64 current_output_frame_start = static_cast( + const int64_t current_output_frame_start = static_cast( round(cumulative_output_frames_ * average_frame_step_samples_)); CHECK_EQ(current_output_frame_start, cumulative_completed_samples_); - const int64 next_output_frame_start = static_cast( + const int64_t next_output_frame_start = static_cast( round((cumulative_output_frames_ + 1) * average_frame_step_samples_)); return next_output_frame_start - current_output_frame_start; } @@ -134,11 +134,11 @@ class TimeSeriesFramerCalculator : public CalculatorBase { // emulate_fractional_frame_overlap is true. double average_frame_step_samples_; int samples_still_to_drop_; - int64 cumulative_output_frames_; + int64_t cumulative_output_frames_; // "Completed" samples are samples that are no longer needed because // the framer has completely stepped past them (taking into account // any overlap). - int64 cumulative_completed_samples_; + int64_t cumulative_completed_samples_; Timestamp initial_input_timestamp_; // The current timestamp is updated along with the incoming packets. Timestamp current_timestamp_; diff --git a/mediapipe/calculators/audio/time_series_framer_calculator_test.cc b/mediapipe/calculators/audio/time_series_framer_calculator_test.cc index ca88cebb5..72e9c88f7 100644 --- a/mediapipe/calculators/audio/time_series_framer_calculator_test.cc +++ b/mediapipe/calculators/audio/time_series_framer_calculator_test.cc @@ -49,7 +49,7 @@ class TimeSeriesFramerCalculatorTest // Returns a float value with the channel and timestamp separated by // an order of magnitude, for easy parsing by humans. - float TestValue(int64 timestamp_in_microseconds, int channel) { + float TestValue(int64_t timestamp_in_microseconds, int channel) { return timestamp_in_microseconds + channel / 10.0; } @@ -59,7 +59,7 @@ class TimeSeriesFramerCalculatorTest auto matrix = new Matrix(num_channels, num_samples); for (int c = 0; c < num_channels; ++c) { for (int i = 0; i < num_samples; ++i) { - int64 timestamp = time_series_util::SecondsToSamples( + int64_t timestamp = time_series_util::SecondsToSamples( starting_timestamp_seconds + i / input_sample_rate_, Timestamp::kTimestampUnitsPerSecond); (*matrix)(c, i) = TestValue(timestamp, c); @@ -429,7 +429,7 @@ class TimeSeriesFramerCalculatorTimestampingTest num_full_packets -= 1; } - int64 num_samples = 0; + int64_t num_samples = 0; for (int packet_num = 0; packet_num < num_full_packets; ++packet_num) { const Packet& packet = output().packets[packet_num]; num_samples += FrameDurationSamples(); diff --git a/mediapipe/calculators/internal/BUILD b/mediapipe/calculators/internal/BUILD index 8647e3f3f..a92a2f252 100644 --- a/mediapipe/calculators/internal/BUILD +++ b/mediapipe/calculators/internal/BUILD @@ -12,25 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) package(default_visibility = ["//visibility:private"]) -proto_library( +mediapipe_proto_library( name = "callback_packet_calculator_proto", srcs = ["callback_packet_calculator.proto"], visibility = ["//mediapipe/framework:__subpackages__"], - deps = ["//mediapipe/framework:calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "callback_packet_calculator_cc_proto", - srcs = ["callback_packet_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//mediapipe/framework:__subpackages__"], - deps = [":callback_packet_calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) cc_library( diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index fd926a8fe..9ae884253 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -467,10 +467,6 @@ cc_library( "-x objective-c++", "-fobjc-arc", # enable reference-counting ], - linkopts = [ - "-framework CoreVideo", - "-framework MetalKit", - ], tags = ["ios"], deps = [ "inference_calculator_interface", @@ -486,7 +482,13 @@ cc_library( "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate_internal", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", "@org_tensorflow//tensorflow/lite/delegates/gpu/metal:buffer_convert", - ], + ] + select({ + "//mediapipe:apple": [ + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:MetalKit", + ], + "//conditions:default": [], + }), alwayslink = 1, ) @@ -721,13 +723,6 @@ cc_library( "//conditions:default": [], }), features = ["-layering_check"], # allow depending on tensors_to_detections_calculator_gpu_deps - linkopts = select({ - "//mediapipe:apple": [ - "-framework CoreVideo", - "-framework MetalKit", - ], - "//conditions:default": [], - }), deps = [ ":tensors_to_detections_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -744,6 +739,12 @@ cc_library( ] + selects.with_or({ ":compute_shader_unavailable": [], "//conditions:default": [":tensors_to_detections_calculator_gpu_deps"], + }) + select({ + "//mediapipe:apple": [ + "//third_party/apple_frameworks:CoreVideo", + "//third_party/apple_frameworks:MetalKit", + ], + "//conditions:default": [], }), alwayslink = 1, ) @@ -1333,6 +1334,7 @@ cc_library( "//mediapipe:ios": [ "//mediapipe/gpu:MPPMetalUtil", "//mediapipe/gpu:MPPMetalHelper", + "//third_party/apple_frameworks:MetalKit", ], "//conditions:default": [ "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc index 95e38f89c..bb4c6de79 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc @@ -92,13 +92,14 @@ class OpenCvProcessor : public ImageToTensorConverter { const int dst_data_type = output_channels == 1 ? mat_gray_type_ : mat_type_; switch (tensor_type_) { case Tensor::ElementType::kInt8: - RET_CHECK_GE(output_shape.num_elements(), - tensor_buffer_offset / sizeof(int8) + num_elements_per_img) + RET_CHECK_GE( + output_shape.num_elements(), + tensor_buffer_offset / sizeof(int8_t) + num_elements_per_img) << "The buffer offset + the input image size is larger than the " "allocated tensor buffer."; - dst = cv::Mat( - output_height, output_width, dst_data_type, - buffer_view.buffer() + tensor_buffer_offset / sizeof(int8)); + dst = cv::Mat(output_height, output_width, dst_data_type, + buffer_view.buffer() + + tensor_buffer_offset / sizeof(int8_t)); break; case Tensor::ElementType::kFloat32: RET_CHECK_GE( @@ -113,12 +114,12 @@ class OpenCvProcessor : public ImageToTensorConverter { case Tensor::ElementType::kUInt8: RET_CHECK_GE( output_shape.num_elements(), - tensor_buffer_offset / sizeof(uint8) + num_elements_per_img) + tensor_buffer_offset / sizeof(uint8_t) + num_elements_per_img) << "The buffer offset + the input image size is larger than the " "allocated tensor buffer."; - dst = cv::Mat( - output_height, output_width, dst_data_type, - buffer_view.buffer() + tensor_buffer_offset / sizeof(uint8)); + dst = cv::Mat(output_height, output_width, dst_data_type, + buffer_view.buffer() + + tensor_buffer_offset / sizeof(uint8_t)); break; default: return InvalidArgumentError( diff --git a/mediapipe/calculators/tensor/tensor_converter_calculator_test.cc b/mediapipe/calculators/tensor/tensor_converter_calculator_test.cc index bdea0795e..2cfbd3d1e 100644 --- a/mediapipe/calculators/tensor/tensor_converter_calculator_test.cc +++ b/mediapipe/calculators/tensor/tensor_converter_calculator_test.cc @@ -41,7 +41,7 @@ constexpr char kTransposeOptionsString[] = using RandomEngine = std::mt19937_64; using testing::Eq; -const uint32 kSeed = 1234; +const uint32_t kSeed = 1234; const int kNumSizes = 8; const int sizes[kNumSizes][2] = {{1, 1}, {12, 1}, {1, 9}, {2, 2}, {5, 3}, {7, 13}, {16, 32}, {101, 2}}; @@ -49,7 +49,7 @@ const int sizes[kNumSizes][2] = {{1, 1}, {12, 1}, {1, 9}, {2, 2}, class TensorConverterCalculatorTest : public ::testing::Test { protected: // Adds a packet with a matrix filled with random values in [0,1]. - void AddRandomMatrix(int num_rows, int num_columns, uint32 seed, + void AddRandomMatrix(int num_rows, int num_columns, uint32_t seed, bool row_major_matrix = false) { RandomEngine random(kSeed); std::uniform_real_distribution<> uniform_dist(0, 1.0); @@ -229,7 +229,7 @@ TEST_F(TensorConverterCalculatorTest, CustomDivAndSub) { MP_ASSERT_OK(graph.StartRun({})); auto input_image = absl::make_unique(ImageFormat::GRAY8, 1, 1); cv::Mat mat = mediapipe::formats::MatView(input_image.get()); - mat.at(0, 0) = 200; + mat.at(0, 0) = 200; MP_ASSERT_OK(graph.AddPacketToInputStream( "input_image", Adopt(input_image.release()).At(Timestamp(0)))); @@ -286,7 +286,7 @@ TEST_F(TensorConverterCalculatorTest, SetOutputRange) { MP_ASSERT_OK(graph.StartRun({})); auto input_image = absl::make_unique(ImageFormat::GRAY8, 1, 1); cv::Mat mat = mediapipe::formats::MatView(input_image.get()); - mat.at(0, 0) = 200; + mat.at(0, 0) = 200; MP_ASSERT_OK(graph.AddPacketToInputStream( "input_image", Adopt(input_image.release()).At(Timestamp(0)))); diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc b/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc index 5bfc00ed7..7041c02e4 100644 --- a/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc @@ -84,7 +84,7 @@ class TensorsToClassificationCalculator : public Node { private: int top_k_ = 0; bool sort_by_descending_score_ = false; - proto_ns::Map local_label_map_; + proto_ns::Map local_label_map_; bool label_map_loaded_ = false; bool is_binary_classification_ = false; float min_score_threshold_ = std::numeric_limits::lowest(); @@ -98,7 +98,8 @@ class TensorsToClassificationCalculator : public Node { // These are used to filter out the output classification results. ClassIndexSet class_index_set_; bool IsClassIndexAllowed(int class_index); - const proto_ns::Map& GetLabelMap(CalculatorContext* cc); + const proto_ns::Map& GetLabelMap( + CalculatorContext* cc); }; MEDIAPIPE_REGISTER_NODE(TensorsToClassificationCalculator); @@ -252,7 +253,7 @@ bool TensorsToClassificationCalculator::IsClassIndexAllowed(int class_index) { } } -const proto_ns::Map& +const proto_ns::Map& TensorsToClassificationCalculator::GetLabelMap(CalculatorContext* cc) { return !local_label_map_.empty() ? local_label_map_ diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index e7cc9cc94..0b30689eb 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -399,7 +399,7 @@ cc_library( # On android, this calculator is configured to run with lite protos. Therefore, # compile your binary with the flag TENSORFLOW_PROTOS=lite. cc_library( - name = "tensorflow_inference_calculator", + name = "tensorflow_inference_calculator_no_envelope_loader", srcs = ["tensorflow_inference_calculator.cc"], deps = [ ":tensorflow_inference_calculator_cc_proto", @@ -432,6 +432,19 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "tensorflow_inference_calculator", + deps = [ + ":tensorflow_inference_calculator_no_envelope_loader", + ] + select({ + # Since "select" has "exactly one match" rule, we will need default condition to avoid + # "no matching conditions" error. Since all necessary dependencies are specified in + # "tensorflow_inference_calculator_no_envelope_loader" dependency, it is empty here. + "//conditions:default": [], + }), + alwayslink = 1, +) + cc_library( name = "tensorflow_session", hdrs = [ diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 435ea9fc1..333de2069 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -193,13 +193,6 @@ cc_library( ":edge_tpu_pci": ["MEDIAPIPE_EDGE_TPU=pci"], ":edge_tpu_all": ["MEDIAPIPE_EDGE_TPU=all"], }), - linkopts = select({ - "//mediapipe:ios": [ - "-framework CoreVideo", - "-framework MetalKit", - ], - "//conditions:default": [], - }), deps = [ ":tflite_inference_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -222,6 +215,8 @@ cc_library( "@org_tensorflow//tensorflow/lite/delegates/gpu/metal:buffer_convert", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate_internal", + "//third_party/apple_frameworks:MetalKit", + "//third_party/apple_frameworks:CoreVideo", ], "//conditions:default": [ "//mediapipe/util/tflite:tflite_gpu_runner", @@ -271,13 +266,6 @@ cc_library( ], "//conditions:default": [], }), - linkopts = select({ - "//mediapipe:ios": [ - "-framework CoreVideo", - "-framework MetalKit", - ], - "//conditions:default": [], - }), deps = [ ":tflite_converter_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -296,6 +284,8 @@ cc_library( "//mediapipe/gpu:MPPMetalHelper", "//mediapipe/objc:mediapipe_framework_ios", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", + "//third_party/apple_frameworks:MetalKit", + "//third_party/apple_frameworks:CoreVideo", ], "//conditions:default": [ "//mediapipe/gpu:gl_calculator_helper", @@ -393,13 +383,6 @@ cc_library( ], "//conditions:default": [], }), - linkopts = select({ - "//mediapipe:ios": [ - "-framework CoreVideo", - "-framework MetalKit", - ], - "//conditions:default": [], - }), deps = [ ":tflite_tensors_to_detections_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -420,6 +403,8 @@ cc_library( "//mediapipe/gpu:MPPMetalHelper", "//mediapipe/objc:mediapipe_framework_ios", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", + "//third_party/apple_frameworks:MetalKit", + "//third_party/apple_frameworks:CoreVideo", ], "//conditions:default": [ "//mediapipe/gpu:gl_calculator_helper", diff --git a/mediapipe/calculators/util/clock_latency_calculator.cc b/mediapipe/calculators/util/clock_latency_calculator.cc index 5c5711731..beaa41e66 100644 --- a/mediapipe/calculators/util/clock_latency_calculator.cc +++ b/mediapipe/calculators/util/clock_latency_calculator.cc @@ -66,17 +66,17 @@ class ClockLatencyCalculator : public CalculatorBase { absl::Status Process(CalculatorContext* cc) override; private: - int64 num_packet_streams_ = -1; + int64_t num_packet_streams_ = -1; }; REGISTER_CALCULATOR(ClockLatencyCalculator); absl::Status ClockLatencyCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_GT(cc->Inputs().NumEntries(), 1); - int64 num_packet_streams = cc->Inputs().NumEntries() - 1; + int64_t num_packet_streams = cc->Inputs().NumEntries() - 1; RET_CHECK_EQ(cc->Outputs().NumEntries(), num_packet_streams); - for (int64 i = 0; i < num_packet_streams; ++i) { + for (int64_t i = 0; i < num_packet_streams; ++i) { cc->Inputs().Index(i).Set(); cc->Outputs().Index(i).Set(); } @@ -99,7 +99,7 @@ absl::Status ClockLatencyCalculator::Process(CalculatorContext* cc) { cc->Inputs().Tag(kReferenceTag).Get(); // Push Duration packets for every input stream we have. - for (int64 i = 0; i < num_packet_streams_; ++i) { + for (int64_t i = 0; i < num_packet_streams_; ++i) { if (!cc->Inputs().Index(i).IsEmpty()) { const absl::Time& input_stream_time = cc->Inputs().Index(i).Get(); diff --git a/mediapipe/calculators/util/collection_has_min_size_calculator_test.cc b/mediapipe/calculators/util/collection_has_min_size_calculator_test.cc index 62eb1d8ae..71cba9430 100644 --- a/mediapipe/calculators/util/collection_has_min_size_calculator_test.cc +++ b/mediapipe/calculators/util/collection_has_min_size_calculator_test.cc @@ -33,7 +33,7 @@ typedef CollectionHasMinSizeCalculator> TestIntCollectionHasMinSizeCalculator; REGISTER_CALCULATOR(TestIntCollectionHasMinSizeCalculator); -void AddInputVector(const std::vector& input, int64 timestamp, +void AddInputVector(const std::vector& input, int64_t timestamp, CalculatorRunner* runner) { runner->MutableInputs() ->Tag(kIterableTag) diff --git a/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc b/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc index 0b8dde20d..0c1d6892e 100644 --- a/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc +++ b/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc @@ -57,9 +57,10 @@ class DetectionLabelIdToTextCalculator : public CalculatorBase { private: // Local label map built from the calculator options' `label_map_path` or // `label` field. - proto_ns::Map local_label_map_; + proto_ns::Map local_label_map_; bool keep_label_id_; - const proto_ns::Map& GetLabelMap(CalculatorContext* cc); + const proto_ns::Map& GetLabelMap( + CalculatorContext* cc); }; REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator); @@ -115,7 +116,7 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) { output_detections.push_back(input_detection); Detection& output_detection = output_detections.back(); bool has_text_label = false; - for (const int32 label_id : output_detection.label_id()) { + for (const int32_t label_id : output_detection.label_id()) { if (GetLabelMap(cc).contains(label_id)) { auto item = GetLabelMap(cc).at(label_id); output_detection.add_label(item.name()); @@ -136,7 +137,7 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) { return absl::OkStatus(); } -const proto_ns::Map& +const proto_ns::Map& DetectionLabelIdToTextCalculator::GetLabelMap(CalculatorContext* cc) { return !local_label_map_.empty() ? local_label_map_ diff --git a/mediapipe/calculators/util/detection_letterbox_removal_calculator_test.cc b/mediapipe/calculators/util/detection_letterbox_removal_calculator_test.cc index c4f084363..75dd93cc3 100644 --- a/mediapipe/calculators/util/detection_letterbox_removal_calculator_test.cc +++ b/mediapipe/calculators/util/detection_letterbox_removal_calculator_test.cc @@ -40,7 +40,7 @@ LocationData CreateRelativeLocationData(double xmin, double ymin, double width, } Detection CreateDetection(const std::vector& labels, - const std::vector& label_ids, + const std::vector& label_ids, const std::vector& scores, const LocationData& location_data, const std::string& feature_tag) { diff --git a/mediapipe/calculators/util/detection_transformation_calculator_test.cc b/mediapipe/calculators/util/detection_transformation_calculator_test.cc index e280b5153..30d1bc64b 100644 --- a/mediapipe/calculators/util/detection_transformation_calculator_test.cc +++ b/mediapipe/calculators/util/detection_transformation_calculator_test.cc @@ -39,8 +39,8 @@ constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS"; constexpr char kRelativeDetectionListTag[] = "RELATIVE_DETECTION_LIST"; constexpr char kRelativeDetectionsTag[] = "RELATIVE_DETECTIONS"; -Detection DetectionWithBoundingBox(int32 xmin, int32 ymin, int32 width, - int32 height) { +Detection DetectionWithBoundingBox(int32_t xmin, int32_t ymin, int32_t width, + int32_t height) { Detection detection; LocationData* location_data = detection.mutable_location_data(); location_data->set_format(LocationData::BOUNDING_BOX); diff --git a/mediapipe/calculators/util/detection_unique_id_calculator.cc b/mediapipe/calculators/util/detection_unique_id_calculator.cc index ac8889ffb..d5b1cffa3 100644 --- a/mediapipe/calculators/util/detection_unique_id_calculator.cc +++ b/mediapipe/calculators/util/detection_unique_id_calculator.cc @@ -26,7 +26,7 @@ constexpr char kDetectionListTag[] = "DETECTION_LIST"; // Each detection processed by DetectionUniqueIDCalculator will be assigned an // unique id that starts from 1. If a detection already has an ID other than 0, // the ID will be overwritten. -static int64 detection_id = 0; +static int64_t detection_id = 0; inline int GetNextDetectionId() { return ++detection_id; } diff --git a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc index 63de60a60..95e18e90c 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc +++ b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc @@ -56,8 +56,8 @@ MATCHER_P4(NormRectEq, x_center, y_center, width, height, "") { testing::Value(arg.height(), testing::FloatEq(height)); } -Detection DetectionWithLocationData(int32 xmin, int32 ymin, int32 width, - int32 height) { +Detection DetectionWithLocationData(int32_t xmin, int32_t ymin, int32_t width, + int32_t height) { Detection detection; LocationData* location_data = detection.mutable_location_data(); location_data->set_format(LocationData::BOUNDING_BOX); diff --git a/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc b/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc index 0d0da2350..6da8c449a 100644 --- a/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc +++ b/mediapipe/calculators/util/detections_to_render_data_calculator_test.cc @@ -43,8 +43,8 @@ void VerifyRenderAnnotationColorThickness( EXPECT_EQ(annotation.thickness(), options.thickness()); } -LocationData CreateLocationData(int32 xmin, int32 ymin, int32 width, - int32 height) { +LocationData CreateLocationData(int32_t xmin, int32_t ymin, int32_t width, + int32_t height) { LocationData location_data; location_data.set_format(LocationData::BOUNDING_BOX); location_data.mutable_bounding_box()->set_xmin(xmin); @@ -66,7 +66,7 @@ LocationData CreateRelativeLocationData(double xmin, double ymin, double width, } Detection CreateDetection(const std::vector& labels, - const std::vector& label_ids, + const std::vector& label_ids, const std::vector& scores, const LocationData& location_data, const std::string& feature_tag) { diff --git a/mediapipe/calculators/util/filter_collection_calculator.cc b/mediapipe/calculators/util/filter_collection_calculator.cc index ab361f450..2cf41ead8 100644 --- a/mediapipe/calculators/util/filter_collection_calculator.cc +++ b/mediapipe/calculators/util/filter_collection_calculator.cc @@ -24,7 +24,7 @@ namespace mediapipe { -typedef FilterCollectionCalculator> +typedef FilterCollectionCalculator> FilterUInt64CollectionCalculator; REGISTER_CALCULATOR(FilterUInt64CollectionCalculator); diff --git a/mediapipe/calculators/util/from_image_calculator.cc b/mediapipe/calculators/util/from_image_calculator.cc index 0ddb342eb..706f8727b 100644 --- a/mediapipe/calculators/util/from_image_calculator.cc +++ b/mediapipe/calculators/util/from_image_calculator.cc @@ -163,8 +163,8 @@ absl::Status FromImageCalculator::Process(CalculatorContext* cc) { std::unique_ptr output = std::make_unique( input.image_format(), input.width(), input.height(), input.step(), - const_cast(input.GetImageFrameSharedPtr()->PixelData()), - [packet_copy_ptr](uint8*) { delete packet_copy_ptr; }); + const_cast(input.GetImageFrameSharedPtr()->PixelData()), + [packet_copy_ptr](uint8_t*) { delete packet_copy_ptr; }); cc->Outputs() .Tag(kImageFrameTag) .Add(output.release(), cc->InputTimestamp()); diff --git a/mediapipe/calculators/util/packet_frequency_calculator.cc b/mediapipe/calculators/util/packet_frequency_calculator.cc index 19ffae70e..f9c28f5ff 100644 --- a/mediapipe/calculators/util/packet_frequency_calculator.cc +++ b/mediapipe/calculators/util/packet_frequency_calculator.cc @@ -84,23 +84,24 @@ class PacketFrequencyCalculator : public CalculatorBase { const Timestamp& input_timestamp); // Adds the input timestamp in the particular stream's timestamp buffer. - absl::Status AddPacketTimestampForStream(int stream_id, int64 timestamp); + absl::Status AddPacketTimestampForStream(int stream_id, int64_t timestamp); // For the specified input stream, clears timestamps from buffer that are // older than the configured time_window_sec. - absl::Status ClearOldpacketTimestamps(int stream_id, int64 current_timestamp); + absl::Status ClearOldpacketTimestamps(int stream_id, + int64_t current_timestamp); // Options for the calculator. PacketFrequencyCalculatorOptions options_; // Map where key is the input stream ID and value is the timestamp of the // first packet received on that stream. - std::map first_timestamp_for_stream_id_usec_; + std::map first_timestamp_for_stream_id_usec_; // Map where key is the input stream ID and value is a vector that stores // timestamps of recently received packets on the stream. Timestamps older // than the time_window_sec are continuously deleted for all the streams. - std::map> previous_timestamps_for_stream_id_; + std::map> previous_timestamps_for_stream_id_; }; REGISTER_CALCULATOR(PacketFrequencyCalculator); @@ -166,7 +167,7 @@ absl::Status PacketFrequencyCalculator::Process(CalculatorContext* cc) { } absl::Status PacketFrequencyCalculator::AddPacketTimestampForStream( - int stream_id, int64 timestamp_usec) { + int stream_id, int64_t timestamp_usec) { if (previous_timestamps_for_stream_id_.find(stream_id) == previous_timestamps_for_stream_id_.end()) { return absl::InvalidArgumentError("Input stream id is invalid"); @@ -178,19 +179,20 @@ absl::Status PacketFrequencyCalculator::AddPacketTimestampForStream( } absl::Status PacketFrequencyCalculator::ClearOldpacketTimestamps( - int stream_id, int64 current_timestamp_usec) { + int stream_id, int64_t current_timestamp_usec) { if (previous_timestamps_for_stream_id_.find(stream_id) == previous_timestamps_for_stream_id_.end()) { return absl::InvalidArgumentError("Input stream id is invalid"); } auto& timestamps_buffer = previous_timestamps_for_stream_id_[stream_id]; - int64 time_window_usec = options_.time_window_sec() * kSecondsToMicroseconds; + int64_t time_window_usec = + options_.time_window_sec() * kSecondsToMicroseconds; timestamps_buffer.erase( std::remove_if(timestamps_buffer.begin(), timestamps_buffer.end(), [&time_window_usec, - ¤t_timestamp_usec](const int64 timestamp_usec) { + ¤t_timestamp_usec](const int64_t timestamp_usec) { return current_timestamp_usec - timestamp_usec > time_window_usec; }), diff --git a/mediapipe/calculators/util/packet_latency_calculator.cc b/mediapipe/calculators/util/packet_latency_calculator.cc index 0e5b2e885..6509f016f 100644 --- a/mediapipe/calculators/util/packet_latency_calculator.cc +++ b/mediapipe/calculators/util/packet_latency_calculator.cc @@ -118,24 +118,24 @@ class PacketLatencyCalculator : public CalculatorBase { std::shared_ptr<::mediapipe::Clock> clock_; // Clock time when the first reference packet was received. - int64 first_process_time_usec_ = -1; + int64_t first_process_time_usec_ = -1; // Timestamp of the first reference packet received. - int64 first_reference_timestamp_usec_ = -1; + int64_t first_reference_timestamp_usec_ = -1; // Number of packet streams. - int64 num_packet_streams_ = -1; + int64_t num_packet_streams_ = -1; // Latency output for each packet stream. std::vector packet_latencies_; // Running sum and count of latencies for each packet stream. This is required // to compute the average latency. - std::vector sum_latencies_usec_; - std::vector num_latencies_; + std::vector sum_latencies_usec_; + std::vector num_latencies_; // Clock time when last reset was done for histogram and running average. - int64 last_reset_time_usec_ = -1; + int64_t last_reset_time_usec_ = -1; }; REGISTER_CALCULATOR(PacketLatencyCalculator); @@ -143,9 +143,9 @@ absl::Status PacketLatencyCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_GT(cc->Inputs().NumEntries(), 1); // Input and output streams. - int64 num_packet_streams = cc->Inputs().NumEntries() - 1; + int64_t num_packet_streams = cc->Inputs().NumEntries() - 1; RET_CHECK_EQ(cc->Outputs().NumEntries(), num_packet_streams); - for (int64 i = 0; i < num_packet_streams; ++i) { + for (int64_t i = 0; i < num_packet_streams; ++i) { cc->Inputs().Index(i).SetAny(); cc->Outputs().Index(i).Set(); } @@ -165,8 +165,8 @@ absl::Status PacketLatencyCalculator::GetContract(CalculatorContract* cc) { void PacketLatencyCalculator::ResetStatistics() { // Initialize histogram with zero counts and set running average to zero. - for (int64 i = 0; i < num_packet_streams_; ++i) { - for (int64 interval_index = 0; interval_index < options_.num_intervals(); + for (int64_t i = 0; i < num_packet_streams_; ++i) { + for (int64_t interval_index = 0; interval_index < options_.num_intervals(); ++interval_index) { packet_latencies_[i].set_counts(interval_index, 0); } @@ -196,7 +196,7 @@ absl::Status PacketLatencyCalculator::Open(CalculatorContext* cc) { packet_latencies_.resize(num_packet_streams_); sum_latencies_usec_.resize(num_packet_streams_); num_latencies_.resize(num_packet_streams_); - for (int64 i = 0; i < num_packet_streams_; ++i) { + for (int64_t i = 0; i < num_packet_streams_; ++i) { // Initialize latency histograms with zero counts. packet_latencies_[i].set_num_intervals(options_.num_intervals()); packet_latencies_[i].set_interval_size_usec(options_.interval_size_usec()); @@ -208,7 +208,7 @@ absl::Status PacketLatencyCalculator::Open(CalculatorContext* cc) { if (labels_provided) { packet_latencies_[i].set_label(options_.packet_labels(i)); } else { - int64 input_stream_index = cc->Inputs().TagMap()->GetId("", i).value(); + int64_t input_stream_index = cc->Inputs().TagMap()->GetId("", i).value(); packet_latencies_[i].set_label( cc->Inputs().TagMap()->Names()[input_stream_index]); } @@ -242,7 +242,7 @@ absl::Status PacketLatencyCalculator::Process(CalculatorContext* cc) { } if (options_.reset_duration_usec() > 0) { - const int64 time_now_usec = absl::ToUnixMicros(clock_->TimeNow()); + const int64_t time_now_usec = absl::ToUnixMicros(clock_->TimeNow()); if (time_now_usec - last_reset_time_usec_ >= options_.reset_duration_usec()) { ResetStatistics(); @@ -251,16 +251,16 @@ absl::Status PacketLatencyCalculator::Process(CalculatorContext* cc) { } // Update latency info if there is any incoming packet. - for (int64 i = 0; i < num_packet_streams_; ++i) { + for (int64_t i = 0; i < num_packet_streams_; ++i) { if (!cc->Inputs().Index(i).IsEmpty()) { const auto& packet_timestamp_usec = cc->InputTimestamp().Value(); // Update latency statistics for this stream. - int64 current_clock_time_usec = absl::ToUnixMicros(clock_->TimeNow()); - int64 current_calibrated_timestamp_usec = + int64_t current_clock_time_usec = absl::ToUnixMicros(clock_->TimeNow()); + int64_t current_calibrated_timestamp_usec = (current_clock_time_usec - first_process_time_usec_) + first_reference_timestamp_usec_; - int64 packet_latency_usec = + int64_t packet_latency_usec = current_calibrated_timestamp_usec - packet_timestamp_usec; // Invalid timestamps in input signals could result in negative latencies. @@ -270,7 +270,7 @@ absl::Status PacketLatencyCalculator::Process(CalculatorContext* cc) { // Update the latency, running average and histogram for this stream. packet_latencies_[i].set_current_latency_usec(packet_latency_usec); - int64 interval_index = + int64_t interval_index = packet_latency_usec / packet_latencies_[i].interval_size_usec(); if (interval_index >= packet_latencies_[i].num_intervals()) { interval_index = packet_latencies_[i].num_intervals() - 1; diff --git a/mediapipe/calculators/util/packet_latency_calculator_test.cc b/mediapipe/calculators/util/packet_latency_calculator_test.cc index 6f03f2e75..d323a14f9 100644 --- a/mediapipe/calculators/util/packet_latency_calculator_test.cc +++ b/mediapipe/calculators/util/packet_latency_calculator_test.cc @@ -169,10 +169,10 @@ class PacketLatencyCalculatorTest : public ::testing::Test { } PacketLatency CreatePacketLatency(const double latency_usec, - const int64 num_intervals, - const int64 interval_size_usec, + const int64_t num_intervals, + const int64_t interval_size_usec, const std::vector& counts, - const int64 avg_latency_usec, + const int64_t avg_latency_usec, const std::string& label) { PacketLatency latency_info; latency_info.set_current_latency_usec(latency_usec); diff --git a/mediapipe/calculators/video/BUILD b/mediapipe/calculators/video/BUILD index e4aa1bff8..7245b13c2 100644 --- a/mediapipe/calculators/video/BUILD +++ b/mediapipe/calculators/video/BUILD @@ -13,7 +13,7 @@ # limitations under the License. # -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") load( "//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_binary_graph", @@ -23,28 +23,35 @@ licenses(["notice"]) package(default_visibility = ["//visibility:public"]) -proto_library( +mediapipe_proto_library( name = "flow_to_image_calculator_proto", srcs = ["flow_to_image_calculator.proto"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "opencv_video_encoder_calculator_proto", srcs = ["opencv_video_encoder_calculator.proto"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "motion_analysis_calculator_proto", srcs = ["motion_analysis_calculator.proto"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:motion_analysis_proto", ], ) -proto_library( +mediapipe_proto_library( name = "flow_packager_calculator_proto", srcs = ["flow_packager_calculator.proto"], deps = [ @@ -54,114 +61,45 @@ proto_library( ], ) -proto_library( +mediapipe_proto_library( name = "box_tracker_calculator_proto", srcs = ["box_tracker_calculator.proto"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:box_tracker_proto", ], ) -proto_library( +mediapipe_proto_library( name = "tracked_detection_manager_calculator_proto", srcs = ["tracked_detection_manager_calculator.proto"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:tracked_detection_manager_config_proto", ], ) -proto_library( +mediapipe_proto_library( name = "box_detector_calculator_proto", srcs = ["box_detector_calculator.proto"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:box_detector_proto", ], ) -proto_library( +mediapipe_proto_library( name = "video_pre_stream_calculator_proto", srcs = ["video_pre_stream_calculator.proto"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -mediapipe_cc_proto_library( - name = "motion_analysis_calculator_cc_proto", - srcs = ["motion_analysis_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/util/tracking:motion_analysis_cc_proto", - ], - deps = [":motion_analysis_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "flow_packager_calculator_cc_proto", - srcs = ["flow_packager_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/util/tracking:flow_packager_cc_proto", - ], - deps = [":flow_packager_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "box_tracker_calculator_cc_proto", - srcs = ["box_tracker_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/util/tracking:box_tracker_cc_proto", - ], - deps = [":box_tracker_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "tracked_detection_manager_calculator_cc_proto", - srcs = ["tracked_detection_manager_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/util/tracking:tracked_detection_manager_config_cc_proto", - ], - deps = [":tracked_detection_manager_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "box_detector_calculator_cc_proto", - srcs = ["box_detector_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/util/tracking:box_detector_cc_proto", - ], - deps = [":box_detector_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "video_pre_stream_calculator_cc_proto", - srcs = ["video_pre_stream_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - ], - deps = [":video_pre_stream_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "flow_to_image_calculator_cc_proto", - srcs = ["flow_to_image_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - deps = [":flow_to_image_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "opencv_video_encoder_calculator_cc_proto", - srcs = ["opencv_video_encoder_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - deps = [":opencv_video_encoder_calculator_proto"], -) - cc_library( name = "flow_to_image_calculator", srcs = ["flow_to_image_calculator.cc"], diff --git a/mediapipe/examples/desktop/autoflip/BUILD b/mediapipe/examples/desktop/autoflip/BUILD index 340205caa..fe994e2e0 100644 --- a/mediapipe/examples/desktop/autoflip/BUILD +++ b/mediapipe/examples/desktop/autoflip/BUILD @@ -1,4 +1,4 @@ -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") # Copyright 2019 The MediaPipe Authors. # @@ -22,7 +22,7 @@ package(default_visibility = [ "//photos/editing/mobile/mediapipe/proto:__subpackages__", ]) -proto_library( +mediapipe_proto_library( name = "autoflip_messages_proto", srcs = ["autoflip_messages.proto"], deps = [ @@ -30,29 +30,6 @@ proto_library( ], ) -java_lite_proto_library( - name = "autoflip_messages_java_proto_lite", - visibility = [ - "//java/com/google/android/apps/photos:__subpackages__", - "//javatests/com/google/android/apps/photos:__subpackages__", - ], - deps = [ - ":autoflip_messages_proto", - ], -) - -mediapipe_cc_proto_library( - name = "autoflip_messages_cc_proto", - srcs = ["autoflip_messages.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = [ - "//mediapipe/examples:__subpackages__", - "//photos/editing/mobile/mediapipe/calculators:__pkg__", - "//photos/editing/mobile/mediapipe/calculators:__subpackages__", - ], - deps = [":autoflip_messages_proto"], -) - cc_binary( name = "run_autoflip", data = [ diff --git a/mediapipe/examples/desktop/autoflip/calculators/BUILD b/mediapipe/examples/desktop/autoflip/calculators/BUILD index 18f56cc4f..a3b2ace2a 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/BUILD +++ b/mediapipe/examples/desktop/autoflip/calculators/BUILD @@ -1,4 +1,4 @@ -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") # Copyright 2019 The MediaPipe Authors. # @@ -40,22 +40,16 @@ cc_library( alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "border_detection_calculator_proto", srcs = ["border_detection_calculator.proto"], + visibility = ["//mediapipe/examples:__subpackages__"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -mediapipe_cc_proto_library( - name = "border_detection_calculator_cc_proto", - srcs = ["border_detection_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//mediapipe/examples:__subpackages__"], - deps = [":border_detection_calculator_proto"], -) - cc_library( name = "content_zooming_calculator_state", hdrs = ["content_zooming_calculator_state.h"], @@ -85,27 +79,16 @@ cc_library( alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "content_zooming_calculator_proto", srcs = ["content_zooming_calculator.proto"], - deps = [ - "//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver_proto", - "//mediapipe/framework:calculator_proto", - ], -) - -mediapipe_cc_proto_library( - name = "content_zooming_calculator_cc_proto", - srcs = ["content_zooming_calculator.proto"], - cc_deps = [ - "//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver_cc_proto", - "//mediapipe/framework:calculator_cc_proto", - ], visibility = [ "//mediapipe/examples:__subpackages__", ], deps = [ - ":content_zooming_calculator_proto", + "//mediapipe/examples/desktop/autoflip/quality:kinematic_path_solver_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", ], ) @@ -177,23 +160,16 @@ cc_library( alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "video_filtering_calculator_proto", srcs = ["video_filtering_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -mediapipe_cc_proto_library( - name = "video_filtering_calculator_cc_proto", - srcs = ["video_filtering_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":video_filtering_calculator_proto"], -) - cc_test( name = "video_filtering_calculator_test", srcs = ["video_filtering_calculator_test.cc"], @@ -209,27 +185,17 @@ cc_test( ], ) -proto_library( +mediapipe_proto_library( name = "scene_cropping_calculator_proto", srcs = ["scene_cropping_calculator.proto"], visibility = ["//visibility:public"], deps = [ "//mediapipe/examples/desktop/autoflip/quality:cropping_proto", + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -mediapipe_cc_proto_library( - name = "scene_cropping_calculator_cc_proto", - srcs = ["scene_cropping_calculator.proto"], - cc_deps = [ - "//mediapipe/examples/desktop/autoflip/quality:cropping_cc_proto", - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":scene_cropping_calculator_proto"], -) - cc_library( name = "scene_cropping_calculator", srcs = ["scene_cropping_calculator.cc"], @@ -296,26 +262,17 @@ cc_library( alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "signal_fusing_calculator_proto", srcs = ["signal_fusing_calculator.proto"], + visibility = ["//mediapipe/examples:__subpackages__"], deps = [ "//mediapipe/examples/desktop/autoflip:autoflip_messages_proto", + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -mediapipe_cc_proto_library( - name = "signal_fusing_calculator_cc_proto", - srcs = ["signal_fusing_calculator.proto"], - cc_deps = [ - "//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto", - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//mediapipe/examples:__subpackages__"], - deps = [":signal_fusing_calculator_proto"], -) - cc_test( name = "signal_fusing_calculator_test", srcs = ["signal_fusing_calculator_test.cc"], @@ -353,18 +310,14 @@ cc_library( alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "shot_boundary_calculator_proto", srcs = ["shot_boundary_calculator.proto"], - deps = ["//mediapipe/framework:calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "shot_boundary_calculator_cc_proto", - srcs = ["shot_boundary_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], visibility = ["//mediapipe/examples:__subpackages__"], - deps = [":shot_boundary_calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) cc_test( @@ -413,26 +366,17 @@ cc_library( alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "face_to_region_calculator_proto", srcs = ["face_to_region_calculator.proto"], + visibility = ["//mediapipe/examples:__subpackages__"], deps = [ "//mediapipe/examples/desktop/autoflip/quality:visual_scorer_proto", + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -mediapipe_cc_proto_library( - name = "face_to_region_calculator_cc_proto", - srcs = ["face_to_region_calculator.proto"], - cc_deps = [ - "//mediapipe/examples/desktop/autoflip/quality:visual_scorer_cc_proto", - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//mediapipe/examples:__subpackages__"], - deps = [":face_to_region_calculator_proto"], -) - cc_test( name = "face_to_region_calculator_test", srcs = ["face_to_region_calculator_test.cc"], @@ -454,22 +398,16 @@ cc_test( ], ) -proto_library( +mediapipe_proto_library( name = "localization_to_region_calculator_proto", srcs = ["localization_to_region_calculator.proto"], + visibility = ["//mediapipe/examples:__subpackages__"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -mediapipe_cc_proto_library( - name = "localization_to_region_calculator_cc_proto", - srcs = ["localization_to_region_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//mediapipe/examples:__subpackages__"], - deps = [":localization_to_region_calculator_proto"], -) - cc_library( name = "localization_to_region_calculator", srcs = ["localization_to_region_calculator.cc"], diff --git a/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc index caaa368a7..238bcf8be 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc @@ -214,7 +214,7 @@ double BorderDetectionCalculator::ColorCount(const Color& mask_color, const cv::Mat& image) const { int background_count = 0; for (int i = 0; i < image.rows; i++) { - const uint8* row_ptr = image.ptr(i); + const uint8_t* row_ptr = image.ptr(i); for (int j = 0; j < image.cols * 3; j += 3) { if (std::abs(mask_color.r() - static_cast(row_ptr[j + 2])) <= options_.color_tolerance() && diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc index 823080786..5241f56e4 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc @@ -142,7 +142,7 @@ class ContentZoomingCalculator : public CalculatorBase { // Stores the first crop rectangle. mediapipe::NormalizedRect first_rect_; // Stores the time of the last "only_required" input. - int64 last_only_required_detection_; + int64_t last_only_required_detection_; // Rect values of last message with detection(s). int last_measured_height_; int last_measured_x_offset_; @@ -500,7 +500,7 @@ bool ContentZoomingCalculator::IsAnimatingToFirstRect( return false; } - const int64 delta_us = (timestamp - first_rect_timestamp_).Value(); + const int64_t delta_us = (timestamp - first_rect_timestamp_).Value(); return (0 <= delta_us && delta_us <= options_.us_to_first_rect()); } @@ -522,8 +522,8 @@ absl::StatusOr ContentZoomingCalculator::GetAnimationRect( RET_CHECK(IsAnimatingToFirstRect(timestamp)) << "Must only be called if animating to first rect."; - const int64 delta_us = (timestamp - first_rect_timestamp_).Value(); - const int64 delay = options_.us_to_first_rect_delay(); + const int64_t delta_us = (timestamp - first_rect_timestamp_).Value(); + const int64_t delay = options_.us_to_first_rect_delay(); const double interpolation = easeInOutQuad(std::max( 0.0, (delta_us - delay) / static_cast(options_.us_to_first_rect() - delay))); diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc index 48e4a28a8..0e817b260 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc @@ -226,7 +226,7 @@ struct AddDetectionFlags { std::optional max_zoom_factor_percent; }; -void AddDetectionFrameSize(const cv::Rect_& position, const int64 time, +void AddDetectionFrameSize(const cv::Rect_& position, const int64_t time, const int width, const int height, CalculatorRunner* runner, const AddDetectionFlags& flags = {}) { @@ -275,7 +275,7 @@ void AddDetectionFrameSize(const cv::Rect_& position, const int64 time, } } -void AddDetection(const cv::Rect_& position, const int64 time, +void AddDetection(const cv::Rect_& position, const int64_t time, CalculatorRunner* runner) { AddDetectionFrameSize(position, time, 1000, 1000, runner); } diff --git a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc index 7e286b743..f4cc98674 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc @@ -200,7 +200,7 @@ absl::Status ParseAspectRatioString(const std::string& aspect_ratio_string, } void ConstructExternalRenderMessage( const cv::Rect& crop_from_location, const cv::Rect& render_to_location, - const cv::Scalar& padding_color, const uint64 timestamp_us, + const cv::Scalar& padding_color, const uint64_t timestamp_us, ExternalRenderFrame* external_render_message, int frame_width, int frame_height) { auto crop_from_message = @@ -717,7 +717,7 @@ absl::Status SceneCroppingCalculator::FormatAndOutputCroppedFrames( for (int i = 0; i < num_frames; ++i) { // Set default padding color to white. cv::Scalar padding_color_to_add = cv::Scalar(255, 255, 255); - const int64 time_ms = scene_frame_timestamps_[i]; + const int64_t time_ms = scene_frame_timestamps_[i]; if (*apply_padding) { if (has_solid_background_) { double lab[3]; @@ -747,7 +747,7 @@ absl::Status SceneCroppingCalculator::FormatAndOutputCroppedFrames( // Resizes cropped frames, pads frames, and output frames. for (int i = 0; i < num_frames; ++i) { - const int64 time_ms = scene_frame_timestamps_[i]; + const int64_t time_ms = scene_frame_timestamps_[i]; const Timestamp timestamp(time_ms); auto scaled_frame = absl::make_unique( frame_format_, scaled_width, scaled_height); diff --git a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc index c3285ea58..74535022d 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc @@ -175,7 +175,7 @@ constexpr int kMinNumDetections = 0; constexpr int kMaxNumDetections = 10; constexpr int kDownSampleRate = 4; -constexpr int64 kTimestampDiff = 20000; +constexpr int64_t kTimestampDiff = 20000; // Returns a singleton random engine for generating random values. The seed is // fixed for reproducibility. @@ -254,7 +254,7 @@ std::unique_ptr MakeImageFrameFromColor(const cv::Scalar& color, // Randomly generates a number of detections in the range of kMinNumDetections // and kMaxNumDetections. Optionally add a key image frame of random solid color // and given size. -void AddKeyFrameFeatures(const int64 time_ms, const int key_frame_width, +void AddKeyFrameFeatures(const int64_t time_ms, const int key_frame_width, const int key_frame_height, bool randomize, CalculatorRunner::StreamContentsSet* inputs) { Timestamp timestamp(time_ms); @@ -286,7 +286,7 @@ void AddScene(const int start_frame_index, const int num_scene_frames, const int key_frame_width, const int key_frame_height, const int DownSampleRate, CalculatorRunner::StreamContentsSet* inputs) { - int64 time_ms = start_frame_index * kTimestampDiff; + int64_t time_ms = start_frame_index * kTimestampDiff; for (int i = 0; i < num_scene_frames; ++i) { Timestamp timestamp(time_ms); if (inputs->HasTag(kVideoFramesTag)) { @@ -657,7 +657,7 @@ TEST(SceneCroppingCalculatorTest, PadsWithSolidColorFromStaticFeatures) { // Add inputs. auto* inputs = runner->MutableInputs(); - int64 time_ms = 0; + int64_t time_ms = 0; int num_static_features = 0; for (int i = 0; i < kSceneSize; ++i) { Timestamp timestamp(time_ms); diff --git a/mediapipe/examples/desktop/autoflip/quality/BUILD b/mediapipe/examples/desktop/autoflip/quality/BUILD index 0b5970ee9..20e286107 100644 --- a/mediapipe/examples/desktop/autoflip/quality/BUILD +++ b/mediapipe/examples/desktop/autoflip/quality/BUILD @@ -1,4 +1,4 @@ -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") # Copyright 2019 The MediaPipe Authors. # @@ -20,7 +20,7 @@ package(default_visibility = [ "//mediapipe/examples:__subpackages__", ]) -proto_library( +mediapipe_proto_library( name = "cropping_proto", srcs = ["cropping.proto"], deps = [ @@ -29,41 +29,18 @@ proto_library( ], ) -mediapipe_cc_proto_library( - name = "cropping_cc_proto", - srcs = ["cropping.proto"], - cc_deps = [ - ":kinematic_path_solver_cc_proto", - "//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto", - ], - visibility = ["//mediapipe/examples:__subpackages__"], - deps = [":cropping_proto"], -) - -proto_library( +mediapipe_proto_library( name = "kinematic_path_solver_proto", srcs = ["kinematic_path_solver.proto"], -) - -mediapipe_cc_proto_library( - name = "kinematic_path_solver_cc_proto", - srcs = ["kinematic_path_solver.proto"], visibility = [ "//mediapipe/examples:__subpackages__", ], - deps = [":kinematic_path_solver_proto"], ) -proto_library( +mediapipe_proto_library( name = "focus_point_proto", srcs = ["focus_point.proto"], -) - -mediapipe_cc_proto_library( - name = "focus_point_cc_proto", - srcs = ["focus_point.proto"], visibility = ["//mediapipe/examples:__subpackages__"], - deps = [":focus_point_proto"], ) cc_library( @@ -333,16 +310,10 @@ cc_test( ], ) -proto_library( +mediapipe_proto_library( name = "visual_scorer_proto", srcs = ["visual_scorer.proto"], -) - -mediapipe_cc_proto_library( - name = "visual_scorer_cc_proto", - srcs = ["visual_scorer.proto"], visibility = ["//mediapipe/examples:__subpackages__"], - deps = [":visual_scorer_proto"], ) cc_library( diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.cc b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.cc index 0bfe72548..96fc5f888 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.cc +++ b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.cc @@ -34,7 +34,7 @@ absl::Status SceneCameraMotionAnalyzer::AnalyzeSceneAndPopulateFocusPointFrames( const KeyFrameCropOptions& key_frame_crop_options, const std::vector& key_frame_crop_results, const int scene_frame_width, const int scene_frame_height, - const std::vector& scene_frame_timestamps, + const std::vector& scene_frame_timestamps, const bool has_solid_color_background, SceneKeyFrameCropSummary* scene_summary, std::vector* focus_point_frames, @@ -45,7 +45,7 @@ absl::Status SceneCameraMotionAnalyzer::AnalyzeSceneAndPopulateFocusPointFrames( key_frame_crop_options, key_frame_crop_results, scene_frame_width, scene_frame_height, scene_summary)); - const int64 scene_span_ms = + const int64_t scene_span_ms = scene_frame_timestamps.empty() ? 0 : scene_frame_timestamps.back() - scene_frame_timestamps.front(); @@ -103,7 +103,7 @@ absl::Status SceneCameraMotionAnalyzer::ToUseSweepingMotion( absl::Status SceneCameraMotionAnalyzer::DecideCameraMotionType( const KeyFrameCropOptions& key_frame_crop_options, - const double scene_span_sec, const int64 end_time_us, + const double scene_span_sec, const int64_t end_time_us, SceneKeyFrameCropSummary* scene_summary, SceneCameraMotion* scene_camera_motion) const { RET_CHECK_GE(scene_span_sec, 0.0) << "Scene time span is negative."; @@ -298,7 +298,7 @@ absl::Status SceneCameraMotionAnalyzer::AddFocusPointsFromCenterTypeAndWeight( absl::Status SceneCameraMotionAnalyzer::PopulateFocusPointFrames( const SceneKeyFrameCropSummary& scene_summary, const SceneCameraMotion& scene_camera_motion, - const std::vector& scene_frame_timestamps, + const std::vector& scene_frame_timestamps, std::vector* focus_point_frames) const { RET_CHECK_NE(focus_point_frames, nullptr) << "Output vector of FocusPointFrame is null."; @@ -380,7 +380,7 @@ absl::Status SceneCameraMotionAnalyzer::PopulateFocusPointFrames( absl::Status SceneCameraMotionAnalyzer::PopulateFocusPointFramesForTracking( const SceneKeyFrameCropSummary& scene_summary, const FocusPointFrameType focus_point_frame_type, - const std::vector& scene_frame_timestamps, + const std::vector& scene_frame_timestamps, std::vector* focus_point_frames) const { RET_CHECK_GE(scene_summary.key_frame_max_score(), 0.0) << "Maximum score is negative."; @@ -392,7 +392,7 @@ absl::Status SceneCameraMotionAnalyzer::PopulateFocusPointFramesForTracking( const int scene_frame_height = scene_summary.scene_frame_height(); PiecewiseLinearFunction center_x_function, center_y_function, score_function; - const int64 timestamp_offset = key_frame_compact_infos[0].timestamp_ms(); + const int64_t timestamp_offset = key_frame_compact_infos[0].timestamp_ms(); for (int i = 0; i < num_key_frames; ++i) { const float center_x = key_frame_compact_infos[i].center_x(); const float center_y = key_frame_compact_infos[i].center_y(); diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index da09acc83..ee9796e49 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -425,7 +425,10 @@ using GenericNode = Node; template class Node : public NodeBase { public: - Node() : NodeBase(std::string(Calc::kCalculatorName)) {} + Node() + : NodeBase( + FunctionRegistry::GetLookupName(Calc::kCalculatorName)) {} + // Overrides the built-in calculator type string with the provided argument. // Can be used to create nodes from pure interfaces. // TODO: only use this for pure interfaces @@ -546,6 +549,7 @@ class Graph { // Creates a node of a specific type. Should be used for pure interfaces, // which do not have a built-in type string. + // `type` is a calculator type-name with dot-separated namespaces. template Node& AddNode(absl::string_view type) { auto node = @@ -557,6 +561,7 @@ class Graph { // Creates a generic node, with no compile-time checking of inputs and // outputs. This can be used for calculators whose contract is not visible. + // `type` is a calculator type-name with dot-separated namespaces. GenericNode& AddNode(absl::string_view type) { auto node = std::make_unique(std::string(type.data(), type.size())); diff --git a/mediapipe/framework/calculator_graph.cc b/mediapipe/framework/calculator_graph.cc index b49930b7a..06a57fa6d 100644 --- a/mediapipe/framework/calculator_graph.cc +++ b/mediapipe/framework/calculator_graph.cc @@ -192,8 +192,7 @@ absl::Status CalculatorGraph::InitializeStreams() { auto input_tag_map, tool::TagMap::Create(validated_graph_->Config().input_stream())); for (const auto& stream_name : input_tag_map->Names()) { - RET_CHECK(!mediapipe::ContainsKey(graph_input_streams_, stream_name)) - .SetNoLogging() + RET_CHECK(!graph_input_streams_.contains(stream_name)).SetNoLogging() << "CalculatorGraph Initialization failed, graph input stream \"" << stream_name << "\" was specified twice."; int output_stream_index = validated_graph_->OutputStreamIndex(stream_name); diff --git a/mediapipe/framework/calculator_graph_bounds_test.cc b/mediapipe/framework/calculator_graph_bounds_test.cc index d149337cc..81ce9902c 100644 --- a/mediapipe/framework/calculator_graph_bounds_test.cc +++ b/mediapipe/framework/calculator_graph_bounds_test.cc @@ -679,7 +679,7 @@ REGISTER_CALCULATOR(BoundToPacketCalculator); // A Calculator that produces packets at timestamps beyond the input timestamp. class FuturePacketCalculator : public CalculatorBase { public: - static constexpr int64 kOutputFutureMicros = 3; + static constexpr int64_t kOutputFutureMicros = 3; static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); diff --git a/mediapipe/framework/calculator_graph_side_packet_test.cc b/mediapipe/framework/calculator_graph_side_packet_test.cc index 57fcff866..a9567c805 100644 --- a/mediapipe/framework/calculator_graph_side_packet_test.cc +++ b/mediapipe/framework/calculator_graph_side_packet_test.cc @@ -188,21 +188,21 @@ class Uint64PacketGenerator : public PacketGenerator { static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { - output_side_packets->Index(0).Set(); + output_side_packets->Index(0).Set(); return absl::OkStatus(); } static absl::Status Generate(const PacketGeneratorOptions& extendable_options, const PacketSet& input_side_packets, PacketSet* output_side_packets) { - output_side_packets->Index(0) = Adopt(new uint64(15LL << 32 | 5)); + output_side_packets->Index(0) = Adopt(new uint64_t(15LL << 32 | 5)); return absl::OkStatus(); } }; REGISTER_PACKET_GENERATOR(Uint64PacketGenerator); TEST(CalculatorGraph, OutputSidePacketInProcess) { - const int64 offset = 100; + const int64_t offset = 100; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: "offset" @@ -400,7 +400,7 @@ TEST(CalculatorGraph, SharePacketGeneratorGraph) { } TEST(CalculatorGraph, OutputSidePacketAlreadySet) { - const int64 offset = 100; + const int64_t offset = 100; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: "offset" @@ -427,7 +427,7 @@ TEST(CalculatorGraph, OutputSidePacketAlreadySet) { } TEST(CalculatorGraph, OutputSidePacketWithTimestamp) { - const int64 offset = 100; + const int64_t offset = 100; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: "offset" @@ -716,7 +716,7 @@ TEST(CalculatorGraph, GetOutputSidePacket) { // Run the graph twice. int max_count = 100; std::map extra_side_packets; - extra_side_packets.insert({"input_uint64", MakePacket(1123)}); + extra_side_packets.insert({"input_uint64", MakePacket(1123)}); for (int run = 0; run < 1; ++run) { MP_ASSERT_OK(graph.StartRun(extra_side_packets)); status_or_packet = graph.GetOutputSidePacket("output_uint32_pair"); diff --git a/mediapipe/framework/calculator_graph_test.cc b/mediapipe/framework/calculator_graph_test.cc index 6ca206ab1..2e7d99ef6 100644 --- a/mediapipe/framework/calculator_graph_test.cc +++ b/mediapipe/framework/calculator_graph_test.cc @@ -439,7 +439,7 @@ class GlobalCountSourceCalculator : public CalculatorBase { ++local_count_; } - int64 local_count_ = 0; + int64_t local_count_ = 0; }; const int GlobalCountSourceCalculator::kNumOutputPackets = 5; REGISTER_CALCULATOR(GlobalCountSourceCalculator); @@ -765,7 +765,7 @@ class TypedStatusHandler : public StatusHandler { } }; typedef TypedStatusHandler StringStatusHandler; -typedef TypedStatusHandler Uint32StatusHandler; +typedef TypedStatusHandler Uint32StatusHandler; REGISTER_STATUS_HANDLER(StringStatusHandler); REGISTER_STATUS_HANDLER(Uint32StatusHandler); @@ -1398,9 +1398,9 @@ void RunComprehensiveTest(CalculatorGraph* graph, MP_ASSERT_OK(graph->Initialize(proto)); std::map extra_side_packets; - extra_side_packets.emplace("node_3", Adopt(new uint64((15LL << 32) | 3))); + extra_side_packets.emplace("node_3", Adopt(new uint64_t((15LL << 32) | 3))); if (define_node_5) { - extra_side_packets.emplace("node_5", Adopt(new uint64((15LL << 32) | 5))); + extra_side_packets.emplace("node_5", Adopt(new uint64_t((15LL << 32) | 5))); } // Call graph->Run() several times, to make sure that the appropriate @@ -1452,9 +1452,9 @@ void RunComprehensiveTest(CalculatorGraph* graph, // Verify that the graph can still run (but not successfully) when // one of the nodes is caused to fail. extra_side_packets.clear(); - extra_side_packets.emplace("node_3", Adopt(new uint64((15LL << 32) | 0))); + extra_side_packets.emplace("node_3", Adopt(new uint64_t((15LL << 32) | 0))); if (define_node_5) { - extra_side_packets.emplace("node_5", Adopt(new uint64((15LL << 32) | 5))); + extra_side_packets.emplace("node_5", Adopt(new uint64_t((15LL << 32) | 5))); } dumped_final_sum_packet = Packet(); dumped_final_stddev_packet = Packet(); @@ -1579,14 +1579,14 @@ class Uint64PacketGenerator : public PacketGenerator { static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { - output_side_packets->Index(0).Set(); + output_side_packets->Index(0).Set(); return absl::OkStatus(); } static absl::Status Generate(const PacketGeneratorOptions& extendable_options, const PacketSet& input_side_packets, PacketSet* output_side_packets) { - output_side_packets->Index(0) = Adopt(new uint64(15LL << 32 | 5)); + output_side_packets->Index(0) = Adopt(new uint64_t(15LL << 32 | 5)); return absl::OkStatus(); } }; @@ -1759,7 +1759,7 @@ TEST(CalculatorGraph, StatusHandlerInputVerification) { )pb"); MP_ASSERT_OK(graph->Initialize(config)); Packet extra_string = Adopt(new std::string("foo")); - Packet a_uint64 = Adopt(new uint64(0)); + Packet a_uint64 = Adopt(new uint64_t(0)); MP_EXPECT_OK( graph->Run({{"extra_string", extra_string}, {"a_uint64", a_uint64}})); @@ -1789,7 +1789,7 @@ TEST(CalculatorGraph, StatusHandlerInputVerification) { testing::HasSubstr("string"), // Expected type. testing::HasSubstr( - MediaPipeTypeStringOrDemangled()))); + MediaPipeTypeStringOrDemangled()))); // Should fail verification when the type of a to-be-generated packet is // wrong. The added handler now expects a string but will receive the uint32 @@ -1802,14 +1802,14 @@ TEST(CalculatorGraph, StatusHandlerInputVerification) { status = graph->Initialize(config); EXPECT_THAT(status.message(), - testing::AllOf( - testing::HasSubstr("StringStatusHandler"), - // The problematic input side packet. - testing::HasSubstr("generated_by_generator"), - // Actual type. - testing::HasSubstr(MediaPipeTypeStringOrDemangled()), - // Expected type. - testing::HasSubstr("string"))); + testing::AllOf(testing::HasSubstr("StringStatusHandler"), + // The problematic input side packet. + testing::HasSubstr("generated_by_generator"), + // Actual type. + testing::HasSubstr( + MediaPipeTypeStringOrDemangled()), + // Expected type. + testing::HasSubstr("string"))); } TEST(CalculatorGraph, GenerateInInitialize) { diff --git a/mediapipe/framework/calculator_runner.cc b/mediapipe/framework/calculator_runner.cc index 833797483..1bd3211ed 100644 --- a/mediapipe/framework/calculator_runner.cc +++ b/mediapipe/framework/calculator_runner.cc @@ -216,7 +216,7 @@ mediapipe::Counter* CalculatorRunner::GetCounter(const std::string& name) { return graph_->GetCounterFactory()->GetCounter(name); } -std::map CalculatorRunner::GetCountersValues() { +std::map CalculatorRunner::GetCountersValues() { return graph_->GetCounterFactory()->GetCounterSet()->GetCountersValues(); } diff --git a/mediapipe/framework/counter_factory.cc b/mediapipe/framework/counter_factory.cc index 94a6a4213..895b44ea6 100644 --- a/mediapipe/framework/counter_factory.cc +++ b/mediapipe/framework/counter_factory.cc @@ -39,14 +39,14 @@ class BasicCounter : public Counter { value_ += amount; } - int64 Get() ABSL_LOCKS_EXCLUDED(mu_) override { + int64_t Get() ABSL_LOCKS_EXCLUDED(mu_) override { absl::ReaderMutexLock lock(&mu_); return value_; } private: absl::Mutex mu_; - int64 value_ ABSL_GUARDED_BY(mu_); + int64_t value_ ABSL_GUARDED_BY(mu_); }; } // namespace @@ -73,10 +73,10 @@ Counter* CounterSet::Get(const std::string& name) ABSL_LOCKS_EXCLUDED(mu_) { return counters_[name].get(); } -std::map CounterSet::GetCountersValues() +std::map CounterSet::GetCountersValues() ABSL_LOCKS_EXCLUDED(mu_) { absl::ReaderMutexLock lock(&mu_); - std::map result; + std::map result; for (const auto& it : counters_) { result[it.first] = it.second->Get(); } diff --git a/mediapipe/framework/deps/mathutil_unittest.cc b/mediapipe/framework/deps/mathutil_unittest.cc index 7468e927a..b25b73306 100644 --- a/mediapipe/framework/deps/mathutil_unittest.cc +++ b/mediapipe/framework/deps/mathutil_unittest.cc @@ -75,17 +75,17 @@ BENCHMARK(BM_IntCast); static void BM_Int64Cast(benchmark::State& state) { double x = 0.1; - int64 sum = 0; + int64_t sum = 0; for (auto _ : state) { - sum += static_cast(x); + sum += static_cast(x); x += 0.1; - sum += static_cast(x); + sum += static_cast(x); x += 0.1; - sum += static_cast(x); + sum += static_cast(x); x += 0.1; - sum += static_cast(x); + sum += static_cast(x); x += 0.1; - sum += static_cast(x); + sum += static_cast(x); x += 0.1; } EXPECT_NE(sum, 0); // Don't let 'sum' get optimized away. @@ -134,15 +134,15 @@ static void BM_Int64Round(benchmark::State& state) { double x = 0.1; int sum = 0; for (auto _ : state) { - sum += mediapipe::MathUtil::Round(x); + sum += mediapipe::MathUtil::Round(x); x += 0.1; - sum += mediapipe::MathUtil::Round(x); + sum += mediapipe::MathUtil::Round(x); x += 0.1; - sum += mediapipe::MathUtil::Round(x); + sum += mediapipe::MathUtil::Round(x); x += 0.1; - sum += mediapipe::MathUtil::Round(x); + sum += mediapipe::MathUtil::Round(x); x += 0.1; - sum += mediapipe::MathUtil::Round(x); + sum += mediapipe::MathUtil::Round(x); x += 0.1; } EXPECT_NE(sum, 0); // Don't let 'sum' get optimized away. @@ -153,15 +153,15 @@ static void BM_UintRound(benchmark::State& state) { double x = 0.1; int sum = 0; for (auto _ : state) { - sum += mediapipe::MathUtil::Round(x); + sum += mediapipe::MathUtil::Round(x); x += 0.1; - sum += mediapipe::MathUtil::Round(x); + sum += mediapipe::MathUtil::Round(x); x += 0.1; - sum += mediapipe::MathUtil::Round(x); + sum += mediapipe::MathUtil::Round(x); x += 0.1; - sum += mediapipe::MathUtil::Round(x); + sum += mediapipe::MathUtil::Round(x); x += 0.1; - sum += mediapipe::MathUtil::Round(x); + sum += mediapipe::MathUtil::Round(x); x += 0.1; } EXPECT_NE(sum, 0); // Don't let 'sum' get optimized away. @@ -191,15 +191,15 @@ static void BM_SafeInt64Cast(benchmark::State& state) { double x = 0.1; int sum = 0; for (auto _ : state) { - sum += mediapipe::MathUtil::SafeCast(x); + sum += mediapipe::MathUtil::SafeCast(x); x += 0.1; - sum += mediapipe::MathUtil::SafeCast(x); + sum += mediapipe::MathUtil::SafeCast(x); x += 0.1; - sum += mediapipe::MathUtil::SafeCast(x); + sum += mediapipe::MathUtil::SafeCast(x); x += 0.1; - sum += mediapipe::MathUtil::SafeCast(x); + sum += mediapipe::MathUtil::SafeCast(x); x += 0.1; - sum += mediapipe::MathUtil::SafeCast(x); + sum += mediapipe::MathUtil::SafeCast(x); x += 0.1; } EXPECT_NE(sum, 0); // Don't let 'sum' get optimized away. @@ -229,15 +229,15 @@ static void BM_SafeInt64Round(benchmark::State& state) { double x = 0.1; int sum = 0; for (auto _ : state) { - sum += mediapipe::MathUtil::SafeRound(x); + sum += mediapipe::MathUtil::SafeRound(x); x += 0.1; - sum += mediapipe::MathUtil::SafeRound(x); + sum += mediapipe::MathUtil::SafeRound(x); x += 0.1; - sum += mediapipe::MathUtil::SafeRound(x); + sum += mediapipe::MathUtil::SafeRound(x); x += 0.1; - sum += mediapipe::MathUtil::SafeRound(x); + sum += mediapipe::MathUtil::SafeRound(x); x += 0.1; - sum += mediapipe::MathUtil::SafeRound(x); + sum += mediapipe::MathUtil::SafeRound(x); x += 0.1; } EXPECT_NE(sum, 0); // Don't let 'sum' get optimized away. @@ -262,8 +262,8 @@ TEST(MathUtil, IntRound) { // A double-precision number has a 53-bit mantissa (52 fraction bits), // so the following value can be represented exactly. - int64 value64 = static_cast(0x1234567890abcd00); - EXPECT_EQ(mediapipe::MathUtil::Round(static_cast(value64)), + int64_t value64 = static_cast(0x1234567890abcd00); + EXPECT_EQ(mediapipe::MathUtil::Round(static_cast(value64)), value64); } @@ -369,7 +369,7 @@ class SafeCastTester { if (sizeof(FloatIn) >= 64) { // A double-precision number has a 53-bit mantissa (52 fraction bits), // so the following value can be represented exactly by a double. - int64 value64 = static_cast(0x1234567890abcd00); + int64_t value64 = static_cast(0x1234567890abcd00); const IntOut expected = (sizeof(IntOut) >= 64) ? static_cast(value64) : imax; EXPECT_EQ( @@ -536,22 +536,22 @@ class SafeCastTester { }; TEST(MathUtil, SafeCast) { - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); - SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); + SafeCastTester::Run(); // Spot-check SafeCast EXPECT_EQ(mediapipe::MathUtil::SafeCast(static_cast(12345.678)), @@ -682,7 +682,7 @@ class SafeRoundTester { if (sizeof(FloatIn) >= 64) { // A double-precision number has a 53-bit mantissa (52 fraction bits), // so the following value can be represented exactly by a double. - int64 value64 = static_cast(0x1234567890abcd00); + int64_t value64 = static_cast(0x1234567890abcd00); const IntOut expected = (sizeof(IntOut) >= 64) ? static_cast(value64) : imax; EXPECT_EQ( @@ -843,22 +843,22 @@ class SafeRoundTester { }; TEST(MathUtil, SafeRound) { - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); - SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); + SafeRoundTester::Run(); // Spot-check SafeRound EXPECT_EQ(mediapipe::MathUtil::SafeRound(static_cast(12345.678)), diff --git a/mediapipe/framework/deps/monotonic_clock_test.cc b/mediapipe/framework/deps/monotonic_clock_test.cc index 533830e43..0a049392f 100644 --- a/mediapipe/framework/deps/monotonic_clock_test.cc +++ b/mediapipe/framework/deps/monotonic_clock_test.cc @@ -244,7 +244,7 @@ TEST_F(MonotonicClockTest, RealTime) { // Call mono_clock->Now() continuously for FLAGS_real_test_secs seconds. absl::Time start = absl::Now(); absl::Time time = start; - int64 num_calls = 0; + int64_t num_calls = 0; do { absl::Time last_time = time; time = mono_clock->TimeNow(); @@ -406,7 +406,7 @@ class ClockFrenzy { while (Running()) { // 40% of the time, advance a simulated clock. // 50% of the time, read a monotonic clock. - const int32 u = UniformRandom(100); + const int32_t u = UniformRandom(100); if (u < 40) { // Pick a simulated clock and advance it. const int nclocks = sim_clocks_.size(); @@ -463,9 +463,9 @@ class ClockFrenzy { // Thread-safe random number generation functions for use by other class // member functions. - int32 UniformRandom(int32 n) { + int32_t UniformRandom(int32_t n) { absl::MutexLock l(&lock_); - return std::uniform_int_distribution(0, n - 1)(*random_); + return std::uniform_int_distribution(0, n - 1)(*random_); } float RndFloatRandom() { diff --git a/mediapipe/framework/deps/registration.h b/mediapipe/framework/deps/registration.h index cc8ba03fe..7965539b6 100644 --- a/mediapipe/framework/deps/registration.h +++ b/mediapipe/framework/deps/registration.h @@ -301,6 +301,18 @@ class FunctionRegistry { return cxx_name; } + // Returns a type name with '.' separated namespaces. + static std::string GetLookupName(const absl::string_view cxx_type_name) { + constexpr absl::string_view kCxxSep = "::"; + constexpr absl::string_view kNameSep = "."; + std::vector names = + absl::StrSplit(cxx_type_name, kCxxSep); + if (names[0].empty()) { + names.erase(names.begin()); + } + return absl::StrJoin(names, kNameSep); + } + private: mutable absl::Mutex lock_; absl::flat_hash_map functions_ ABSL_GUARDED_BY(lock_); diff --git a/mediapipe/framework/deps/safe_int_test.cc b/mediapipe/framework/deps/safe_int_test.cc index 7f385848f..83932d551 100644 --- a/mediapipe/framework/deps/safe_int_test.cc +++ b/mediapipe/framework/deps/safe_int_test.cc @@ -20,21 +20,21 @@ #include "mediapipe/framework/port/gtest.h" -MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeInt8, int8, +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeInt8, int8_t, mediapipe::intops::LogFatalOnError); -MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeUInt8, uint8, +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeUInt8, uint8_t, mediapipe::intops::LogFatalOnError); -MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeInt16, int16, +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeInt16, int16_t, mediapipe::intops::LogFatalOnError); -MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeUInt16, uint16, +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeUInt16, uint16_t, mediapipe::intops::LogFatalOnError); -MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeInt32, int32, +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeInt32, int32_t, mediapipe::intops::LogFatalOnError); -MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeInt64, int64, +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeInt64, int64_t, mediapipe::intops::LogFatalOnError); -MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeUInt32, uint32, +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeUInt32, uint32_t, mediapipe::intops::LogFatalOnError); -MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeUInt64, uint64, +MEDIAPIPE_DEFINE_SAFE_INT_TYPE(SafeUInt64, uint64_t, mediapipe::intops::LogFatalOnError); namespace mediapipe { @@ -102,8 +102,8 @@ TYPED_TEST(SignNeutralSafeIntTest, TestCtorFailures) { typedef typename T::ValueType V; { // Test out-of-bounds construction. - if (std::numeric_limits::is_signed || sizeof(V) < sizeof(uint64)) { - EXPECT_DEATH((T(std::numeric_limits::max())), "bounds"); + if (std::numeric_limits::is_signed || sizeof(V) < sizeof(uint64_t)) { + EXPECT_DEATH((T(std::numeric_limits::max())), "bounds"); } } { // Test out-of-bounds construction from float. @@ -233,20 +233,20 @@ TYPED_TEST(SignNeutralSafeIntTest, TestMultiply) { typedef typename T::ValueType V; // Test positive vs. positive multiplication across types. - TEST_T_OP_NUM(9, *, int32, 3); - TEST_T_OP_NUM(9, *, uint32, 3); + TEST_T_OP_NUM(9, *, int32_t, 3); + TEST_T_OP_NUM(9, *, uint32_t, 3); TEST_T_OP_NUM(9, *, float, 3); TEST_T_OP_NUM(9, *, double, 3); // Test positive vs. zero multiplication commutatively across types. This // was a real bug. - TEST_T_OP_NUM(93, *, int32, 0); - TEST_T_OP_NUM(93, *, uint32, 0); + TEST_T_OP_NUM(93, *, int32_t, 0); + TEST_T_OP_NUM(93, *, uint32_t, 0); TEST_T_OP_NUM(93, *, float, 0); TEST_T_OP_NUM(93, *, double, 0); - TEST_T_OP_NUM(0, *, int32, 76); - TEST_T_OP_NUM(0, *, uint32, 76); + TEST_T_OP_NUM(0, *, int32_t, 76); + TEST_T_OP_NUM(0, *, uint32_t, 76); TEST_T_OP_NUM(0, *, float, 76); TEST_T_OP_NUM(0, *, double, 76); @@ -279,14 +279,14 @@ TYPED_TEST(SignNeutralSafeIntTest, TestDivide) { typedef typename T::ValueType V; // Test positive vs. positive division across types. - TEST_T_OP_NUM(9, /, int32, 3); - TEST_T_OP_NUM(9, /, uint32, 3); + TEST_T_OP_NUM(9, /, int32_t, 3); + TEST_T_OP_NUM(9, /, uint32_t, 3); TEST_T_OP_NUM(9, /, float, 3); TEST_T_OP_NUM(9, /, double, 3); // Test zero vs. positive division across types. - TEST_T_OP_NUM(0, /, int32, 76); - TEST_T_OP_NUM(0, /, uint32, 76); + TEST_T_OP_NUM(0, /, int32_t, 76); + TEST_T_OP_NUM(0, /, uint32_t, 76); TEST_T_OP_NUM(0, /, float, 76); TEST_T_OP_NUM(0, /, double, 76); } @@ -307,12 +307,12 @@ TYPED_TEST(SignNeutralSafeIntTest, TestModulo) { typedef typename T::ValueType V; // Test positive vs. positive modulo across signedness. - TEST_T_OP_NUM(7, %, int32, 6); - TEST_T_OP_NUM(7, %, uint32, 6); + TEST_T_OP_NUM(7, %, int32_t, 6); + TEST_T_OP_NUM(7, %, uint32_t, 6); // Test zero vs. positive modulo across signedness. - TEST_T_OP_NUM(0, %, int32, 6); - TEST_T_OP_NUM(0, %, uint32, 6); + TEST_T_OP_NUM(0, %, int32_t, 6); + TEST_T_OP_NUM(0, %, uint32_t, 6); } TYPED_TEST(SignNeutralSafeIntTest, TestModuloFailures) { @@ -534,28 +534,28 @@ TYPED_TEST(SignedSafeIntTest, TestMultiply) { typedef typename T::ValueType V; // Test negative vs. positive multiplication across types. - TEST_T_OP_NUM(-9, *, int32, 3); - TEST_T_OP_NUM(-9, *, uint32, 3); + TEST_T_OP_NUM(-9, *, int32_t, 3); + TEST_T_OP_NUM(-9, *, uint32_t, 3); TEST_T_OP_NUM(-9, *, float, 3); TEST_T_OP_NUM(-9, *, double, 3); // Test positive vs. negative multiplication across types. - TEST_T_OP_NUM(9, *, int32, -3); + TEST_T_OP_NUM(9, *, int32_t, -3); // Don't cover unsigneds that are initialized from negative values. TEST_T_OP_NUM(9, *, float, -3); TEST_T_OP_NUM(9, *, double, -3); // Test negative vs. negative multiplication across types. - TEST_T_OP_NUM(-9, *, int32, -3); + TEST_T_OP_NUM(-9, *, int32_t, -3); // Don't cover unsigneds that are initialized from negative values. TEST_T_OP_NUM(-9, *, float, -3); TEST_T_OP_NUM(-9, *, double, -3); // Test negative vs. zero multiplication commutatively across types. - TEST_T_OP_NUM(-93, *, int32, 0); - TEST_T_OP_NUM(-93, *, uint32, 0); + TEST_T_OP_NUM(-93, *, int32_t, 0); + TEST_T_OP_NUM(-93, *, uint32_t, 0); TEST_T_OP_NUM(-93, *, float, 0); TEST_T_OP_NUM(-93, *, double, 0); - TEST_T_OP_NUM(0, *, int32, -76); - TEST_T_OP_NUM(0, *, uint32, -76); + TEST_T_OP_NUM(0, *, int32_t, -76); + TEST_T_OP_NUM(0, *, uint32_t, -76); TEST_T_OP_NUM(0, *, float, -76); TEST_T_OP_NUM(0, *, double, -76); @@ -600,24 +600,24 @@ TYPED_TEST(SignedSafeIntTest, TestDivide) { typedef typename T::ValueType V; // Test negative vs. positive division across types. - TEST_T_OP_NUM(-9, /, int32, 3); - TEST_T_OP_NUM(-9, /, uint32, 3); + TEST_T_OP_NUM(-9, /, int32_t, 3); + TEST_T_OP_NUM(-9, /, uint32_t, 3); TEST_T_OP_NUM(-9, /, float, 3); TEST_T_OP_NUM(-9, /, double, 3); // Test positive vs. negative division across types. - TEST_T_OP_NUM(9, /, int32, -3); - TEST_T_OP_NUM(9, /, uint32, -3); + TEST_T_OP_NUM(9, /, int32_t, -3); + TEST_T_OP_NUM(9, /, uint32_t, -3); TEST_T_OP_NUM(9, /, float, -3); TEST_T_OP_NUM(9, /, double, -3); // Test negative vs. negative division across types. - TEST_T_OP_NUM(-9, /, int32, -3); - TEST_T_OP_NUM(-9, /, uint32, -3); + TEST_T_OP_NUM(-9, /, int32_t, -3); + TEST_T_OP_NUM(-9, /, uint32_t, -3); TEST_T_OP_NUM(-9, /, float, -3); TEST_T_OP_NUM(-9, /, double, -3); // Test zero vs. negative division across types. - TEST_T_OP_NUM(0, /, int32, -76); - TEST_T_OP_NUM(0, /, uint32, -76); + TEST_T_OP_NUM(0, /, int32_t, -76); + TEST_T_OP_NUM(0, /, uint32_t, -76); TEST_T_OP_NUM(0, /, float, -76); TEST_T_OP_NUM(0, /, double, -76); } @@ -638,18 +638,18 @@ TYPED_TEST(SignedSafeIntTest, TestModulo) { typedef typename T::ValueType V; // Test negative vs. positive modulo across signedness. - TEST_T_OP_NUM(-7, %, int32, 6); - TEST_T_OP_NUM(-7, %, uint32, 6); + TEST_T_OP_NUM(-7, %, int32_t, 6); + TEST_T_OP_NUM(-7, %, uint32_t, 6); // Test positive vs. negative modulo across signedness. - TEST_T_OP_NUM(7, %, int32, -6); - TEST_T_OP_NUM(7, %, uint32, -6); + TEST_T_OP_NUM(7, %, int32_t, -6); + TEST_T_OP_NUM(7, %, uint32_t, -6); // Test negative vs. negative modulo across signedness. - TEST_T_OP_NUM(-7, %, int32, -6); - TEST_T_OP_NUM(-7, %, uint32, -6); + TEST_T_OP_NUM(-7, %, int32_t, -6); + TEST_T_OP_NUM(-7, %, uint32_t, -6); // Test zero vs. negative modulo across signedness. - TEST_T_OP_NUM(0, %, int32, -6); - TEST_T_OP_NUM(0, %, uint32, -6); + TEST_T_OP_NUM(0, %, int32_t, -6); + TEST_T_OP_NUM(0, %, uint32_t, -6); } TYPED_TEST(SignedSafeIntTest, TestModuloFailures) { diff --git a/mediapipe/framework/formats/image_format.proto b/mediapipe/framework/formats/image_format.proto index 61e004ac6..e9b69a4c1 100644 --- a/mediapipe/framework/formats/image_format.proto +++ b/mediapipe/framework/formats/image_format.proto @@ -69,6 +69,9 @@ message ImageFormat { // Two floats per pixel. VEC32F2 = 12; + // Four floats per pixel. + VEC32F4 = 13; + // LAB, interleaved: one byte for L, then one byte for a, then one // byte for b for each pixel. LAB8 = 10; diff --git a/mediapipe/framework/formats/image_frame.cc b/mediapipe/framework/formats/image_frame.cc index 913ffae24..2de819a35 100644 --- a/mediapipe/framework/formats/image_frame.cc +++ b/mediapipe/framework/formats/image_frame.cc @@ -33,7 +33,7 @@ namespace mediapipe { namespace { -int CountOnes(uint32 n) { +int CountOnes(uint32_t n) { #if (defined(__i386__) || defined(__x86_64__)) && defined(__POPCNT__) && \ defined(__GNUC__) return __builtin_popcount(n); @@ -47,20 +47,21 @@ int CountOnes(uint32 n) { } // namespace const ImageFrame::Deleter ImageFrame::PixelDataDeleter::kArrayDelete = - std::default_delete(); + std::default_delete(); const ImageFrame::Deleter ImageFrame::PixelDataDeleter::kFree = free; const ImageFrame::Deleter ImageFrame::PixelDataDeleter::kAlignedFree = aligned_free; -const ImageFrame::Deleter ImageFrame::PixelDataDeleter::kNone = [](uint8* x) {}; +const ImageFrame::Deleter ImageFrame::PixelDataDeleter::kNone = [](uint8_t* x) { +}; -const uint32 ImageFrame::kDefaultAlignmentBoundary; -const uint32 ImageFrame::kGlDefaultAlignmentBoundary; +const uint32_t ImageFrame::kDefaultAlignmentBoundary; +const uint32_t ImageFrame::kGlDefaultAlignmentBoundary; ImageFrame::ImageFrame() : format_(ImageFormat::UNKNOWN), width_(0), height_(0), width_step_(0) {} ImageFrame::ImageFrame(ImageFormat::Format format, int width, int height, - uint32 alignment_boundary) + uint32_t alignment_boundary) : format_(format), width_(width), height_(height) { Reset(format, width, height, alignment_boundary); } @@ -71,7 +72,7 @@ ImageFrame::ImageFrame(ImageFormat::Format format, int width, int height) } ImageFrame::ImageFrame(ImageFormat::Format format, int width, int height, - int width_step, uint8* pixel_data, + int width_step, uint8_t* pixel_data, ImageFrame::Deleter deleter) { AdoptPixelData(format, width, height, width_step, pixel_data, deleter); } @@ -93,7 +94,7 @@ ImageFrame& ImageFrame::operator=(ImageFrame&& move_from) { } void ImageFrame::Reset(ImageFormat::Format format, int width, int height, - uint32 alignment_boundary) { + uint32_t alignment_boundary) { format_ = format; width_ = width; height_ = height; @@ -101,7 +102,7 @@ void ImageFrame::Reset(ImageFormat::Format format, int width, int height, CHECK(IsValidAlignmentNumber(alignment_boundary)); width_step_ = width * NumberOfChannels() * ByteDepth(); if (alignment_boundary == 1) { - pixel_data_ = {new uint8[height * width_step_], + pixel_data_ = {new uint8_t[height * width_step_], PixelDataDeleter::kArrayDelete}; } else { // Increase width_step_ to the smallest multiple of alignment_boundary @@ -109,14 +110,14 @@ void ImageFrame::Reset(ImageFormat::Format format, int width, int height, // twiddling bits. alignment_boundary - 1 is a mask which sets all // the low order bits. width_step_ = ((width_step_ - 1) | (alignment_boundary - 1)) + 1; - pixel_data_ = {reinterpret_cast(aligned_malloc(height * width_step_, - alignment_boundary)), + pixel_data_ = {reinterpret_cast(aligned_malloc( + height * width_step_, alignment_boundary)), PixelDataDeleter::kAlignedFree}; } } void ImageFrame::AdoptPixelData(ImageFormat::Format format, int width, - int height, int width_step, uint8* pixel_data, + int height, int width_step, uint8_t* pixel_data, ImageFrame::Deleter deleter) { format_ = format; width_ = width; @@ -129,12 +130,12 @@ void ImageFrame::AdoptPixelData(ImageFormat::Format format, int width, pixel_data_ = {pixel_data, deleter}; } -std::unique_ptr ImageFrame::Release() { +std::unique_ptr ImageFrame::Release() { return std::move(pixel_data_); } void ImageFrame::InternalCopyFrom(int width, int height, int width_step, - int channel_size, const uint8* pixel_data) { + int channel_size, const uint8_t* pixel_data) { CHECK_EQ(width_, width); CHECK_EQ(height_, height); // row_bytes = channel_size * num_channels * width @@ -192,9 +193,9 @@ void ImageFrame::SetAlignmentPaddingAreas() { const int pixel_size = ByteDepth() * NumberOfChannels(); const int padding_size = width_step_ - width_ * pixel_size; for (int row = 0; row < height_; ++row) { - uint8* row_start = pixel_data_.get() + width_step_ * row; - uint8* last_pixel_in_row = row_start + (width_ - 1) * pixel_size; - uint8* padding = row_start + width_ * pixel_size; + uint8_t* row_start = pixel_data_.get() + width_step_ * row; + uint8_t* last_pixel_in_row = row_start + (width_ - 1) * pixel_size; + uint8_t* padding = row_start + width_ * pixel_size; int padding_index = 0; while (padding_index + pixel_size - 1 < padding_size) { // Copy the entire last pixel in the row into this padding pixel. @@ -220,7 +221,7 @@ bool ImageFrame::IsContiguous() const { return width_step_ == width_ * NumberOfChannels() * ByteDepth(); } -bool ImageFrame::IsAligned(uint32 alignment_boundary) const { +bool ImageFrame::IsAligned(uint32_t alignment_boundary) const { CHECK(IsValidAlignmentNumber(alignment_boundary)); if (!pixel_data_) { return false; @@ -236,7 +237,7 @@ bool ImageFrame::IsAligned(uint32 alignment_boundary) const { } // static -bool ImageFrame::IsValidAlignmentNumber(uint32 alignment_boundary) { +bool ImageFrame::IsValidAlignmentNumber(uint32_t alignment_boundary) { return CountOnes(alignment_boundary) == 1; } @@ -279,6 +280,8 @@ int ImageFrame::NumberOfChannelsForFormat(ImageFormat::Format format) { return 1; case ImageFormat::VEC32F2: return 2; + case ImageFormat::VEC32F4: + return 4; case ImageFormat::LAB8: return 3; case ImageFormat::SBGRA: @@ -293,25 +296,27 @@ int ImageFrame::ChannelSize() const { return ChannelSizeForFormat(format_); } int ImageFrame::ChannelSizeForFormat(ImageFormat::Format format) { switch (format) { case ImageFormat::GRAY8: - return sizeof(uint8); + return sizeof(uint8_t); case ImageFormat::SRGB: - return sizeof(uint8); + return sizeof(uint8_t); case ImageFormat::SRGBA: - return sizeof(uint8); + return sizeof(uint8_t); case ImageFormat::GRAY16: - return sizeof(uint16); + return sizeof(uint16_t); case ImageFormat::SRGB48: - return sizeof(uint16); + return sizeof(uint16_t); case ImageFormat::SRGBA64: - return sizeof(uint16); + return sizeof(uint16_t); case ImageFormat::VEC32F1: return sizeof(float); case ImageFormat::VEC32F2: return sizeof(float); + case ImageFormat::VEC32F4: + return sizeof(float); case ImageFormat::LAB8: - return sizeof(uint8); + return sizeof(uint8_t); case ImageFormat::SBGRA: - return sizeof(uint8); + return sizeof(uint8_t); default: LOG(FATAL) << InvalidFormatString(format); } @@ -337,6 +342,8 @@ int ImageFrame::ByteDepthForFormat(ImageFormat::Format format) { return 4; case ImageFormat::VEC32F2: return 4; + case ImageFormat::VEC32F4: + return 4; case ImageFormat::LAB8: return 1; case ImageFormat::SBGRA: @@ -347,7 +354,7 @@ int ImageFrame::ByteDepthForFormat(ImageFormat::Format format) { } void ImageFrame::CopyFrom(const ImageFrame& image_frame, - uint32 alignment_boundary) { + uint32_t alignment_boundary) { // Reset the current image. Reset(image_frame.Format(), image_frame.Width(), image_frame.Height(), alignment_boundary); @@ -359,29 +366,29 @@ void ImageFrame::CopyFrom(const ImageFrame& image_frame, } void ImageFrame::CopyPixelData(ImageFormat::Format format, int width, - int height, const uint8* pixel_data, - uint32 alignment_boundary) { + int height, const uint8_t* pixel_data, + uint32_t alignment_boundary) { CopyPixelData(format, width, height, 0 /* contiguous storage */, pixel_data, alignment_boundary); } void ImageFrame::CopyPixelData(ImageFormat::Format format, int width, int height, int width_step, - const uint8* pixel_data, - uint32 alignment_boundary) { + const uint8_t* pixel_data, + uint32_t alignment_boundary) { Reset(format, width, height, alignment_boundary); InternalCopyFrom(width, height, width_step, ChannelSizeForFormat(format), pixel_data); } -void ImageFrame::CopyToBuffer(uint8* buffer, int buffer_size) const { +void ImageFrame::CopyToBuffer(uint8_t* buffer, int buffer_size) const { CHECK(buffer); CHECK_EQ(1, ByteDepth()); const int data_size = width_ * height_ * NumberOfChannels(); CHECK_LE(data_size, buffer_size); if (IsContiguous()) { // The data is stored contiguously, we can just copy. - const uint8* src = reinterpret_cast(pixel_data_.get()); + const uint8_t* src = reinterpret_cast(pixel_data_.get()); std::copy_n(src, data_size, buffer); } else { InternalCopyToBuffer(0 /* contiguous storage */, @@ -389,14 +396,14 @@ void ImageFrame::CopyToBuffer(uint8* buffer, int buffer_size) const { } } -void ImageFrame::CopyToBuffer(uint16* buffer, int buffer_size) const { +void ImageFrame::CopyToBuffer(uint16_t* buffer, int buffer_size) const { CHECK(buffer); CHECK_EQ(2, ByteDepth()); const int data_size = width_ * height_ * NumberOfChannels(); CHECK_LE(data_size, buffer_size); if (IsContiguous()) { // The data is stored contiguously, we can just copy. - const uint16* src = reinterpret_cast(pixel_data_.get()); + const uint16_t* src = reinterpret_cast(pixel_data_.get()); std::copy_n(src, data_size, buffer); } else { InternalCopyToBuffer(0 /* contiguous storage */, diff --git a/mediapipe/framework/formats/image_frame_opencv.cc b/mediapipe/framework/formats/image_frame_opencv.cc index 940e18263..1ba8c719f 100644 --- a/mediapipe/framework/formats/image_frame_opencv.cc +++ b/mediapipe/framework/formats/image_frame_opencv.cc @@ -59,6 +59,9 @@ int GetMatType(const mediapipe::ImageFormat::Format format) { case mediapipe::ImageFormat::VEC32F2: type = CV_32FC2; break; + case mediapipe::ImageFormat::VEC32F4: + type = CV_32FC4; + break; case mediapipe::ImageFormat::LAB8: type = CV_8U; break; diff --git a/mediapipe/framework/formats/image_frame_opencv_test.cc b/mediapipe/framework/formats/image_frame_opencv_test.cc index f75915d06..87d2ffb36 100644 --- a/mediapipe/framework/formats/image_frame_opencv_test.cc +++ b/mediapipe/framework/formats/image_frame_opencv_test.cc @@ -51,8 +51,8 @@ TEST(ImageFrameOpencvTest, ConvertToMat) { // Check adding constant images. const uint8_t frame1_val = 12; const uint8_t frame2_val = 34; - SetToColor(&frame1_val, &frame1); - SetToColor(&frame2_val, &frame2); + SetToColor(&frame1_val, &frame1); + SetToColor(&frame2_val, &frame2); // Get Mat wrapper around ImageFrame memory (zero copy). cv::Mat frame1_mat = formats::MatView(&frame1); cv::Mat frame2_mat = formats::MatView(&frame2); @@ -62,7 +62,7 @@ TEST(ImageFrameOpencvTest, ConvertToMat) { EXPECT_EQ(frame_avg, frame1_val + frame2_val); // Check setting min/max pixels. - uint8* frame1_ptr = frame1.MutablePixelData(); + uint8_t* frame1_ptr = frame1.MutablePixelData(); frame1_ptr[(i_width - 5) + (i_height - 5) * frame1.WidthStep()] = 1; frame1_ptr[(i_width - 6) + (i_height - 6) * frame1.WidthStep()] = 100; double min, max; @@ -84,8 +84,8 @@ TEST(ImageFrameOpencvTest, ConvertToIpl) { // Check adding constant images. const uint8_t frame1_val = 12; const uint8_t frame2_val = 34; - SetToColor(&frame1_val, &frame1); - SetToColor(&frame2_val, &frame2); + SetToColor(&frame1_val, &frame1); + SetToColor(&frame2_val, &frame2); const cv::Mat frame1_mat = formats::MatView(&frame1); const cv::Mat frame2_mat = formats::MatView(&frame2); const cv::Mat frame_sum = frame1_mat + frame2_mat; @@ -93,7 +93,7 @@ TEST(ImageFrameOpencvTest, ConvertToIpl) { EXPECT_EQ(frame_avg, frame1_val + frame2_val); // Check setting min/max pixels. - uint8* frame1_ptr = frame1.MutablePixelData(); + uint8_t* frame1_ptr = frame1.MutablePixelData(); frame1_ptr[(i_width - 5) + (i_height - 5) * frame1.WidthStep()] = 1; frame1_ptr[(i_width - 6) + (i_height - 6) * frame1.WidthStep()] = 100; double min, max; @@ -113,6 +113,7 @@ TEST(ImageFrameOpencvTest, ImageFormats) { ImageFrame frame_g16(ImageFormat::GRAY16, i_width, i_height); ImageFrame frame_v32f1(ImageFormat::VEC32F1, i_width, i_height); ImageFrame frame_v32f2(ImageFormat::VEC32F2, i_width, i_height); + ImageFrame frame_v32f4(ImageFormat::VEC32F4, i_width, i_height); ImageFrame frame_c3(ImageFormat::SRGB, i_width, i_height); ImageFrame frame_c4(ImageFormat::SRGBA, i_width, i_height); @@ -120,6 +121,7 @@ TEST(ImageFrameOpencvTest, ImageFormats) { cv::Mat mat_g16 = formats::MatView(&frame_g16); cv::Mat mat_v32f1 = formats::MatView(&frame_v32f1); cv::Mat mat_v32f2 = formats::MatView(&frame_v32f2); + cv::Mat mat_v32f4 = formats::MatView(&frame_v32f4); cv::Mat mat_c3 = formats::MatView(&frame_c3); cv::Mat mat_c4 = formats::MatView(&frame_c4); @@ -127,6 +129,7 @@ TEST(ImageFrameOpencvTest, ImageFormats) { EXPECT_EQ(mat_g16.type(), CV_16UC1); EXPECT_EQ(mat_v32f1.type(), CV_32FC1); EXPECT_EQ(mat_v32f2.type(), CV_32FC2); + EXPECT_EQ(mat_v32f4.type(), CV_32FC4); EXPECT_EQ(mat_c3.type(), CV_8UC3); EXPECT_EQ(mat_c4.type(), CV_8UC4); } diff --git a/mediapipe/framework/formats/image_opencv.cc b/mediapipe/framework/formats/image_opencv.cc index 9ccaa632b..498c7831f 100644 --- a/mediapipe/framework/formats/image_opencv.cc +++ b/mediapipe/framework/formats/image_opencv.cc @@ -60,6 +60,9 @@ int GetMatType(const mediapipe::ImageFormat::Format format) { case mediapipe::ImageFormat::VEC32F2: type = CV_32FC2; break; + case mediapipe::ImageFormat::VEC32F4: + type = CV_32FC4; + break; case mediapipe::ImageFormat::LAB8: type = CV_8U; break; @@ -96,7 +99,7 @@ std::shared_ptr MatView(const mediapipe::Image* image) { image->image_format()))}; auto owner = std::make_shared(const_cast(image)); - uint8* data_ptr = owner->lock.Pixels(); + uint8_t* data_ptr = owner->lock.Pixels(); CHECK(data_ptr != nullptr); // Use Image to initialize in-place. Image still owns memory. if (steps[0] == sizes[1] * image->channels() * diff --git a/mediapipe/framework/formats/location_opencv.cc b/mediapipe/framework/formats/location_opencv.cc index de59633ca..6e15b299a 100644 --- a/mediapipe/framework/formats/location_opencv.cc +++ b/mediapipe/framework/formats/location_opencv.cc @@ -91,7 +91,7 @@ std::unique_ptr GetCvMask(const Location& location) { new cv::Mat(mask.height(), mask.width(), CV_8UC1, cv::Scalar(0))); for (const auto& interval : location_data.mask().rasterization().interval()) { for (int x = interval.left_x(); x <= interval.right_x(); ++x) { - mat->at(interval.y(), x) = 255; + mat->at(interval.y(), x) = 255; } } return mat; @@ -174,7 +174,7 @@ void EnlargeLocation(Location& location, const float factor) { } else { cv::erode(*mask, *mask, morph_element); } - CreateCvMaskLocation(*mask).ConvertToProto(&location_data); + CreateCvMaskLocation(*mask).ConvertToProto(&location_data); break; } } diff --git a/mediapipe/framework/formats/location_opencv_test.cc b/mediapipe/framework/formats/location_opencv_test.cc index 5740d2b17..6e3a89b58 100644 --- a/mediapipe/framework/formats/location_opencv_test.cc +++ b/mediapipe/framework/formats/location_opencv_test.cc @@ -25,8 +25,8 @@ namespace mediapipe { // segments per row. static const int kWidth = 7; static const int kHeight = 3; -const std::vector kTestPatternVector = {0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, - 0, 0, 0, 1, 0, 1, 0, 1, 0, 0}; +const std::vector kTestPatternVector = { + 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0}; // Interval {y, x_start, x_end} representation of kTestPatternVector. const std::vector> kTestPatternIntervals = { @@ -67,8 +67,8 @@ TEST(LocationOpencvTest, CreateBBoxLocation) { } TEST(LocationOpencvTest, CreateCvMaskLocation) { - cv::Mat_ test_mask(kHeight, kWidth, - const_cast(kTestPatternVector.data())); + cv::Mat_ test_mask(kHeight, kWidth, + const_cast(kTestPatternVector.data())); Location location = CreateCvMaskLocation(test_mask); auto intervals = location.ConvertToProto().mask().rasterization().interval(); EXPECT_EQ(intervals.size(), kTestPatternIntervals.size()); @@ -157,8 +157,8 @@ TEST(LocationOpenCvTest, GetCvMask) { auto cv_mask = *GetCvMask(test_location); EXPECT_EQ(cv_mask.cols * cv_mask.rows, kTestPatternVector.size()); int flat_idx = 0; - for (auto it = cv_mask.begin(); it != cv_mask.end(); ++it) { - const uint8 expected_value = kTestPatternVector[flat_idx] == 0 ? 0 : 255; + for (auto it = cv_mask.begin(); it != cv_mask.end(); ++it) { + const uint8_t expected_value = kTestPatternVector[flat_idx] == 0 ? 0 : 255; EXPECT_EQ(*it, expected_value); flat_idx++; } diff --git a/mediapipe/framework/formats/motion/BUILD b/mediapipe/framework/formats/motion/BUILD index c9bb8b4ff..919b82406 100644 --- a/mediapipe/framework/formats/motion/BUILD +++ b/mediapipe/framework/formats/motion/BUILD @@ -16,23 +16,17 @@ # Description: # Working with dense optical flow in mediapipe. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) package(default_visibility = ["//visibility:public"]) -proto_library( +mediapipe_proto_library( name = "optical_flow_field_data_proto", srcs = ["optical_flow_field_data.proto"], ) -mediapipe_cc_proto_library( - name = "optical_flow_field_data_cc_proto", - srcs = ["optical_flow_field_data.proto"], - deps = [":optical_flow_field_data_proto"], -) - cc_library( name = "optical_flow_field", srcs = ["optical_flow_field.cc"], diff --git a/mediapipe/framework/formats/motion/optical_flow_field.cc b/mediapipe/framework/formats/motion/optical_flow_field.cc index 1e6adef48..a96504192 100644 --- a/mediapipe/framework/formats/motion/optical_flow_field.cc +++ b/mediapipe/framework/formats/motion/optical_flow_field.cc @@ -66,12 +66,12 @@ cv::Mat MakeVisualizationHsv(const cv::Mat_& angles, cv::Mat hsv(angles.size(), CV_8UC3); for (int r = 0; r < hsv.rows; ++r) { for (int c = 0; c < hsv.cols; ++c) { - const uint8 hue = static_cast(255.0f * angles(r, c) / 360.0f); - uint8 saturation = 255; + const uint8_t hue = static_cast(255.0f * angles(r, c) / 360.0f); + uint8_t saturation = 255; if (magnitudes(r, c) < max_mag) { - saturation = static_cast(255.0f * magnitudes(r, c) / max_mag); + saturation = static_cast(255.0f * magnitudes(r, c) / max_mag); } - const uint8 value = 255; + const uint8_t value = 255; hsv.at(r, c) = cv::Vec3b(hue, saturation, value); } @@ -282,7 +282,7 @@ void OpticalFlowField::EstimateMotionConsistencyOcclusions( Location OpticalFlowField::FindMotionInconsistentPixels( const OpticalFlowField& forward, const OpticalFlowField& backward, double spatial_distance_threshold) { - const uint8 kOccludedPixelValue = 1; + const uint8_t kOccludedPixelValue = 1; const double threshold_sq = spatial_distance_threshold * spatial_distance_threshold; cv::Mat occluded = cv::Mat::zeros(forward.height(), forward.width(), CV_8UC1); @@ -301,10 +301,10 @@ Location OpticalFlowField::FindMotionInconsistentPixels( if (!in_bounds_in_next_frame || Point2_f(x - round_trip_x, y - round_trip_y).ToVector().Norm2() > threshold_sq) { - occluded.at(y, x) = kOccludedPixelValue; + occluded.at(y, x) = kOccludedPixelValue; } } } - return CreateCvMaskLocation(occluded); + return CreateCvMaskLocation(occluded); } } // namespace mediapipe diff --git a/mediapipe/framework/formats/motion/optical_flow_field_test.cc b/mediapipe/framework/formats/motion/optical_flow_field_test.cc index 521256c48..fdce418fa 100644 --- a/mediapipe/framework/formats/motion/optical_flow_field_test.cc +++ b/mediapipe/framework/formats/motion/optical_flow_field_test.cc @@ -300,15 +300,15 @@ TEST(OpticalFlowField, Occlusions) { for (int y = 0; y < occlusion_mat->rows; ++y) { // Bottom row and pixel at (x, y) = (1, 0) are occluded. if (y == occlusion_mat->rows - 1 || (x == 1 && y == 0)) { - EXPECT_GT(occlusion_mat->at(y, x), 0); + EXPECT_GT(occlusion_mat->at(y, x), 0); } else { - EXPECT_EQ(0, occlusion_mat->at(y, x)); + EXPECT_EQ(0, occlusion_mat->at(y, x)); } // Top row and pixel at (x, y) = (1, 2) are disoccluded. if (y == 0 || (x == 1 && y == 2)) { - EXPECT_GT(disocclusion_mat->at(y, x), 0); + EXPECT_GT(disocclusion_mat->at(y, x), 0); } else { - EXPECT_EQ(0, disocclusion_mat->at(y, x)); + EXPECT_EQ(0, disocclusion_mat->at(y, x)); } } } diff --git a/mediapipe/framework/scheduler_queue.cc b/mediapipe/framework/scheduler_queue.cc index efad97282..33214cf64 100644 --- a/mediapipe/framework/scheduler_queue.cc +++ b/mediapipe/framework/scheduler_queue.cc @@ -240,7 +240,7 @@ void SchedulerQueue::RunCalculatorNode(CalculatorNode* node, // we should not run any more sources. Close the node if it is a source. if (shared_->stopping && node->IsSource()) { VLOG(4) << "Closing " << node->DebugName() << " due to StatusStop()."; - int64 start_time = shared_->timer.StartNode(); + int64_t start_time = shared_->timer.StartNode(); // It's OK to not reset/release the prepared CalculatorContext since a // source node always reuses the same CalculatorContext and Close() doesn't // access any inputs. @@ -256,7 +256,7 @@ void SchedulerQueue::RunCalculatorNode(CalculatorNode* node, } else { // Note that we don't need a lock because only one thread can execute this // due to the lock on running_nodes. - int64 start_time = shared_->timer.StartNode(); + int64_t start_time = shared_->timer.StartNode(); const absl::Status result = node->ProcessNode(cc); shared_->timer.EndNode(start_time); @@ -283,7 +283,7 @@ void SchedulerQueue::RunCalculatorNode(CalculatorNode* node, void SchedulerQueue::OpenCalculatorNode(CalculatorNode* node) { VLOG(3) << "Opening " << node->DebugName(); - int64 start_time = shared_->timer.StartNode(); + int64_t start_time = shared_->timer.StartNode(); const absl::Status result = node->OpenNode(); shared_->timer.EndNode(start_time); if (!result.ok()) { diff --git a/mediapipe/framework/stream_handler/BUILD b/mediapipe/framework/stream_handler/BUILD index 68a9af52d..8b54ade8b 100644 --- a/mediapipe/framework/stream_handler/BUILD +++ b/mediapipe/framework/stream_handler/BUILD @@ -13,7 +13,7 @@ # limitations under the License. # -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) @@ -22,56 +22,32 @@ package( features = ["-layering_check"], ) -proto_library( +mediapipe_proto_library( name = "default_input_stream_handler_proto", srcs = ["default_input_stream_handler.proto"], deps = ["//mediapipe/framework:mediapipe_options_proto"], + alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "fixed_size_input_stream_handler_proto", srcs = ["fixed_size_input_stream_handler.proto"], deps = ["//mediapipe/framework:mediapipe_options_proto"], + alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "sync_set_input_stream_handler_proto", srcs = ["sync_set_input_stream_handler.proto"], deps = ["//mediapipe/framework:mediapipe_options_proto"], + alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "timestamp_align_input_stream_handler_proto", srcs = ["timestamp_align_input_stream_handler.proto"], deps = ["//mediapipe/framework:mediapipe_options_proto"], -) - -mediapipe_cc_proto_library( - name = "default_input_stream_handler_cc_proto", - srcs = ["default_input_stream_handler.proto"], - cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - deps = [":default_input_stream_handler_proto"], -) - -mediapipe_cc_proto_library( - name = "fixed_size_input_stream_handler_cc_proto", - srcs = ["fixed_size_input_stream_handler.proto"], - cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - deps = [":fixed_size_input_stream_handler_proto"], -) - -mediapipe_cc_proto_library( - name = "sync_set_input_stream_handler_cc_proto", - srcs = ["sync_set_input_stream_handler.proto"], - cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - deps = [":sync_set_input_stream_handler_proto"], -) - -mediapipe_cc_proto_library( - name = "timestamp_align_input_stream_handler_cc_proto", - srcs = ["timestamp_align_input_stream_handler.proto"], - cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - deps = [":timestamp_align_input_stream_handler_proto"], + alwayslink = 1, ) cc_library( diff --git a/mediapipe/framework/tool/executor_util.cc b/mediapipe/framework/tool/executor_util.cc index 91089cc71..6d967768e 100644 --- a/mediapipe/framework/tool/executor_util.cc +++ b/mediapipe/framework/tool/executor_util.cc @@ -22,7 +22,7 @@ namespace mediapipe { namespace tool { -void EnsureMinimumDefaultExecutorStackSize(const int32 min_stack_size, +void EnsureMinimumDefaultExecutorStackSize(const int32_t min_stack_size, CalculatorGraphConfig* config) { mediapipe::ExecutorConfig* default_executor_config = nullptr; for (mediapipe::ExecutorConfig& executor_config : diff --git a/mediapipe/framework/tool/mediapipe_proto.bzl b/mediapipe/framework/tool/mediapipe_proto.bzl index 7ed87aba9..527774ff3 100644 --- a/mediapipe/framework/tool/mediapipe_proto.bzl +++ b/mediapipe/framework/tool/mediapipe_proto.bzl @@ -90,7 +90,6 @@ def mediapipe_proto_library_impl( visibility = visibility, testonly = testonly, compatible_with = compatible_with, - alwayslink = alwayslink, )) if def_cc_proto: diff --git a/mediapipe/framework/tool/options_field_util.cc b/mediapipe/framework/tool/options_field_util.cc index 483b023b9..308932d4f 100644 --- a/mediapipe/framework/tool/options_field_util.cc +++ b/mediapipe/framework/tool/options_field_util.cc @@ -487,24 +487,24 @@ FieldData AsFieldData(const proto_ns::MessageLite& message) { // Represents a protobuf enum value stored in a Packet. struct ProtoEnum { - ProtoEnum(int32 v) : value(v) {} - int32 value; + ProtoEnum(int32_t v) : value(v) {} + int32_t value; }; absl::StatusOr AsPacket(const FieldData& data) { Packet result; switch (data.value_case()) { case FieldData::ValueCase::kInt32Value: - result = MakePacket(data.int32_value()); + result = MakePacket(data.int32_value()); break; case FieldData::ValueCase::kInt64Value: - result = MakePacket(data.int64_value()); + result = MakePacket(data.int64_value()); break; case FieldData::ValueCase::kUint32Value: - result = MakePacket(data.uint32_value()); + result = MakePacket(data.uint32_value()); break; case FieldData::ValueCase::kUint64Value: - result = MakePacket(data.uint64_value()); + result = MakePacket(data.uint64_value()); break; case FieldData::ValueCase::kDoubleValue: result = MakePacket(data.double_value()); @@ -538,11 +538,11 @@ absl::StatusOr AsPacket(const FieldData& data) { } absl::StatusOr AsFieldData(Packet packet) { - static const auto* kTypeIds = new std::map{ - {kTypeId, WireFormatLite::CPPTYPE_INT32}, - {kTypeId, WireFormatLite::CPPTYPE_INT64}, - {kTypeId, WireFormatLite::CPPTYPE_UINT32}, - {kTypeId, WireFormatLite::CPPTYPE_UINT64}, + static const auto* kTypeIds = new std::map{ + {kTypeId, WireFormatLite::CPPTYPE_INT32}, + {kTypeId, WireFormatLite::CPPTYPE_INT64}, + {kTypeId, WireFormatLite::CPPTYPE_UINT32}, + {kTypeId, WireFormatLite::CPPTYPE_UINT64}, {kTypeId, WireFormatLite::CPPTYPE_DOUBLE}, {kTypeId, WireFormatLite::CPPTYPE_FLOAT}, {kTypeId, WireFormatLite::CPPTYPE_BOOL}, @@ -566,16 +566,16 @@ absl::StatusOr AsFieldData(Packet packet) { switch (kTypeIds->at(packet.GetTypeId())) { case WireFormatLite::CPPTYPE_INT32: - result.set_int32_value(packet.Get()); + result.set_int32_value(packet.Get()); break; case WireFormatLite::CPPTYPE_INT64: - result.set_int64_value(packet.Get()); + result.set_int64_value(packet.Get()); break; case WireFormatLite::CPPTYPE_UINT32: - result.set_uint32_value(packet.Get()); + result.set_uint32_value(packet.Get()); break; case WireFormatLite::CPPTYPE_UINT64: - result.set_uint64_value(packet.Get()); + result.set_uint64_value(packet.Get()); break; case WireFormatLite::CPPTYPE_DOUBLE: result.set_double_value(packet.Get()); diff --git a/mediapipe/framework/tool/proto_util_lite.cc b/mediapipe/framework/tool/proto_util_lite.cc index a810ce129..745f4a13b 100644 --- a/mediapipe/framework/tool/proto_util_lite.cc +++ b/mediapipe/framework/tool/proto_util_lite.cc @@ -48,11 +48,11 @@ bool IsLengthDelimited(WireFormatLite::WireType wire_type) { } // Reads a single data value for a wire type. -absl::Status ReadFieldValue(uint32 tag, CodedInputStream* in, +absl::Status ReadFieldValue(uint32_t tag, CodedInputStream* in, std::string* result) { WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag); if (IsLengthDelimited(wire_type)) { - uint32 length; + uint32_t length; RET_CHECK_NO_LOG(in->ReadVarint32(&length)); RET_CHECK_NO_LOG(in->ReadString(result, length)); } else { @@ -72,10 +72,10 @@ absl::Status ReadFieldValue(uint32 tag, CodedInputStream* in, absl::Status ReadPackedValues(WireFormatLite::WireType wire_type, CodedInputStream* in, std::vector* field_values) { - uint32 data_size; + uint32_t data_size; RET_CHECK_NO_LOG(in->ReadVarint32(&data_size)); // fake_tag encodes the wire-type for calls to WireFormatLite::SkipField. - uint32 fake_tag = WireFormatLite::MakeTag(1, wire_type); + uint32_t fake_tag = WireFormatLite::MakeTag(1, wire_type); while (data_size > 0) { std::string number; MP_RETURN_IF_ERROR(ReadFieldValue(fake_tag, in, &number)); @@ -88,10 +88,10 @@ absl::Status ReadPackedValues(WireFormatLite::WireType wire_type, // Extracts the data value(s) for one field from a serialized message. // The message with these field values removed is written to |out|. -absl::Status GetFieldValues(uint32 field_id, CodedInputStream* in, +absl::Status GetFieldValues(uint32_t field_id, CodedInputStream* in, CodedOutputStream* out, std::vector* field_values) { - uint32 tag; + uint32_t tag; while ((tag = in->ReadTag()) != 0) { int field_number = WireFormatLite::GetTagFieldNumber(tag); WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag); @@ -112,10 +112,10 @@ absl::Status GetFieldValues(uint32 field_id, CodedInputStream* in, } // Injects the data value(s) for one field into a serialized message. -void SetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type, +void SetFieldValues(uint32_t field_id, WireFormatLite::WireType wire_type, const std::vector& field_values, CodedOutputStream* out) { - uint32 tag = WireFormatLite::MakeTag(field_id, wire_type); + uint32_t tag = WireFormatLite::MakeTag(field_id, wire_type); for (const std::string& field_value : field_values) { out->WriteVarint32(tag); if (IsLengthDelimited(wire_type)) { @@ -125,7 +125,7 @@ void SetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type, } } -FieldAccess::FieldAccess(uint32 field_id, FieldType field_type) +FieldAccess::FieldAccess(uint32_t field_id, FieldType field_type) : field_id_(field_id), field_type_(field_type) {} absl::Status FieldAccess::SetMessage(const std::string& message) { @@ -397,11 +397,11 @@ static absl::Status DeserializeValue(const FieldValue& bytes, case W::TYPE_UINT64: return ReadPrimitive(&input, result); case W::TYPE_INT32: - return ReadPrimitive(&input, result); + return ReadPrimitive(&input, result); case W::TYPE_FIXED64: return ReadPrimitive(&input, result); case W::TYPE_FIXED32: - return ReadPrimitive(&input, result); + return ReadPrimitive(&input, result); case W::TYPE_BOOL: return ReadPrimitive(&input, result); case W::TYPE_BYTES: @@ -413,15 +413,15 @@ static absl::Status DeserializeValue(const FieldValue& bytes, case W::TYPE_MESSAGE: CHECK(false) << "DeserializeValue cannot deserialize a Message."; case W::TYPE_UINT32: - return ReadPrimitive(&input, result); + return ReadPrimitive(&input, result); case W::TYPE_ENUM: return ReadPrimitive(&input, result); case W::TYPE_SFIXED32: - return ReadPrimitive(&input, result); + return ReadPrimitive(&input, result); case W::TYPE_SFIXED64: return ReadPrimitive(&input, result); case W::TYPE_SINT32: - return ReadPrimitive(&input, result); + return ReadPrimitive(&input, result); case W::TYPE_SINT64: return ReadPrimitive(&input, result); } @@ -523,27 +523,27 @@ absl::Status ReadValue(absl::string_view field_bytes, FieldType field_type, switch (field_type) { case WireFormatLite::TYPE_INT32: result->set_int32_value( - ReadValue(field_bytes, &status)); + ReadValue(field_bytes, &status)); break; case WireFormatLite::TYPE_SINT32: - result->set_int32_value( - ReadValue(field_bytes, &status)); + result->set_int32_value(ReadValue( + field_bytes, &status)); break; case WireFormatLite::TYPE_INT64: result->set_int64_value( - ReadValue(field_bytes, &status)); + ReadValue(field_bytes, &status)); break; case WireFormatLite::TYPE_SINT64: - result->set_int64_value( - ReadValue(field_bytes, &status)); + result->set_int64_value(ReadValue( + field_bytes, &status)); break; case WireFormatLite::TYPE_UINT32: - result->set_uint32_value( - ReadValue(field_bytes, &status)); + result->set_uint32_value(ReadValue( + field_bytes, &status)); break; case WireFormatLite::TYPE_UINT64: - result->set_uint64_value( - ReadValue(field_bytes, &status)); + result->set_uint64_value(ReadValue( + field_bytes, &status)); break; case WireFormatLite::TYPE_DOUBLE: result->set_double_value( @@ -559,7 +559,7 @@ absl::Status ReadValue(absl::string_view field_bytes, FieldType field_type, break; case WireFormatLite::TYPE_ENUM: result->set_enum_value( - ReadValue(field_bytes, &status)); + ReadValue(field_bytes, &status)); break; case WireFormatLite::TYPE_STRING: result->set_string_value(std::string(field_bytes)); diff --git a/mediapipe/framework/tool/simulation_clock_test.cc b/mediapipe/framework/tool/simulation_clock_test.cc index 3f2c3615c..c4c76e37e 100644 --- a/mediapipe/framework/tool/simulation_clock_test.cc +++ b/mediapipe/framework/tool/simulation_clock_test.cc @@ -99,17 +99,17 @@ class SimulationClockTest : public ::testing::Test { void SetupRealClock() { clock_ = mediapipe::Clock::RealClock(); } // Return the values of the timestamps of a vector of Packets. - static std::vector TimestampValues( + static std::vector TimestampValues( const std::vector& packets) { - std::vector result; + std::vector result; for (const Packet& p : packets) { result.push_back(p.Timestamp().Value()); } return result; } - static std::vector TimeValues(const std::vector& times) { - std::vector result; + static std::vector TimeValues(const std::vector& times) { + std::vector result; for (const absl::Time& t : times) { result.push_back(absl::ToUnixMicros(t)); } @@ -225,9 +225,9 @@ TEST_F(SimulationClockTest, InFlight) { // Add 10 input packets to the graph, one each 10 ms, starting after 11 ms // of clock time. Timestamps lag clock times by 1 ms. clock_->Sleep(absl::Microseconds(11000)); - for (uint64 ts = 10000; ts <= 100000; ts += 10000) { + for (uint64_t ts = 10000; ts <= 100000; ts += 10000) { MP_EXPECT_OK(graph_.AddPacketToInputStream( - "input_packets_0", MakePacket(ts).At(Timestamp(ts)))); + "input_packets_0", MakePacket(ts).At(Timestamp(ts)))); clock_->Sleep(absl::Microseconds(10000)); } @@ -266,7 +266,7 @@ TEST_F(SimulationClockTest, DestroyClock) { clock_->Sleep(absl::Microseconds(20000)); if (++input_count < 4) { outputs->Index(0).AddPacket( - MakePacket(input_count).At(Timestamp(input_count))); + MakePacket(input_count).At(Timestamp(input_count))); return absl::OkStatus(); } else { return tool::StatusStop(); diff --git a/mediapipe/framework/tool/sink.h b/mediapipe/framework/tool/sink.h index d659115ee..f786e60a7 100644 --- a/mediapipe/framework/tool/sink.h +++ b/mediapipe/framework/tool/sink.h @@ -62,9 +62,9 @@ namespace tool { // Example usage: // CalculatorGraphConfig config = tool::ParseGraphFromFileOrDie("config.txt"); // std::vector packet_dump; -// tool::AddVectorSink("output_samples", &config, &packet_dump, -// /*use_std_function=*/true); -// // Call tool::AddVectorSink() more times if you wish. +// tool::AddVectorSink("output_samples", &config, &packet_dump); +// // Call tool::AddVectorSink() more times if you wish. Note that each stream +// // needs to get its own packet vector. // CalculatorGraph graph; // CHECK_OK(graph.Initialize(config)); // // Set other input side packets. diff --git a/mediapipe/framework/tool/switch_container_test.cc b/mediapipe/framework/tool/switch_container_test.cc index b20979b10..08cc4ab5a 100644 --- a/mediapipe/framework/tool/switch_container_test.cc +++ b/mediapipe/framework/tool/switch_container_test.cc @@ -144,7 +144,7 @@ void RunTestContainer(CalculatorGraphConfig supergraph, if (!send_bounds) { // Send enable == true signal at 5000 us. - const int64 enable_ts = 5000; + const int64_t enable_ts = 5000; MP_EXPECT_OK(graph.AddPacketToInputStream( "enable", MakePacket(true).At(Timestamp(enable_ts)))); MP_ASSERT_OK(graph.WaitUntilIdle()); @@ -152,7 +152,7 @@ void RunTestContainer(CalculatorGraphConfig supergraph, const int packet_count = 10; // Send int value packets at {10K, 20K, 30K, ..., 100K}. - for (uint64 t = 1; t <= packet_count; ++t) { + for (uint64_t t = 1; t <= packet_count; ++t) { if (send_bounds) { MP_EXPECT_OK(graph.AddPacketToInputStream( "enable", MakePacket(true).At(Timestamp(t * 10000)))); @@ -180,7 +180,7 @@ void RunTestContainer(CalculatorGraphConfig supergraph, } // Send int value packets at {110K, 120K, ..., 200K}. - for (uint64 t = 11; t <= packet_count * 2; ++t) { + for (uint64_t t = 11; t <= packet_count * 2; ++t) { if (send_bounds) { MP_EXPECT_OK(graph.AddPacketToInputStream( "enable", MakePacket(false).At(Timestamp(t * 10000)))); diff --git a/mediapipe/framework/tool/template_parser.cc b/mediapipe/framework/tool/template_parser.cc index e26275387..f012ac418 100644 --- a/mediapipe/framework/tool/template_parser.cc +++ b/mediapipe/framework/tool/template_parser.cc @@ -511,7 +511,7 @@ class TemplateParser::Parser::ParserImpl { DO(ConsumeIdentifier(&field_name)); if (allow_field_number_) { - int32 field_number = std::atoi(field_name.c_str()); // NOLINT + int32_t field_number = std::atoi(field_name.c_str()); // NOLINT if (descriptor->IsExtensionNumber(field_number)) { field = reflection->FindKnownExtensionByNumber(field_number); } else if (descriptor->IsReservedNumber(field_number)) { @@ -765,28 +765,28 @@ class TemplateParser::Parser::ParserImpl { switch (field->cpp_type()) { case FieldDescriptor::CPPTYPE_INT32: { - int64 value; + int64_t value; DO(ConsumeSignedInteger(&value, kint32max)); - SET_FIELD(Int32, static_cast(value)); + SET_FIELD(Int32, static_cast(value)); break; } case FieldDescriptor::CPPTYPE_UINT32: { - uint64 value; + uint64_t value; DO(ConsumeUnsignedInteger(&value, kuint32max)); - SET_FIELD(UInt32, static_cast(value)); + SET_FIELD(UInt32, static_cast(value)); break; } case FieldDescriptor::CPPTYPE_INT64: { - int64 value; + int64_t value; DO(ConsumeSignedInteger(&value, kint64max)); SET_FIELD(Int64, value); break; } case FieldDescriptor::CPPTYPE_UINT64: { - uint64 value; + uint64_t value; DO(ConsumeUnsignedInteger(&value, kuint64max)); SET_FIELD(UInt64, value); break; @@ -815,7 +815,7 @@ class TemplateParser::Parser::ParserImpl { case FieldDescriptor::CPPTYPE_BOOL: { if (LookingAtType(io::Tokenizer::TYPE_INTEGER)) { - uint64 value; + uint64_t value; DO(ConsumeUnsignedInteger(&value, 1)); SET_FIELD(Bool, value); } else { @@ -836,7 +836,7 @@ class TemplateParser::Parser::ParserImpl { case FieldDescriptor::CPPTYPE_ENUM: { std::string value; - int64 int_value = kint64max; + int64_t int_value = kint64max; const EnumDescriptor* enum_type = field->enum_type(); const EnumValueDescriptor* enum_value = NULL; @@ -1037,7 +1037,7 @@ class TemplateParser::Parser::ParserImpl { // Consumes a uint64 and saves its value in the value parameter. // Returns false if the token is not of type INTEGER. - bool ConsumeUnsignedInteger(uint64* value, uint64 max_value) { + bool ConsumeUnsignedInteger(uint64_t* value, uint64_t max_value) { if (!LookingAtType(io::Tokenizer::TYPE_INTEGER)) { ReportError("Expected integer, got: " + tokenizer_.current().text); return false; @@ -1058,7 +1058,7 @@ class TemplateParser::Parser::ParserImpl { // we actually may consume an additional token (for the minus sign) in this // method. Returns false if the token is not an integer // (signed or otherwise). - bool ConsumeSignedInteger(int64* value, uint64 max_value) { + bool ConsumeSignedInteger(int64_t* value, uint64_t max_value) { bool negative = false; #ifndef PROTO2_OPENSOURCE if (absl::StartsWith(tokenizer_.current().text, "0x")) { @@ -1075,18 +1075,18 @@ class TemplateParser::Parser::ParserImpl { ++max_value; } - uint64 unsigned_value; + uint64_t unsigned_value; DO(ConsumeUnsignedInteger(&unsigned_value, max_value)); if (negative) { - if ((static_cast(kint64max) + 1) == unsigned_value) { + if ((static_cast(kint64max) + 1) == unsigned_value) { *value = kint64min; } else { - *value = -static_cast(unsigned_value); + *value = -static_cast(unsigned_value); } } else { - *value = static_cast(unsigned_value); + *value = static_cast(unsigned_value); } return true; @@ -1094,7 +1094,7 @@ class TemplateParser::Parser::ParserImpl { // Consumes a uint64 and saves its value in the value parameter. // Accepts decimal numbers only, rejects hex or oct numbers. - bool ConsumeUnsignedDecimalInteger(uint64* value, uint64 max_value) { + bool ConsumeUnsignedDecimalInteger(uint64_t* value, uint64_t max_value) { if (!LookingAtType(io::Tokenizer::TYPE_INTEGER)) { ReportError("Expected integer, got: " + tokenizer_.current().text); return false; @@ -1131,7 +1131,7 @@ class TemplateParser::Parser::ParserImpl { // Therefore, we must check both cases here. if (LookingAtType(io::Tokenizer::TYPE_INTEGER)) { // We have found an integer value for the double. - uint64 integer_value; + uint64_t integer_value; DO(ConsumeUnsignedDecimalInteger(&integer_value, kuint64max)); *value = static_cast(integer_value); diff --git a/mediapipe/framework/tool/test_util.cc b/mediapipe/framework/tool/test_util.cc index e8b02084b..5642941e9 100644 --- a/mediapipe/framework/tool/test_util.cc +++ b/mediapipe/framework/tool/test_util.cc @@ -182,15 +182,16 @@ absl::Status CompareImageFrames(const ImageFrame& image1, case ImageFormat::SRGB: case ImageFormat::SRGBA: case ImageFormat::LAB8: - return CompareDiff(image1, image2, max_color_diff, max_alpha_diff, - max_avg_diff, diff_image); + return CompareDiff(image1, image2, max_color_diff, + max_alpha_diff, max_avg_diff, diff_image); case ImageFormat::GRAY16: case ImageFormat::SRGB48: case ImageFormat::SRGBA64: - return CompareDiff(image1, image2, max_color_diff, max_alpha_diff, - max_avg_diff, diff_image); + return CompareDiff(image1, image2, max_color_diff, + max_alpha_diff, max_avg_diff, diff_image); case ImageFormat::VEC32F1: case ImageFormat::VEC32F2: + case ImageFormat::VEC32F4: return CompareDiff(image1, image2, max_color_diff, max_alpha_diff, max_avg_diff, diff_image); default: @@ -350,17 +351,17 @@ std::unique_ptr GenerateLuminanceImage( auto luminance_image = absl::make_unique(original_image.Format(), width, height, ImageFrame::kGlDefaultAlignmentBoundary); - const uint8* pixel1 = original_image.PixelData(); - uint8* pixel2 = luminance_image->MutablePixelData(); + const uint8_t* pixel1 = original_image.PixelData(); + uint8_t* pixel2 = luminance_image->MutablePixelData(); const int width_padding1 = original_image.WidthStep() - width * channels; const int width_padding2 = luminance_image->WidthStep() - width * channels; for (int row = 0; row < height; ++row) { for (int col = 0; col < width; ++col) { float luminance = pixel1[0] * 0.2125f + pixel1[1] * 0.7154f + pixel1[2] * 0.0721f; - uint8 luminance_byte = 255; + uint8_t luminance_byte = 255; if (luminance < 255.0f) { - luminance_byte = static_cast(luminance); + luminance_byte = static_cast(luminance); } pixel2[0] = luminance_byte; pixel2[1] = luminance_byte; diff --git a/mediapipe/framework/tool/validate_name.cc b/mediapipe/framework/tool/validate_name.cc index bea857dd4..ad66b43d8 100644 --- a/mediapipe/framework/tool/validate_name.cc +++ b/mediapipe/framework/tool/validate_name.cc @@ -185,7 +185,7 @@ absl::Status ParseTagIndexName(const std::string& tag_index_name, tag_status = ValidateTag(v[0]); number_status = ValidateNumber(v[1]); if (number_status.ok()) { - int64 index64; + int64_t index64; RET_CHECK(absl::SimpleAtoi(v[1], &index64)); RET_CHECK_LE(index64, internal::kMaxCollectionItemId); the_index = index64; @@ -227,7 +227,7 @@ absl::Status ParseTagIndex(const std::string& tag_index, std::string* tag, } number_status = ValidateNumber(v[1]); if (number_status.ok()) { - int64 index64; + int64_t index64; RET_CHECK(absl::SimpleAtoi(v[1], &index64)); RET_CHECK_LE(index64, internal::kMaxCollectionItemId); the_index = index64; diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 2f7f7ec33..c785e5624 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -15,7 +15,7 @@ load("@bazel_skylib//lib:selects.bzl", "selects") load("//mediapipe/gpu:metal.bzl", "metal_library") load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library", "mediapipe_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test") load("//mediapipe/framework:more_selects.bzl", "more_selects") @@ -423,12 +423,15 @@ cc_library( cc_library( name = "gpu_buffer_storage_image_frame", + srcs = ["gpu_buffer_storage_image_frame.cc"], hdrs = ["gpu_buffer_storage_image_frame.h"], visibility = ["//visibility:public"], deps = [ + ":frame_buffer_view", ":gpu_buffer_format", ":gpu_buffer_storage", ":image_frame_view", + "//mediapipe/framework/formats:frame_buffer", "//mediapipe/framework/formats:image_frame", ], ) @@ -555,7 +558,10 @@ mediapipe_proto_library( name = "gl_context_options_proto", srcs = ["gl_context_options.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) # This is a hack needed to work around some issues with strict hdrs_check. @@ -929,6 +935,7 @@ mediapipe_proto_library( srcs = ["gl_animation_overlay_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) @@ -939,6 +946,7 @@ mediapipe_proto_library( visibility = ["//visibility:public"], deps = [ ":scale_mode_proto", + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) @@ -982,26 +990,16 @@ cc_library( alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "gl_surface_sink_calculator_proto", srcs = ["gl_surface_sink_calculator.proto"], deps = [ ":scale_mode_proto", + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -mediapipe_cc_proto_library( - name = "gl_surface_sink_calculator_cc_proto", - srcs = ["gl_surface_sink_calculator.proto"], - cc_deps = [ - ":scale_mode_cc_proto", - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":gl_surface_sink_calculator_proto"], -) - ### Metal calculators metal_library( @@ -1017,21 +1015,14 @@ objc_library( deps = [":simple_shaders_mtl"], ) -proto_library( +mediapipe_proto_library( name = "copy_calculator_proto", srcs = ["copy_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "copy_calculator_cc_proto", - srcs = ["copy_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", ], - visibility = ["//visibility:public"], - deps = [":copy_calculator_proto"], ) objc_library( diff --git a/mediapipe/gpu/frame_buffer_view.h b/mediapipe/gpu/frame_buffer_view.h new file mode 100644 index 000000000..76d773a5e --- /dev/null +++ b/mediapipe/gpu/frame_buffer_view.h @@ -0,0 +1,37 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_GPU_FRAME_BUFFER_VIEW_H_ +#define MEDIAPIPE_GPU_FRAME_BUFFER_VIEW_H_ + +#include "mediapipe/framework/formats/frame_buffer.h" +#include "mediapipe/gpu/gpu_buffer_storage.h" + +namespace mediapipe { +namespace internal { + +template <> +class ViewProvider { + public: + virtual ~ViewProvider() = default; + virtual std::shared_ptr GetReadView( + types) const = 0; + virtual std::shared_ptr GetWriteView(types) = 0; +}; + +} // namespace internal +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_FRAME_BUFFER_VIEW_H_ diff --git a/mediapipe/gpu/gpu_buffer_format.cc b/mediapipe/gpu/gpu_buffer_format.cc index 8e2e3858e..a820f04d6 100644 --- a/mediapipe/gpu/gpu_buffer_format.cc +++ b/mediapipe/gpu/gpu_buffer_format.cc @@ -204,6 +204,8 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) { return ImageFormat::SRGB; case GpuBufferFormat::kTwoComponentFloat32: return ImageFormat::VEC32F2; + case GpuBufferFormat::kRGBAFloat128: + return ImageFormat::VEC32F4; case GpuBufferFormat::kRGBA32: // TODO: this likely maps to ImageFormat::SRGBA case GpuBufferFormat::kGrayHalf16: @@ -211,7 +213,6 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) { case GpuBufferFormat::kTwoComponent8: case GpuBufferFormat::kTwoComponentHalf16: case GpuBufferFormat::kRGBAHalf64: - case GpuBufferFormat::kRGBAFloat128: case GpuBufferFormat::kNV12: case GpuBufferFormat::kNV21: case GpuBufferFormat::kI420: @@ -232,6 +233,8 @@ GpuBufferFormat GpuBufferFormatForImageFormat(ImageFormat::Format format) { return GpuBufferFormat::kGrayFloat32; case ImageFormat::VEC32F2: return GpuBufferFormat::kTwoComponentFloat32; + case ImageFormat::VEC32F4: + return GpuBufferFormat::kRGBAFloat128; case ImageFormat::GRAY8: return GpuBufferFormat::kOneComponent8; case ImageFormat::YCBCR420P: diff --git a/mediapipe/gpu/gpu_buffer_storage_image_frame.cc b/mediapipe/gpu/gpu_buffer_storage_image_frame.cc new file mode 100644 index 000000000..1cd661d37 --- /dev/null +++ b/mediapipe/gpu/gpu_buffer_storage_image_frame.cc @@ -0,0 +1,71 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" + +#include +#include + +#include "mediapipe/framework/formats/frame_buffer.h" +#include "mediapipe/framework/formats/image_frame.h" + +namespace mediapipe { + +namespace { + +FrameBuffer::Format FrameBufferFormatForImageFrameFormat( + ImageFormat::Format format) { + switch (format) { + case ImageFormat::SRGB: + return FrameBuffer::Format::kRGB; + case ImageFormat::SRGBA: + return FrameBuffer::Format::kRGBA; + case ImageFormat::GRAY8: + return FrameBuffer::Format::kGRAY; + default: + return FrameBuffer::Format::kUNKNOWN; + } +} + +std::shared_ptr ImageFrameToFrameBuffer( + std::shared_ptr image_frame) { + FrameBuffer::Format format = + FrameBufferFormatForImageFrameFormat(image_frame->Format()); + CHECK(format != FrameBuffer::Format::kUNKNOWN) + << "Invalid format. Only SRGB, SRGBA and GRAY8 are supported."; + const FrameBuffer::Dimension dimension{/*width=*/image_frame->Width(), + /*height=*/image_frame->Height()}; + const FrameBuffer::Stride stride{ + /*row_stride_bytes=*/image_frame->WidthStep(), + /*pixel_stride_bytes=*/image_frame->ByteDepth() * + image_frame->NumberOfChannels()}; + const std::vector planes{ + {image_frame->MutablePixelData(), stride}}; + return std::make_shared(planes, dimension, format); +} + +} // namespace + +std::shared_ptr GpuBufferStorageImageFrame::GetReadView( + internal::types) const { + return ImageFrameToFrameBuffer(image_frame_); +} + +std::shared_ptr GpuBufferStorageImageFrame::GetWriteView( + internal::types) { + return ImageFrameToFrameBuffer(image_frame_); +} + +} // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_storage_image_frame.h b/mediapipe/gpu/gpu_buffer_storage_image_frame.h index ab547b9ea..542791f98 100644 --- a/mediapipe/gpu/gpu_buffer_storage_image_frame.h +++ b/mediapipe/gpu/gpu_buffer_storage_image_frame.h @@ -1,9 +1,26 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_IMAGE_FRAME_H_ #define MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_IMAGE_FRAME_H_ #include +#include "mediapipe/framework/formats/frame_buffer.h" #include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/gpu/frame_buffer_view.h" #include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/gpu_buffer_storage.h" #include "mediapipe/gpu/image_frame_view.h" @@ -13,7 +30,8 @@ namespace mediapipe { // Implements support for ImageFrame as a backing storage of GpuBuffer. class GpuBufferStorageImageFrame : public internal::GpuBufferStorageImpl< - GpuBufferStorageImageFrame, internal::ViewProvider> { + GpuBufferStorageImageFrame, internal::ViewProvider, + internal::ViewProvider> { public: explicit GpuBufferStorageImageFrame(std::shared_ptr image_frame) : image_frame_(image_frame) {} @@ -36,6 +54,10 @@ class GpuBufferStorageImageFrame internal::types) override { return image_frame_; } + std::shared_ptr GetReadView( + internal::types) const override; + std::shared_ptr GetWriteView( + internal::types) override; private: std::shared_ptr image_frame_; diff --git a/mediapipe/gpu/gpu_buffer_storage_yuv_image.cc b/mediapipe/gpu/gpu_buffer_storage_yuv_image.cc new file mode 100644 index 000000000..c7acd1340 --- /dev/null +++ b/mediapipe/gpu/gpu_buffer_storage_yuv_image.cc @@ -0,0 +1,228 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/gpu/gpu_buffer_storage_yuv_image.h" + +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "libyuv/video_common.h" +#include "mediapipe/framework/formats/frame_buffer.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/yuv_image.h" +#include "mediapipe/gpu/gpu_buffer_format.h" +#include "mediapipe/util/frame_buffer/frame_buffer_util.h" + +namespace mediapipe { + +namespace { + +// Default data alignment. +constexpr int kDefaultDataAligment = 16; + +GpuBufferFormat GpuBufferFormatForFourCC(libyuv::FourCC fourcc) { + switch (fourcc) { + case libyuv::FOURCC_NV12: + return GpuBufferFormat::kNV12; + case libyuv::FOURCC_NV21: + return GpuBufferFormat::kNV21; + case libyuv::FOURCC_YV12: + return GpuBufferFormat::kYV12; + case libyuv::FOURCC_I420: + return GpuBufferFormat::kI420; + default: + return GpuBufferFormat::kUnknown; + } +} + +libyuv::FourCC FourCCForGpuBufferFormat(GpuBufferFormat format) { + switch (format) { + case GpuBufferFormat::kNV12: + return libyuv::FOURCC_NV12; + case GpuBufferFormat::kNV21: + return libyuv::FOURCC_NV21; + case GpuBufferFormat::kYV12: + return libyuv::FOURCC_YV12; + case GpuBufferFormat::kI420: + return libyuv::FOURCC_I420; + default: + return libyuv::FOURCC_ANY; + } +} + +FrameBuffer::Format FrameBufferFormatForFourCC(libyuv::FourCC fourcc) { + switch (fourcc) { + case libyuv::FOURCC_NV12: + return FrameBuffer::Format::kNV12; + case libyuv::FOURCC_NV21: + return FrameBuffer::Format::kNV21; + case libyuv::FOURCC_YV12: + return FrameBuffer::Format::kYV12; + case libyuv::FOURCC_I420: + return FrameBuffer::Format::kYV21; + default: + return FrameBuffer::Format::kUNKNOWN; + } +} + +// Converts a YuvImage into a FrameBuffer that shares the same data buffers. +std::shared_ptr YuvImageToFrameBuffer( + std::shared_ptr yuv_image) { + FrameBuffer::Format format = FrameBufferFormatForFourCC(yuv_image->fourcc()); + FrameBuffer::Dimension dimension{/*width=*/yuv_image->width(), + /*height=*/yuv_image->height()}; + std::vector planes; + CHECK(yuv_image->mutable_data(0) != nullptr && yuv_image->stride(0) > 0) + << "Invalid YuvImage. Expected plane at index 0 to be non-null and have " + "stride > 0."; + planes.emplace_back( + yuv_image->mutable_data(0), + FrameBuffer::Stride{/*row_stride_bytes=*/yuv_image->stride(0), + /*pixel_stride_bytes=*/1}); + switch (format) { + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: { + CHECK(yuv_image->mutable_data(1) != nullptr && yuv_image->stride(1) > 0) + << "Invalid YuvImage. Expected plane at index 1 to be non-null and " + "have stride > 0."; + planes.emplace_back( + yuv_image->mutable_data(1), + FrameBuffer::Stride{/*row_stride_bytes=*/yuv_image->stride(1), + /*pixel_stride_bytes=*/2}); + break; + } + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: { + CHECK(yuv_image->mutable_data(1) != nullptr && yuv_image->stride(1) > 0 && + yuv_image->mutable_data(2) != nullptr && yuv_image->stride(2) > 0) + << "Invalid YuvImage. Expected planes at indices 1 and 2 to be " + "non-null and have stride > 0."; + planes.emplace_back( + yuv_image->mutable_data(1), + FrameBuffer::Stride{/*row_stride_bytes=*/yuv_image->stride(1), + /*pixel_stride_bytes=*/1}); + planes.emplace_back( + yuv_image->mutable_data(2), + FrameBuffer::Stride{/*row_stride_bytes=*/yuv_image->stride(2), + /*pixel_stride_bytes=*/1}); + break; + } + default: + LOG(FATAL) + << "Invalid format. Only FOURCC_NV12, FOURCC_NV21, FOURCC_YV12 and " + "FOURCC_I420 are supported."; + } + return std::make_shared(planes, dimension, format); +} + +// Converts a YUVImage into an ImageFrame with ImageFormat::SRGB format. +// Note that this requires YUV -> RGB conversion. +std::shared_ptr YuvImageToImageFrame( + std::shared_ptr yuv_image) { + auto yuv_buffer = YuvImageToFrameBuffer(yuv_image); + // Allocate the RGB ImageFrame to return. + auto image_frame = std::make_shared( + ImageFormat::SRGB, yuv_buffer->dimension().width, + yuv_buffer->dimension().height); + // Wrap it into a FrameBuffer + std::vector planes{ + {image_frame->MutablePixelData(), + {/*row_stride_bytes=*/image_frame->WidthStep(), + /*pixel_stride_bytes=*/image_frame->NumberOfChannels() * + image_frame->ChannelSize()}}}; + auto rgb_buffer = + FrameBuffer(planes, yuv_buffer->dimension(), FrameBuffer::Format::kRGB); + // Convert. + CHECK_OK(frame_buffer::Convert(*yuv_buffer, &rgb_buffer)); + return image_frame; +} + +} // namespace + +GpuBufferStorageYuvImage::GpuBufferStorageYuvImage( + std::shared_ptr yuv_image) { + CHECK(GpuBufferFormatForFourCC(yuv_image->fourcc()) != + GpuBufferFormat::kUnknown) + << "Invalid format. Only FOURCC_NV12, FOURCC_NV21, FOURCC_YV12 and " + "FOURCC_I420 are supported."; + yuv_image_ = yuv_image; +} + +GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height, + GpuBufferFormat format) { + libyuv::FourCC fourcc = FourCCForGpuBufferFormat(format); + int y_stride = std::ceil(1.0f * width / kDefaultDataAligment); + auto y_data = std::make_unique(y_stride * height); + switch (fourcc) { + case libyuv::FOURCC_NV12: + case libyuv::FOURCC_NV21: { + // Interleaved U/V planes, 2x2 downsampling. + int uv_width = 2 * std::ceil(0.5f * width); + int uv_height = std::ceil(0.5f * height); + int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment); + auto uv_data = std::make_unique(uv_stride * uv_height); + yuv_image_ = std::make_shared( + fourcc, std::move(y_data), y_stride, std::move(uv_data), uv_stride, + nullptr, 0, width, height); + break; + } + case libyuv::FOURCC_YV12: + case libyuv::FOURCC_I420: { + // Non-interleaved U/V planes, 2x2 downsampling. + int uv_width = std::ceil(0.5f * width); + int uv_height = std::ceil(0.5f * height); + int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment); + auto u_data = std::make_unique(uv_stride * uv_height); + auto v_data = std::make_unique(uv_stride * uv_height); + yuv_image_ = std::make_shared( + fourcc, std::move(y_data), y_stride, std::move(u_data), uv_stride, + std::move(v_data), uv_stride, width, height); + break; + } + default: + LOG(FATAL) + << "Invalid format. Only kNV12, kNV21, kYV12 and kYV21 are supported"; + } +} + +GpuBufferFormat GpuBufferStorageYuvImage::format() const { + return GpuBufferFormatForFourCC(yuv_image_->fourcc()); +} + +std::shared_ptr GpuBufferStorageYuvImage::GetReadView( + internal::types) const { + return YuvImageToFrameBuffer(yuv_image_); +} + +std::shared_ptr GpuBufferStorageYuvImage::GetWriteView( + internal::types) { + return YuvImageToFrameBuffer(yuv_image_); +} + +std::shared_ptr GpuBufferStorageYuvImage::GetReadView( + internal::types) const { + return YuvImageToImageFrame(yuv_image_); +} + +std::shared_ptr GpuBufferStorageYuvImage::GetWriteView( + internal::types) { + // Not supported on purpose: writes into the resulting ImageFrame cannot + // easily be ported back to the original YUV image. + LOG(FATAL) << "GetWriteView is not supported."; +} +} // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_storage_yuv_image.h b/mediapipe/gpu/gpu_buffer_storage_yuv_image.h new file mode 100644 index 000000000..6b34f4948 --- /dev/null +++ b/mediapipe/gpu/gpu_buffer_storage_yuv_image.h @@ -0,0 +1,84 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mediapipe/framework/formats/frame_buffer.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/yuv_image.h" +#include "mediapipe/gpu/frame_buffer_view.h" +#include "mediapipe/gpu/gpu_buffer_format.h" +#include "mediapipe/gpu/gpu_buffer_storage.h" +#include "mediapipe/gpu/image_frame_view.h" + +#ifndef MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_YUV_IMAGE_H_ +#define MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_YUV_IMAGE_H_ + +namespace mediapipe { + +namespace internal { + +template <> +class ViewProvider { + public: + virtual ~ViewProvider() = default; + virtual std::shared_ptr GetReadView( + types) const = 0; + virtual std::shared_ptr GetWriteView(types) = 0; +}; + +} // namespace internal + +// TODO: add support for I444. +class GpuBufferStorageYuvImage + : public internal::GpuBufferStorageImpl< + GpuBufferStorageYuvImage, internal::ViewProvider, + internal::ViewProvider, + internal::ViewProvider> { + public: + // Constructor from an existing YUVImage with FOURCC_NV12, FOURCC_NV21, + // FOURCC_YV12 or FOURCC_I420 format. + explicit GpuBufferStorageYuvImage(std::shared_ptr yuv_image); + // Constructor. Supported formats are kNV12, kNV21, kYV12 and kI420. + // Stride is set by default so that row boundaries align to 16 bytes. + GpuBufferStorageYuvImage(int width, int height, GpuBufferFormat format); + + int width() const override { return yuv_image_->width(); } + int height() const override { return yuv_image_->height(); } + GpuBufferFormat format() const override; + + std::shared_ptr GetReadView( + internal::types) const override { + return yuv_image_; + } + std::shared_ptr GetWriteView(internal::types) override { + return yuv_image_; + } + + std::shared_ptr GetReadView( + internal::types) const override; + std::shared_ptr GetWriteView( + internal::types) override; + std::shared_ptr GetReadView( + internal::types) const override; + std::shared_ptr GetWriteView( + internal::types) override; + + private: + std::shared_ptr yuv_image_; +}; +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_YUV_IMAGE_H_ diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index 49e9cf22a..f542f0bb2 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -119,7 +119,7 @@ GpuResources::~GpuResources() { extern const GraphService kGpuService; absl::Status GpuResources::PrepareGpuNode(CalculatorNode* node) { - CHECK(ContainsKey(node->Contract().ServiceRequests(), kGpuService.key)); + CHECK(node->Contract().ServiceRequests().contains(kGpuService.key)); std::string node_id = node->GetCalculatorState().NodeName(); std::string node_type = node->GetCalculatorState().CalculatorType(); std::string context_key; diff --git a/mediapipe/graphs/iris_tracking/calculators/BUILD b/mediapipe/graphs/iris_tracking/calculators/BUILD index f5124b464..9ddce7f36 100644 --- a/mediapipe/graphs/iris_tracking/calculators/BUILD +++ b/mediapipe/graphs/iris_tracking/calculators/BUILD @@ -12,33 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) -proto_library( +mediapipe_proto_library( name = "iris_to_render_data_calculator_proto", srcs = ["iris_to_render_data_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/util:color_proto", "//mediapipe/util:render_data_proto", ], ) -mediapipe_cc_proto_library( - name = "iris_to_render_data_calculator_cc_proto", - srcs = ["iris_to_render_data_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/util:color_cc_proto", - "//mediapipe/util:render_data_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":iris_to_render_data_calculator_proto"], -) - cc_library( name = "iris_to_render_data_calculator", srcs = ["iris_to_render_data_calculator.cc"], @@ -56,25 +45,16 @@ cc_library( alwayslink = 1, ) -proto_library( +mediapipe_proto_library( name = "iris_to_depth_calculator_proto", srcs = ["iris_to_depth_calculator.proto"], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], ) -mediapipe_cc_proto_library( - name = "iris_to_depth_calculator_cc_proto", - srcs = ["iris_to_depth_calculator.proto"], - cc_deps = [ - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":iris_to_depth_calculator_proto"], -) - cc_library( name = "iris_to_depth_calculator", srcs = ["iris_to_depth_calculator.cc"], diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc index 23bd553af..d565187d9 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc @@ -578,7 +578,7 @@ mediapipe::GpuResources* Graph::GetGpuResources() const { } #endif // !MEDIAPIPE_DISABLE_GPU -absl::Status Graph::SetParentGlContext(int64 java_gl_context) { +absl::Status Graph::SetParentGlContext(int64_t java_gl_context) { #if MEDIAPIPE_DISABLE_GPU LOG(FATAL) << "GPU support has been disabled in this build!"; #else diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc index 46ea1ce41..f7430e6e8 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc @@ -132,7 +132,8 @@ CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width, // code might expect to be able to overwrite the buffer after creating an // ImageFrame from it. image_frame->CopyPixelData( - format, width, height, width_step, static_cast(buffer_data), + format, width, height, width_step, + static_cast(buffer_data), mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); return image_frame; diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc index d5bd773f3..cc273bca4 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc @@ -65,12 +65,12 @@ bool CopyImageDataToByteBuffer(JNIEnv* env, const mediapipe::ImageFrame& image, switch (image.ByteDepth()) { case 1: { - uint8* data = static_cast(buffer_data); + uint8_t* data = static_cast(buffer_data); image.CopyToBuffer(data, expected_buffer_size); break; } case 2: { - uint16* data = static_cast(buffer_data); + uint16_t* data = static_cast(buffer_data); image.CopyToBuffer(data, expected_buffer_size); break; } @@ -503,8 +503,8 @@ JNIEXPORT jbyteArray JNICALL PACKET_GETTER_METHOD(nativeGetAudioData)( int offset = 0; for (int sample = 0; sample < num_samples; ++sample) { for (int channel = 0; channel < num_channels; ++channel) { - int16 value = - static_cast(audio_mat(channel, sample) * kMultiplier); + int16_t value = + static_cast(audio_mat(channel, sample) * kMultiplier); // The java and native has the same byte order, by default is little // Endian, we can safely copy data directly, we have tests to cover // this. diff --git a/mediapipe/model_maker/python/core/data/dataset.py b/mediapipe/model_maker/python/core/data/dataset.py index 113969384..3b4182c14 100644 --- a/mediapipe/model_maker/python/core/data/dataset.py +++ b/mediapipe/model_maker/python/core/data/dataset.py @@ -84,7 +84,7 @@ class Dataset(object): create randomness during model training. preprocess: A function taking three arguments in order, feature, label and boolean is_training. - drop_remainder: boolean, whether the finaly batch drops remainder. + drop_remainder: boolean, whether the finally batch drops remainder. Returns: A TF dataset ready to be consumed by Keras model. diff --git a/mediapipe/model_maker/python/core/hyperparameters.py b/mediapipe/model_maker/python/core/hyperparameters.py index 5cff30930..e6848e0de 100644 --- a/mediapipe/model_maker/python/core/hyperparameters.py +++ b/mediapipe/model_maker/python/core/hyperparameters.py @@ -32,7 +32,7 @@ class BaseHParams: epochs: Number of training iterations over the dataset. steps_per_epoch: An optional integer indicate the number of training steps per epoch. If not set, the training pipeline calculates the default steps - per epoch as the training dataset size devided by batch size. + per epoch as the training dataset size divided by batch size. shuffle: True if the dataset is shuffled before training. export_dir: The location of the model checkpoint files. distribution_strategy: A string specifying which Distribution Strategy to diff --git a/mediapipe/model_maker/python/core/utils/file_util.py b/mediapipe/model_maker/python/core/utils/file_util.py index 7871d90cb..221df94fd 100644 --- a/mediapipe/model_maker/python/core/utils/file_util.py +++ b/mediapipe/model_maker/python/core/utils/file_util.py @@ -94,4 +94,6 @@ class DownloadedFiles: pathlib.Path.mkdir(absolute_path.parent, parents=True, exist_ok=True) with open(absolute_path, 'wb') as f: f.write(r.content) + else: + print(f'Using existing files at {absolute_path}') return str(absolute_path) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index 77ed2e016..e96421593 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -21,7 +21,7 @@ package( default_visibility = ["//mediapipe:__subpackages__"], ) -# TODO: Remove the unncessary test data once the demo data are moved to an open-sourced +# TODO: Remove the unnecessary test data once the demo data are moved to an open-sourced # directory. filegroup( name = "testdata", diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index ad2f211f5..11b4f9759 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -15,7 +15,6 @@ import io import os import tempfile -import unittest from unittest import mock as unittest_mock import zipfile @@ -32,7 +31,6 @@ _TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdat tf.keras.backend.experimental.enable_tf_random_generator() -@unittest.skip('b/273818271') class GestureRecognizerTest(tf.test.TestCase): def _load_data(self): @@ -47,9 +45,6 @@ class GestureRecognizerTest(tf.test.TestCase): def setUp(self): super().setUp() tf.keras.utils.set_random_seed(87654321) - all_data = self._load_data() - # Splits data, 90% data for training, 10% for validation - self._train_data, self._validation_data = all_data.split(0.9) # Mock tempfile.gettempdir() to be unique for each test to avoid race # condition when downloading model since these tests may run in parallel. mock_gettempdir = unittest_mock.patch.object( @@ -60,6 +55,10 @@ class GestureRecognizerTest(tf.test.TestCase): ) self.mock_gettempdir = mock_gettempdir.start() self.addCleanup(mock_gettempdir.stop) + # Load dataset used by tests + all_data = self._load_data() + # Splits data, 90% data for training, 10% for validation + self._train_data, self._validation_data = all_data.split(0.9) def test_gesture_recognizer_model(self): mo = gesture_recognizer.ModelOptions() @@ -74,7 +73,6 @@ class GestureRecognizerTest(tf.test.TestCase): self._test_accuracy(model) - @unittest.skip('b/273818271') @unittest_mock.patch.object( tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense ) diff --git a/mediapipe/model_maker/python/vision/object_detector/__init__.py b/mediapipe/model_maker/python/vision/object_detector/__init__.py index 6b60760d4..ef7a92010 100644 --- a/mediapipe/model_maker/python/vision/object_detector/__init__.py +++ b/mediapipe/model_maker/python/vision/object_detector/__init__.py @@ -28,3 +28,14 @@ HParams = hyperparameters.HParams QATHParams = hyperparameters.QATHParams Dataset = dataset.Dataset ObjectDetectorOptions = object_detector_options.ObjectDetectorOptions + +# Remove duplicated and non-public API +del dataset +del dataset_util # pylint: disable=undefined-variable +del hyperparameters +del model # pylint: disable=undefined-variable +del model_options +del model_spec +del object_detector +del object_detector_options +del preprocessor # pylint: disable=undefined-variable diff --git a/mediapipe/model_maker/python/vision/object_detector/dataset.py b/mediapipe/model_maker/python/vision/object_detector/dataset.py index 741263129..f260c82c5 100644 --- a/mediapipe/model_maker/python/vision/object_detector/dataset.py +++ b/mediapipe/model_maker/python/vision/object_detector/dataset.py @@ -155,8 +155,8 @@ class Dataset(classification_dataset.ClassificationDataset): ObjectDetectorDataset object. """ # Get TFRecord Files - tfrecord_file_patten = cache_prefix + '*.tfrecord' - matched_files = tf.io.gfile.glob(tfrecord_file_patten) + tfrecord_file_pattern = cache_prefix + '*.tfrecord' + matched_files = tf.io.gfile.glob(tfrecord_file_pattern) if not matched_files: raise ValueError('TFRecord files are empty.') diff --git a/mediapipe/model_maker/python/vision/object_detector/dataset_util.py b/mediapipe/model_maker/python/vision/object_detector/dataset_util.py index 020c94501..440d45945 100644 --- a/mediapipe/model_maker/python/vision/object_detector/dataset_util.py +++ b/mediapipe/model_maker/python/vision/object_detector/dataset_util.py @@ -345,7 +345,7 @@ def _coco_annotations_to_lists( Args: bbox_annotations: List of dicts with keys ['bbox', 'category_id'] image_height: Height of image - image_width: Width of iamge + image_width: Width of image Returns: (data, num_annotations_skipped) tuple where data contains the keys: diff --git a/mediapipe/model_maker/python/vision/object_detector/hyperparameters.py b/mediapipe/model_maker/python/vision/object_detector/hyperparameters.py index 435dd9745..241104cf8 100644 --- a/mediapipe/model_maker/python/vision/object_detector/hyperparameters.py +++ b/mediapipe/model_maker/python/vision/object_detector/hyperparameters.py @@ -29,9 +29,9 @@ class HParams(hp.BaseHParams): epochs: Number of training iterations over the dataset. do_fine_tuning: If true, the base module is trained together with the classification layer on top. - learning_rate_boundaries: List of epoch boundaries where - learning_rate_boundaries[i] is the epoch where the learning rate will - decay to learning_rate * learning_rate_decay_multipliers[i]. + learning_rate_epoch_boundaries: List of epoch boundaries where + learning_rate_epoch_boundaries[i] is the epoch where the learning rate + will decay to learning_rate * learning_rate_decay_multipliers[i]. learning_rate_decay_multipliers: List of learning rate multipliers which calculates the learning rate at the ith boundary as learning_rate * learning_rate_decay_multipliers[i]. @@ -43,35 +43,39 @@ class HParams(hp.BaseHParams): epochs: int = 10 # Parameters for learning rate decay - learning_rate_boundaries: List[int] = dataclasses.field( - default_factory=lambda: [5, 8] + learning_rate_epoch_boundaries: List[int] = dataclasses.field( + default_factory=lambda: [] ) learning_rate_decay_multipliers: List[float] = dataclasses.field( - default_factory=lambda: [0.1, 0.01] + default_factory=lambda: [] ) def __post_init__(self): # Validate stepwise learning rate parameters - lr_boundary_len = len(self.learning_rate_boundaries) + lr_boundary_len = len(self.learning_rate_epoch_boundaries) lr_decay_multipliers_len = len(self.learning_rate_decay_multipliers) if lr_boundary_len != lr_decay_multipliers_len: raise ValueError( - "Length of learning_rate_boundaries and ", + "Length of learning_rate_epoch_boundaries and ", "learning_rate_decay_multipliers do not match: ", f"{lr_boundary_len}!={lr_decay_multipliers_len}", ) - # Validate learning_rate_boundaries - if sorted(self.learning_rate_boundaries) != self.learning_rate_boundaries: - raise ValueError( - "learning_rate_boundaries is not in ascending order: ", - self.learning_rate_boundaries, - ) + # Validate learning_rate_epoch_boundaries if ( - self.learning_rate_boundaries - and self.learning_rate_boundaries[-1] > self.epochs + sorted(self.learning_rate_epoch_boundaries) + != self.learning_rate_epoch_boundaries ): raise ValueError( - "Values in learning_rate_boundaries cannot be greater ", "than epochs" + "learning_rate_epoch_boundaries is not in ascending order: ", + self.learning_rate_epoch_boundaries, + ) + if ( + self.learning_rate_epoch_boundaries + and self.learning_rate_epoch_boundaries[-1] > self.epochs + ): + raise ValueError( + "Values in learning_rate_epoch_boundaries cannot be greater ", + "than epochs", ) diff --git a/mediapipe/model_maker/python/vision/object_detector/object_detector.py b/mediapipe/model_maker/python/vision/object_detector/object_detector.py index a6f678cd9..2d1d92ef3 100644 --- a/mediapipe/model_maker/python/vision/object_detector/object_detector.py +++ b/mediapipe/model_maker/python/vision/object_detector/object_detector.py @@ -57,7 +57,6 @@ class ObjectDetector(classifier.Classifier): self._preprocessor = preprocessor.Preprocessor(model_spec) self._hparams = hparams self._model_options = model_options - self._optimizer = self._create_optimizer() self._is_qat = False @classmethod @@ -104,6 +103,13 @@ class ObjectDetector(classifier.Classifier): train_data: Training data. validation_data: Validation data. """ + self._optimizer = self._create_optimizer( + model_util.get_steps_per_epoch( + steps_per_epoch=self._hparams.steps_per_epoch, + batch_size=self._hparams.batch_size, + train_data=train_data, + ) + ) self._create_model() self._train_model( train_data, validation_data, preprocessor=self._preprocessor @@ -333,21 +339,34 @@ class ObjectDetector(classifier.Classifier): with open(metadata_file, 'w') as f: f.write(metadata_json) - def _create_optimizer(self) -> tf.keras.optimizers.Optimizer: + def _create_optimizer( + self, steps_per_epoch: int + ) -> tf.keras.optimizers.Optimizer: """Creates an optimizer with learning rate schedule for regular training. Uses Keras PiecewiseConstantDecay schedule by default. + Args: + steps_per_epoch: Steps per epoch to calculate the step boundaries from the + learning_rate_epoch_boundaries + Returns: A tf.keras.optimizer.Optimizer for model training. """ init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256 - lr_values = [init_lr] + [ - init_lr * m for m in self._hparams.learning_rate_decay_multipliers - ] - learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay( - self._hparams.learning_rate_boundaries, lr_values - ) + if self._hparams.learning_rate_epoch_boundaries: + lr_values = [init_lr] + [ + init_lr * m for m in self._hparams.learning_rate_decay_multipliers + ] + lr_step_boundaries = [ + steps_per_epoch * epoch_boundary + for epoch_boundary in self._hparams.learning_rate_epoch_boundaries + ] + learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay( + lr_step_boundaries, lr_values + ) + else: + learning_rate = init_lr return tf.keras.optimizers.experimental.SGD( - learning_rate=learning_rate_fn, momentum=0.9 + learning_rate=learning_rate, momentum=0.9 ) diff --git a/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py b/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py index 3feb75f2e..df6b58a07 100644 --- a/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py +++ b/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py @@ -14,7 +14,6 @@ import os import tempfile -import unittest # pylint:disable=unused-import from unittest import mock as unittest_mock from absl.testing import parameterized @@ -28,7 +27,6 @@ from mediapipe.model_maker.python.vision.object_detector import object_detector_ from mediapipe.tasks.python.test import test_utils as task_test_utils -@unittest.skip('b/275624089') class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): @@ -51,7 +49,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): def test_object_detector(self): hparams = hyperparameters.HParams( - epochs=10, + epochs=1, batch_size=2, learning_rate=0.9, shuffle=False, @@ -75,7 +73,6 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): output_tflite_file = os.path.join( options.hparams.export_dir, 'model.tflite' ) - print('ASDF float', os.path.getsize(output_tflite_file)) self.assertTrue(os.path.exists(output_tflite_file)) self.assertGreater(os.path.getsize(output_tflite_file), 0) self.assertTrue(os.path.exists(output_metadata_file)) @@ -85,7 +82,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): qat_hparams = hyperparameters.QATHParams( learning_rate=0.9, batch_size=2, - epochs=5, + epochs=1, decay_steps=6, decay_rate=0.96, ) @@ -101,7 +98,6 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): output_tflite_file = os.path.join( options.hparams.export_dir, 'model_qat.tflite' ) - print('ASDF qat', os.path.getsize(output_tflite_file)) self.assertTrue(os.path.exists(output_tflite_file)) self.assertGreater(os.path.getsize(output_tflite_file), 0) self.assertLess(os.path.getsize(output_tflite_file), 3500000) diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index 7df6c8027..83567a4d8 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -89,6 +89,7 @@ objc_library( "//mediapipe/gpu:metal_shared_resources", "//mediapipe/gpu:pixel_buffer_pool_util", "//mediapipe/util:cpu_util", + "//third_party/apple_frameworks:AVFoundation", "//third_party/apple_frameworks:Accelerate", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", @@ -124,8 +125,10 @@ objc_library( visibility = ["//visibility:public"], deps = [ "//third_party/apple_frameworks:AVFoundation", + "//third_party/apple_frameworks:CoreAudio", "//third_party/apple_frameworks:CoreVideo", "//third_party/apple_frameworks:Foundation", + "//third_party/apple_frameworks:MediaToolbox", ], ) diff --git a/mediapipe/objc/MPPInputSource.h b/mediapipe/objc/MPPInputSource.h index 2c518fdc4..121261c59 100644 --- a/mediapipe/objc/MPPInputSource.h +++ b/mediapipe/objc/MPPInputSource.h @@ -13,8 +13,11 @@ // limitations under the License. #import +#import #import +NS_ASSUME_NONNULL_BEGIN + @class MPPInputSource; /// A delegate that can receive frames from a source. @@ -31,7 +34,7 @@ timestamp:(CMTime)timestamp fromSource:(MPPInputSource*)source; -// Provides the delegate with a new depth frame data +// Provides the delegate with new depth frame data. @optional - (void)processDepthData:(AVDepthData*)depthData timestamp:(CMTime)timestamp @@ -40,6 +43,23 @@ @optional - (void)videoDidPlayToEnd:(CMTime)timestamp; +// Provides the delegate with the format of the audio track to be played. +@optional +- (void)willStartPlayingAudioWithFormat:(const AudioStreamBasicDescription*)format + fromSource:(MPPInputSource*)source; + +// Lets the delegate know that there is no audio track despite audio playback +// having been requested (or that audio playback failed for other reasons). +@optional +- (void)noAudioAvailableFromSource:(MPPInputSource*)source; + +// Provides the delegate with a new audio packet. +@optional +- (void)processAudioPacket:(const AudioBufferList*)audioPacket + numFrames:(CMItemCount)numFrames + timestamp:(CMTime)timestamp + fromSource:(MPPInputSource*)source; + @end /// Abstract class for a video source. @@ -68,3 +88,5 @@ - (void)stop; @end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/objc/MPPPlayerInputSource.h b/mediapipe/objc/MPPPlayerInputSource.h index e1516abe9..87054d953 100644 --- a/mediapipe/objc/MPPPlayerInputSource.h +++ b/mediapipe/objc/MPPPlayerInputSource.h @@ -18,7 +18,15 @@ /// Not meant for batch processing of video. @interface MPPPlayerInputSource : MPPInputSource -/// Designated initializer. +/// Initializes the video source with optional audio processing. +/// +/// @param video The video asset to play. +/// @param audioProcessingEnabled If set, indicates that the (first) audio track +/// should be processed if it exists, and the corresponding methods for +/// audio will be invoked on the @c delegate. +- (instancetype)initWithAVAsset:(AVAsset*)video audioProcessingEnabled:(BOOL)audioProcessingEnabled; + +/// Initializes the video source to process @c video with audio processing disabled. - (instancetype)initWithAVAsset:(AVAsset*)video; /// Skip into video @c time from beginning (time 0), within error of +/- tolerance to closest time. diff --git a/mediapipe/objc/MPPPlayerInputSource.m b/mediapipe/objc/MPPPlayerInputSource.m index f5741f8af..6cd489ff7 100644 --- a/mediapipe/objc/MPPPlayerInputSource.m +++ b/mediapipe/objc/MPPPlayerInputSource.m @@ -13,11 +13,13 @@ // limitations under the License. #import +#import #import "MPPPlayerInputSource.h" #if !TARGET_OS_OSX #import "mediapipe/objc/MPPDisplayLinkWeakTarget.h" #endif +#import "mediapipe/objc/MPPInputSource.h" @implementation MPPPlayerInputSource { AVAsset* _video; @@ -35,7 +37,53 @@ BOOL _playing; } +void InitAudio(MTAudioProcessingTapRef tap, void* clientInfo, void** tapStorageOut) { + // `clientInfo` comes as a user-defined argument through + // `MTAudioProcessingTapCallbacks`; we pass our `MPPPlayerInputSource` + // there. Tap processing functions allow for user-defined "storage" - we just + // treat our input source as such. + *tapStorageOut = clientInfo; +} + +void PrepareAudio(MTAudioProcessingTapRef tap, CMItemCount maxFrames, + const AudioStreamBasicDescription* audioFormat) { + // See `InitAudio`. + MPPPlayerInputSource* source = + (__bridge MPPPlayerInputSource*)MTAudioProcessingTapGetStorage(tap); + if ([source.delegate respondsToSelector:@selector(willStartPlayingAudioWithFormat:fromSource:)]) { + [source.delegate willStartPlayingAudioWithFormat:audioFormat fromSource:source]; + } +} + +void ProcessAudio(MTAudioProcessingTapRef tap, CMItemCount numberFrames, + MTAudioProcessingTapFlags flags, AudioBufferList* bufferListInOut, + CMItemCount* numberFramesOut, MTAudioProcessingTapFlags* flagsOut) { + CMTimeRange timeRange; + OSStatus status = MTAudioProcessingTapGetSourceAudio(tap, numberFrames, bufferListInOut, flagsOut, + &timeRange, numberFramesOut); + if (status != 0) { + NSLog(@"Error from GetSourceAudio: %ld", (long)status); + return; + } + + // See `InitAudio`. + MPPPlayerInputSource* source = + (__bridge MPPPlayerInputSource*)MTAudioProcessingTapGetStorage(tap); + if ([source.delegate respondsToSelector:@selector(processAudioPacket: + numFrames:timestamp:fromSource:)]) { + [source.delegate processAudioPacket:bufferListInOut + numFrames:numberFrames + timestamp:timeRange.start + fromSource:source]; + } +} + - (instancetype)initWithAVAsset:(AVAsset*)video { + return [self initWithAVAsset:video audioProcessingEnabled:NO]; +} + +- (instancetype)initWithAVAsset:(AVAsset*)video + audioProcessingEnabled:(BOOL)audioProcessingEnabled { self = [super init]; if (self) { _video = video; @@ -67,6 +115,11 @@ CVDisplayLinkStop(_videoDisplayLink); CVDisplayLinkSetOutputCallback(_videoDisplayLink, renderCallback, (__bridge void*)self); #endif // TARGET_OS_OSX + + if (audioProcessingEnabled) { + [self setupAudioPlayback]; + } + _videoPlayer = [AVPlayer playerWithPlayerItem:_videoItem]; _videoPlayer.actionAtItemEnd = AVPlayerActionAtItemEndNone; NSNotificationCenter* center = [NSNotificationCenter defaultCenter]; @@ -88,6 +141,47 @@ return self; } +- (void)setupAudioPlayback { + bool have_audio = false; + NSArray* audioTracks = + [_video tracksWithMediaCharacteristic:AVMediaCharacteristicAudible]; + if (audioTracks.count != 0) { + // We always limit ourselves to the first audio track if there are + // multiple (which is a rarity) - note that it can still be e.g. stereo. + AVAssetTrack* audioTrack = audioTracks[0]; + MTAudioProcessingTapCallbacks audioCallbacks; + audioCallbacks.version = kMTAudioProcessingTapCallbacksVersion_0; + audioCallbacks.clientInfo = (__bridge void*)(self); + audioCallbacks.init = InitAudio; + audioCallbacks.prepare = PrepareAudio; + audioCallbacks.process = ProcessAudio; + audioCallbacks.unprepare = NULL; + audioCallbacks.finalize = NULL; + + MTAudioProcessingTapRef audioTap; + OSStatus status = + MTAudioProcessingTapCreate(kCFAllocatorDefault, &audioCallbacks, + kMTAudioProcessingTapCreationFlag_PreEffects, &audioTap); + if (status == noErr && audioTap != NULL) { + AVMutableAudioMixInputParameters* audioMixInputParams = + [AVMutableAudioMixInputParameters audioMixInputParametersWithTrack:audioTrack]; + audioMixInputParams.audioTapProcessor = audioTap; + CFRelease(audioTap); + + AVMutableAudioMix* audioMix = [AVMutableAudioMix audioMix]; + + audioMix.inputParameters = @[ audioMixInputParams ]; + _videoItem.audioMix = audioMix; + have_audio = true; + } else { + NSLog(@"Error %ld when trying to create the audio processing tap", (long)status); + } + } + if (!have_audio && [self.delegate respondsToSelector:@selector(noAudioAvailableFromSource:)]) { + [self.delegate noAudioAvailableFromSource:self]; + } +} + - (void)start { [_videoPlayer play]; _playing = YES; diff --git a/mediapipe/objc/util.cc b/mediapipe/objc/util.cc index 895463060..36ad4e195 100644 --- a/mediapipe/objc/util.cc +++ b/mediapipe/objc/util.cc @@ -365,6 +365,10 @@ absl::StatusOr> CreateCVPixelBufferForImageFrame( pixel_format = kCVPixelFormatType_TwoComponent32Float; break; + case mediapipe::ImageFormat::VEC32F4: + pixel_format = kCVPixelFormatType_128RGBAFloat; + break; + default: return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "unsupported ImageFrame format: " << image_format; @@ -440,6 +444,10 @@ absl::StatusOr> CreateCVPixelBufferCopyingImageFrame( pixel_format = kCVPixelFormatType_TwoComponent32Float; break; + case mediapipe::ImageFormat::VEC32F4: + pixel_format = kCVPixelFormatType_128RGBAFloat; + break; + default: return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "unsupported ImageFrame format: " << image_format; diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index e5d33fb31..a7e777039 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -94,6 +94,7 @@ cc_library( "//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph", "//mediapipe/tasks/cc/vision/face_detector:face_detector_graph", "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph", + "//mediapipe/tasks/cc/vision/face_stylizer:face_stylizer_graph", "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", diff --git a/mediapipe/python/pybind/image.cc b/mediapipe/python/pybind/image.cc index 6049abfae..5e5ba7530 100644 --- a/mediapipe/python/pybind/image.cc +++ b/mediapipe/python/pybind/image.cc @@ -120,16 +120,17 @@ void ImageSubmodule(pybind11::module* module) { py::init([](mediapipe::ImageFormat::Format format, const py::array_t& data) { if (format != mediapipe::ImageFormat::VEC32F1 && - format != mediapipe::ImageFormat::VEC32F2) { + format != mediapipe::ImageFormat::VEC32F2 && + format != mediapipe::ImageFormat::VEC32F4) { throw RaisePyError( PyExc_RuntimeError, - "float image data should be either VEC32F1 or VEC32F2 " - "MediaPipe image formats."); + "float image data should be either VEC32F1, VEC32F2, or " + "VEC32F4 MediaPipe image formats."); } return Image(std::shared_ptr( CreateImageFrame(format, data))); }), - R"doc(For float data type, valid ImageFormat are VEC32F1 and VEC32F2.)doc", + R"doc(For float data type, valid ImageFormat are VEC32F1, VEC32F2, and VEC32F4.)doc", py::arg("image_format"), py::arg("data").noconvert()); image.def( diff --git a/mediapipe/python/pybind/packet_creator.cc b/mediapipe/python/pybind/packet_creator.cc index 92e695020..c8ae7c259 100644 --- a/mediapipe/python/pybind/packet_creator.cc +++ b/mediapipe/python/pybind/packet_creator.cc @@ -42,7 +42,8 @@ Packet CreateImageFramePacket(mediapipe::ImageFormat::Format format, format == mediapipe::ImageFormat::SRGBA64) { return Adopt(CreateImageFrame(format, data, copy).release()); } else if (format == mediapipe::ImageFormat::VEC32F1 || - format == mediapipe::ImageFormat::VEC32F2) { + format == mediapipe::ImageFormat::VEC32F2 || + format == mediapipe::ImageFormat::VEC32F4) { return Adopt(CreateImageFrame(format, data, copy).release()); } throw RaisePyError(PyExc_RuntimeError, @@ -63,7 +64,8 @@ Packet CreateImagePacket(mediapipe::ImageFormat::Format format, return MakePacket(std::shared_ptr( CreateImageFrame(format, data, copy))); } else if (format == mediapipe::ImageFormat::VEC32F1 || - format == mediapipe::ImageFormat::VEC32F2) { + format == mediapipe::ImageFormat::VEC32F2 || + format == mediapipe::ImageFormat::VEC32F4) { return MakePacket(std::shared_ptr( CreateImageFrame(format, data, copy))); } diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc index 145076cd3..01e1292c3 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc @@ -111,7 +111,7 @@ class ClassificationAggregationCalculator : public Node { private: std::vector head_names_; bool time_aggregation_enabled_; - std::unordered_map> + std::unordered_map> cached_classifications_; ClassificationResult ConvertToClassificationResult(CalculatorContext* cc); diff --git a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc index 6e06c4e32..94e0fcb36 100644 --- a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc @@ -83,7 +83,7 @@ class EmbeddingAggregationCalculator : public Node { private: bool time_aggregation_enabled_; - std::unordered_map cached_embeddings_; + std::unordered_map cached_embeddings_; }; absl::Status EmbeddingAggregationCalculator::UpdateContract( diff --git a/mediapipe/tasks/cc/core/external_file_handler.cc b/mediapipe/tasks/cc/core/external_file_handler.cc index c271f3dac..a56f03d55 100644 --- a/mediapipe/tasks/cc/core/external_file_handler.cc +++ b/mediapipe/tasks/cc/core/external_file_handler.cc @@ -66,13 +66,13 @@ using ::absl::StatusCode; // Gets the offset aligned to page size for mapping given files into memory by // file descriptor correctly, as according to mmap(2), the offset used in mmap // must be a multiple of sysconf(_SC_PAGE_SIZE). -int64 GetPageSizeAlignedOffset(int64 offset) { +int64_t GetPageSizeAlignedOffset(int64_t offset) { #ifdef _WIN32 // mmap is not used on Windows return 0; #else - int64 aligned_offset = offset; - int64 page_size = sysconf(_SC_PAGE_SIZE); + int64_t aligned_offset = offset; + int64_t page_size = sysconf(_SC_PAGE_SIZE); if (offset % page_size != 0) { aligned_offset = offset / page_size * page_size; } diff --git a/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc b/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc index 32ff51482..63cd2ff9c 100644 --- a/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc +++ b/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc @@ -111,7 +111,7 @@ TEST(MetadataVersionTest, TEST(MetadataVersionTest, GetMinimumMetadataParserVersionForModelMetadataVocabAssociatedFiles) { // Creates a metadata flatbuffer with the field, - // ModelMetadata.associated_fiels, populated with the vocabulary file type. + // ModelMetadata.associated_fields, populated with the vocabulary file type. FlatBufferBuilder builder(1024); AssociatedFileBuilder associated_file_builder(builder); associated_file_builder.add_type(tflite::AssociatedFileType_VOCABULARY); @@ -159,8 +159,8 @@ TEST(MetadataVersionTest, TEST(MetadataVersionTest, GetMinimumMetadataParserVersionForInputMetadataVocabAssociatedFiles) { // Creates a metadata flatbuffer with the field, - // SubGraphMetadata.input_tensor_metadata.associated_fiels, populated with the - // vocabulary file type. + // SubGraphMetadata.input_tensor_metadata.associated_fields, populated with + // the vocabulary file type. FlatBufferBuilder builder(1024); AssociatedFileBuilder associated_file_builder(builder); associated_file_builder.add_type(tflite::AssociatedFileType_VOCABULARY); @@ -184,7 +184,7 @@ TEST(MetadataVersionTest, TEST(MetadataVersionTest, GetMinimumMetadataParserVersionForOutputMetadataVocabAssociatedFiles) { // Creates a metadata flatbuffer with the field, - // SubGraphMetadata.output_tensor_metadata.associated_fiels, populated with + // SubGraphMetadata.output_tensor_metadata.associated_fields, populated with // the vocabulary file type. FlatBufferBuilder builder(1024); AssociatedFileBuilder associated_file_builder(builder); diff --git a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h index c56f350b2..8e6105e18 100644 --- a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h +++ b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h @@ -188,7 +188,7 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi { // For 90° and 270° rotations, we need to swap width and height. // This is due to the internal behavior of ImageToTensorCalculator, which: // - first denormalizes the provided rect by multiplying the rect width or - // height by the image width or height, repectively. + // height by the image width or height, respectively. // - then rotates this by denormalized rect by the provided rotation, and // uses this for cropping, // - then finally rotates this back. diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index 3fe999937..527363d1f 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -374,22 +374,22 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { // Inference for custom gesture classifier if it exists. if (has_custom_gesture_classifier) { ASSIGN_OR_RETURN( - auto gesture_clasification_list, + auto gesture_classification_list, GetGestureClassificationList( sub_task_model_resources.custom_gesture_classifier_model_resource, graph_options.custom_gesture_classifier_graph_options(), embedding_tensors, graph)); - gesture_clasification_list >> combine_predictions.In(classifier_nums++); + gesture_classification_list >> combine_predictions.In(classifier_nums++); } // Inference for canned gesture classifier. ASSIGN_OR_RETURN( - auto gesture_clasification_list, + auto gesture_classification_list, GetGestureClassificationList( sub_task_model_resources.canned_gesture_classifier_model_resource, graph_options.canned_gesture_classifier_graph_options(), embedding_tensors, graph)); - gesture_clasification_list >> combine_predictions.In(classifier_nums++); + gesture_classification_list >> combine_predictions.In(classifier_nums++); auto combined_classification_list = combine_predictions.Out(kPredictionTag).Cast(); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 69833a5f6..ee1cd3693 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -16,6 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +cc_library( + name = "image_segmenter_result", + hdrs = ["image_segmenter_result.h"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework/formats:image"], +) + # Docs for Mediapipe Tasks Image Segmenter # https://developers.google.com/mediapipe/solutions/vision/image_segmenter cc_library( @@ -25,6 +32,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":image_segmenter_graph", + ":image_segmenter_result", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", @@ -82,6 +90,7 @@ cc_library( "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "//mediapipe/tasks/metadata:image_segmenter_metadata_schema_cc", "//mediapipe/tasks/metadata:metadata_schema_cc", + "//mediapipe/util:graph_builder_utils", "//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_util", "@com_google_absl//absl/status", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.cc index 5a09d3a8d..da5dcacae 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.cc @@ -22,7 +22,21 @@ using mediapipe::kBasicVertexShader; using ::mediapipe::tasks::vision::Shape; using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; +// TODO: This part of the setup code is so common, we should really +// refactor to a helper utility. enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; +const GLint attr_location[NUM_ATTRIBUTES] = { + ATTRIB_VERTEX, + ATTRIB_TEXTURE_POSITION, +}; +const GLchar* attr_name[NUM_ATTRIBUTES] = { + "position", + "texture_coordinate", +}; + +// We assume ES3.0+ for some of our shaders here so we can make liberal use of +// MRT easily. +static constexpr char kEs30RequirementHeader[] = "#version 300 es\n"; static constexpr char kActivationFragmentShader[] = R"( DEFAULT_PRECISION(mediump, float) @@ -140,55 +154,93 @@ void main() { gl_FragColor = vec4(out_value, out_value, out_value, out_value); })"; -// Quick softmax shader hardcoded to max of N=12 classes. Performs softmax -// calculations, but renders to one chunk at a time. -// TODO: For more efficiency, should at least use MRT to render all -// chunks simultaneously. -static constexpr char kSoftmaxShader[] = R"( +// Softmax is in 3 steps: +// - First we find max over all masks +// - Then we transform all masks to be exp(val - maxval), and also add to +// cumulative-sum image with MRT +// - Then we normalize all masks by cumulative-sum image + +// Part one: max shader +// To start with, we just do this chunk by chunk, using GL_MAX blend mode so we +// don't need to tap into the max-so-far texture. +static constexpr char kMaxShader[] = R"( DEFAULT_PRECISION(mediump, float) in vec2 sample_coordinate; -uniform sampler2D input_texture0; -uniform sampler2D input_texture1; -uniform sampler2D input_texture2; -uniform int chunk_select; +uniform sampler2D current_chunk; +uniform int num_channels; // how many channels from current chunk to use (1-4) float max4(vec4 vec) { return max(max(vec.x, vec.y), max(vec.z, vec.w)); } - -vec4 expTransform(vec4 vec, float maxval) { - return exp(vec - maxval); +float max3(vec4 vec) { + return max(max(vec.x, vec.y), vec.z); } +float max2(vec4 vec) { + return max(vec.x, vec.y); +} +void main() { + vec4 chunk_pixel = texture2D(current_chunk, sample_coordinate); + float new_max; + if (num_channels == 1) { + new_max = chunk_pixel.x; + } else if (num_channels == 2) { + new_max = max2(chunk_pixel); + } else if (num_channels == 3) { + new_max = max3(chunk_pixel); + } else { + new_max = max4(chunk_pixel); + } + gl_FragColor = vec4(new_max, 0.0, 0.0, 1.0); +})"; + +// Part two: transform-and-sum shader +// We use GL blending so we can more easily render a cumulative sum texture, and +// this only costs us a glClear for the output chunk (needed since using MRT). +static constexpr char kTransformAndSumShader[] = R"( +DEFAULT_PRECISION(highp, float) +in vec2 sample_coordinate; +uniform sampler2D max_value_texture; +uniform sampler2D current_chunk; +uniform int num_channels; // how many channels from current chunk to use (1-4) + +layout(location = 0) out vec4 cumulative_sum_texture; +layout(location = 1) out vec4 out_chunk_texture; void main() { - // Grab all vecs - vec4 pixel0 = texture2D(input_texture0, sample_coordinate); - vec4 pixel1 = texture2D(input_texture1, sample_coordinate); - vec4 pixel2 = texture2D(input_texture2, sample_coordinate); + float max_pixel = texture(max_value_texture, sample_coordinate).r; + vec4 chunk_pixel = texture(current_chunk, sample_coordinate); + vec4 new_chunk_pixel = exp(chunk_pixel - max_pixel); - // Find maxval amongst all vectors - float max0 = max4(pixel0); - float max1 = max4(pixel1); - float max2 = max4(pixel2); - float maxval = max(max(max0, max1), max2); + float sum_so_far; + if (num_channels == 1) { + sum_so_far = new_chunk_pixel.x; + } else if (num_channels == 2) { + sum_so_far = dot(vec2(1.0, 1.0), new_chunk_pixel.xy); + } else if (num_channels == 3) { + sum_so_far = dot(vec3(1.0, 1.0, 1.0), new_chunk_pixel.xyz); + } else { + sum_so_far = dot(vec4(1.0, 1.0, 1.0, 1.0), new_chunk_pixel); + } - vec4 outPixel0 = expTransform(pixel0, maxval); - vec4 outPixel1 = expTransform(pixel1, maxval); - vec4 outPixel2 = expTransform(pixel2, maxval); + cumulative_sum_texture = vec4(sum_so_far, 0.0, 0.0, 1.0); + out_chunk_texture = new_chunk_pixel; +})"; - // Quick hack to sum all components in vec4: dot with <1, 1, 1, 1> - vec4 ones = vec4(1.0, 1.0, 1.0, 1.0); - float weightSum = dot(ones, outPixel0) + dot(ones, outPixel1) + dot(ones, outPixel2); +// Part three: normalization shader +static constexpr char kNormalizationShader[] = R"( +DEFAULT_PRECISION(mediump, float) +in vec2 sample_coordinate; +uniform sampler2D sum_texture; // cumulative summation value (to normalize by) +uniform sampler2D current_chunk; // current chunk - vec4 outPixel; - if (chunk_select == 0) { - outPixel = outPixel0 / weightSum; - } else if (chunk_select == 1) { - outPixel = outPixel1 / weightSum; - } else { - outPixel = outPixel2 / weightSum; - } - gl_FragColor = outPixel; +void main() { + float sum_pixel = texture2D(sum_texture, sample_coordinate).r; + vec4 chunk_pixel = texture2D(current_chunk, sample_coordinate); + + // NOTE: We assume non-zero sum_pixel here, which is a safe assumption for + // result of an exp transform, but not if this shader is extended to other + // uses. + gl_FragColor = chunk_pixel / sum_pixel; })"; } // namespace @@ -208,19 +260,38 @@ absl::Status SegmentationPostprocessorGl::Initialize( return absl::OkStatus(); } +absl::Status SegmentationPostprocessorGl::CreateBasicFragmentShaderProgram( + std::string const& program_name, std::string const& fragment_shader_source, + std::vector const& uniform_names, GlShader* shader_struct_ptr, + bool is_es30_only = false) { + // Format source and create basic ES3.0+ fragment-shader-only program + const std::string frag_shader_source = + absl::StrCat(is_es30_only ? std::string(kEs30RequirementHeader) : "", + std::string(mediapipe::kMediaPipeFragmentShaderPreamble), + std::string(fragment_shader_source)); + const std::string vert_shader_source = + absl::StrCat(is_es30_only ? std::string(kEs30RequirementHeader) : "", + std::string(kBasicVertexShader)); + mediapipe::GlhCreateProgram( + vert_shader_source.c_str(), frag_shader_source.c_str(), NUM_ATTRIBUTES, + &attr_name[0], attr_location, &shader_struct_ptr->program, + /* force_log_errors */ true); + RET_CHECK(shader_struct_ptr->program) + << "Problem initializing the " << program_name << " program."; + + // Hook up all desired uniforms + for (const auto& uniform_name : uniform_names) { + shader_struct_ptr->uniforms[uniform_name] = + glGetUniformLocation(shader_struct_ptr->program, uniform_name.c_str()); + RET_CHECK(shader_struct_ptr->uniforms[uniform_name] > 0) + << uniform_name << " uniform not found for " << program_name + << " program"; + } + return absl::OkStatus(); +} + absl::Status SegmentationPostprocessorGl::GlInit() { return helper_.RunInGlContext([this]() -> absl::Status { - // TODO: This part of the setup code is so common, we should really - // refactor to a helper utility. - const GLint attr_location[NUM_ATTRIBUTES] = { - ATTRIB_VERTEX, - ATTRIB_TEXTURE_POSITION, - }; - const GLchar* attr_name[NUM_ATTRIBUTES] = { - "position", - "texture_coordinate", - }; - // Default to passthrough/NONE std::string activation_fn = "vec4 out_value = in_value;"; switch (options_.segmenter_options().activation()) { @@ -263,9 +334,17 @@ absl::Status SegmentationPostprocessorGl::GlInit() { absl::StrCat(std::string(mediapipe::kMediaPipeFragmentShaderPreamble), std::string(kArgmaxShader)); - const std::string softmax_shader_source = - absl::StrCat(std::string(mediapipe::kMediaPipeFragmentShaderPreamble), - std::string(kSoftmaxShader)); + // Softmax shaders (Max, Transform+Sum, and Normalization) + MP_RETURN_IF_ERROR(CreateBasicFragmentShaderProgram( + "softmax max", kMaxShader, {"current_chunk", "num_channels"}, + &softmax_max_shader_)); + MP_RETURN_IF_ERROR(CreateBasicFragmentShaderProgram( + "softmax transform-and-sum", kTransformAndSumShader, + {"max_value_texture", "current_chunk", "num_channels"}, + &softmax_transform_and_sum_shader_, true /* is_es30_only */)); + MP_RETURN_IF_ERROR(CreateBasicFragmentShaderProgram( + "softmax normalization", kNormalizationShader, + {"sum_texture", "current_chunk"}, &softmax_normalization_shader_)); // Compile all our shader programs. // Note: we enable `force_log_errors` so that we get full debugging error @@ -299,12 +378,6 @@ absl::Status SegmentationPostprocessorGl::GlInit() { /* force_log_errors */ true); RET_CHECK(argmax_program_) << "Problem initializing the argmax program."; - mediapipe::GlhCreateProgram(kBasicVertexShader, - softmax_shader_source.c_str(), NUM_ATTRIBUTES, - &attr_name[0], attr_location, &softmax_program_, - /* force_log_errors */ true); - RET_CHECK(softmax_program_) << "Problem initializing the softmax program."; - // Get uniform locations. activation_texture_uniform_ = glGetUniformLocation(activation_program_, "input_texture"); @@ -341,23 +414,6 @@ absl::Status SegmentationPostprocessorGl::GlInit() { RET_CHECK(argmax_texture2_uniform_ > 0) << "argmax input_texture2 uniform not found."; - softmax_texture0_uniform_ = - glGetUniformLocation(softmax_program_, "input_texture0"); - RET_CHECK(softmax_texture0_uniform_ > 0) - << "softmax input_texture0 uniform not found."; - softmax_texture1_uniform_ = - glGetUniformLocation(softmax_program_, "input_texture1"); - RET_CHECK(softmax_texture1_uniform_ > 0) - << "softmax input_texture1 uniform not found."; - softmax_texture2_uniform_ = - glGetUniformLocation(softmax_program_, "input_texture2"); - RET_CHECK(softmax_texture2_uniform_ > 0) - << "softmax input_texture2 uniform not found."; - softmax_chunk_select_uniform_ = - glGetUniformLocation(softmax_program_, "chunk_select"); - RET_CHECK(softmax_chunk_select_uniform_ > 0) - << "softmax chunk select uniform not found."; - // TODO: If ES3.0+ only, switch to VAO for handling attributes. glGenBuffers(1, &square_vertices_); glBindBuffer(GL_ARRAY_BUFFER, square_vertices_); @@ -408,6 +464,9 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(const Shape& input_shape, // Uint8 pipeline and conversions are lacking, so for now we just use F32 // textures even for category masks. + // TODO: Also, some platforms (like certain iOS devices) do not + // allow for rendering to RGBAF32 textures, so we should switch to using + // F16 textures in those instances. const GpuBufferFormat final_output_format = GpuBufferFormat::kGrayFloat32; const Tensor::OpenGlTexture2dView read_view = tensor.GetOpenGlTexture2dReadView(); @@ -467,7 +526,7 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(const Shape& input_shape, ((float)i + tex_offset) / (float)(input_width)); // Technically duplicated, but fine for now; we want this after the bind glBindTexture(GL_TEXTURE_2D, activated_texture.name()); - // Disable HW interpolation + // Disable hardware GPU interpolation glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST); glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST); // Render @@ -477,45 +536,126 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(const Shape& input_shape, std::vector softmax_chunks; if (is_softmax) { - // Step 2.5: For SOFTMAX, apply softmax shader with up to 3 textures to - // create softmax-transformed chunks before channel extraction. - RET_CHECK(num_chunks <= 3) - << "Cannot handle more than 12 classes in softmax shader."; + // Step 2.5: For SOFTMAX, apply softmax shaders (max, transformAndSum, and + // normalization) to create softmax-transformed chunks before channel + // extraction. + // NOTE: exp(x-C) / sum_over_x(exp(x-C)) = exp(x) / sum_over_x(exp(x)). So + // theoretically we can skip the max shader step entirely. However, + // applying it does bring all our values into a nice (0, 1] range, so it + // will likely be better for precision, especially when dealing with an + // exponential function on arbitrary values. Therefore, we keep it, but + // this is potentially a skippable step for known "good" models, if we + // ever want to provide that as an option. + // TODO: For a tiny bit more efficiency, could combine channel + // extraction into last step of this via MRT. - glUseProgram(softmax_program_); - glUniform1i(softmax_texture0_uniform_, 1); - glUniform1i(softmax_texture1_uniform_, 2); - glUniform1i(softmax_texture2_uniform_, 3); + // Max + glUseProgram(softmax_max_shader_.program); + glUniform1i(softmax_max_shader_.uniforms["current_chunk"], 1); + + // We just need one channel, so format will match final output confidence + // masks + auto max_texture = + helper_.CreateDestinationTexture(width, height, final_output_format); + helper_.BindFramebuffer(max_texture); + + // We clear our newly-created destination texture to a reasonable minimum. + glClearColor(0.0, 0.0, 0.0, 0.0); + glClear(GL_COLOR_BUFFER_BIT); + + // We will use hardware GPU blending to apply max to all our writes. + glEnable(GL_BLEND); + glBlendEquation(GL_MAX); + + glActiveTexture(GL_TEXTURE1); + for (int i = 0; i < num_chunks; i++) { + int num_channels = 4; + if ((i + 1) * 4 > num_outputs) num_channels = num_outputs % 4; + glUniform1i(softmax_max_shader_.uniforms["num_channels"], num_channels); + glBindTexture(GL_TEXTURE_2D, chunks[i].name()); + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + } + + // Transform & Sum + std::vector unnormalized_softmax_chunks; + glUseProgram(softmax_transform_and_sum_shader_.program); + glUniform1i(softmax_transform_and_sum_shader_.uniforms["current_chunk"], + 1); + glUniform1i( + softmax_transform_and_sum_shader_.uniforms["max_value_texture"], 2); + + auto sum_texture = + helper_.CreateDestinationTexture(width, height, final_output_format); + helper_.BindFramebuffer(sum_texture); + glClear(GL_COLOR_BUFFER_BIT); + + glActiveTexture(GL_TEXTURE2); + glBindTexture(GL_TEXTURE_2D, max_texture.name()); + + glBlendEquation(GL_FUNC_ADD); + glBlendFunc(GL_ONE, GL_ONE); + glActiveTexture(GL_TEXTURE1); + + // We use glDrawBuffers to clear only the new texture, then again to + // draw to both textures simultaneously for rendering. + GLuint both_attachments[2] = {GL_COLOR_ATTACHMENT0, GL_COLOR_ATTACHMENT1}; + GLuint one_attachment[2] = {GL_NONE, GL_COLOR_ATTACHMENT1}; + for (int i = 0; i < num_chunks; i++) { + int num_channels = 4; + if ((i + 1) * 4 > num_outputs) num_channels = num_outputs % 4; + glUniform1i(softmax_transform_and_sum_shader_.uniforms["num_channels"], + num_channels); + unnormalized_softmax_chunks.push_back(helper_.CreateDestinationTexture( + width, height, chunk_output_format)); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT1, + GL_TEXTURE_2D, + unnormalized_softmax_chunks.back().name(), 0); + + // Note that we must bind AFTER the CreateDestinationTexture, or else we + // end up with (0, 0, 0, 1) data being read from an unbound texture + // unit. + glBindTexture(GL_TEXTURE_2D, chunks[i].name()); + + // Clear *only* the new chunk + glDrawBuffers(2, one_attachment); + glClear(GL_COLOR_BUFFER_BIT); + + // Then draw into both + glDrawBuffers(2, both_attachments); + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + } + + // Turn off MRT and blending, and unbind second color attachment + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT1, + GL_TEXTURE_2D, 0, 0); + glDrawBuffers(1, both_attachments); + glDisable(GL_BLEND); + + // Normalize each chunk into a new chunk as our final step + glUseProgram(softmax_normalization_shader_.program); + glUniform1i(softmax_normalization_shader_.uniforms["current_chunk"], 1); + glUniform1i(softmax_normalization_shader_.uniforms["sum_texture"], 2); + + glActiveTexture(GL_TEXTURE2); + glBindTexture(GL_TEXTURE_2D, sum_texture.name()); + glActiveTexture(GL_TEXTURE1); for (int i = 0; i < num_chunks; i++) { - glUniform1i(softmax_chunk_select_uniform_, i); softmax_chunks.push_back(helper_.CreateDestinationTexture( - output_width, output_height, chunk_output_format)); + width, height, chunk_output_format)); helper_.BindFramebuffer(softmax_chunks.back()); - - // Bind however many chunks we have - for (int j = 0; j < num_chunks; ++j) { - glActiveTexture(GL_TEXTURE1 + j); - glBindTexture(GL_TEXTURE_2D, chunks[j].name()); - } - - for (int j = num_chunks; j < 3; ++j) { // 3 is hard-coded max chunks - glActiveTexture(GL_TEXTURE1 + j); - // If texture is unbound, sampling from it should always give zeros. - // This is not ideal, but is ok for now for not polluting the argmax - // shader results too much. - glBindTexture(GL_TEXTURE_2D, 0); - } - + glBindTexture(GL_TEXTURE_2D, unnormalized_softmax_chunks[i].name()); glClear(GL_COLOR_BUFFER_BIT); glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); } - // Unbind the extra textures here. - for (int i = 0; i < num_chunks; ++i) { - glActiveTexture(GL_TEXTURE1 + i); - glBindTexture(GL_TEXTURE_2D, 0); - } + // Unbind textures here + glActiveTexture(GL_TEXTURE2); + glBindTexture(GL_TEXTURE_2D, 0); + // We make sure to switch back to texture unit 1, since our confidence + // mask extraction code assumes that's our default. + glActiveTexture(GL_TEXTURE1); + glBindTexture(GL_TEXTURE_2D, 0); } std::vector outputs; @@ -607,17 +747,19 @@ SegmentationPostprocessorGl::~SegmentationPostprocessorGl() { glDeleteProgram(activation_program_); glDeleteProgram(argmax_program_); glDeleteProgram(channel_select_program_); - glDeleteProgram(softmax_program_); glDeleteProgram(split_program_); glDeleteBuffers(1, &square_vertices_); glDeleteBuffers(1, &texture_vertices_); activation_program_ = 0; argmax_program_ = 0; channel_select_program_ = 0; - softmax_program_ = 0; split_program_ = 0; square_vertices_ = 0; texture_vertices_ = 0; + + glDeleteProgram(softmax_max_shader_.program); + glDeleteProgram(softmax_transform_and_sum_shader_.program); + glDeleteProgram(softmax_normalization_shader_.program); }); } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.h b/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.h index aceb3c8d6..c50f93077 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.h @@ -38,7 +38,17 @@ class SegmentationPostprocessorGl { const Tensor& tensor); private: + struct GlShader { + GLuint program = 0; + absl::flat_hash_map uniforms; + }; + absl::Status GlInit(); + absl::Status CreateBasicFragmentShaderProgram( + std::string const& program_name, + std::string const& fragment_shader_source, + std::vector const& uniform_names, + GlShader* shader_struct_ptr, bool is_es30_only); TensorsToSegmentationCalculatorOptions options_; GlCalculatorHelper helper_; @@ -47,7 +57,6 @@ class SegmentationPostprocessorGl { GLuint activation_program_ = 0; GLuint argmax_program_ = 0; GLuint channel_select_program_ = 0; - GLuint softmax_program_ = 0; GLuint split_program_ = 0; GLuint square_vertices_ = 0; GLuint texture_vertices_ = 0; @@ -57,12 +66,12 @@ class SegmentationPostprocessorGl { GLint argmax_texture2_uniform_; GLint channel_select_texture_uniform_; GLint channel_select_index_uniform_; - GLint softmax_texture0_uniform_; - GLint softmax_texture1_uniform_; - GLint softmax_texture2_uniform_; - GLint softmax_chunk_select_uniform_; GLint split_texture_uniform_; GLint split_x_offset_uniform_; + + GlShader softmax_max_shader_; + GlShader softmax_transform_and_sum_shader_; + GlShader softmax_normalization_shader_; }; } // namespace tasks diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc index 0cdc8fe0f..790285546 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status_macros.h" @@ -80,10 +81,10 @@ void Sigmoid(absl::Span values, [](float value) { return 1. / (1 + std::exp(-value)); }); } -std::vector ProcessForCategoryMaskCpu(const Shape& input_shape, - const Shape& output_shape, - const SegmenterOptions& options, - const float* tensors_buffer) { +Image ProcessForCategoryMaskCpu(const Shape& input_shape, + const Shape& output_shape, + const SegmenterOptions& options, + const float* tensors_buffer) { cv::Mat resized_tensors_mat; cv::Mat tensors_mat_view( input_shape.height, input_shape.width, CV_32FC(input_shape.channels), @@ -135,7 +136,7 @@ std::vector ProcessForCategoryMaskCpu(const Shape& input_shape, pixel = maximum_category_idx; } }); - return {category_mask}; + return category_mask; } std::vector ProcessForConfidenceMaskCpu(const Shape& input_shape, @@ -209,7 +210,10 @@ std::vector ProcessForConfidenceMaskCpu(const Shape& input_shape, } // namespace -// Converts Tensors from a vector of Tensor to Segmentation. +// Converts Tensors from a vector of Tensor to Segmentation masks. The +// calculator can output optional confidence masks if CONFIDENCE_MASK is +// connected, and an optional category mask if CATEGORY_MASK is connected. At +// least one of CONFIDENCE_MASK and CATEGORY_MASK must be connected. // // Performs optional resizing to OUTPUT_SIZE dimension if provided, // otherwise the segmented masks is the same size as input tensor. @@ -221,7 +225,12 @@ std::vector ProcessForConfidenceMaskCpu(const Shape& input_shape, // the size to resize masks to. // // Output: -// Segmentation: Segmentation proto. +// CONFIDENCE_MASK @Multiple: Multiple masks of float image where, for each +// mask, each pixel represents the prediction confidence, usually in the [0, +// 1] range. +// CATEGORY_MASK @Optional: A category mask of uint8 image where each pixel +// represents the class which the pixel in the original image was predicted to +// belong to. // // Options: // See tensors_to_segmentation_calculator.proto @@ -231,13 +240,13 @@ std::vector ProcessForConfidenceMaskCpu(const Shape& input_shape, // calculator: "TensorsToSegmentationCalculator" // input_stream: "TENSORS:tensors" // input_stream: "OUTPUT_SIZE:size" -// output_stream: "SEGMENTATION:0:segmentation" -// output_stream: "SEGMENTATION:1:segmentation" +// output_stream: "CONFIDENCE_MASK:0:confidence_mask" +// output_stream: "CONFIDENCE_MASK:1:confidence_mask" +// output_stream: "CATEGORY_MASK:category_mask" // options { // [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { // segmenter_options { // activation: SOFTMAX -// output_type: CONFIDENCE_MASK // } // } // } @@ -248,7 +257,11 @@ class TensorsToSegmentationCalculator : public Node { static constexpr Input>::Optional kOutputSizeIn{ "OUTPUT_SIZE"}; static constexpr Output::Multiple kSegmentationOut{"SEGMENTATION"}; - MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut); + static constexpr Output::Multiple kConfidenceMaskOut{ + "CONFIDENCE_MASK"}; + static constexpr Output::Optional kCategoryMaskOut{"CATEGORY_MASK"}; + MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut, + kConfidenceMaskOut, kCategoryMaskOut); static absl::Status UpdateContract(CalculatorContract* cc); @@ -279,9 +292,20 @@ absl::Status TensorsToSegmentationCalculator::UpdateContract( absl::Status TensorsToSegmentationCalculator::Open( mediapipe::CalculatorContext* cc) { options_ = cc->Options(); - RET_CHECK_NE(options_.segmenter_options().output_type(), - SegmenterOptions::UNSPECIFIED) - << "Must specify output_type as one of [CONFIDENCE_MASK|CATEGORY_MASK]."; + // TODO: remove deprecated output type support. + if (options_.segmenter_options().has_output_type()) { + RET_CHECK_NE(options_.segmenter_options().output_type(), + SegmenterOptions::UNSPECIFIED) + << "Must specify output_type as one of " + "[CONFIDENCE_MASK|CATEGORY_MASK]."; + } else { + if (!cc->Outputs().HasTag("CONFIDENCE_MASK") && + !cc->Outputs().HasTag("CATEGORY_MASK")) { + return absl::InvalidArgumentError( + "At least one of CONFIDENCE_MASK and CATEGORY_MASK must be " + "connected."); + } + } #ifdef __EMSCRIPTEN__ MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_)); #endif // __EMSCRIPTEN__ @@ -309,6 +333,10 @@ absl::Status TensorsToSegmentationCalculator::Process( if (cc->Inputs().HasTag("OUTPUT_SIZE")) { std::tie(output_width, output_height) = kOutputSizeIn(cc).Get(); } + + // Use GPU postprocessing on web when Tensor is there already and has <= 12 + // categories. +#ifdef __EMSCRIPTEN__ Shape output_shape = { /* height= */ output_height, /* width= */ output_width, @@ -316,10 +344,6 @@ absl::Status TensorsToSegmentationCalculator::Process( SegmenterOptions::CATEGORY_MASK ? 1 : input_shape.channels}; - - // Use GPU postprocessing on web when Tensor is there already and has <= 12 - // categories. -#ifdef __EMSCRIPTEN__ if (input_tensor.ready_as_opengl_texture_2d() && input_shape.channels <= 12) { std::vector> segmented_masks = postprocessor_.GetSegmentationResultGpu(input_shape, output_shape, @@ -332,10 +356,43 @@ absl::Status TensorsToSegmentationCalculator::Process( #endif // __EMSCRIPTEN__ // Otherwise, use CPU postprocessing. - std::vector segmented_masks = GetSegmentationResultCpu( - input_shape, output_shape, input_tensor.GetCpuReadView().buffer()); - for (int i = 0; i < segmented_masks.size(); ++i) { - kSegmentationOut(cc)[i].Send(std::move(segmented_masks[i])); + const float* tensors_buffer = input_tensor.GetCpuReadView().buffer(); + + // TODO: remove deprecated output type support. + if (options_.segmenter_options().has_output_type()) { + std::vector segmented_masks = GetSegmentationResultCpu( + input_shape, + {/* height= */ output_height, + /* width= */ output_width, + /* channels= */ options_.segmenter_options().output_type() == + SegmenterOptions::CATEGORY_MASK + ? 1 + : input_shape.channels}, + input_tensor.GetCpuReadView().buffer()); + for (int i = 0; i < segmented_masks.size(); ++i) { + kSegmentationOut(cc)[i].Send(std::move(segmented_masks[i])); + } + return absl::OkStatus(); + } + + if (cc->Outputs().HasTag("CONFIDENCE_MASK")) { + std::vector confidence_masks = ProcessForConfidenceMaskCpu( + input_shape, + {/* height= */ output_height, + /* width= */ output_width, + /* channels= */ input_shape.channels}, + options_.segmenter_options(), tensors_buffer); + for (int i = 0; i < confidence_masks.size(); ++i) { + kConfidenceMaskOut(cc)[i].Send(std::move(confidence_masks[i])); + } + } + if (cc->Outputs().HasTag("CATEGORY_MASK")) { + kCategoryMaskOut(cc).Send(ProcessForCategoryMaskCpu( + input_shape, + {/* height= */ output_height, + /* width= */ output_width, + /* channels= */ 1}, + options_.segmenter_options(), tensors_buffer)); } return absl::OkStatus(); } @@ -345,9 +402,9 @@ std::vector TensorsToSegmentationCalculator::GetSegmentationResultCpu( const float* tensors_buffer) { if (options_.segmenter_options().output_type() == SegmenterOptions::CATEGORY_MASK) { - return ProcessForCategoryMaskCpu(input_shape, output_shape, - options_.segmenter_options(), - tensors_buffer); + return {ProcessForCategoryMaskCpu(input_shape, output_shape, + options_.segmenter_options(), + tensors_buffer)}; } else { return ProcessForConfidenceMaskCpu(input_shape, output_shape, options_.segmenter_options(), diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc index 54fb9b816..d6a2f3fd9 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc @@ -79,8 +79,9 @@ void PushTensorsToRunner(int tensor_height, int tensor_width, std::vector GetPackets(const CalculatorRunner& runner) { std::vector mask_packets; for (int i = 0; i < runner.Outputs().NumEntries(); ++i) { - EXPECT_EQ(runner.Outputs().Get("SEGMENTATION", i).packets.size(), 1); - mask_packets.push_back(runner.Outputs().Get("SEGMENTATION", i).packets[0]); + EXPECT_EQ(runner.Outputs().Get("CONFIDENCE_MASK", i).packets.size(), 1); + mask_packets.push_back( + runner.Outputs().Get("CONFIDENCE_MASK", i).packets[0]); } return mask_packets; } @@ -118,13 +119,10 @@ TEST(TensorsToSegmentationCalculatorTest, FailsInvalidTensorDimensionOne) { R"pb( calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" - output_stream: "SEGMENTATION:segmentation" + output_stream: "CONFIDENCE_MASK:segmentation" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: SOFTMAX - output_type: CONFIDENCE_MASK - } + segmenter_options { activation: SOFTMAX } } } )pb")); @@ -145,13 +143,10 @@ TEST(TensorsToSegmentationCalculatorTest, FailsInvalidTensorDimensionFive) { R"pb( calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" - output_stream: "SEGMENTATION:segmentation" + output_stream: "CONFIDENCE_MASK:segmentation" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: SOFTMAX - output_type: CONFIDENCE_MASK - } + segmenter_options { activation: SOFTMAX } } } )pb")); @@ -173,16 +168,13 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithSoftmax) { R"pb( calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" - output_stream: "SEGMENTATION:0:segmented_mask_0" - output_stream: "SEGMENTATION:1:segmented_mask_1" - output_stream: "SEGMENTATION:2:segmented_mask_2" - output_stream: "SEGMENTATION:3:segmented_mask_3" + output_stream: "CONFIDENCE_MASK:0:segmented_mask_0" + output_stream: "CONFIDENCE_MASK:1:segmented_mask_1" + output_stream: "CONFIDENCE_MASK:2:segmented_mask_2" + output_stream: "CONFIDENCE_MASK:3:segmented_mask_3" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: SOFTMAX - output_type: CONFIDENCE_MASK - } + segmenter_options { activation: SOFTMAX } } } )pb")); @@ -218,16 +210,13 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithNone) { R"pb( calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" - output_stream: "SEGMENTATION:0:segmented_mask_0" - output_stream: "SEGMENTATION:1:segmented_mask_1" - output_stream: "SEGMENTATION:2:segmented_mask_2" - output_stream: "SEGMENTATION:3:segmented_mask_3" + output_stream: "CONFIDENCE_MASK:0:segmented_mask_0" + output_stream: "CONFIDENCE_MASK:1:segmented_mask_1" + output_stream: "CONFIDENCE_MASK:2:segmented_mask_2" + output_stream: "CONFIDENCE_MASK:3:segmented_mask_3" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: NONE - output_type: CONFIDENCE_MASK - } + segmenter_options { activation: NONE } } } )pb")); @@ -259,16 +248,13 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithSigmoid) { R"pb( calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" - output_stream: "SEGMENTATION:0:segmented_mask_0" - output_stream: "SEGMENTATION:1:segmented_mask_1" - output_stream: "SEGMENTATION:2:segmented_mask_2" - output_stream: "SEGMENTATION:3:segmented_mask_3" + output_stream: "CONFIDENCE_MASK:0:segmented_mask_0" + output_stream: "CONFIDENCE_MASK:1:segmented_mask_1" + output_stream: "CONFIDENCE_MASK:2:segmented_mask_2" + output_stream: "CONFIDENCE_MASK:3:segmented_mask_3" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: SIGMOID - output_type: CONFIDENCE_MASK - } + segmenter_options { activation: SIGMOID } } } )pb")); @@ -301,13 +287,14 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMask) { R"pb( calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" - output_stream: "SEGMENTATION:segmentation" + output_stream: "CONFIDENCE_MASK:0:segmented_mask_0" + output_stream: "CONFIDENCE_MASK:1:segmented_mask_1" + output_stream: "CONFIDENCE_MASK:2:segmented_mask_2" + output_stream: "CONFIDENCE_MASK:3:segmented_mask_3" + output_stream: "CATEGORY_MASK:segmentation" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: NONE - output_type: CATEGORY_MASK - } + segmenter_options { activation: NONE } } } )pb")); @@ -318,11 +305,11 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMask) { tensor_height, tensor_width, std::vector(kTestValues.begin(), kTestValues.end()), &runner); MP_ASSERT_OK(runner.Run()); - ASSERT_EQ(runner.Outputs().NumEntries(), 1); + ASSERT_EQ(runner.Outputs().NumEntries(), 5); // Largest element index is 3. const int expected_index = 3; const std::vector buffer_indices = {0}; - std::vector packets = GetPackets(runner); + std::vector packets = runner.Outputs().Tag("CATEGORY_MASK").packets; EXPECT_THAT(packets, testing::ElementsAre( Uint8ImagePacket(tensor_height, tensor_width, expected_index, buffer_indices))); @@ -335,13 +322,14 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) { calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" input_stream: "TENSORS:tensors" input_stream: "OUTPUT_SIZE:size" - output_stream: "SEGMENTATION:segmentation" + output_stream: "CONFIDENCE_MASK:0:segmented_mask_0" + output_stream: "CONFIDENCE_MASK:1:segmented_mask_1" + output_stream: "CONFIDENCE_MASK:2:segmented_mask_2" + output_stream: "CONFIDENCE_MASK:3:segmented_mask_3" + output_stream: "CATEGORY_MASK:segmentation" options { [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] { - segmenter_options { - activation: NONE - output_type: CATEGORY_MASK - } + segmenter_options { activation: NONE } } } )pb")); @@ -367,7 +355,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) { const std::vector buffer_indices = { 0 * output_width + 0, 0 * output_width + 1, 1 * output_width + 0, 1 * output_width + 1}; - std::vector packets = GetPackets(runner); + std::vector packets = runner.Outputs().Tag("CATEGORY_MASK").packets; EXPECT_THAT(packets, testing::ElementsAre( Uint8ImagePacket(output_height, output_width, expected_index, buffer_indices))); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index ab1d3c84b..33c868e05 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -37,8 +37,10 @@ namespace vision { namespace image_segmenter { namespace { -constexpr char kSegmentationStreamName[] = "segmented_mask_out"; -constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; +constexpr char kConfidenceMasksTag[] = "CONFIDENCE_MASKS"; +constexpr char kConfidenceMasksStreamName[] = "confidence_masks"; +constexpr char kCategoryMaskTag[] = "CATEGORY_MASK"; +constexpr char kCategoryMaskStreamName[] = "category_mask"; constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageTag[] = "IMAGE"; @@ -51,7 +53,6 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::Image; using ::mediapipe::NormalizedRect; -using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: image_segmenter::proto::ImageSegmenterGraphOptions; @@ -59,6 +60,7 @@ using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: // "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph". CalculatorGraphConfig CreateGraphConfig( std::unique_ptr options, + bool output_confidence_masks, bool output_category_mask, bool enable_flow_limiting) { api2::builder::Graph graph; auto& task_subgraph = graph.AddNode(kSubgraphTypeName); @@ -66,14 +68,20 @@ CalculatorGraphConfig CreateGraphConfig( options.get()); graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName); - task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >> - graph.Out(kGroupedSegmentationTag); + if (output_confidence_masks) { + task_subgraph.Out(kConfidenceMasksTag) + .SetName(kConfidenceMasksStreamName) >> + graph.Out(kConfidenceMasksTag); + } + if (output_category_mask) { + task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >> + graph.Out(kCategoryMaskTag); + } task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); if (enable_flow_limiting) { - return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph, - {kImageTag, kNormRectTag}, - kGroupedSegmentationTag); + return tasks::core::AddFlowLimiterCalculator( + graph, task_subgraph, {kImageTag, kNormRectTag}, kConfidenceMasksTag); } graph.In(kImageTag) >> task_subgraph.In(kImageTag); graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag); @@ -91,16 +99,6 @@ ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) { options_proto->mutable_base_options()->set_use_stream_mode( options->running_mode != core::RunningMode::IMAGE); options_proto->set_display_names_locale(options->display_names_locale); - switch (options->output_type) { - case ImageSegmenterOptions::OutputType::CATEGORY_MASK: - options_proto->mutable_segmenter_options()->set_output_type( - SegmenterOptions::CATEGORY_MASK); - break; - case ImageSegmenterOptions::OutputType::CONFIDENCE_MASK: - options_proto->mutable_segmenter_options()->set_output_type( - SegmenterOptions::CONFIDENCE_MASK); - break; - } return options_proto; } @@ -141,10 +139,17 @@ absl::StatusOr> GetLabelsFromGraphConfig( absl::StatusOr> ImageSegmenter::Create( std::unique_ptr options) { + if (!options->output_confidence_masks && !options->output_category_mask) { + return absl::InvalidArgumentError( + "At least one of `output_confidence_masks` and `output_category_mask` " + "must be set."); + } auto options_proto = ConvertImageSegmenterOptionsToProto(options.get()); tasks::core::PacketsCallback packets_callback = nullptr; if (options->result_callback) { auto result_callback = options->result_callback; + bool output_category_mask = options->output_category_mask; + bool output_confidence_masks = options->output_confidence_masks; packets_callback = [=](absl::StatusOr status_or_packets) { if (!status_or_packets.ok()) { @@ -156,34 +161,46 @@ absl::StatusOr> ImageSegmenter::Create( if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { return; } - Packet segmented_masks = - status_or_packets.value()[kSegmentationStreamName]; + std::optional> confidence_masks; + if (output_confidence_masks) { + confidence_masks = + status_or_packets.value()[kConfidenceMasksStreamName] + .Get>(); + } + std::optional category_mask; + if (output_category_mask) { + category_mask = + status_or_packets.value()[kCategoryMaskStreamName].Get(); + } Packet image_packet = status_or_packets.value()[kImageOutStreamName]; - result_callback(segmented_masks.Get>(), - image_packet.Get(), - segmented_masks.Timestamp().Value() / - kMicroSecondsPerMilliSecond); + result_callback( + {{confidence_masks, category_mask}}, image_packet.Get(), + image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); }; } - auto image_segmenter = core::VisionTaskApiFactory::Create( CreateGraphConfig( - std::move(options_proto), + std::move(options_proto), options->output_confidence_masks, + options->output_category_mask, options->running_mode == core::RunningMode::LIVE_STREAM), std::move(options->base_options.op_resolver), options->running_mode, std::move(packets_callback)); if (!image_segmenter.ok()) { return image_segmenter.status(); } + image_segmenter.value()->output_confidence_masks_ = + options->output_confidence_masks; + image_segmenter.value()->output_category_mask_ = + options->output_category_mask; ASSIGN_OR_RETURN( (*image_segmenter)->labels_, GetLabelsFromGraphConfig((*image_segmenter)->runner_->GetGraphConfig())); return image_segmenter; } -absl::StatusOr> ImageSegmenter::Segment( +absl::StatusOr ImageSegmenter::Segment( mediapipe::Image image, std::optional image_processing_options) { if (image.UsesGpu()) { @@ -201,11 +218,20 @@ absl::StatusOr> ImageSegmenter::Segment( {{kImageInStreamName, mediapipe::MakePacket(std::move(image))}, {kNormRectStreamName, MakePacket(std::move(norm_rect))}})); - return output_packets[kSegmentationStreamName].Get>(); + std::optional> confidence_masks; + if (output_confidence_masks_) { + confidence_masks = + output_packets[kConfidenceMasksStreamName].Get>(); + } + std::optional category_mask; + if (output_category_mask_) { + category_mask = output_packets[kCategoryMaskStreamName].Get(); + } + return {{confidence_masks, category_mask}}; } -absl::StatusOr> ImageSegmenter::SegmentForVideo( - mediapipe::Image image, int64 timestamp_ms, +absl::StatusOr ImageSegmenter::SegmentForVideo( + mediapipe::Image image, int64_t timestamp_ms, std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( @@ -225,11 +251,20 @@ absl::StatusOr> ImageSegmenter::SegmentForVideo( {kNormRectStreamName, MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); - return output_packets[kSegmentationStreamName].Get>(); + std::optional> confidence_masks; + if (output_confidence_masks_) { + confidence_masks = + output_packets[kConfidenceMasksStreamName].Get>(); + } + std::optional category_mask; + if (output_category_mask_) { + category_mask = output_packets[kCategoryMaskStreamName].Get(); + } + return {{confidence_masks, category_mask}}; } absl::Status ImageSegmenter::SegmentAsync( - Image image, int64 timestamp_ms, + Image image, int64_t timestamp_ms, std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index 076a5016c..352d6b273 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -26,6 +26,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h" #include "tensorflow/lite/kernels/register.h" namespace mediapipe { @@ -52,23 +53,17 @@ struct ImageSegmenterOptions { // Metadata, if any. Defaults to English. std::string display_names_locale = "en"; - // The output type of segmentation results. - enum OutputType { - // Gives a single output mask where each pixel represents the class which - // the pixel in the original image was predicted to belong to. - CATEGORY_MASK = 0, - // Gives a list of output masks where, for each mask, each pixel represents - // the prediction confidence, usually in the [0, 1] range. - CONFIDENCE_MASK = 1, - }; + // Whether to output confidence masks. + bool output_confidence_masks = true; - OutputType output_type = OutputType::CATEGORY_MASK; + // Whether to output category mask. + bool output_category_mask = false; // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. - std::function>, - const Image&, int64)> + std::function, const Image&, + int64_t)> result_callback = nullptr; }; @@ -84,13 +79,11 @@ struct ImageSegmenterOptions { // 1 or 3). // - if type is kTfLiteFloat32, NormalizationOptions are required to be // attached to the metadata for input normalization. -// Output tensors: -// (kTfLiteUInt8/kTfLiteFloat32) -// - list of segmented masks. -// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. -// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size -// `channels`. -// - batch is always 1 +// Output ImageSegmenterResult: +// Provides optional confidence masks if `output_confidence_masks` is set +// true, and an optional category mask if `output_category_mask` is set +// true. At least one of `output_confidence_masks` and `output_category_mask` +// must be set to true. // An example of such model can be found at: // https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { @@ -114,12 +107,8 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // setting its 'rotation_degrees' field. Note that specifying a // region-of-interest using the 'region_of_interest' field is NOT supported // and will result in an invalid argument error being returned. - // - // If the output_type is CATEGORY_MASK, the returned vector of images is - // per-category segmented image mask. - // If the output_type is CONFIDENCE_MASK, the returned vector of images - // contains only one confidence image mask. - absl::StatusOr> Segment( + + absl::StatusOr Segment( mediapipe::Image image, std::optional image_processing_options = std::nullopt); @@ -137,13 +126,8 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // setting its 'rotation_degrees' field. Note that specifying a // region-of-interest using the 'region_of_interest' field is NOT supported // and will result in an invalid argument error being returned. - // - // If the output_type is CATEGORY_MASK, the returned vector of images is - // per-category segmented image mask. - // If the output_type is CONFIDENCE_MASK, the returned vector of images - // contains only one confidence image mask. - absl::StatusOr> SegmentForVideo( - mediapipe::Image image, int64 timestamp_ms, + absl::StatusOr SegmentForVideo( + mediapipe::Image image, int64_t timestamp_ms, std::optional image_processing_options = std::nullopt); @@ -164,17 +148,13 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // and will result in an invalid argument error being returned. // // The "result_callback" prvoides - // - A vector of segmented image masks. - // If the output_type is CATEGORY_MASK, the returned vector of images is - // per-category segmented image mask. - // If the output_type is CONFIDENCE_MASK, the returned vector of images - // contains only one confidence image mask. + // - An ImageSegmenterResult. // - The const reference to the corresponding input image that the image // segmentation runs on. Note that the const reference to the image will // no longer be valid when the callback returns. To access the image data // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. - absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms, + absl::Status SegmentAsync(mediapipe::Image image, int64_t timestamp_ms, std::optional image_processing_options = std::nullopt); @@ -182,9 +162,9 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { absl::Status Close() { return runner_->Close(); } // Get the category label list of the ImageSegmenter can recognize. For - // CATEGORY_MASK type, the index in the category mask corresponds to the - // category in the label list. For CONFIDENCE_MASK type, the output mask list - // at index corresponds to the category in the label list. + // CATEGORY_MASK, the index in the category mask corresponds to the category + // in the label list. For CONFIDENCE_MASK, the output mask list at index + // corresponds to the category in the label list. // // If there is no labelmap provided in the model file, empty label list is // returned. @@ -192,6 +172,8 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { private: std::vector labels_; + bool output_confidence_masks_; + bool output_category_mask_; }; } // namespace image_segmenter diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index fe6265b73..840e7933a 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include @@ -42,6 +43,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "mediapipe/tasks/metadata/image_segmenter_metadata_schema_generated.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" +#include "mediapipe/util/graph_builder_utils.h" #include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map_util.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -65,10 +67,13 @@ using ::mediapipe::tasks::vision::image_segmenter::proto:: ImageSegmenterGraphOptions; using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ::tflite::TensorMetadata; -using LabelItems = mediapipe::proto_ns::Map; +using LabelItems = mediapipe::proto_ns::Map; constexpr char kSegmentationTag[] = "SEGMENTATION"; constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; +constexpr char kConfidenceMaskTag[] = "CONFIDENCE_MASK"; +constexpr char kConfidenceMasksTag[] = "CONFIDENCE_MASKS"; +constexpr char kCategoryMaskTag[] = "CATEGORY_MASK"; constexpr char kImageTag[] = "IMAGE"; constexpr char kImageCpuTag[] = "IMAGE_CPU"; constexpr char kImageGpuTag[] = "IMAGE_GPU"; @@ -80,7 +85,9 @@ constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA"; // Struct holding the different output streams produced by the image segmenter // subgraph. struct ImageSegmenterOutputs { - std::vector> segmented_masks; + std::optional>> segmented_masks; + std::optional>> confidence_masks; + std::optional> category_mask; // The same as the input image, mainly used for live stream mode. Source image; }; @@ -95,8 +102,10 @@ struct ImageAndTensorsOnDevice { } // namespace absl::Status SanityCheckOptions(const ImageSegmenterGraphOptions& options) { - if (options.segmenter_options().output_type() == - SegmenterOptions::UNSPECIFIED) { + // TODO: remove deprecated output type support. + if (options.segmenter_options().has_output_type() && + options.segmenter_options().output_type() == + SegmenterOptions::UNSPECIFIED) { return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, "`output_type` must not be UNSPECIFIED", MediaPipeTasksStatus::kInvalidArgumentError); @@ -133,9 +142,8 @@ absl::Status ConfigureTensorsToSegmentationCalculator( const core::ModelResources& model_resources, TensorsToSegmentationCalculatorOptions* options) { // Set default activation function NONE - options->mutable_segmenter_options()->set_output_type( - segmenter_option.segmenter_options().output_type()); - options->mutable_segmenter_options()->set_activation(SegmenterOptions::NONE); + options->mutable_segmenter_options()->CopyFrom( + segmenter_option.segmenter_options()); // Find the custom metadata of ImageSegmenterOptions type in model metadata. const auto* metadata_extractor = model_resources.GetMetadataExtractor(); bool found_activation_in_metadata = false; @@ -317,12 +325,16 @@ absl::StatusOr ConvertImageToTensors( } } -// An "mediapipe.tasks.vision.ImageSegmenterGraph" performs semantic -// segmentation. -// Two kinds of outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION. -// Users can retrieve segmented mask of only particular category/channel from -// SEGMENTATION, and users can also get all segmented masks from -// GROUPED_SEGMENTATION. +// An "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph" performs +// semantic segmentation. The graph can output optional confidence masks if +// CONFIDENCE_MASKS is connected, and an optional category mask if CATEGORY_MASK +// is connected. At least one of CONFIDENCE_MASK, CONFIDENCE_MASKS and +// CATEGORY_MASK must be connected. +// +// Two kinds of outputs for confidence mask are provided: CONFIDENCE_MASK and +// CONFIDENCE_MASKS. Users can retrieve segmented mask of only particular +// category/channel from CONFIDENCE_MASK, and users can also get all segmented +// confidence masks from CONFIDENCE_MASKS. // - Accepts CPU input images and outputs segmented masks on CPU. // // Inputs: @@ -334,17 +346,19 @@ absl::StatusOr ConvertImageToTensors( // @Optional: rect covering the whole image is used if not specified. // // Outputs: -// SEGMENTATION - mediapipe::Image @Multiple -// Segmented masks for individual category. Segmented mask of single +// CONFIDENCE_MASK - mediapipe::Image @Multiple +// Confidence masks for individual category. Confidence mask of single // category can be accessed by index based output stream. -// GROUPED_SEGMENTATION - std::vector -// The output segmented masks grouped in a vector. +// CONFIDENCE_MASKS - std::vector @Optional +// The output confidence masks grouped in a vector. +// CATEGORY_MASK - mediapipe::Image @Optional +// Optional Category mask. // IMAGE - mediapipe::Image // The image that image segmenter runs on. // // Example: // node { -// calculator: "mediapipe.tasks.vision.ImageSegmenterGraph" +// calculator: "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph" // input_stream: "IMAGE:image" // output_stream: "SEGMENTATION:segmented_masks" // options { @@ -369,28 +383,64 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { ASSIGN_OR_RETURN(const auto* model_resources, CreateModelResources(sc)); Graph graph; + const auto& options = sc->Options(); + // TODO: remove deprecated output type support. + if (!options.segmenter_options().has_output_type()) { + MP_RETURN_IF_ERROR(SanityCheck(sc)); + } ASSIGN_OR_RETURN( auto output_streams, BuildSegmentationTask( - sc->Options(), *model_resources, - graph[Input(kImageTag)], + options, *model_resources, graph[Input(kImageTag)], graph[Input::Optional(kNormRectTag)], graph)); - auto& merge_images_to_vector = - graph.AddNode("MergeImagesToVectorCalculator"); - for (int i = 0; i < output_streams.segmented_masks.size(); ++i) { - output_streams.segmented_masks[i] >> - merge_images_to_vector[Input::Multiple("")][i]; - output_streams.segmented_masks[i] >> - graph[Output::Multiple(kSegmentationTag)][i]; + // TODO: remove deprecated output type support. + if (options.segmenter_options().has_output_type()) { + auto& merge_images_to_vector = + graph.AddNode("MergeImagesToVectorCalculator"); + for (int i = 0; i < output_streams.segmented_masks->size(); ++i) { + output_streams.segmented_masks->at(i) >> + merge_images_to_vector[Input::Multiple("")][i]; + output_streams.segmented_masks->at(i) >> + graph[Output::Multiple(kSegmentationTag)][i]; + } + merge_images_to_vector.Out("") >> + graph[Output>(kGroupedSegmentationTag)]; + } else { + if (output_streams.confidence_masks) { + auto& merge_images_to_vector = + graph.AddNode("MergeImagesToVectorCalculator"); + for (int i = 0; i < output_streams.confidence_masks->size(); ++i) { + output_streams.confidence_masks->at(i) >> + merge_images_to_vector[Input::Multiple("")][i]; + output_streams.confidence_masks->at(i) >> + graph[Output::Multiple(kConfidenceMaskTag)][i]; + } + merge_images_to_vector.Out("") >> + graph[Output>::Optional(kConfidenceMasksTag)]; + } + if (output_streams.category_mask) { + *output_streams.category_mask >> graph[Output(kCategoryMaskTag)]; + } } - merge_images_to_vector.Out("") >> - graph[Output>(kGroupedSegmentationTag)]; output_streams.image >> graph[Output(kImageTag)]; return graph.GetConfig(); } private: + absl::Status SanityCheck(mediapipe::SubgraphContext* sc) { + const auto& node = sc->OriginalNode(); + output_confidence_masks_ = HasOutput(node, kConfidenceMaskTag) || + HasOutput(node, kConfidenceMasksTag); + output_category_mask_ = HasOutput(node, kCategoryMaskTag); + if (!output_confidence_masks_ && !output_category_mask_) { + return absl::InvalidArgumentError( + "At least one of CONFIDENCE_MASK, CONFIDENCE_MASKS and CATEGORY_MASK " + "must be connected."); + } + return absl::OkStatus(); + } + // Adds a mediapipe image segmentation task pipeline graph into the provided // builder::Graph instance. The segmentation pipeline takes images // (mediapipe::Image) as the input and returns segmented image mask as output. @@ -435,23 +485,53 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { image_properties.Out("SIZE") >> tensor_to_images.In(kOutputSizeTag); // Exports multiple segmented masks. - std::vector> segmented_masks; - if (task_options.segmenter_options().output_type() == - SegmenterOptions::CATEGORY_MASK) { - segmented_masks.push_back( - Source(tensor_to_images[Output(kSegmentationTag)])); - } else { - ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor, - GetOutputTensor(model_resources)); - int segmentation_streams_num = *output_tensor->shape()->rbegin(); - for (int i = 0; i < segmentation_streams_num; ++i) { - segmented_masks.push_back(Source( - tensor_to_images[Output::Multiple(kSegmentationTag)][i])); + // TODO: remove deprecated output type support. + if (task_options.segmenter_options().has_output_type()) { + std::vector> segmented_masks; + if (task_options.segmenter_options().output_type() == + SegmenterOptions::CATEGORY_MASK) { + segmented_masks.push_back( + Source(tensor_to_images[Output(kSegmentationTag)])); + } else { + ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor, + GetOutputTensor(model_resources)); + int segmentation_streams_num = *output_tensor->shape()->rbegin(); + for (int i = 0; i < segmentation_streams_num; ++i) { + segmented_masks.push_back(Source( + tensor_to_images[Output::Multiple(kSegmentationTag)][i])); + } } + return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks, + /*confidence_masks=*/std::nullopt, + /*category_mask=*/std::nullopt, + /*image=*/image_and_tensors.image}; + } else { + std::optional>> confidence_masks; + if (output_confidence_masks_) { + ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor, + GetOutputTensor(model_resources)); + int segmentation_streams_num = *output_tensor->shape()->rbegin(); + confidence_masks = std::vector>(); + confidence_masks->reserve(segmentation_streams_num); + for (int i = 0; i < segmentation_streams_num; ++i) { + confidence_masks->push_back(Source( + tensor_to_images[Output::Multiple(kConfidenceMaskTag)] + [i])); + } + } + std::optional> category_mask; + if (output_category_mask_) { + category_mask = tensor_to_images[Output(kCategoryMaskTag)]; + } + return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt, + /*confidence_masks=*/confidence_masks, + /*category_mask=*/category_mask, + /*image=*/image_and_tensors.image}; } - return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks, - /*image=*/image_and_tensors.image}; } + + bool output_confidence_masks_ = false; + bool output_category_mask_ = false; }; REGISTER_MEDIAPIPE_GRAPH( diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h new file mode 100644 index 000000000..f14ee7a90 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h @@ -0,0 +1,43 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_RESULT_H_ +#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_RESULT_H_ + +#include + +#include "mediapipe/framework/formats/image.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace image_segmenter { + +// The output result of ImageSegmenter +struct ImageSegmenterResult { + // Multiple masks of float image in VEC32F1 format where, for each mask, each + // pixel represents the prediction confidence, usually in the [0, 1] range. + std::optional> confidence_masks; + // A category mask of uint8 image in GRAY8 format where each pixel represents + // the class which the pixel in the original image was predicted to belong to. + std::optional category_mask; +}; + +} // namespace image_segmenter +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_RESULT_H_ diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index 1d75a3fb7..0c5a61486 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -36,6 +36,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" @@ -256,7 +257,6 @@ TEST(GetLabelsTest, SucceedsWithLabelsInModel) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -278,15 +278,15 @@ TEST_F(ImageModeTest, SucceedsWithCategoryMask) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; - + options->output_confidence_masks = false; + options->output_category_mask = true; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto category_masks, segmenter->Segment(image)); - EXPECT_EQ(category_masks.size(), 1); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_TRUE(result.category_mask.has_value()); cv::Mat actual_mask = mediapipe::formats::MatView( - category_masks[0].GetImageFrameSharedPtr().get()); + result.category_mask->GetImageFrameSharedPtr().get()); cv::Mat expected_mask = cv::imread( JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"), @@ -303,12 +303,11 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); - EXPECT_EQ(confidence_masks.size(), 21); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_EQ(result.confidence_masks->size(), 21); cv::Mat expected_mask = cv::imread( JoinPath("./", kTestDataDirectory, "cat_mask.jpg"), cv::IMREAD_GRAYSCALE); @@ -317,7 +316,7 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { // Cat category index 8. cv::Mat cat_mask = mediapipe::formats::MatView( - confidence_masks[8].GetImageFrameSharedPtr().get()); + result.confidence_masks->at(8).GetImageFrameSharedPtr().get()); EXPECT_THAT(cat_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -331,15 +330,14 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); ImageProcessingOptions image_processing_options; image_processing_options.rotation_degrees = -90; - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image, image_processing_options)); - EXPECT_EQ(confidence_masks.size(), 21); + EXPECT_EQ(result.confidence_masks->size(), 21); cv::Mat expected_mask = cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"), @@ -349,7 +347,7 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { // Cat category index 8. cv::Mat cat_mask = mediapipe::formats::MatView( - confidence_masks[8].GetImageFrameSharedPtr().get()); + result.confidence_masks->at(8).GetImageFrameSharedPtr().get()); EXPECT_THAT(cat_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -361,7 +359,6 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -384,12 +381,11 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); - EXPECT_EQ(confidence_masks.size(), 2); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_EQ(result.confidence_masks->size(), 2); cv::Mat expected_mask = cv::imread(JoinPath("./", kTestDataDirectory, @@ -400,7 +396,7 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { // Selfie category index 1. cv::Mat selfie_mask = mediapipe::formats::MatView( - confidence_masks[1].GetImageFrameSharedPtr().get()); + result.confidence_masks->at(1).GetImageFrameSharedPtr().get()); EXPECT_THAT(selfie_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -411,11 +407,10 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); - EXPECT_EQ(confidence_masks.size(), 1); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_EQ(result.confidence_masks->size(), 1); cv::Mat expected_mask = cv::imread(JoinPath("./", kTestDataDirectory, @@ -425,7 +420,7 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); cv::Mat selfie_mask = mediapipe::formats::MatView( - confidence_masks[0].GetImageFrameSharedPtr().get()); + result.confidence_masks->at(0).GetImageFrameSharedPtr().get()); EXPECT_THAT(selfie_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -436,12 +431,11 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfieSegmentation); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); - EXPECT_EQ(confidence_masks.size(), 1); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_EQ(result.confidence_masks->size(), 1); MP_ASSERT_OK(segmenter->Close()); cv::Mat expected_mask = cv::imread( @@ -452,7 +446,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) { expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); cv::Mat selfie_mask = mediapipe::formats::MatView( - confidence_masks[0].GetImageFrameSharedPtr().get()); + result.confidence_masks->at(0).GetImageFrameSharedPtr().get()); EXPECT_THAT(selfie_mask, SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } @@ -463,16 +457,15 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfieSegmentation); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; - + options->output_category_mask = true; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto category_mask, segmenter->Segment(image)); - EXPECT_EQ(category_mask.size(), 1); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_TRUE(result.category_mask.has_value()); MP_ASSERT_OK(segmenter->Close()); cv::Mat selfie_mask = mediapipe::formats::MatView( - category_mask[0].GetImageFrameSharedPtr().get()); + result.category_mask->GetImageFrameSharedPtr().get()); cv::Mat expected_mask = cv::imread( JoinPath("./", kTestDataDirectory, "portrait_selfie_segmentation_expected_category_mask.jpg"), @@ -487,16 +480,15 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfieSegmentationLandscape); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; - + options->output_category_mask = true; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto category_mask, segmenter->Segment(image)); - EXPECT_EQ(category_mask.size(), 1); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_TRUE(result.category_mask.has_value()); MP_ASSERT_OK(segmenter->Close()); cv::Mat selfie_mask = mediapipe::formats::MatView( - category_mask[0].GetImageFrameSharedPtr().get()); + result.category_mask->GetImageFrameSharedPtr().get()); cv::Mat expected_mask = cv::imread( JoinPath( "./", kTestDataDirectory, @@ -512,14 +504,13 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kHairSegmentationWithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); - EXPECT_EQ(confidence_masks.size(), 2); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image)); + EXPECT_EQ(result.confidence_masks->size(), 2); cv::Mat hair_mask = mediapipe::formats::MatView( - confidence_masks[1].GetImageFrameSharedPtr().get()); + result.confidence_masks->at(1).GetImageFrameSharedPtr().get()); MP_ASSERT_OK(segmenter->Close()); cv::Mat expected_mask = cv::imread( JoinPath("./", kTestDataDirectory, "portrait_hair_expected_mask.jpg"), @@ -540,7 +531,6 @@ TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; options->running_mode = core::RunningMode::VIDEO; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, @@ -572,7 +562,7 @@ TEST_F(VideoModeTest, Succeeds) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; + options->output_category_mask = true; options->running_mode = core::RunningMode::VIDEO; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -580,11 +570,10 @@ TEST_F(VideoModeTest, Succeeds) { JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"), cv::IMREAD_GRAYSCALE); for (int i = 0; i < iterations; ++i) { - MP_ASSERT_OK_AND_ASSIGN(auto category_masks, - segmenter->SegmentForVideo(image, i)); - EXPECT_EQ(category_masks.size(), 1); + MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->SegmentForVideo(image, i)); + EXPECT_TRUE(result.category_mask.has_value()); cv::Mat actual_mask = mediapipe::formats::MatView( - category_masks[0].GetImageFrameSharedPtr().get()); + result.category_mask->GetImageFrameSharedPtr().get()); EXPECT_THAT(actual_mask, SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, kGoldenMaskMagnificationFactor)); @@ -601,11 +590,10 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; options->running_mode = core::RunningMode::LIVE_STREAM; options->result_callback = - [](absl::StatusOr> segmented_masks, const Image& image, - int64 timestamp_ms) {}; + [](absl::StatusOr segmented_masks, + const Image& image, int64_t timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -634,11 +622,9 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; options->running_mode = core::RunningMode::LIVE_STREAM; - options->result_callback = - [](absl::StatusOr> segmented_masks, const Image& image, - int64 timestamp_ms) {}; + options->result_callback = [](absl::StatusOr result, + const Image& image, int64_t timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK(segmenter->SegmentAsync(image, 1)); @@ -660,23 +646,23 @@ TEST_F(LiveStreamModeTest, Succeeds) { Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "segmentation_input_rotation0.jpg"))); - std::vector> segmented_masks_results; + std::vector segmented_masks_results; std::vector> image_sizes; - std::vector timestamps; + std::vector timestamps; auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); - options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; + options->output_category_mask = true; options->running_mode = core::RunningMode::LIVE_STREAM; - options->result_callback = - [&segmented_masks_results, &image_sizes, ×tamps]( - absl::StatusOr> segmented_masks, - const Image& image, int64 timestamp_ms) { - MP_ASSERT_OK(segmented_masks.status()); - segmented_masks_results.push_back(std::move(segmented_masks).value()); - image_sizes.push_back({image.width(), image.height()}); - timestamps.push_back(timestamp_ms); - }; + options->result_callback = [&segmented_masks_results, &image_sizes, + ×tamps]( + absl::StatusOr result, + const Image& image, int64_t timestamp_ms) { + MP_ASSERT_OK(result.status()); + segmented_masks_results.push_back(std::move(*result->category_mask)); + image_sizes.push_back({image.width(), image.height()}); + timestamps.push_back(timestamp_ms); + }; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); for (int i = 0; i < iterations; ++i) { @@ -690,10 +676,9 @@ TEST_F(LiveStreamModeTest, Succeeds) { cv::Mat expected_mask = cv::imread( JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"), cv::IMREAD_GRAYSCALE); - for (const auto& segmented_masks : segmented_masks_results) { - EXPECT_EQ(segmented_masks.size(), 1); + for (const auto& category_mask : segmented_masks_results) { cv::Mat actual_mask = mediapipe::formats::MatView( - segmented_masks[0].GetImageFrameSharedPtr().get()); + category_mask.GetImageFrameSharedPtr().get()); EXPECT_THAT(actual_mask, SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, kGoldenMaskMagnificationFactor)); @@ -702,7 +687,7 @@ TEST_F(LiveStreamModeTest, Succeeds) { EXPECT_EQ(image_size.first, image.width()); EXPECT_EQ(image_size.second, image.height()); } - int64 timestamp_ms = -1; + int64_t timestamp_ms = -1; for (const auto& timestamp : timestamps) { EXPECT_GT(timestamp, timestamp_ms); timestamp_ms = timestamp; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto index be2b8a589..b1ec529d0 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto @@ -33,7 +33,7 @@ message SegmenterOptions { CONFIDENCE_MASK = 2; } // Optional output mask type. - optional OutputType output_type = 1 [default = CATEGORY_MASK]; + optional OutputType output_type = 1 [deprecated = true]; // Supported activation functions for filtering. enum Activation { diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/BUILD b/mediapipe/tasks/cc/vision/pose_landmarker/BUILD index 86b6c67b2..c5964c648 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/pose_landmarker/BUILD @@ -59,3 +59,44 @@ cc_library( "@com_google_absl//absl/status:statusor", ], ) + +cc_library( + name = "pose_landmarker_graph", + srcs = ["pose_landmarker_graph.cc"], + deps = [ + ":pose_landmarks_detector_graph", + "//mediapipe/calculators/core:clip_vector_size_calculator", + "//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto", + "//mediapipe/calculators/core:gate_calculator", + "//mediapipe/calculators/core:gate_calculator_cc_proto", + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/calculators/core:previous_loopback_calculator", + "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/calculators/util:association_calculator_cc_proto", + "//mediapipe/calculators/util:association_norm_rect_calculator", + "//mediapipe/calculators/util:collection_has_min_size_calculator", + "//mediapipe/calculators/util:collection_has_min_size_calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/utils:gate", + "//mediapipe/tasks/cc/core:model_asset_bundle_resources", + "//mediapipe/tasks/cc/core:model_resources_cache", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata/utils:zip_utils", + "//mediapipe/tasks/cc/vision/pose_detector:pose_detector_graph", + "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto", + "//mediapipe/util:graph_builder_utils", + "@com_google_absl//absl/strings:str_format", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc new file mode 100644 index 000000000..ae3a7482e --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc @@ -0,0 +1,384 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h" +#include "mediapipe/calculators/core/gate_calculator.pb.h" +#include "mediapipe/calculators/util/association_calculator.pb.h" +#include "mediapipe/calculators/util/collection_has_min_size_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/utils/gate.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" +#include "mediapipe/tasks/cc/core/model_resources_cache.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" +#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h" +#include "mediapipe/util/graph_builder_utils.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace pose_landmarker { + +namespace { + +using ::mediapipe::NormalizedRect; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::SidePacket; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::utils::DisallowIf; +using ::mediapipe::tasks::core::ModelAssetBundleResources; +using ::mediapipe::tasks::metadata::SetExternalFile; +using ::mediapipe::tasks::vision::pose_detector::proto:: + PoseDetectorGraphOptions; +using ::mediapipe::tasks::vision::pose_landmarker::proto:: + PoseLandmarkerGraphOptions; +using ::mediapipe::tasks::vision::pose_landmarker::proto:: + PoseLandmarksDetectorGraphOptions; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS"; +constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; +constexpr char kAuxiliaryLandmarksTag[] = "AUXILIARY_LANDMARKS"; +constexpr char kPoseRectsNextFrameTag[] = "POSE_RECTS_NEXT_FRAME"; +constexpr char kExpandedPoseRectsTag[] = "EXPANDED_POSE_RECTS"; +constexpr char kDetectionsTag[] = "DETECTIONS"; +constexpr char kLoopTag[] = "LOOP"; +constexpr char kPrevLoopTag[] = "PREV_LOOP"; +constexpr char kMainTag[] = "MAIN"; +constexpr char kIterableTag[] = "ITERABLE"; +constexpr char kSegmentationMaskTag[] = "SEGMENTATION_MASK"; + +constexpr char kPoseDetectorTFLiteName[] = "pose_detector.tflite"; +constexpr char kPoseLandmarksDetectorTFLiteName[] = + "pose_landmarks_detector.tflite"; + +struct PoseLandmarkerOutputs { + Source> landmark_lists; + Source> world_landmark_lists; + Source> auxiliary_landmark_lists; + Source> pose_rects_next_frame; + Source> pose_detections; + Source> segmentation_masks; + Source image; +}; + +// Sets the base options in the sub tasks. +absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, + PoseLandmarkerGraphOptions* options, + bool is_copy) { + auto* pose_detector_graph_options = + options->mutable_pose_detector_graph_options(); + if (!pose_detector_graph_options->base_options().has_model_asset()) { + ASSIGN_OR_RETURN(const auto pose_detector_file, + resources.GetFile(kPoseDetectorTFLiteName)); + SetExternalFile(pose_detector_file, + pose_detector_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + } + pose_detector_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + pose_detector_graph_options->mutable_base_options()->set_use_stream_mode( + options->base_options().use_stream_mode()); + auto* pose_landmarks_detector_graph_options = + options->mutable_pose_landmarks_detector_graph_options(); + if (!pose_landmarks_detector_graph_options->base_options() + .has_model_asset()) { + ASSIGN_OR_RETURN(const auto pose_landmarks_detector_file, + resources.GetFile(kPoseLandmarksDetectorTFLiteName)); + SetExternalFile( + pose_landmarks_detector_file, + pose_landmarks_detector_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + } + pose_landmarks_detector_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + pose_landmarks_detector_graph_options->mutable_base_options() + ->set_use_stream_mode(options->base_options().use_stream_mode()); + + return absl::OkStatus(); +} + +} // namespace + +// A "mediapipe.tasks.vision.pose_landmarker.PoseLandmarkerGraph" performs pose +// landmarks detection. The PoseLandmarkerGraph consists of two subgraphs: +// PoseDetectorGraph, MultiplePoseLandmarksDetectorGraph +// +// MultiplePoseLandmarksDetectorGraph detects landmarks from bounding boxes +// produced by PoseDetectorGraph. PoseLandmarkerGraph tracks the landmarks over +// time, and skips the PoseDetectorGraph. If the tracking is lost or the +// detected poses are less than configured max number poses, PoseDetectorGraph +// would be triggered to detect poses. +// +// +// Inputs: +// IMAGE - Image +// Image to perform pose landmarks detection on. +// NORM_RECT - NormalizedRect @Optional +// Describes image rotation and region of image to perform landmarks +// detection on. If not provided, whole image is used for pose landmarks +// detection. +// +// +// Outputs: +// NORM_LANDMARKS: - std::vector +// Vector of detected pose landmarks. +// WORLD_LANDMARKS: std::vector +// Vector of detected world pose landmarks. +// AUXILIARY_LANDMARKS: - std::vector +// Vector of detected auxiliary landmarks. +// POSE_RECTS_NEXT_FRAME - std::vector +// Vector of the expanded rects enclosing the whole pose RoI for landmark +// detection on the next frame. +// POSE_RECTS - std::vector +// Detected pose bounding boxes in normalized coordinates from pose +// detection. +// SEGMENTATION_MASK - std::vector +// Segmentation masks. +// IMAGE - Image +// The input image that the pose landmarker runs on and has the pixel data +// stored on the target storage (CPU vs GPU). +// All returned coordinates are in the unrotated and uncropped input image +// coordinates system. +// +// Example: +// node { +// calculator: "mediapipe.tasks.vision.pose_landmarker.PoseLandmarkerGraph" +// input_stream: "IMAGE:image_in" +// input_stream: "NORM_RECT:norm_rect" +// output_stream: "NORM_LANDMARKS:pose_landmarks" +// output_stream: "LANDMARKS:world_landmarks" +// output_stream: "NORM_LANDMAKRS:auxiliary_landmarks" +// output_stream: "POSE_RECTS_NEXT_FRAME:pose_rects_next_frame" +// output_stream: "POSE_RECTS:pose_rects" +// output_stream: "SEGMENTATION_MASK:segmentation_masks" +// output_stream: "IMAGE:image_out" +// options { +// [mediapipe.tasks.vision.pose_landmarker.proto.PoseLandmarkerGraphOptions.ext] +// { +// base_options { +// model_asset { +// file_name: "pose_landmarker.task" +// } +// } +// pose_detector_graph_options { +// min_detection_confidence: 0.5 +// num_poses: 2 +// } +// pose_landmarks_detector_graph_options { +// min_detection_confidence: 0.5 +// } +// } +// } +// } +class PoseLandmarkerGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + Graph graph; + if (sc->Options() + .base_options() + .has_model_asset()) { + ASSIGN_OR_RETURN( + const auto* model_asset_bundle_resources, + CreateModelAssetBundleResources(sc)); + // Copies the file content instead of passing the pointer of file in + // memory if the subgraph model resource service is not available. + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + *model_asset_bundle_resources, + sc->MutableOptions(), + !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) + .IsAvailable())); + } + ASSIGN_OR_RETURN( + auto outs, + BuildPoseLandmarkerGraph( + *sc->MutableOptions(), + graph[Input(kImageTag)], + graph[Input::Optional(kNormRectTag)], graph)); + outs.landmark_lists >> + graph[Output>(kNormLandmarksTag)]; + outs.world_landmark_lists >> + graph[Output>(kWorldLandmarksTag)]; + outs.auxiliary_landmark_lists >> + graph[Output>( + kAuxiliaryLandmarksTag)]; + outs.pose_rects_next_frame >> + graph[Output>(kPoseRectsNextFrameTag)]; + outs.segmentation_masks >> + graph[Output>(kSegmentationMaskTag)]; + outs.pose_detections >> + graph[Output>(kDetectionsTag)]; + outs.image >> graph[Output(kImageTag)]; + + // TODO remove when support is fixed. + // As mediapipe GraphBuilder currently doesn't support configuring + // InputStreamInfo, modifying the CalculatorGraphConfig proto directly. + CalculatorGraphConfig config = graph.GetConfig(); + for (int i = 0; i < config.node_size(); ++i) { + if (config.node(i).calculator() == "PreviousLoopbackCalculator") { + auto* info = config.mutable_node(i)->add_input_stream_info(); + info->set_tag_index(kLoopTag); + info->set_back_edge(true); + break; + } + } + + return config; + } + + private: + // Adds a mediapipe pose landmarker graph into the provided builder::Graph + // instance. + // + // tasks_options: the mediapipe tasks module PoseLandmarkerGraphOptions. + // image_in: (mediapipe::Image) stream to run pose landmark detection on. + // graph: the mediapipe graph instance to be updated. + absl::StatusOr BuildPoseLandmarkerGraph( + PoseLandmarkerGraphOptions& tasks_options, Source image_in, + Source norm_rect_in, Graph& graph) { + const int max_num_poses = + tasks_options.pose_detector_graph_options().num_poses(); + + auto& pose_detector = + graph.AddNode("mediapipe.tasks.vision.pose_detector.PoseDetectorGraph"); + pose_detector.GetOptions().Swap( + tasks_options.mutable_pose_detector_graph_options()); + auto& clip_pose_rects = + graph.AddNode("ClipNormalizedRectVectorSizeCalculator"); + clip_pose_rects.GetOptions() + .set_max_vec_size(max_num_poses); + auto clipped_pose_rects = clip_pose_rects.Out(""); + + auto& pose_landmarks_detector_graph = graph.AddNode( + "mediapipe.tasks.vision.pose_landmarker." + "MultiplePoseLandmarksDetectorGraph"); + pose_landmarks_detector_graph + .GetOptions() + .Swap(tasks_options.mutable_pose_landmarks_detector_graph_options()); + image_in >> pose_landmarks_detector_graph.In(kImageTag); + clipped_pose_rects >> pose_landmarks_detector_graph.In(kNormRectTag); + + // TODO: Add landmarks smoothing calculators to + // PoseLandmarkerGraph + auto landmarks = pose_landmarks_detector_graph.Out("LANDMARKS") + .Cast>(); + auto world_landmarks = pose_landmarks_detector_graph.Out(kWorldLandmarksTag) + .Cast>(); + auto aux_landmarks = + pose_landmarks_detector_graph.Out(kAuxiliaryLandmarksTag) + .Cast>(); + auto pose_rects_for_next_frame = + pose_landmarks_detector_graph.Out(kPoseRectsNextFrameTag) + .Cast>(); + auto segmentation_masks = + pose_landmarks_detector_graph.Out(kSegmentationMaskTag) + .Cast>(); + + if (tasks_options.base_options().use_stream_mode()) { + auto& previous_loopback = graph.AddNode("PreviousLoopbackCalculator"); + image_in >> previous_loopback.In(kMainTag); + auto prev_pose_rects_from_landmarks = + previous_loopback[Output>(kPrevLoopTag)]; + + auto& min_size_node = + graph.AddNode("NormalizedRectVectorHasMinSizeCalculator"); + prev_pose_rects_from_landmarks >> min_size_node.In(kIterableTag); + min_size_node.GetOptions() + .set_min_size(max_num_poses); + auto has_enough_poses = min_size_node.Out("").Cast(); + + // While in stream mode, skip pose detector graph when we successfully + // track the poses from the last frame. + auto image_for_pose_detector = + DisallowIf(image_in, has_enough_poses, graph); + auto norm_rect_in_for_pose_detector = + DisallowIf(norm_rect_in, has_enough_poses, graph); + image_for_pose_detector >> pose_detector.In(kImageTag); + norm_rect_in_for_pose_detector >> pose_detector.In(kNormRectTag); + auto expanded_pose_rects_from_pose_detector = + pose_detector.Out(kExpandedPoseRectsTag); + auto& pose_association = graph.AddNode("AssociationNormRectCalculator"); + pose_association.GetOptions() + .set_min_similarity_threshold( + tasks_options.min_tracking_confidence()); + prev_pose_rects_from_landmarks >> + pose_association[Input>::Multiple("")][0]; + expanded_pose_rects_from_pose_detector >> + pose_association[Input>::Multiple("")][1]; + auto pose_rects = pose_association.Out(""); + pose_rects >> clip_pose_rects.In(""); + // Back edge. + pose_rects_for_next_frame >> previous_loopback.In(kLoopTag); + } else { + // While not in stream mode, the input images are not guaranteed to be in + // series, and we don't want to enable the tracking and rect associations + // between input images. Always use the pose detector graph. + image_in >> pose_detector.In(kImageTag); + norm_rect_in >> pose_detector.In(kNormRectTag); + auto pose_rects = pose_detector.Out(kExpandedPoseRectsTag); + pose_rects >> clip_pose_rects.In(""); + } + + // TODO: Replace PassThroughCalculator with a calculator that + // converts the pixel data to be stored on the target storage (CPU vs GPU). + auto& pass_through = graph.AddNode("PassThroughCalculator"); + image_in >> pass_through.In(""); + + return {{ + /* landmark_lists= */ landmarks, + /* world_landmarks= */ world_landmarks, + /* aux_landmarks= */ aux_landmarks, + /* pose_rects_next_frame= */ pose_rects_for_next_frame, + /* pose_detections */ + pose_detector.Out(kDetectionsTag).Cast>(), + /* segmentation_masks= */ segmentation_masks, + /* image= */ + pass_through[Output("")], + }}; + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::pose_landmarker::PoseLandmarkerGraph); + +} // namespace pose_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph_test.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph_test.cc new file mode 100644 index 000000000..6a4cc93b6 --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph_test.cc @@ -0,0 +1,190 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace pose_landmarker { +namespace { + +using ::file::Defaults; +using ::file::GetTextProto; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::core::TaskRunner; +using ::mediapipe::tasks::vision::DecodeImageFromFile; +using ::mediapipe::tasks::vision::pose_landmarker::proto:: + PoseLandmarkerGraphOptions; +using ::testing::EqualsProto; +using ::testing::Pointwise; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kPoseLandmarkerModelBundleName[] = "pose_landmarker.task"; +constexpr char kPoseImageName[] = "pose.jpg"; +constexpr char kExpectedPoseLandmarksName[] = + "expected_pose_landmarks.prototxt"; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageName[] = "image"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kNormRectName[] = "norm_rect"; +constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS"; +constexpr char kNormLandmarksName[] = "norm_landmarks"; + +constexpr float kLiteModelFractionDiff = 0.05; // percentage + +template +ProtoT GetExpectedProto(absl::string_view filename) { + ProtoT expected_proto; + MP_EXPECT_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, filename), + &expected_proto, Defaults())); + return expected_proto; +} + +// Struct holding the parameters for parameterized PoseLandmarkerGraphTest +// class. +struct PoseLandmarkerGraphTestParams { + // The name of this test, for convenience when displaying test results. + std::string test_name; + // The filename of the model to test. + std::string input_model_name; + // The filename of the test image. + std::string test_image_name; + // The expected output landmarks positions. + std::optional> expected_landmarks_list; + // The max value difference between expected_positions and detected positions. + float landmarks_diff_threshold; +}; + +// Helper function to create a PoseLandmarkerGraph TaskRunner. +absl::StatusOr> CreatePoseLandmarkerGraphTaskRunner( + absl::string_view model_name) { + Graph graph; + + auto& pose_landmarker = graph.AddNode( + "mediapipe.tasks.vision.pose_landmarker." + "PoseLandmarkerGraph"); + + auto* options = &pose_landmarker.GetOptions(); + options->mutable_base_options()->mutable_model_asset()->set_file_name( + JoinPath("./", kTestDataDirectory, model_name)); + options->mutable_pose_detector_graph_options()->set_num_poses(1); + options->mutable_base_options()->set_use_stream_mode(true); + + graph[Input(kImageTag)].SetName(kImageName) >> + pose_landmarker.In(kImageTag); + graph[Input(kNormRectTag)].SetName(kNormRectName) >> + pose_landmarker.In(kNormRectTag); + + pose_landmarker.Out(kNormLandmarksTag).SetName(kNormLandmarksName) >> + graph[Output>(kNormLandmarksTag)]; + + return TaskRunner::Create( + graph.GetConfig(), + absl::make_unique()); +} + +// Helper function to construct NormalizeRect proto. +NormalizedRect MakeNormRect(float x_center, float y_center, float width, + float height, float rotation) { + NormalizedRect pose_rect; + pose_rect.set_x_center(x_center); + pose_rect.set_y_center(y_center); + pose_rect.set_width(width); + pose_rect.set_height(height); + pose_rect.set_rotation(rotation); + return pose_rect; +} + +class PoseLandmarkerGraphTest + : public testing::TestWithParam {}; + +TEST_P(PoseLandmarkerGraphTest, Succeeds) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + GetParam().test_image_name))); + MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreatePoseLandmarkerGraphTaskRunner( + GetParam().input_model_name)); + + auto output_packets = task_runner->Process( + {{kImageName, MakePacket(std::move(image))}, + {kNormRectName, + MakePacket(MakeNormRect(0.5, 0.5, 1.0, 1.0, 0))}}); + MP_ASSERT_OK(output_packets); + + if (GetParam().expected_landmarks_list) { + const std::vector& landmarks_lists = + (*output_packets)[kNormLandmarksName] + .Get>(); + EXPECT_THAT(landmarks_lists, + Pointwise(Approximately(Partially(EqualsProto()), + GetParam().landmarks_diff_threshold), + *GetParam().expected_landmarks_list)); + } +} + +INSTANTIATE_TEST_SUITE_P( + PoseLandmarkerGraphTests, PoseLandmarkerGraphTest, + Values(PoseLandmarkerGraphTestParams{ + /* test_name= */ "PoseLandmarkerLite", + /* input_model_name= */ kPoseLandmarkerModelBundleName, + /* test_image_name= */ kPoseImageName, + /* expected_landmarks_list= */ + {{GetExpectedProto( + kExpectedPoseLandmarksName)}}, + /* landmarks_diff_threshold= */ kLiteModelFractionDiff}), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace pose_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/proto/BUILD b/mediapipe/tasks/cc/vision/pose_landmarker/proto/BUILD index 6a95762c0..a2ad7b0b1 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/proto/BUILD +++ b/mediapipe/tasks/cc/vision/pose_landmarker/proto/BUILD @@ -29,3 +29,15 @@ mediapipe_proto_library( "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) + +mediapipe_proto_library( + name = "pose_landmarker_graph_options_proto", + srcs = ["pose_landmarker_graph_options.proto"], + deps = [ + ":pose_landmarks_detector_graph_options_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_proto", + ], +) diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.proto b/mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.proto new file mode 100644 index 000000000..bde314bad --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.proto @@ -0,0 +1,48 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.vision.pose_landmarker.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; +import "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.proto"; +import "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.poselandmarker.proto"; +option java_outer_classname = "PoseLandmarkerGraphOptionsProto"; + +message PoseLandmarkerGraphOptions { + extend mediapipe.CalculatorOptions { + optional PoseLandmarkerGraphOptions ext = 516587230; + } + // Base options for configuring Task library, such as specifying the TfLite + // model file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; + + // Options for pose detector graph. + optional pose_detector.proto.PoseDetectorGraphOptions + pose_detector_graph_options = 2; + + // Options for pose landmarks detector graph. + optional PoseLandmarksDetectorGraphOptions + pose_landmarks_detector_graph_options = 3; + + // Minimum confidence for pose landmarks tracking to be considered + // successfully. + optional float min_tracking_confidence = 4 [default = 0.5]; +} diff --git a/mediapipe/tasks/cc/vision/utils/image_tensor_specs.cc b/mediapipe/tasks/cc/vision/utils/image_tensor_specs.cc index 4d3d2cb96..3f0425a69 100644 --- a/mediapipe/tasks/cc/vision/utils/image_tensor_specs.cc +++ b/mediapipe/tasks/cc/vision/utils/image_tensor_specs.cc @@ -191,8 +191,9 @@ absl::StatusOr BuildInputImageTensorSpecs( MediaPipeTasksStatus::kInvalidInputTensorDimensionsError); } - size_t byte_depth = - tensor_type == tflite::TensorType_FLOAT32 ? sizeof(float) : sizeof(uint8); + size_t byte_depth = tensor_type == tflite::TensorType_FLOAT32 + ? sizeof(float) + : sizeof(uint8_t); int bytes_size = byte_depth * batch * height * width * depth; // Sanity checks. if (tensor_type == tflite::TensorType_FLOAT32) { diff --git a/mediapipe/tasks/ios/common/utils/BUILD b/mediapipe/tasks/ios/common/utils/BUILD index a29c700da..da093a775 100644 --- a/mediapipe/tasks/ios/common/utils/BUILD +++ b/mediapipe/tasks/ios/common/utils/BUILD @@ -24,7 +24,6 @@ objc_library( "//mediapipe/tasks/cc:common", "//mediapipe/tasks/ios/common:MPPCommon", "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", ], ) diff --git a/mediapipe/tasks/ios/components/containers/BUILD b/mediapipe/tasks/ios/components/containers/BUILD index 06df9576a..3ed7669cb 100644 --- a/mediapipe/tasks/ios/components/containers/BUILD +++ b/mediapipe/tasks/ios/components/containers/BUILD @@ -49,5 +49,8 @@ objc_library( name = "MPPDetection", srcs = ["sources/MPPDetection.m"], hdrs = ["sources/MPPDetection.h"], - deps = [":MPPCategory"], + deps = [ + ":MPPCategory", + "//third_party/apple_frameworks:UIKit", + ], ) diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPDetection.h b/mediapipe/tasks/ios/components/containers/sources/MPPDetection.h index cc7c2ebeb..e085007a6 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPDetection.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPDetection.h @@ -13,6 +13,7 @@ // limitations under the License. #import +#import #import "mediapipe/tasks/ios/components/containers/sources/MPPCategory.h" NS_ASSUME_NONNULL_BEGIN @@ -65,10 +66,12 @@ NS_SWIFT_NAME(Detection) /** The bounding box of the detected object. */ @property(nonatomic, readonly) CGRect boundingBox; -/** An optional array of `MPPNormalizedKeypoint` objects associated with the detection. Keypoints +/** + * An optional array of `MPPNormalizedKeypoint` objects associated with the detection. Keypoints * represent interesting points related to the detection. For example, the keypoints represent the - * eyes, ear and mouth from face detection model. Or in the template matching detection, e.g. KNIFT, - * they can represent the feature points for template matching. */ + * eyes, ear and mouth from the from detection model. In template matching detection, e.g. KNIFT, + * they can instead represent the feature points for template matching. + */ @property(nonatomic, readonly, nullable) NSArray *keypoints; /** @@ -80,8 +83,8 @@ NS_SWIFT_NAME(Detection) * @param boundingBox A `CGRect` that represents the bounding box. * @param keypoints: An optional array of `MPPNormalizedKeypoint` objects associated with the * detection. Keypoints represent interesting points related to the detection. For example, the - * keypoints represent the eyes, ear and mouth from face detection model. Or in the template - * matching detection, e.g. KNIFT, they can represent the feature points for template matching. + * keypoints represent the eyes, ear and mouth from the face detection model. In template matching + * detection, e.g. KNIFT, they can instead represent the feature points for template matching. * * @return An instance of `MPPDetection` initialized with the given array of categories, bounding * box and `nil` keypoints. diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPDetection.m b/mediapipe/tasks/ios/components/containers/sources/MPPDetection.m index 42259ffde..c245478db 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPDetection.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPDetection.m @@ -45,10 +45,8 @@ MPPNormalizedKeypoint *otherKeypoint = (MPPNormalizedKeypoint *)object; - if (CGPointEqualToPoint(self.location, otherKeypoint.location) && - (self.label == otherKeypoint.label) && (self.score == otherKeypoint.score)) { - return YES; - } + return CGPointEqualToPoint(self.location, otherKeypoint.location) && + (self.label == otherKeypoint.label) && (self.score == otherKeypoint.score); } @end @@ -67,4 +65,4 @@ return self; } -@end \ No newline at end of file +@end diff --git a/mediapipe/tasks/ios/components/containers/utils/BUILD b/mediapipe/tasks/ios/components/containers/utils/BUILD index 3520740b0..3f93c0f36 100644 --- a/mediapipe/tasks/ios/components/containers/utils/BUILD +++ b/mediapipe/tasks/ios/components/containers/utils/BUILD @@ -68,6 +68,7 @@ objc_library( hdrs = ["sources/MPPDetection+Helpers.h"], deps = [ "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", "//mediapipe/tasks/ios/common/utils:NSStringHelpers", "//mediapipe/tasks/ios/components/containers:MPPDetection", ], diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.mm index e5cc8dc03..f485b9110 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.mm +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.mm @@ -13,6 +13,7 @@ // limitations under the License. #import "mediapipe/tasks/ios/components/containers/utils/sources/MPPDetection+Helpers.h" +#import "mediapipe/framework/formats/location_data.pb.h" #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" diff --git a/mediapipe/tasks/ios/test/vision/utils/BUILD b/mediapipe/tasks/ios/test/vision/utils/BUILD index cf7626dbd..c5acf7a46 100644 --- a/mediapipe/tasks/ios/test/vision/utils/BUILD +++ b/mediapipe/tasks/ios/test/vision/utils/BUILD @@ -7,7 +7,5 @@ objc_library( srcs = ["sources/MPPImage+TestUtils.m"], hdrs = ["sources/MPPImage+TestUtils.h"], module_name = "MPPImageTestUtils", - deps = [ - "//mediapipe/tasks/ios/vision/core:MPPImage", - ], + deps = ["//mediapipe/tasks/ios/vision/core:MPPImage"], ) diff --git a/mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h b/mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h index 9dfe29fd3..8cd1c6a67 100644 --- a/mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h +++ b/mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h @@ -29,7 +29,7 @@ NS_ASSUME_NONNULL_BEGIN * @param classObject The specified class associated with the bundle containing the file to be * loaded. * @param name Name of the image file. - * @param type Extenstion of the image file. + * @param type Extension of the image file. * * @return The `MPPImage` object contains the loaded image. This method returns * nil if it cannot load the image. @@ -46,7 +46,7 @@ NS_ASSUME_NONNULL_BEGIN * @param classObject The specified class associated with the bundle containing the file to be * loaded. * @param name Name of the image file. - * @param type Extenstion of the image file. + * @param type Extension of the image file. * @param orientation Orientation of the image. * * @return The `MPPImage` object contains the loaded image. This method returns diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h index 6e4921efc..590867bf8 100644 --- a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h @@ -22,9 +22,11 @@ NS_ASSUME_NONNULL_BEGIN NS_SWIFT_NAME(ObjectDetectionResult) @interface MPPObjectDetectionResult : MPPTaskResult -/** The array of `MPPDetection` objects each of which has a bounding box that is expressed in the +/** + * The array of `MPPDetection` objects each of which has a bounding box that is expressed in the * unrotated input frame of reference coordinates system, i.e. in `[0,image_width) x - * [0,image_height)`, which are the dimensions of the underlying image data. */ + * [0,image_height)`, which are the dimensions of the underlying image data. + */ @property(nonatomic, readonly) NSArray *detections; /** diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java index 49c459ef1..c330b1a56 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java @@ -33,7 +33,7 @@ public class OutputHandler { /** * Interface for the customizable MediaPipe task result listener that can reteive both task result - * objects and the correpsonding input data. + * objects and the corresponding input data. */ public interface ResultListener { void run(OutputT result, InputT input); @@ -90,8 +90,8 @@ public class OutputHandler { } /** - * Sets whether the output handler should react to the timestamp bound changes that are reprsented - * as empty output {@link Packet}s. + * Sets whether the output handler should react to the timestamp bound changes that are + * represented as empty output {@link Packet}s. * * @param handleTimestampBoundChanges A boolean value. */ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java index 310f5739c..31af80f5c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java @@ -24,7 +24,7 @@ import java.util.ArrayList; import java.util.List; /** - * {@link TaskInfo} contains all needed informaton to initialize a MediaPipe Task {@link + * {@link TaskInfo} contains all needed information to initialize a MediaPipe Task {@link * com.google.mediapipe.framework.Graph}. */ @AutoValue diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java index 51735ff76..155536a4e 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java @@ -94,7 +94,7 @@ public class TaskRunner implements AutoCloseable { * *

Note: This method is designed for processing batch data such as unrelated images and texts. * The call blocks the current thread until a failure status or a successful result is returned. - * An internal timestamp will be assigend per invocation. This method is thread-safe and allows + * An internal timestamp will be assigned per invocation. This method is thread-safe and allows * clients to call it from different threads. * * @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs. diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java index bafa40e19..7054856fc 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java @@ -108,7 +108,7 @@ public abstract class FaceLandmarkerResult implements TaskResult { public abstract Optional>> faceBlendshapes(); /** - * Optional facial transformation matrix list from cannonical face to the detected face landmarks. + * Optional facial transformation matrix list from canonical face to the detected face landmarks. * The 4x4 facial transformation matrix is represetned as a flat column-major float array. */ public abstract Optional> facialTransformationMatrixes(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java index 9a52d114d..a6e246f1d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java @@ -254,7 +254,7 @@ public final class FaceStylizer extends BaseVisionTaskApi { * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a * region-of-interest. * @throws MediaPipeException if there is an internal error. Or if {@link FaceStylizer} is not - * created wtih {@link ResultListener} set in {@link FaceStylizerOptions}. + * created with {@link ResultListener} set in {@link FaceStylizerOptions}. */ public void stylizeWithResultListener(MPImage image) { stylizeWithResultListener(image, ImageProcessingOptions.builder().build()); @@ -283,7 +283,7 @@ public final class FaceStylizer extends BaseVisionTaskApi { * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a * region-of-interest. * @throws MediaPipeException if there is an internal error. Or if {@link FaceStylizer} is not - * created wtih {@link ResultListener} set in {@link FaceStylizerOptions}. + * created with {@link ResultListener} set in {@link FaceStylizerOptions}. */ public void stylizeWithResultListener( MPImage image, ImageProcessingOptions imageProcessingOptions) { @@ -384,7 +384,7 @@ public final class FaceStylizer extends BaseVisionTaskApi { * @param image a MediaPipe {@link MPImage} object for processing. * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. Or if {@link FaceStylizer} is not - * created wtih {@link ResultListener} set in {@link FaceStylizerOptions}. + * created with {@link ResultListener} set in {@link FaceStylizerOptions}. */ public void stylizeForVideoWithResultListener(MPImage image, long timestampMs) { stylizeForVideoWithResultListener(image, ImageProcessingOptions.builder().build(), timestampMs); @@ -411,7 +411,7 @@ public final class FaceStylizer extends BaseVisionTaskApi { * @param image a MediaPipe {@link MPImage} object for processing. * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. Or if {@link FaceStylizer} is not - * created wtih {@link ResultListener} set in {@link FaceStylizerOptions}. + * created with {@link ResultListener} set in {@link FaceStylizerOptions}. */ public void stylizeForVideoWithResultListener( MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java index a933d2f65..5b2d7191f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java @@ -403,10 +403,10 @@ public final class GestureRecognizer extends BaseVisionTaskApi { public abstract Builder setMinTrackingConfidence(Float value); /** - * Sets the optional {@link ClassifierOptions} controling the canned gestures classifier, such - * as score threshold, allow list and deny list of gestures. The categories for canned gesture - * classifiers are: ["None", "Closed_Fist", "Open_Palm", "Pointing_Up", "Thumb_Down", - * "Thumb_Up", "Victory", "ILoveYou"] + * Sets the optional {@link ClassifierOptions} controlling the canned gestures classifier, + * such as score threshold, allow list and deny list of gestures. The categories + * for canned gesture classifiers are: ["None", "Closed_Fist", "Open_Palm", + * "Pointing_Up", "Thumb_Down", "Thumb_Up", "Victory", "ILoveYou"] * *

TODO Note this option is subject to change, after scoring merging * calculator is implemented. @@ -415,8 +415,8 @@ public final class GestureRecognizer extends BaseVisionTaskApi { ClassifierOptions classifierOptions); /** - * Sets the optional {@link ClassifierOptions} controling the custom gestures classifier, such - * as score threshold, allow list and deny list of gestures. + * Sets the optional {@link ClassifierOptions} controlling the custom gestures classifier, + * such as score threshold, allow list and deny list of gestures. * *

TODO Note this option is subject to change, after scoring merging * calculator is implemented. diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java index f1a08d425..b809ab963 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java @@ -302,7 +302,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi { * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a * region-of-interest. * @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not - * created wtih {@link ResultListener} set in {@link ImageSegmenterOptions}. + * created with {@link ResultListener} set in {@link ImageSegmenterOptions}. */ public void segmentWithResultListener(MPImage image) { segmentWithResultListener(image, ImageProcessingOptions.builder().build()); @@ -329,7 +329,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi { * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a * region-of-interest. * @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not - * created wtih {@link ResultListener} set in {@link ImageSegmenterOptions}. + * created with {@link ResultListener} set in {@link ImageSegmenterOptions}. */ public void segmentWithResultListener( MPImage image, ImageProcessingOptions imageProcessingOptions) { @@ -421,7 +421,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi { * @param image a MediaPipe {@link MPImage} object for processing. * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not - * created wtih {@link ResultListener} set in {@link ImageSegmenterOptions}. + * created with {@link ResultListener} set in {@link ImageSegmenterOptions}. */ public void segmentForVideoWithResultListener(MPImage image, long timestampMs) { segmentForVideoWithResultListener(image, ImageProcessingOptions.builder().build(), timestampMs); @@ -444,7 +444,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi { * @param image a MediaPipe {@link MPImage} object for processing. * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not - * created wtih {@link ResultListener} set in {@link ImageSegmenterOptions}. + * created with {@link ResultListener} set in {@link ImageSegmenterOptions}. */ public void segmentForVideoWithResultListener( MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java index 8ee6951f8..657716b6b 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java @@ -327,7 +327,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a * region-of-interest. * @throws MediaPipeException if there is an internal error. Or if {@link InteractiveSegmenter} is - * not created wtih {@link ResultListener} set in {@link InteractiveSegmenterOptions}. + * not created with {@link ResultListener} set in {@link InteractiveSegmenterOptions}. */ public void segmentWithResultListener(MPImage image, RegionOfInterest roi) { segmentWithResultListener(image, roi, ImageProcessingOptions.builder().build()); @@ -357,7 +357,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a * region-of-interest. * @throws MediaPipeException if there is an internal error. Or if {@link InteractiveSegmenter} is - * not created wtih {@link ResultListener} set in {@link InteractiveSegmenterOptions}. + * not created with {@link ResultListener} set in {@link InteractiveSegmenterOptions}. */ public void segmentWithResultListener( MPImage image, RegionOfInterest roi, ImageProcessingOptions imageProcessingOptions) { diff --git a/mediapipe/tasks/metadata/metadata_schema.fbs b/mediapipe/tasks/metadata/metadata_schema.fbs index 8fe7a08fa..8660ba38c 100644 --- a/mediapipe/tasks/metadata/metadata_schema.fbs +++ b/mediapipe/tasks/metadata/metadata_schema.fbs @@ -142,7 +142,7 @@ enum AssociatedFileType : byte { // TODO: introduce the ScaNN index file with links once the code // is released. - // Contains on-devide ScaNN index file with LevelDB format. + // Contains on-device ScaNN index file with LevelDB format. // Added in: 1.4.0 SCANN_INDEX_FILE = 6, } diff --git a/mediapipe/tasks/python/audio/core/BUILD b/mediapipe/tasks/python/audio/core/BUILD index 5b4203d7b..ad22faa98 100644 --- a/mediapipe/tasks/python/audio/core/BUILD +++ b/mediapipe/tasks/python/audio/core/BUILD @@ -23,12 +23,18 @@ py_library( srcs = ["audio_task_running_mode.py"], ) +py_library( + name = "audio_record", + srcs = ["audio_record.py"], +) + py_library( name = "base_audio_task_api", srcs = [ "base_audio_task_api.py", ], deps = [ + ":audio_record", ":audio_task_running_mode", "//mediapipe/framework:calculator_py_pb2", "//mediapipe/python:_framework_bindings", diff --git a/mediapipe/tasks/python/audio/core/audio_record.py b/mediapipe/tasks/python/audio/core/audio_record.py new file mode 100644 index 000000000..91e394584 --- /dev/null +++ b/mediapipe/tasks/python/audio/core/audio_record.py @@ -0,0 +1,125 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A module to record audio in a streaming basis.""" +import threading +import numpy as np + +try: + import sounddevice as sd +except OSError as oe: + sd = None + sd_error = oe +except ImportError as ie: + sd = None + sd_error = ie + + +class AudioRecord(object): + """A class to record audio in a streaming basis.""" + + def __init__( + self, channels: int, sampling_rate: int, buffer_size: int + ) -> None: + """Creates an AudioRecord instance. + + Args: + channels: Number of input channels. + sampling_rate: Sampling rate in Hertz. + buffer_size: Size of the ring buffer in number of samples. + + Raises: + ValueError: if any of the arguments is non-positive. + ImportError: if failed to import `sounddevice`. + OSError: if failed to load `PortAudio`. + """ + if sd is None: + raise sd_error + + if channels <= 0: + raise ValueError('channels must be postive.') + if sampling_rate <= 0: + raise ValueError('sampling_rate must be postive.') + if buffer_size <= 0: + raise ValueError('buffer_size must be postive.') + + self._audio_buffer = [] + self._buffer_size = buffer_size + self._channels = channels + self._sampling_rate = sampling_rate + + # Create a ring buffer to store the input audio. + self._buffer = np.zeros([buffer_size, channels], dtype=float) + self._lock = threading.Lock() + + def audio_callback(data, *_): + """A callback to receive recorded audio data from sounddevice.""" + self._lock.acquire() + shift = len(data) + if shift > buffer_size: + self._buffer = np.copy(data[:buffer_size]) + else: + self._buffer = np.roll(self._buffer, -shift, axis=0) + self._buffer[-shift:, :] = np.copy(data) + self._lock.release() + + # Create an input stream to continuously capture the audio data. + self._stream = sd.InputStream( + channels=channels, + samplerate=sampling_rate, + callback=audio_callback, + ) + + @property + def channels(self) -> int: + return self._channels + + @property + def sampling_rate(self) -> int: + return self._sampling_rate + + @property + def buffer_size(self) -> int: + return self._buffer_size + + def start_recording(self) -> None: + """Starts the audio recording.""" + # Clear the internal ring buffer. + self._buffer.fill(0) + + # Start recording using sounddevice's InputStream. + self._stream.start() + + def stop(self) -> None: + """Stops the audio recording.""" + self._stream.stop() + + def read(self, size: int) -> np.ndarray: + """Reads the latest audio data captured in the buffer. + + Args: + size: Number of samples to read from the buffer. + + Returns: + A NumPy array containing the audio data. + + Raises: + ValueError: Raised if `size` is larger than the buffer size. + """ + if size > self._buffer_size: + raise ValueError('Cannot read more samples than the size of the buffer.') + elif size <= 0: + raise ValueError('Size must be positive.') + + start_index = self._buffer_size - size + return np.copy(self._buffer[start_index:]) diff --git a/mediapipe/tasks/python/audio/core/base_audio_task_api.py b/mediapipe/tasks/python/audio/core/base_audio_task_api.py index 5b08a2b76..be8ff9324 100644 --- a/mediapipe/tasks/python/audio/core/base_audio_task_api.py +++ b/mediapipe/tasks/python/audio/core/base_audio_task_api.py @@ -20,6 +20,7 @@ from mediapipe.python import packet_creator from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import task_runner as task_runner_module from mediapipe.python._framework_bindings import timestamp as timestamp_module +from mediapipe.tasks.python.audio.core import audio_record from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -83,12 +84,15 @@ class BaseAudioTaskApi(object): """ if self._running_mode != _RunningMode.AUDIO_CLIPS: raise ValueError( - 'Task is not initialized with the audio clips mode. Current running mode:' - + self._running_mode.name) + 'Task is not initialized with the audio clips mode. Current running' + ' mode:' + + self._running_mode.name + ) return self._runner.process(inputs) - def _set_sample_rate(self, sample_rate_stream_name: str, - sample_rate: float) -> None: + def _set_sample_rate( + self, sample_rate_stream_name: str, sample_rate: float + ) -> None: """An asynchronous method to set audio sample rate in the audio stream mode. Args: @@ -122,10 +126,40 @@ class BaseAudioTaskApi(object): """ if self._running_mode != _RunningMode.AUDIO_STREAM: raise ValueError( - 'Task is not initialized with the audio stream mode. Current running mode:' - + self._running_mode.name) + 'Task is not initialized with the audio stream mode. Current running' + ' mode:' + + self._running_mode.name + ) self._runner.send(inputs) + def create_audio_record( + self, num_channels: int, sample_rate: int, required_input_buffer_size: int + ) -> audio_record.AudioRecord: + """Creates an AudioRecord instance to record audio stream. + + The returned AudioRecord instance is initialized and client needs to call + the appropriate method to start recording. + + Note that MediaPipe Audio tasks will up/down sample automatically to fit the + sample rate required by the model. The default sample rate of the MediaPipe + pretrained audio model, Yamnet is 16kHz. + + Args: + num_channels: The number of audio channels. + sample_rate: The audio sample rate. + required_input_buffer_size: The required input buffer size in number of + float elements. + + Returns: + An AudioRecord instance. + + Raises: + ValueError: If there's a problem creating the AudioRecord instance. + """ + return audio_record.AudioRecord( + num_channels, sample_rate, required_input_buffer_size + ) + def close(self) -> None: """Shuts down the mediapipe audio task instance. diff --git a/mediapipe/tasks/python/metadata/metadata.py b/mediapipe/tasks/python/metadata/metadata.py index 25d83cae8..6a107c8d8 100644 --- a/mediapipe/tasks/python/metadata/metadata.py +++ b/mediapipe/tasks/python/metadata/metadata.py @@ -121,7 +121,7 @@ class MetadataPopulator(object): Then, pack the metadata and label file into the model as follows. ```python - # Populating a metadata file (or a metadta buffer) and associated files to + # Populating a metadata file (or a metadata buffer) and associated files to a model file: populator = MetadataPopulator.with_model_file(model_file) # For metadata buffer (bytearray read from the metadata file), use: @@ -332,7 +332,7 @@ class MetadataPopulator(object): Raises: IOError: File not found. ValueError: The metadata to be populated is empty. - ValueError: The metadata does not have the expected flatbuffer identifer. + ValueError: The metadata does not have the expected flatbuffer identifier. ValueError: Cannot get minimum metadata parser version. ValueError: The number of SubgraphMetadata is not 1. ValueError: The number of input/output tensors does not match the number diff --git a/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py b/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py index f201ab7e0..10b66ff18 100644 --- a/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py +++ b/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py @@ -559,7 +559,7 @@ class InputTextTensorMd(TensorMd): name: name of the tensor. description: description of what the tensor is. tokenizer_md: information of the tokenizer in the input text tensor, if - any. Only `RegexTokenizer` [1] is currenly supported. If the tokenizer + any. Only `RegexTokenizer` [1] is currently supported. If the tokenizer is `BertTokenizer` [2] or `SentencePieceTokenizer` [3], refer to `BertInputTensorsMd` class. [1]: diff --git a/mediapipe/tasks/python/test/audio/BUILD b/mediapipe/tasks/python/test/audio/BUILD index 43f1d417c..3df783180 100644 --- a/mediapipe/tasks/python/test/audio/BUILD +++ b/mediapipe/tasks/python/test/audio/BUILD @@ -27,6 +27,7 @@ py_test( ], deps = [ "//mediapipe/tasks/python/audio:audio_classifier", + "//mediapipe/tasks/python/audio/core:audio_record", "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/components/containers:audio_data", "//mediapipe/tasks/python/components/containers:classification_result", @@ -44,9 +45,9 @@ py_test( ], deps = [ "//mediapipe/tasks/python/audio:audio_embedder", + "//mediapipe/tasks/python/audio/core:audio_record", "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/components/containers:audio_data", - "//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", ], diff --git a/mediapipe/tasks/python/test/audio/audio_classifier_test.py b/mediapipe/tasks/python/test/audio/audio_classifier_test.py index 75146547c..fbd96ad3e 100644 --- a/mediapipe/tasks/python/test/audio/audio_classifier_test.py +++ b/mediapipe/tasks/python/test/audio/audio_classifier_test.py @@ -19,11 +19,11 @@ from unittest import mock from absl.testing import absltest from absl.testing import parameterized - import numpy as np from scipy.io import wavfile from mediapipe.tasks.python.audio import audio_classifier +from mediapipe.tasks.python.audio.core import audio_record from mediapipe.tasks.python.audio.core import audio_task_running_mode from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module @@ -34,6 +34,7 @@ _AudioClassifier = audio_classifier.AudioClassifier _AudioClassifierOptions = audio_classifier.AudioClassifierOptions _AudioClassifierResult = classification_result_module.ClassificationResult _AudioData = audio_data_module.AudioData +_AudioRecord = audio_record.AudioRecord _BaseOptions = base_options_module.BaseOptions _RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode @@ -204,6 +205,19 @@ class AudioClassifierTest(parameterized.TestCase): self._read_wav_file(audio_file)) self._check_yamnet_result(classification_result_list) + @mock.patch('sounddevice.InputStream', return_value=mock.MagicMock()) + def test_create_audio_record_from_classifier_succeeds(self, _): + # Creates AudioRecord instance using the classifier successfully. + with _AudioClassifier.create_from_model_path( + self.yamnet_model_path + ) as classifier: + self.assertIsInstance(classifier, _AudioClassifier) + record = classifier.create_audio_record(1, 16000, 16000) + self.assertIsInstance(record, _AudioRecord) + self.assertEqual(record.channels, 1) + self.assertEqual(record.sampling_rate, 16000) + self.assertEqual(record.buffer_size, 16000) + def test_max_result_options(self): with _AudioClassifier.create_from_options( _AudioClassifierOptions( diff --git a/mediapipe/tasks/python/test/audio/audio_embedder_test.py b/mediapipe/tasks/python/test/audio/audio_embedder_test.py index 934cdc8db..0fc01ee7b 100644 --- a/mediapipe/tasks/python/test/audio/audio_embedder_test.py +++ b/mediapipe/tasks/python/test/audio/audio_embedder_test.py @@ -24,6 +24,7 @@ import numpy as np from scipy.io import wavfile from mediapipe.tasks.python.audio import audio_embedder +from mediapipe.tasks.python.audio.core import audio_record from mediapipe.tasks.python.audio.core import audio_task_running_mode from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.core import base_options as base_options_module @@ -33,6 +34,7 @@ _AudioEmbedder = audio_embedder.AudioEmbedder _AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions _AudioEmbedderResult = audio_embedder.AudioEmbedderResult _AudioData = audio_data_module.AudioData +_AudioRecord = audio_record.AudioRecord _BaseOptions = base_options_module.BaseOptions _RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode @@ -165,6 +167,19 @@ class AudioEmbedderTest(parameterized.TestCase): self.assertLen(embedding_result0_list, 5) self.assertLen(embedding_result1_list, 5) + @mock.patch('sounddevice.InputStream', return_value=mock.MagicMock()) + def test_create_audio_record_from_embedder_succeeds(self, _): + # Creates AudioRecord instance using the embedder successfully. + with _AudioEmbedder.create_from_model_path( + self.yamnet_model_path + ) as embedder: + self.assertIsInstance(embedder, _AudioEmbedder) + record = embedder.create_audio_record(1, 16000, 16000) + self.assertIsInstance(record, _AudioRecord) + self.assertEqual(record.channels, 1) + self.assertEqual(record.sampling_rate, 16000) + self.assertEqual(record.buffer_size, 16000) + def test_embed_with_yamnet_model_and_different_inputs(self): with _AudioEmbedder.create_from_model_path( self.yamnet_model_path) as embedder: diff --git a/mediapipe/tasks/python/test/audio/core/BUILD b/mediapipe/tasks/python/test/audio/core/BUILD new file mode 100644 index 000000000..2f9c66be3 --- /dev/null +++ b/mediapipe/tasks/python/test/audio/core/BUILD @@ -0,0 +1,25 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict test compatibility macro. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +py_test( + name = "audio_record_test", + srcs = ["audio_record_test.py"], + deps = ["//mediapipe/tasks/python/audio/core:audio_record"], +) diff --git a/mediapipe/tasks/python/test/audio/core/audio_record_test.py b/mediapipe/tasks/python/test/audio/core/audio_record_test.py new file mode 100644 index 000000000..ac804b894 --- /dev/null +++ b/mediapipe/tasks/python/test/audio/core/audio_record_test.py @@ -0,0 +1,104 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for audio_record.""" + +import unittest + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np + +from mediapipe.tasks.python.audio.core import audio_record + + +_mock = unittest.mock + +_CHANNELS = 2 +_SAMPLING_RATE = 16000 +_BUFFER_SIZE = 15600 + + +class AudioRecordTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + + # Mock sounddevice.InputStream + with _mock.patch("sounddevice.InputStream") as mock_input_stream_new_method: + self.mock_input_stream = _mock.MagicMock() + mock_input_stream_new_method.return_value = self.mock_input_stream + self.record = audio_record.AudioRecord( + _CHANNELS, _SAMPLING_RATE, _BUFFER_SIZE + ) + + # Save the initialization arguments of InputStream for later assertion. + _, self.init_args = mock_input_stream_new_method.call_args + + def test_init_args(self): + # Assert parameters of InputStream initialization + self.assertEqual( + self.init_args["channels"], + _CHANNELS, + "InputStream's channels doesn't match the initialization argument.", + ) + self.assertEqual( + self.init_args["samplerate"], + _SAMPLING_RATE, + "InputStream's samplerate doesn't match the initialization argument.", + ) + + def test_life_cycle(self): + # Assert start recording routine. + self.record.start_recording() + self.mock_input_stream.start.assert_called_once() + + # Assert stop recording routine. + self.record.stop() + self.mock_input_stream.stop.assert_called_once() + + def test_read_succeeds_with_valid_sample_size(self): + callback_fn = self.init_args["callback"] + + # Create dummy data to feed to the AudioRecord instance. + chunk_size = int(_BUFFER_SIZE * 0.5) + input_data = [] + for _ in range(3): + dummy_data = np.random.rand(chunk_size, _CHANNELS).astype(float) + input_data.append(dummy_data) + callback_fn(dummy_data) + + # Assert read data of a single chunk. + recorded_audio_data = self.record.read(chunk_size) + self.assertTrue(np.array_equal(recorded_audio_data, input_data[-1])) + + # Assert read all data in buffer. + recorded_audio_data = self.record.read(chunk_size * 2) + print(input_data[-2].shape) + expected_data = np.concatenate(input_data[-2:]) + self.assertTrue(np.array_equal(recorded_audio_data, expected_data)) + + def test_read_fails_with_invalid_sample_size(self): + callback_fn = self.init_args["callback"] + + # Create dummy data to feed to the AudioRecord instance. + dummy_data = np.zeros([_BUFFER_SIZE, 1], dtype=float) + callback_fn(dummy_data) + + # Assert exception if read too much data. + with self.assertRaises(ValueError): + self.record.read(_BUFFER_SIZE + 1) + + +if __name__ == "__main__": + absltest.main() diff --git a/mediapipe/tasks/python/test/metadata/metadata_test.py b/mediapipe/tasks/python/test/metadata/metadata_test.py index d892f1b61..c91bcce6e 100644 --- a/mediapipe/tasks/python/test/metadata/metadata_test.py +++ b/mediapipe/tasks/python/test/metadata/metadata_test.py @@ -388,7 +388,7 @@ class MetadataPopulatorTest(MetadataTest): populator = _metadata.MetadataPopulator.with_model_file(self._model_file) populator.load_metadata_file(self._metadata_file) populator.load_associated_files([self._file1]) - # Suppose to populate self._file2, because it is recorded in the metadta. + # Suppose to populate self._file2, because it is recorded in the metadata. with self.assertRaises(ValueError) as error: populator.populate() self.assertEqual(("File, '{0}', is recorded in the metadata, but has " diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 6ea873274..89a988be9 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -197,3 +197,22 @@ py_library( "//mediapipe/tasks/python/vision/core:vision_task_running_mode", ], ) + +py_library( + name = "face_stylizer", + srcs = [ + "face_stylizer.py", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_py_pb2", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + "//mediapipe/tasks/python/vision/core:base_vision_task_api", + "//mediapipe/tasks/python/vision/core:image_processing_options", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + ], +) diff --git a/mediapipe/tasks/python/vision/core/base_vision_task_api.py b/mediapipe/tasks/python/vision/core/base_vision_task_api.py index 0c8262d4b..768d392f1 100644 --- a/mediapipe/tasks/python/vision/core/base_vision_task_api.py +++ b/mediapipe/tasks/python/vision/core/base_vision_task_api.py @@ -144,7 +144,7 @@ class BaseVisionTaskApi(object): set. By default, it's set to True. Returns: - A normalized rect proto that repesents the image processing options. + A normalized rect proto that represents the image processing options. """ normalized_rect = _NormalizedRect( rotation=0, x_center=0.5, y_center=0.5, width=1, height=1) diff --git a/mediapipe/tasks/python/vision/face_stylizer.py b/mediapipe/tasks/python/vision/face_stylizer.py new file mode 100644 index 000000000..0bbd9c4d1 --- /dev/null +++ b/mediapipe/tasks/python/vision/face_stylizer.py @@ -0,0 +1,279 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe face stylizer task.""" + +import dataclasses +from typing import Callable, Mapping, Optional + +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.python._framework_bindings import packet as packet_module +from mediapipe.tasks.cc.vision.face_stylizer.proto import face_stylizer_graph_options_pb2 +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.vision.core import base_vision_task_api +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module +from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module + +_BaseOptions = base_options_module.BaseOptions +_FaceStylizerGraphOptionsProto = ( + face_stylizer_graph_options_pb2.FaceStylizerGraphOptions +) +_RunningMode = running_mode_module.VisionTaskRunningMode +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions +_TaskInfo = task_info_module.TaskInfo + +_STYLIZED_IMAGE_NAME = 'stylized_image' +_STYLIZED_IMAGE_TAG = 'STYLIZED_IMAGE' +_NORM_RECT_STREAM_NAME = 'norm_rect_in' +_NORM_RECT_TAG = 'NORM_RECT' +_IMAGE_IN_STREAM_NAME = 'image_in' +_IMAGE_OUT_STREAM_NAME = 'image_out' +_IMAGE_TAG = 'IMAGE' +_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.face_stylizer.FaceStylizerGraph' +_MICRO_SECONDS_PER_MILLISECOND = 1000 + + +@dataclasses.dataclass +class FaceStylizerOptions: + """Options for the face stylizer task. + + Attributes: + base_options: Base options for the face stylizer task. + running_mode: The running mode of the task. Default to the image mode. Face + stylizer task has three running modes: 1) The image mode for stylizing one + face on a single image input. 2) The video mode for stylizing one face per + frame on the decoded frames of a video. 3) The live stream mode for + stylizing one face on a live stream of input data, such as from camera. + result_callback: The user-defined result callback for processing live stream + data. The result callback should only be specified when the running mode + is set to the live stream mode. + """ + + base_options: _BaseOptions + running_mode: _RunningMode = _RunningMode.IMAGE + result_callback: Optional[ + Callable[[image_module.Image, image_module.Image, int], None] + ] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _FaceStylizerGraphOptionsProto: + """Generates an FaceStylizerOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + base_options_proto.use_stream_mode = ( + False if self.running_mode == _RunningMode.IMAGE else True + ) + return _FaceStylizerGraphOptionsProto(base_options=base_options_proto) + + +class FaceStylizer(base_vision_task_api.BaseVisionTaskApi): + """Class that performs face stylization on images.""" + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'FaceStylizer': + """Creates an `FaceStylizer` object from a TensorFlow Lite model and the default `FaceStylizerOptions`. + + Note that the created `FaceStylizer` instance is in image mode, for + stylizing one face on a single image input. + + Args: + model_path: Path to the model. + + Returns: + `FaceStylizer` object that's created from the model file and the default + `FaceStylizerOptions`. + + Raises: + ValueError: If failed to create `FaceStylizer` object from the provided + file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(model_asset_path=model_path) + options = FaceStylizerOptions( + base_options=base_options, running_mode=_RunningMode.IMAGE + ) + return cls.create_from_options(options) + + @classmethod + def create_from_options(cls, options: FaceStylizerOptions) -> 'FaceStylizer': + """Creates the `FaceStylizer` object from face stylizer options. + + Args: + options: Options for the face stylizer task. + + Returns: + `FaceStylizer` object that's created from `options`. + + Raises: + ValueError: If failed to create `FaceStylizer` object from + `FaceStylizerOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + + def packets_callback(output_packets: Mapping[str, packet_module.Packet]): + if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): + return + image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) + stylized_image_packet = output_packets[_STYLIZED_IMAGE_NAME] + stylized_image = packet_getter.get_image(stylized_image_packet) + + options.result_callback( + stylized_image, + image, + stylized_image_packet.timestamp.value + // _MICRO_SECONDS_PER_MILLISECOND, + ) + + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[ + ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), + ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), + ], + output_streams=[ + ':'.join([_STYLIZED_IMAGE_TAG, _STYLIZED_IMAGE_NAME]), + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), + ], + task_options=options, + ) + return cls( + task_info.generate_graph_config( + enable_flow_limiting=options.running_mode + == _RunningMode.LIVE_STREAM + ), + options.running_mode, + packets_callback if options.result_callback else None, + ) + + def stylize( + self, + image: image_module.Image, + image_processing_options: Optional[_ImageProcessingOptions] = None, + ) -> image_module.Image: + """Performs face stylization on the provided MediaPipe Image. + + Only use this method when the FaceStylizer is created with the image + running mode. + + To ensure that the output image has reasonable quality, the stylized output + image size is the smaller of the model output size and the size of the + `region_of_interest` specified in `image_processing_options`. + + Args: + image: MediaPipe Image. + image_processing_options: Options for image processing. + + Returns: + The stylized image. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If face stylization failed to run. + """ + normalized_rect = self.convert_to_normalized_rect(image_processing_options) + output_packets = self._process_image_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ), + }) + return packet_getter.get_image(output_packets[_STYLIZED_IMAGE_NAME]) + + def stylize_for_video( + self, + image: image_module.Image, + timestamp_ms: int, + image_processing_options: Optional[_ImageProcessingOptions] = None, + ) -> image_module.Image: + """Performs face stylization on the provided video frames. + + Only use this method when the FaceStylizer is created with the video + running mode. It's required to provide the video frame's timestamp (in + milliseconds) along with the video frame. The input timestamps should be + monotonically increasing for adjacent calls of this method. + + To ensure that the output image has reasonable quality, the stylized output + image size is the smaller of the model output size and the size of the + `region_of_interest` specified in `image_processing_options`. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input video frame in milliseconds. + image_processing_options: Options for image processing. + + Returns: + The stylized image. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If face stylization failed to run. + """ + normalized_rect = self.convert_to_normalized_rect(image_processing_options) + output_packets = self._process_video_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND + ), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), + }) + return packet_getter.get_image(output_packets[_STYLIZED_IMAGE_NAME]) + + def stylize_async( + self, + image: image_module.Image, + timestamp_ms: int, + image_processing_options: Optional[_ImageProcessingOptions] = None, + ) -> None: + """Sends live image data (an Image with a unique timestamp) to perform face stylization. + + Only use this method when the FaceStylizer is created with the live stream + running mode. The input timestamps should be monotonically increasing for + adjacent calls of this method. This method will return immediately after the + input image is accepted. The results will be available via the + `result_callback` provided in the `FaceStylizerOptions`. The + `stylize_async` method is designed to process live stream data such as + camera input. To lower the overall latency, face stylizer may drop the input + images if needed. In other words, it's not guaranteed to have output per + input image. + + To ensure that the stylized image has reasonable quality, the stylized + output image size is the smaller of the model output size and the size of + the `region_of_interest` specified in `image_processing_options`. + + The `result_callback` provides: + - The stylized image. + - The input image that the face stylizer runs on. + - The input timestamp in milliseconds. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input image in milliseconds. + image_processing_options: Options for image processing. + + Raises: + ValueError: If the current input timestamp is smaller than what the face + stylizer has already processed. + """ + normalized_rect = self.convert_to_normalized_rect(image_processing_options) + self._send_live_stream_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND + ), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), + }) diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index ed0ee6ea5..0de0c255c 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -78,6 +78,7 @@ mediapipe_files(srcs = [ "pose.jpg", "pose_detection.tflite", "pose_landmark_lite.tflite", + "pose_landmarker.task", "right_hands.jpg", "right_hands_rotated.jpg", "segmentation_golden_rotation0.png", @@ -187,6 +188,7 @@ filegroup( "palm_detection_full.tflite", "pose_detection.tflite", "pose_landmark_lite.tflite", + "pose_landmarker.task", "selfie_segm_128_128_3.tflite", "selfie_segm_144_256_3.tflite", "selfie_segmentation.tflite", diff --git a/mediapipe/tasks/testdata/vision/pose_landmarker.task b/mediapipe/tasks/testdata/vision/pose_landmarker.task new file mode 100644 index 000000000..d57dd9e0d Binary files /dev/null and b/mediapipe/tasks/testdata/vision/pose_landmarker.task differ diff --git a/mediapipe/tasks/web/components/containers/BUILD b/mediapipe/tasks/web/components/containers/BUILD index 927146c04..714b4613b 100644 --- a/mediapipe/tasks/web/components/containers/BUILD +++ b/mediapipe/tasks/web/components/containers/BUILD @@ -4,6 +4,11 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration") package(default_visibility = ["//mediapipe/tasks:internal"]) +mediapipe_ts_declaration( + name = "bounding_box", + srcs = ["bounding_box.d.ts"], +) + mediapipe_ts_declaration( name = "category", srcs = ["category.d.ts"], @@ -15,6 +20,16 @@ mediapipe_ts_declaration( deps = [":category"], ) +mediapipe_ts_declaration( + name = "detection_result", + srcs = ["detection_result.d.ts"], + deps = [ + ":bounding_box", + ":category", + ":keypoint", + ], +) + mediapipe_ts_declaration( name = "keypoint", srcs = ["keypoint.d.ts"], diff --git a/mediapipe/tasks/web/components/containers/bounding_box.d.ts b/mediapipe/tasks/web/components/containers/bounding_box.d.ts new file mode 100644 index 000000000..69174f4c7 --- /dev/null +++ b/mediapipe/tasks/web/components/containers/bounding_box.d.ts @@ -0,0 +1,27 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** An integer bounding box, axis aligned. */ +export declare interface BoundingBox { + /** The X coordinate of the top-left corner, in pixels. */ + originX: number; + /** The Y coordinate of the top-left corner, in pixels. */ + originY: number; + /** The width of the bounding box, in pixels. */ + width: number; + /** The height of the bounding box, in pixels. */ + height: number; +} diff --git a/mediapipe/tasks/web/components/containers/detection_result.d.ts b/mediapipe/tasks/web/components/containers/detection_result.d.ts new file mode 100644 index 000000000..37817307c --- /dev/null +++ b/mediapipe/tasks/web/components/containers/detection_result.d.ts @@ -0,0 +1,43 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BoundingBox} from '../../../../tasks/web/components/containers/bounding_box'; +import {Category} from '../../../../tasks/web/components/containers/category'; +import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/keypoint'; + +/** Represents one detection by a detection task. */ +export declare interface Detection { + /** A list of `Category` objects. */ + categories: Category[]; + + /** The bounding box of the detected objects. */ + boundingBox?: BoundingBox; + + /** + * Optional list of keypoints associated with the detection. Keypoints + * represent interesting points related to the detection. For example, the + * keypoints represent the eye, ear and mouth from face detection model. Or + * in the template matching detection, e.g. KNIFT, they can represent the + * feature points for template matching. + */ + keypoints?: NormalizedKeypoint[]; +} + +/** Detection results of a model. */ +export interface DetectionResult { + /** A list of Detections. */ + detections: Detection[]; +} diff --git a/mediapipe/tasks/web/components/containers/matrix.d.ts b/mediapipe/tasks/web/components/containers/matrix.d.ts index fd4bda4c3..e0bad58c8 100644 --- a/mediapipe/tasks/web/components/containers/matrix.d.ts +++ b/mediapipe/tasks/web/components/containers/matrix.d.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -/** A two-dimenionsal matrix. */ +/** A two-dimensional matrix. */ export declare interface Matrix { /** The number of rows. */ rows: number; diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index b83f73eb2..a5f93a147 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -56,6 +56,27 @@ jasmine_node_test( deps = [":classifier_result_test_lib"], ) +mediapipe_ts_library( + name = "detection_result", + srcs = ["detection_result.ts"], + deps = [ + "//mediapipe/framework/formats:detection_jspb_proto", + "//mediapipe/tasks/web/components/containers:detection_result", + ], +) + +mediapipe_ts_library( + name = "detection_result_test_lib", + testonly = True, + srcs = ["detection_result.test.ts"], + deps = [ + ":detection_result", + "//mediapipe/framework/formats:detection_jspb_proto", + "//mediapipe/framework/formats:location_data_jspb_proto", + "//mediapipe/tasks/web/components/containers:detection_result", + ], +) + mediapipe_ts_library( name = "embedder_result", srcs = ["embedder_result.ts"], diff --git a/mediapipe/tasks/web/components/processors/classifier_result.test.ts b/mediapipe/tasks/web/components/processors/classifier_result.test.ts index 4b93d0a76..2e8f9956c 100644 --- a/mediapipe/tasks/web/components/processors/classifier_result.test.ts +++ b/mediapipe/tasks/web/components/processors/classifier_result.test.ts @@ -32,12 +32,12 @@ describe('convertFromClassificationResultProto()', () => { classifcations.setHeadIndex(1); classifcations.setHeadName('headName'); const classificationList = new ClassificationList(); - const clasification = new Classification(); - clasification.setIndex(2); - clasification.setScore(0.3); - clasification.setDisplayName('displayName'); - clasification.setLabel('categoryName'); - classificationList.addClassification(clasification); + const classification = new Classification(); + classification.setIndex(2); + classification.setScore(0.3); + classification.setDisplayName('displayName'); + classification.setLabel('categoryName'); + classificationList.addClassification(classification); classifcations.setClassificationList(classificationList); classificationResult.addClassifications(classifcations); @@ -62,8 +62,8 @@ describe('convertFromClassificationResultProto()', () => { const classificationResult = new ClassificationResult(); const classifcations = new Classifications(); const classificationList = new ClassificationList(); - const clasification = new Classification(); - classificationList.addClassification(clasification); + const classification = new Classification(); + classificationList.addClassification(classification); classifcations.setClassificationList(classificationList); classificationResult.addClassifications(classifcations); diff --git a/mediapipe/tasks/web/components/processors/detection_result.test.ts b/mediapipe/tasks/web/components/processors/detection_result.test.ts new file mode 100644 index 000000000..289043506 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/detection_result.test.ts @@ -0,0 +1,91 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; +import {LocationData} from '../../../../framework/formats/location_data_pb'; + +import {convertFromDetectionProto} from './detection_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +describe('convertFromDetectionProto()', () => { + it('transforms custom values', () => { + const detection = new DetectionProto(); + detection.addScore(0.1); + detection.addLabelId(1); + detection.addLabel('foo'); + detection.addDisplayName('bar'); + + const locationData = new LocationData(); + const boundingBox = new LocationData.BoundingBox(); + boundingBox.setXmin(1); + boundingBox.setYmin(2); + boundingBox.setWidth(3); + boundingBox.setHeight(4); + locationData.setBoundingBox(boundingBox); + + const keypoint = new LocationData.RelativeKeypoint(); + keypoint.setX(5); + keypoint.setY(6); + keypoint.setScore(0.7); + keypoint.setKeypointLabel('bar'); + locationData.addRelativeKeypoints(new LocationData.RelativeKeypoint()); + + detection.setLocationData(locationData); + + const result = convertFromDetectionProto(detection); + + expect(result).toEqual({ + categories: [{ + score: 0.1, + index: 1, + categoryName: 'foo', + displayName: 'bar', + }], + boundingBox: {originX: 1, originY: 2, width: 3, height: 4}, + keypoints: [{ + x: 5, + y: 6, + score: 0.7, + label: 'bar', + }], + }); + }); + + it('transforms default values', () => { + const detection = new DetectionProto(); + detection.addScore(0.2); + const locationData = new LocationData(); + const boundingBox = new LocationData.BoundingBox(); + locationData.setBoundingBox(boundingBox); + detection.setLocationData(locationData); + + const result = convertFromDetectionProto(detection); + + expect(result).toEqual({ + categories: [{ + score: 0.2, + index: -1, + categoryName: '', + displayName: '', + }], + boundingBox: {originX: 0, originY: 0, width: 0, height: 0} + }); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/detection_result.ts b/mediapipe/tasks/web/components/processors/detection_result.ts new file mode 100644 index 000000000..6b38820bf --- /dev/null +++ b/mediapipe/tasks/web/components/processors/detection_result.ts @@ -0,0 +1,63 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; +import {Detection} from '../../../../tasks/web/components/containers/detection_result'; + +const DEFAULT_CATEGORY_INDEX = -1; + +/** Converts a Detection proto into a Detection object. */ +export function convertFromDetectionProto(source: DetectionProto): Detection { + const scores = source.getScoreList(); + const indexes = source.getLabelIdList(); + const labels = source.getLabelList(); + const displayNames = source.getDisplayNameList(); + + const detection: Detection = {categories: []}; + for (let i = 0; i < scores.length; i++) { + detection.categories.push({ + score: scores[i], + index: indexes[i] ?? DEFAULT_CATEGORY_INDEX, + categoryName: labels[i] ?? '', + displayName: displayNames[i] ?? '', + }); + } + + const boundingBox = source.getLocationData()?.getBoundingBox(); + if (boundingBox) { + detection.boundingBox = { + originX: boundingBox.getXmin() ?? 0, + originY: boundingBox.getYmin() ?? 0, + width: boundingBox.getWidth() ?? 0, + height: boundingBox.getHeight() ?? 0 + }; + } + + if (source.getLocationData()?.getRelativeKeypointsList().length) { + detection.keypoints = []; + for (const keypoint of + source.getLocationData()!.getRelativeKeypointsList()) { + detection.keypoints.push({ + x: keypoint.getX() ?? 0.0, + y: keypoint.getY() ?? 0.0, + score: keypoint.getScore() ?? 0.0, + label: keypoint.getKeypointLabel() ?? '', + }); + } + } + + return detection; +} diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 1f28cb0fe..19c795fd9 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -19,6 +19,7 @@ mediapipe_files(srcs = [ VISION_LIBS = [ "//mediapipe/tasks/web/core:fileset_resolver", + "//mediapipe/tasks/web/vision/face_detector", "//mediapipe/tasks/web/vision/face_landmarker", "//mediapipe/tasks/web/vision/face_stylizer", "//mediapipe/tasks/web/vision/gesture_recognizer", diff --git a/mediapipe/tasks/web/vision/README.md b/mediapipe/tasks/web/vision/README.md index ebeac54c5..d5109142b 100644 --- a/mediapipe/tasks/web/vision/README.md +++ b/mediapipe/tasks/web/vision/README.md @@ -2,6 +2,22 @@ This package contains the vision tasks for MediaPipe. +## Face Detection + +The MediaPipe Face Detector task lets you detect the presence and location of +faces within images or videos. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const faceDetector = await FaceDetector.createFromModelPath(vision, + "https://storage.googleapis.com/mediapipe-tasks/object_detector/efficientdet_lite0_uint8.tflite" +); +const image = document.getElementById("image") as HTMLImageElement; +const detections = faceDetector.detect(image); +``` + ## Face Landmark Detection The MediaPipe Face Landmarker task lets you detect the landmarks of faces in diff --git a/mediapipe/tasks/web/vision/core/types.d.ts b/mediapipe/tasks/web/vision/core/types.d.ts index c04366ac0..b48b5045d 100644 --- a/mediapipe/tasks/web/vision/core/types.d.ts +++ b/mediapipe/tasks/web/vision/core/types.d.ts @@ -19,7 +19,7 @@ import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/ke /** * The segmentation tasks return the segmentation either as a WebGLTexture (when * the output is on GPU) or as a typed JavaScript arrays for CPU-based - * category or confidence masks. `Uint8ClampedArray`s are used to represend + * category or confidence masks. `Uint8ClampedArray`s are used to represent * CPU-based category masks and `Float32Array`s are used for CPU-based * confidence masks. */ diff --git a/mediapipe/tasks/web/vision/core/vision_task_options.d.ts b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts index 72bc2efb1..a45efd6d3 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_options.d.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts @@ -27,7 +27,7 @@ export type RunningMode = 'IMAGE'|'VIDEO'; export declare interface VisionTaskOptions extends TaskRunnerOptions { /** * The canvas element to bind textures to. This has to be set for GPU - * processing. The task will initialize a WebGL context and throw an eror if + * processing. The task will initialize a WebGL context and throw an error if * this fails (e.g. if you have already initialized a different type of * context). */ diff --git a/mediapipe/tasks/web/vision/face_detector/BUILD b/mediapipe/tasks/web/vision/face_detector/BUILD new file mode 100644 index 000000000..8225e4948 --- /dev/null +++ b/mediapipe/tasks/web/vision/face_detector/BUILD @@ -0,0 +1,71 @@ +# This contains the MediaPipe Face Detector Task. +# +# This task takes video frames and outputs synchronized frames along with +# the detection results for one or more faces, using Face Detector. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "face_detector", + srcs = ["face_detector.ts"], + visibility = ["//visibility:public"], + deps = [ + ":face_detector_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/framework/formats:detection_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/processors:detection_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:image_processing_options", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "face_detector_types", + srcs = [ + "face_detector_options.d.ts", + "face_detector_result.d.ts", + ], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/web/components/containers:bounding_box", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:detection_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", + ], +) + +mediapipe_ts_library( + name = "face_detector_test_lib", + testonly = True, + srcs = [ + "face_detector_test.ts", + ], + deps = [ + ":face_detector", + ":face_detector_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:detection_jspb_proto", + "//mediapipe/framework/formats:location_data_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "face_detector_test", + tags = ["nomsan"], + deps = [":face_detector_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/face_detector/face_detector.ts b/mediapipe/tasks/web/vision/face_detector/face_detector.ts new file mode 100644 index 000000000..039f7dd44 --- /dev/null +++ b/mediapipe/tasks/web/vision/face_detector/face_detector.ts @@ -0,0 +1,213 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {FaceDetectorGraphOptions as FaceDetectorGraphOptionsProto} from '../../../../tasks/cc/vision/face_detector/proto/face_detector_graph_options_pb'; +import {convertFromDetectionProto} from '../../../../tasks/web/components/processors/detection_result'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource url + +import {FaceDetectorOptions} from './face_detector_options'; +import {FaceDetectorResult} from './face_detector_result'; + +const IMAGE_STREAM = 'image_in'; +const NORM_RECT_STREAM = 'norm_rect_in'; +const DETECTIONS_STREAM = 'detections'; +const FACE_DETECTOR_GRAPH = + 'mediapipe.tasks.vision.face_detector.FaceDetectorGraph'; + +export * from './face_detector_options'; +export * from './face_detector_result'; +export {ImageSource}; // Used in the public API + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +/** Performs face detection on images. */ +export class FaceDetector extends VisionTaskRunner { + private result: FaceDetectorResult = {detections: []}; + private readonly options = new FaceDetectorGraphOptionsProto(); + + /** + * Initializes the Wasm runtime and creates a new face detector from the + * provided options. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param faceDetectorOptions The options for the FaceDetector. Note that + * either a path to the model asset or a model buffer needs to be + * provided (via `baseOptions`). + */ + static createFromOptions( + wasmFileset: WasmFileset, + faceDetectorOptions: FaceDetectorOptions): Promise { + return VisionTaskRunner.createVisionInstance( + FaceDetector, wasmFileset, faceDetectorOptions); + } + + /** + * Initializes the Wasm runtime and creates a new face detector based on the + * provided model asset buffer. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmFileset: WasmFileset, + modelAssetBuffer: Uint8Array): Promise { + return VisionTaskRunner.createVisionInstance( + FaceDetector, wasmFileset, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new face detector based on the + * path to the model asset. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static async createFromModelPath( + wasmFileset: WasmFileset, + modelAssetPath: string): Promise { + return VisionTaskRunner.createVisionInstance( + FaceDetector, wasmFileset, {baseOptions: {modelAssetPath}}); + } + + /** @hideconstructor */ + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM, /* roiAllowed= */ false); + this.options.setBaseOptions(new BaseOptionsProto()); + this.options.setMinDetectionConfidence(0.5); + this.options.setMinSuppressionThreshold(0.3); + } + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } + + /** + * Sets new options for the FaceDetector. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the FaceDetector. + */ + override setOptions(options: FaceDetectorOptions): Promise { + if ('minDetectionConfidence' in options) { + this.options.setMinDetectionConfidence( + options.minDetectionConfidence ?? 0.5); + } + if ('minSuppressionThreshold' in options) { + this.options.setMinSuppressionThreshold( + options.minSuppressionThreshold ?? 0.3); + } + return this.applyOptions(options); + } + + /** + * Performs face detection on the provided single image and waits + * synchronously for the response. Only use this method when the + * FaceDetector is created with running mode `image`. + * + * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @return A result containing the list of detected faces. + */ + detect(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + FaceDetectorResult { + this.result = {detections: []}; + this.processImageData(image, imageProcessingOptions); + return this.result; + } + + /** + * Performs face detection on the provided video frame and waits + * synchronously for the response. Only use this method when the + * FaceDetector is created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @return A result containing the list of detected faces. + */ + detectForVideo( + videoFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): FaceDetectorResult { + this.result = {detections: []}; + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); + return this.result; + } + + /** Converts raw data into a Detection, and adds it to our detection list. */ + private addJsFaceDetections(data: Uint8Array[]): void { + for (const binaryProto of data) { + const detectionProto = DetectionProto.deserializeBinary(binaryProto); + this.result.detections.push(convertFromDetectionProto(detectionProto)); + } + } + + /** Updates the MediaPipe graph configuration. */ + protected override refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); + graphConfig.addOutputStream(DETECTIONS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + FaceDetectorGraphOptionsProto.ext, this.options); + + const detectorNode = new CalculatorGraphConfig.Node(); + detectorNode.setCalculator(FACE_DETECTOR_GRAPH); + detectorNode.addInputStream('IMAGE:' + IMAGE_STREAM); + detectorNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); + detectorNode.addOutputStream('DETECTIONS:' + DETECTIONS_STREAM); + detectorNode.setOptions(calculatorOptions); + + graphConfig.addNode(detectorNode); + + this.graphRunner.attachProtoVectorListener( + DETECTIONS_STREAM, (binaryProto, timestamp) => { + this.addJsFaceDetections(binaryProto); + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener(DETECTIONS_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + diff --git a/mediapipe/tasks/web/vision/face_detector/face_detector_options.d.ts b/mediapipe/tasks/web/vision/face_detector/face_detector_options.d.ts new file mode 100644 index 000000000..665035f7e --- /dev/null +++ b/mediapipe/tasks/web/vision/face_detector/face_detector_options.d.ts @@ -0,0 +1,33 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; + +/** Options to configure the MediaPipe Face Detector Task */ +export interface FaceDetectorOptions extends VisionTaskOptions { + /** + * The minimum confidence score for the face detection to be considered + * successful. Defaults to 0.5. + */ + minDetectionConfidence?: number|undefined; + + /** + * The minimum non-maximum-suppression threshold for face detection to be + * considered overlapped. Defaults to 0.3. + */ + minSuppressionThreshold?: number|undefined; +} diff --git a/mediapipe/tasks/web/vision/face_detector/face_detector_result.d.ts b/mediapipe/tasks/web/vision/face_detector/face_detector_result.d.ts new file mode 100644 index 000000000..6a36559f7 --- /dev/null +++ b/mediapipe/tasks/web/vision/face_detector/face_detector_result.d.ts @@ -0,0 +1,19 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {BoundingBox} from '../../../../tasks/web/components/containers/bounding_box'; +export {Category} from '../../../../tasks/web/components/containers/category'; +export {Detection, DetectionResult as FaceDetectorResult} from '../../../../tasks/web/components/containers/detection_result'; diff --git a/mediapipe/tasks/web/vision/face_detector/face_detector_test.ts b/mediapipe/tasks/web/vision/face_detector/face_detector_test.ts new file mode 100644 index 000000000..88dd20d2b --- /dev/null +++ b/mediapipe/tasks/web/vision/face_detector/face_detector_test.ts @@ -0,0 +1,193 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; +import {LocationData} from '../../../../framework/formats/location_data_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {FaceDetector} from './face_detector'; +import {FaceDetectorOptions} from './face_detector_options'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class FaceDetectorFake extends FaceDetector implements MediapipeTasksFake { + lastSampleRate: number|undefined; + calculatorName = 'mediapipe.tasks.vision.face_detector.FaceDetectorGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + fakeWasmModule: SpyWasmModule; + protoListener: + ((binaryProtos: Uint8Array[], timestamp: number) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('detections'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('FaceDetector', () => { + let faceDetector: FaceDetectorFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + faceDetector = new FaceDetectorFake(); + await faceDetector.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + }); + + it('initializes graph', async () => { + verifyGraph(faceDetector); + verifyListenersRegistered(faceDetector); + }); + + it('reloads graph when settings are changed', async () => { + await faceDetector.setOptions({minDetectionConfidence: 0.1}); + verifyGraph(faceDetector, ['minDetectionConfidence', 0.1]); + verifyListenersRegistered(faceDetector); + + await faceDetector.setOptions({minDetectionConfidence: 0.2}); + verifyGraph(faceDetector, ['minDetectionConfidence', 0.2]); + verifyListenersRegistered(faceDetector); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await faceDetector.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + faceDetector, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */ + [ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await faceDetector.setOptions({minDetectionConfidence: 0.1}); + await faceDetector.setOptions({minSuppressionThreshold: 0.2}); + verifyGraph(faceDetector, ['minDetectionConfidence', 0.1]); + verifyGraph(faceDetector, ['minSuppressionThreshold', 0.2]); + }); + + describe('setOptions()', () => { + interface TestCase { + optionName: keyof FaceDetectorOptions; + protoName: string; + customValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionName: 'minDetectionConfidence', + protoName: 'minDetectionConfidence', + customValue: 0.1, + defaultValue: 0.5 + }, + { + optionName: 'minSuppressionThreshold', + protoName: 'minSuppressionThreshold', + customValue: 0.2, + defaultValue: 0.3 + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, async () => { + await faceDetector.setOptions( + {[testCase.optionName]: testCase.customValue}); + verifyGraph(faceDetector, [testCase.protoName, testCase.customValue]); + }); + + it(`can clear ${testCase.optionName}`, async () => { + await faceDetector.setOptions( + {[testCase.optionName]: testCase.customValue}); + verifyGraph(faceDetector, [testCase.protoName, testCase.customValue]); + await faceDetector.setOptions({[testCase.optionName]: undefined}); + verifyGraph(faceDetector, [testCase.protoName, testCase.defaultValue]); + }); + } + }); + + it('doesn\'t support region of interest', () => { + expect(() => { + faceDetector.detect( + {} as HTMLImageElement, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + + it('transforms results', async () => { + const detection = new DetectionProto(); + detection.addScore(0.1); + const locationData = new LocationData(); + const boundingBox = new LocationData.BoundingBox(); + locationData.setBoundingBox(boundingBox); + detection.setLocationData(locationData); + + const binaryProto = detection.serializeBinary(); + + // Pass the test data to our listener + faceDetector.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(faceDetector); + faceDetector.protoListener!([binaryProto], 1337); + }); + + // Invoke the face detector + const {detections} = faceDetector.detect({} as HTMLImageElement); + + expect(faceDetector.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(detections.length).toEqual(1); + expect(detections[0]).toEqual({ + categories: [{ + score: 0.1, + index: -1, + categoryName: '', + displayName: '', + }], + boundingBox: {originX: 0, originY: 0, width: 0, height: 0} + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 856d84683..4882e22c4 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -15,6 +15,7 @@ */ import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; +import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector'; import {FaceLandmarker as FaceLandmarkerImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker'; import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer'; import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; @@ -28,6 +29,7 @@ import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/ob // Declare the variables locally so that Rollup in OSS includes them explicitly // as exports. const FilesetResolver = FilesetResolverImpl; +const FaceDetector = FaceDetectorImpl; const FaceLandmarker = FaceLandmarkerImpl; const FaceStylizer = FaceStylizerImpl; const GestureRecognizer = GestureRecognizerImpl; @@ -40,6 +42,7 @@ const ObjectDetector = ObjectDetectorImpl; export { FilesetResolver, + FaceDetector, FaceLandmarker, FaceStylizer, GestureRecognizer, diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index 76fa589c8..2c0fcbdf8 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -22,6 +22,7 @@ mediapipe_ts_library( "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/processors:detection_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", @@ -37,7 +38,9 @@ mediapipe_ts_declaration( ], visibility = ["//visibility:public"], deps = [ + "//mediapipe/tasks/web/components/containers:bounding_box", "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:detection_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/vision/core:vision_task_options", diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 42b62c9e2..5d406f1a0 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -19,6 +19,7 @@ import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; +import {convertFromDetectionProto} from '../../../../tasks/web/components/processors/detection_result'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; @@ -26,15 +27,13 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner // Placeholder for internal dependency on trusted resource url import {ObjectDetectorOptions} from './object_detector_options'; -import {Detection} from './object_detector_result'; +import {ObjectDetectorResult} from './object_detector_result'; const IMAGE_STREAM = 'input_frame_gpu'; const NORM_RECT_STREAM = 'norm_rect'; const DETECTIONS_STREAM = 'detections'; const OBJECT_DETECTOR_GRAPH = 'mediapipe.tasks.vision.ObjectDetectorGraph'; -const DEFAULT_CATEGORY_INDEX = -1; - export * from './object_detector_options'; export * from './object_detector_result'; export {ImageSource}; // Used in the public API @@ -44,7 +43,7 @@ export {ImageSource}; // Used in the public API /** Performs object detection on images. */ export class ObjectDetector extends VisionTaskRunner { - private detections: Detection[] = []; + private result: ObjectDetectorResult = {detections: []}; private readonly options = new ObjectDetectorOptionsProto(); /** @@ -163,13 +162,13 @@ export class ObjectDetector extends VisionTaskRunner { * @param image An image to process. * @param imageProcessingOptions the `ImageProcessingOptions` specifying how * to process the input image before running inference. - * @return The list of detected objects + * @return A result containing a list of detected objects. */ detect(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): - Detection[] { - this.detections = []; + ObjectDetectorResult { + this.result = {detections: []}; this.processImageData(image, imageProcessingOptions); - return [...this.detections]; + return this.result; } /** @@ -181,46 +180,21 @@ export class ObjectDetector extends VisionTaskRunner { * @param timestamp The timestamp of the current frame, in ms. * @param imageProcessingOptions the `ImageProcessingOptions` specifying how * to process the input image before running inference. - * @return The list of detected objects + * @return A result containing a list of detected objects. */ detectForVideo( videoFrame: ImageSource, timestamp: number, - imageProcessingOptions?: ImageProcessingOptions): Detection[] { - this.detections = []; + imageProcessingOptions?: ImageProcessingOptions): ObjectDetectorResult { + this.result = {detections: []}; this.processVideoData(videoFrame, imageProcessingOptions, timestamp); - return [...this.detections]; + return this.result; } /** Converts raw data into a Detection, and adds it to our detection list. */ private addJsObjectDetections(data: Uint8Array[]): void { for (const binaryProto of data) { const detectionProto = DetectionProto.deserializeBinary(binaryProto); - const scores = detectionProto.getScoreList(); - const indexes = detectionProto.getLabelIdList(); - const labels = detectionProto.getLabelList(); - const displayNames = detectionProto.getDisplayNameList(); - - const detection: Detection = {categories: []}; - for (let i = 0; i < scores.length; i++) { - detection.categories.push({ - score: scores[i], - index: indexes[i] ?? DEFAULT_CATEGORY_INDEX, - categoryName: labels[i] ?? '', - displayName: displayNames[i] ?? '', - }); - } - - const boundingBox = detectionProto.getLocationData()?.getBoundingBox(); - if (boundingBox) { - detection.boundingBox = { - originX: boundingBox.getXmin() ?? 0, - originY: boundingBox.getYmin() ?? 0, - width: boundingBox.getWidth() ?? 0, - height: boundingBox.getHeight() ?? 0 - }; - } - - this.detections.push(detection); + this.result.detections.push(convertFromDetectionProto(detectionProto)); } } diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts index c9c87a1bf..54bc6f907 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts @@ -14,27 +14,6 @@ * limitations under the License. */ -import {Category} from '../../../../tasks/web/components/containers/category'; - -export {Category}; - -/** An integer bounding box, axis aligned. */ -export declare interface BoundingBox { - /** The X coordinate of the top-left corner, in pixels. */ - originX: number; - /** The Y coordinate of the top-left corner, in pixels. */ - originY: number; - /** The width of the bounding box, in pixels. */ - width: number; - /** The height of the bounding box, in pixels. */ - height: number; -} - -/** Represents one object detected by the `ObjectDetector`. */ -export declare interface Detection { - /** A list of `Category` objects. */ - categories: Category[]; - - /** The bounding box of the detected objects. */ - boundingBox?: BoundingBox; -} +export {BoundingBox} from '../../../../tasks/web/components/containers/bounding_box'; +export {Category} from '../../../../tasks/web/components/containers/category'; +export {Detection, DetectionResult as ObjectDetectorResult} from '../../../../tasks/web/components/containers/detection_result'; diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts index 9dd64c0b6..18e4a2062 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts @@ -179,56 +179,29 @@ describe('ObjectDetector', () => { }); it('transforms results', async () => { - const detectionProtos: Uint8Array[] = []; - - // Add a detection with all optional properties - let detection = new DetectionProto(); + const detection = new DetectionProto(); detection.addScore(0.1); - detection.addLabelId(1); - detection.addLabel('foo'); - detection.addDisplayName('bar'); - let locationData = new LocationData(); - let boundingBox = new LocationData.BoundingBox(); - boundingBox.setXmin(1); - boundingBox.setYmin(2); - boundingBox.setWidth(3); - boundingBox.setHeight(4); + const locationData = new LocationData(); + const boundingBox = new LocationData.BoundingBox(); locationData.setBoundingBox(boundingBox); detection.setLocationData(locationData); - detectionProtos.push(detection.serializeBinary()); - // Add a detection without optional properties - detection = new DetectionProto(); - detection.addScore(0.2); - locationData = new LocationData(); - boundingBox = new LocationData.BoundingBox(); - locationData.setBoundingBox(boundingBox); - detection.setLocationData(locationData); - detectionProtos.push(detection.serializeBinary()); + const binaryProto = detection.serializeBinary(); // Pass the test data to our listener objectDetector.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(objectDetector); - objectDetector.protoListener!(detectionProtos, 1337); + objectDetector.protoListener!([binaryProto], 1337); }); // Invoke the object detector - const detections = objectDetector.detect({} as HTMLImageElement); + const {detections} = objectDetector.detect({} as HTMLImageElement); expect(objectDetector.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); - expect(detections.length).toEqual(2); + expect(detections.length).toEqual(1); expect(detections[0]).toEqual({ categories: [{ score: 0.1, - index: 1, - categoryName: 'foo', - displayName: 'bar', - }], - boundingBox: {originX: 1, originY: 2, width: 3, height: 4} - }); - expect(detections[1]).toEqual({ - categories: [{ - score: 0.2, index: -1, categoryName: '', displayName: '', diff --git a/mediapipe/tasks/web/vision/types.ts b/mediapipe/tasks/web/vision/types.ts index 2756b05a5..f49161adf 100644 --- a/mediapipe/tasks/web/vision/types.ts +++ b/mediapipe/tasks/web/vision/types.ts @@ -15,6 +15,7 @@ */ export * from '../../../tasks/web/core/fileset_resolver'; +export * from '../../../tasks/web/vision/face_detector/face_detector'; export * from '../../../tasks/web/vision/face_landmarker/face_landmarker'; export * from '../../../tasks/web/vision/face_stylizer/face_stylizer'; export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; diff --git a/mediapipe/util/pose_util.cc b/mediapipe/util/pose_util.cc index dd907fcdd..e5d1f2c9f 100644 --- a/mediapipe/util/pose_util.cc +++ b/mediapipe/util/pose_util.cc @@ -89,19 +89,6 @@ void DrawPose(const mediapipe::NormalizedLandmarkList& pose, int target_width, constexpr int draw_line_width = 5; constexpr int draw_circle_radius = 7; - const int lm = static_cast(PoseLandmarkName::kMouthLeft); - const int rm = static_cast(PoseLandmarkName::kMouthRight); - const int ls = static_cast(PoseLandmarkName::kLeftShoulder); - const int rs = static_cast(PoseLandmarkName::kRightShoulder); - if (visible_landmarks.find(lm) != visible_landmarks.end() && - visible_landmarks.find(rm) != visible_landmarks.end() && - visible_landmarks.find(ls) != visible_landmarks.end() && - visible_landmarks.find(rs) != visible_landmarks.end()) { - cv::line(*image, (visible_landmarks[lm] + visible_landmarks[rm]) * 0.5f, - (visible_landmarks[ls] + visible_landmarks[rs]) * 0.5f, - cv::Scalar(255, 255, 255), draw_line_width); - } - for (int j = 0; j < 35; ++j) { if (visible_landmarks.find(kJointConnection[j][0]) != visible_landmarks.end() && @@ -115,6 +102,19 @@ void DrawPose(const mediapipe::NormalizedLandmarkList& pose, int target_width, } } + const int lm = static_cast(PoseLandmarkName::kMouthLeft); + const int rm = static_cast(PoseLandmarkName::kMouthRight); + const int ls = static_cast(PoseLandmarkName::kLeftShoulder); + const int rs = static_cast(PoseLandmarkName::kRightShoulder); + if (visible_landmarks.find(lm) != visible_landmarks.end() && + visible_landmarks.find(rm) != visible_landmarks.end() && + visible_landmarks.find(ls) != visible_landmarks.end() && + visible_landmarks.find(rs) != visible_landmarks.end()) { + cv::line(*image, (visible_landmarks[lm] + visible_landmarks[rm]) * 0.5f, + (visible_landmarks[ls] + visible_landmarks[rs]) * 0.5f, + cv::Scalar(255, 255, 255), draw_line_width); + } + for (const auto& landmark : visible_landmarks) { cv::circle(*image, landmark.second, draw_circle_radius, cv::Scalar(kJointColorMap[landmark.first][0], diff --git a/mediapipe/util/sequence/media_sequence.cc b/mediapipe/util/sequence/media_sequence.cc index f76c53295..287db6181 100644 --- a/mediapipe/util/sequence/media_sequence.cc +++ b/mediapipe/util/sequence/media_sequence.cc @@ -57,13 +57,13 @@ bool ImageMetadata(const std::string& image_str, std::string* format_string, // Finds the nearest timestamp in a FeatureList of timestamps. The FeatureList // must contain int64 values and only the first value at each step is used. -int NearestIndex(int64 timestamp, +int NearestIndex(int64_t timestamp, const tensorflow::FeatureList& int64_feature_list) { - int64 closest_distance = std::numeric_limits::max(); + int64_t closest_distance = std::numeric_limits::max(); int index = -1; for (int i = 0; i < int64_feature_list.feature_size(); ++i) { - int64 current_value = int64_feature_list.feature(i).int64_list().value(0); - int64 current_distance = std::abs(current_value - timestamp); + int64_t current_value = int64_feature_list.feature(i).int64_list().value(0); + int64_t current_distance = std::abs(current_value - timestamp); if (current_distance < closest_distance) { index = i; closest_distance = current_distance; @@ -74,8 +74,8 @@ int NearestIndex(int64 timestamp, // Find the numerical sampling rate between two values in seconds if the input // timestamps are in microseconds. -float TimestampsToRate(int64 first_timestamp, int64 second_timestamp) { - int64 timestamp_diff = second_timestamp - first_timestamp; +float TimestampsToRate(int64_t first_timestamp, int64_t second_timestamp) { + int64_t timestamp_diff = second_timestamp - first_timestamp; // convert from microseconds to seconds. float rate = 1.0 / (static_cast(timestamp_diff) / 1000000); return rate; @@ -100,18 +100,18 @@ absl::Status ReconcileAnnotationIndicesByImageTimestamps( << "start: " << segment_size << ", end: " << GetSegmentEndTimestampSize(*sequence); - std::vector start_indices; + std::vector start_indices; start_indices.reserve(segment_size); - for (const int64& timestamp : GetSegmentStartTimestamp(*sequence)) { + for (const int64_t& timestamp : GetSegmentStartTimestamp(*sequence)) { index = NearestIndex(timestamp, GetFeatureList(*sequence, kImageTimestampKey)); start_indices.push_back(index); } SetSegmentStartIndex(start_indices, sequence); - std::vector end_indices; + std::vector end_indices; end_indices.reserve(segment_size); - for (const int64& timestamp : GetSegmentEndTimestamp(*sequence)) { + for (const int64_t& timestamp : GetSegmentEndTimestamp(*sequence)) { index = NearestIndex(timestamp, GetFeatureList(*sequence, kImageTimestampKey)); end_indices.push_back(index); @@ -167,8 +167,8 @@ absl::Status ReconcileMetadataFeatureFloats( int number_of_elements = GetFeatureFloatsAt(prefix, *sequence, 0).size(); if (HasFeatureDimensions(prefix, *sequence) && !GetFeatureDimensions(prefix, *sequence).empty()) { - int64 product = 1; - for (int64 value : GetFeatureDimensions(prefix, *sequence)) { + int64_t product = 1; + for (int64_t value : GetFeatureDimensions(prefix, *sequence)) { product *= value; } RET_CHECK_EQ(number_of_elements, product) @@ -249,14 +249,14 @@ absl::Status ReconcileMetadataBoxAnnotations( // Collect which timestamps currently match to which indices in timestamps. // skip empty timestamps. // Requires sorted indices. - ::std::vector box_timestamps(num_bboxes); + ::std::vector box_timestamps(num_bboxes); int bbox_index = 0; std::string timestamp_key = merge_prefix(prefix, kRegionTimestampKey); for (auto& feature : GetFeatureList(*sequence, timestamp_key).feature()) { box_timestamps[bbox_index] = feature.int64_list().value(0); ++bbox_index; } - ::std::vector box_is_annotated(num_bboxes); + ::std::vector box_is_annotated(num_bboxes); bbox_index = 0; std::string is_annotated_key = merge_prefix(prefix, kRegionIsAnnotatedKey); for (auto& feature : @@ -264,7 +264,7 @@ absl::Status ReconcileMetadataBoxAnnotations( box_is_annotated[bbox_index] = feature.int64_list().value(0); ++bbox_index; } - ::std::vector image_timestamps(num_frames); + ::std::vector image_timestamps(num_frames); int frame_index = 0; for (auto& feature : GetFeatureList(*sequence, kImageTimestampKey).feature()) { diff --git a/mediapipe/util/sequence/media_sequence_test.cc b/mediapipe/util/sequence/media_sequence_test.cc index 0797ed472..e220eace0 100644 --- a/mediapipe/util/sequence/media_sequence_test.cc +++ b/mediapipe/util/sequence/media_sequence_test.cc @@ -67,7 +67,7 @@ TEST(MediaSequenceTest, RoundTripEncodedMediaBytes) { TEST(MediaSequenceTest, RoundTripEncodedVideoStartTimestamp) { tensorflow::SequenceExample sequence; - int64 data = 47; + int64_t data = 47; SetClipEncodedMediaStartTimestamp(data, &sequence); ASSERT_EQ(GetClipEncodedMediaStartTimestamp(sequence), data); } @@ -92,7 +92,7 @@ TEST(MediaSequenceTest, RoundTripClipEndTimestamp) { TEST(MediaSequenceTest, RoundTripClipLabelIndex) { tensorflow::SequenceExample sequence; - std::vector label = {5, 3}; + std::vector label = {5, 3}; SetClipLabelIndex(label, &sequence); ASSERT_THAT(GetClipLabelIndex(sequence), testing::ElementsAreArray(label)); } @@ -115,46 +115,46 @@ TEST(MediaSequenceTest, RoundTripFloatListFrameRate) { TEST(MediaSequenceTest, RoundTripSegmentStartTimestamp) { tensorflow::SequenceExample sequence; EXPECT_FALSE(HasContext(sequence, kSegmentStartTimestampKey)); - SetSegmentStartTimestamp(::std::vector({123, 456}), &sequence); + SetSegmentStartTimestamp(::std::vector({123, 456}), &sequence); ASSERT_EQ(2, GetSegmentStartTimestampSize(sequence)); ASSERT_THAT(GetSegmentStartTimestamp(sequence), - testing::ElementsAreArray(::std::vector({123, 456}))); + testing::ElementsAreArray(::std::vector({123, 456}))); } TEST(MediaSequenceTest, RoundTripSegmentEndTimestamp) { tensorflow::SequenceExample sequence; EXPECT_FALSE(HasContext(sequence, kSegmentEndTimestampKey)); - SetSegmentEndTimestamp(::std::vector({123, 456}), &sequence); + SetSegmentEndTimestamp(::std::vector({123, 456}), &sequence); ASSERT_EQ(2, GetSegmentEndTimestampSize(sequence)); ASSERT_THAT(GetSegmentEndTimestamp(sequence), - testing::ElementsAreArray(::std::vector({123, 456}))); + testing::ElementsAreArray(::std::vector({123, 456}))); } TEST(MediaSequenceTest, RoundTripSegmentStartIndex) { tensorflow::SequenceExample sequence; EXPECT_FALSE(HasContext(sequence, kSegmentStartIndexKey)); - SetSegmentStartIndex(::std::vector({123, 456}), &sequence); + SetSegmentStartIndex(::std::vector({123, 456}), &sequence); ASSERT_EQ(2, GetSegmentStartIndexSize(sequence)); ASSERT_THAT(GetSegmentStartIndex(sequence), - testing::ElementsAreArray(::std::vector({123, 456}))); + testing::ElementsAreArray(::std::vector({123, 456}))); } TEST(MediaSequenceTest, RoundTripSegmentEndIndex) { tensorflow::SequenceExample sequence; EXPECT_FALSE(HasContext(sequence, kSegmentEndIndexKey)); - SetSegmentEndIndex(::std::vector({123, 456}), &sequence); + SetSegmentEndIndex(::std::vector({123, 456}), &sequence); ASSERT_EQ(2, GetSegmentEndIndexSize(sequence)); ASSERT_THAT(GetSegmentEndIndex(sequence), - testing::ElementsAreArray(::std::vector({123, 456}))); + testing::ElementsAreArray(::std::vector({123, 456}))); } TEST(MediaSequenceTest, RoundTripSegmentLabelIndex) { tensorflow::SequenceExample sequence; EXPECT_FALSE(HasContext(sequence, kSegmentLabelIndexKey)); - SetSegmentLabelIndex(::std::vector({5, 7}), &sequence); + SetSegmentLabelIndex(::std::vector({5, 7}), &sequence); ASSERT_EQ(2, GetSegmentLabelIndexSize(sequence)); ASSERT_THAT(GetSegmentLabelIndex(sequence), - testing::ElementsAreArray(::std::vector({5, 7}))); + testing::ElementsAreArray(::std::vector({5, 7}))); } TEST(MediaSequenceTest, RoundTripSegmentLabelString) { @@ -180,8 +180,8 @@ TEST(MediaSequenceTest, RoundTripSegmentLabelConfidence) { TEST(MediaSequenceTest, RoundTripImageWidthHeight) { tensorflow::SequenceExample sequence; - int64 height = 2; - int64 width = 3; + int64_t height = 2; + int64_t width = 3; SetImageHeight(height, &sequence); ASSERT_EQ(GetImageHeight(sequence), height); SetImageWidth(width, &sequence); @@ -190,8 +190,8 @@ TEST(MediaSequenceTest, RoundTripImageWidthHeight) { TEST(MediaSequenceTest, RoundTripForwardFlowWidthHeight) { tensorflow::SequenceExample sequence; - int64 height = 2; - int64 width = 3; + int64_t height = 2; + int64_t width = 3; SetForwardFlowHeight(height, &sequence); ASSERT_EQ(GetForwardFlowHeight(sequence), height); SetForwardFlowWidth(width, &sequence); @@ -200,8 +200,8 @@ TEST(MediaSequenceTest, RoundTripForwardFlowWidthHeight) { TEST(MediaSequenceTest, RoundTripClassSegmentationWidthHeightFormat) { tensorflow::SequenceExample sequence; - int64 height = 2; - int64 width = 3; + int64_t height = 2; + int64_t width = 3; std::string format = "JPEG"; SetClassSegmentationHeight(height, &sequence); EXPECT_EQ(GetClassSegmentationHeight(sequence), height); @@ -213,7 +213,7 @@ TEST(MediaSequenceTest, RoundTripClassSegmentationWidthHeightFormat) { TEST(MediaSequenceTest, RoundTripClassSegmentationLabelIndex) { tensorflow::SequenceExample sequence; - std::vector classes = {5, 3}; + std::vector classes = {5, 3}; SetClassSegmentationClassLabelIndex(classes, &sequence); ASSERT_THAT(GetClassSegmentationClassLabelIndex(sequence), testing::ElementsAreArray({5, 3})); @@ -233,8 +233,8 @@ TEST(MediaSequenceTest, RoundTripClassSegmentationLabelString) { TEST(MediaSequenceTest, RoundTripInstanceSegmentationWidthHeightFormat) { tensorflow::SequenceExample sequence; - int64 height = 2; - int64 width = 3; + int64_t height = 2; + int64_t width = 3; std::string format = "JPEG"; SetInstanceSegmentationHeight(height, &sequence); EXPECT_EQ(GetInstanceSegmentationHeight(sequence), height); @@ -246,7 +246,7 @@ TEST(MediaSequenceTest, RoundTripInstanceSegmentationWidthHeightFormat) { TEST(MediaSequenceTest, RoundTripInstanceSegmentationClass) { tensorflow::SequenceExample sequence; - std::vector classes = {5, 3}; + std::vector classes = {5, 3}; SetInstanceSegmentationObjectClassIndex(classes, &sequence); ASSERT_THAT(GetInstanceSegmentationObjectClassIndex(sequence), testing::ElementsAreArray({5, 3})); @@ -286,7 +286,7 @@ TEST(MediaSequenceTest, RoundTripBBoxNumRegions) { TEST(MediaSequenceTest, RoundTripBBoxLabelIndex) { tensorflow::SequenceExample sequence; - std::vector> labels = {{5, 3}, {1, 2}}; + std::vector> labels = {{5, 3}, {1, 2}}; for (int i = 0; i < labels.size(); ++i) { AddBBoxLabelIndex(labels[i], &sequence); ASSERT_EQ(GetBBoxLabelIndexSize(sequence), i + 1); @@ -312,7 +312,7 @@ TEST(MediaSequenceTest, RoundTripBBoxLabelString) { TEST(MediaSequenceTest, RoundTripBBoxClassIndex) { tensorflow::SequenceExample sequence; - std::vector> classes = {{5, 3}, {1, 2}}; + std::vector> classes = {{5, 3}, {1, 2}}; for (int i = 0; i < classes.size(); ++i) { AddBBoxClassIndex(classes[i], &sequence); ASSERT_EQ(GetBBoxClassIndexSize(sequence), i + 1); @@ -338,7 +338,7 @@ TEST(MediaSequenceTest, RoundTripBBoxClassString) { TEST(MediaSequenceTest, RoundTripBBoxTrackIndex) { tensorflow::SequenceExample sequence; - std::vector> tracks = {{5, 3}, {1, 2}}; + std::vector> tracks = {{5, 3}, {1, 2}}; for (int i = 0; i < tracks.size(); ++i) { AddBBoxTrackIndex(tracks[i], &sequence); ASSERT_EQ(GetBBoxTrackIndexSize(sequence), i + 1); @@ -499,7 +499,7 @@ TEST(MediaSequenceTest, RoundTripPredictedBBox) { TEST(MediaSequenceTest, RoundTripPredictedBBoxTimestamp) { tensorflow::SequenceExample sequence; - std::vector timestamps = {3, 6}; + std::vector timestamps = {3, 6}; for (int i = 0; i < timestamps.size(); ++i) { AddPredictedBBoxTimestamp(timestamps[i], &sequence); ASSERT_EQ(GetPredictedBBoxTimestampSize(sequence), i + 1); @@ -659,7 +659,7 @@ TEST(MediaSequenceTest, RoundTripContextFeatureBytes) { TEST(MediaSequenceTest, RoundTripContextFeatureInts) { tensorflow::SequenceExample sequence; std::string feature_key = "TEST"; - std::vector vi = {0, 1, 2, 4}; + std::vector vi = {0, 1, 2, 4}; SetContextFeatureInts(feature_key, vi, &sequence); ASSERT_EQ(GetContextFeatureInts(feature_key, sequence).size(), vi.size()); ASSERT_EQ(GetContextFeatureInts(feature_key, sequence)[3], vi[3]); @@ -725,7 +725,7 @@ TEST(MediaSequenceTest, RoundTripTextContent) { TEST(MediaSequenceTest, RoundTripTextDuration) { tensorflow::SequenceExample sequence; - std::vector timestamps = {4, 7}; + std::vector timestamps = {4, 7}; for (int i = 0; i < timestamps.size(); ++i) { AddTextTimestamp(timestamps[i], &sequence); ASSERT_EQ(GetTextTimestampSize(sequence), i + 1); @@ -765,7 +765,7 @@ TEST(MediaSequenceTest, RoundTripTextEmbedding) { TEST(MediaSequenceTest, RoundTripTextTokenId) { tensorflow::SequenceExample sequence; - std::vector ids = {4, 7}; + std::vector ids = {4, 7}; for (int i = 0; i < ids.size(); ++i) { AddTextTokenId(ids[i], &sequence); ASSERT_EQ(GetTextTokenIdSize(sequence), i + 1); @@ -783,8 +783,8 @@ TEST(MediaSequenceTest, ReconcileMetadataOnEmptySequence) { TEST(MediaSequenceTest, ReconcileMetadataImagestoLabels) { // Need image timestamps and label timestamps. tensorflow::SequenceExample sequence; - SetSegmentStartTimestamp(::std::vector({3, 4}), &sequence); - SetSegmentEndTimestamp(::std::vector({4, 5}), &sequence); + SetSegmentStartTimestamp(::std::vector({3, 4}), &sequence); + SetSegmentEndTimestamp(::std::vector({4, 5}), &sequence); // Skip 0, so the indices are the timestamp - 1 AddImageTimestamp(1, &sequence); @@ -1027,20 +1027,20 @@ TEST(MediaSequenceTest, ReconcileMetadataBoxAnnotationsUpdatesAllFeatures) { AddBBoxNumRegions(1, &sequence); AddBBoxNumRegions(1, &sequence); - AddBBoxLabelIndex(::std::vector({1}), &sequence); - AddBBoxLabelIndex(::std::vector({2}), &sequence); + AddBBoxLabelIndex(::std::vector({1}), &sequence); + AddBBoxLabelIndex(::std::vector({2}), &sequence); AddBBoxLabelString(::std::vector({"one"}), &sequence); AddBBoxLabelString(::std::vector({"two"}), &sequence); - AddBBoxClassIndex(::std::vector({1}), &sequence); - AddBBoxClassIndex(::std::vector({2}), &sequence); + AddBBoxClassIndex(::std::vector({1}), &sequence); + AddBBoxClassIndex(::std::vector({2}), &sequence); AddBBoxClassString(::std::vector({"one"}), &sequence); AddBBoxClassString(::std::vector({"two"}), &sequence); - AddBBoxTrackIndex(::std::vector({1}), &sequence); - AddBBoxTrackIndex(::std::vector({2}), &sequence); + AddBBoxTrackIndex(::std::vector({1}), &sequence); + AddBBoxTrackIndex(::std::vector({2}), &sequence); AddBBoxTrackString(::std::vector({"one"}), &sequence); AddBBoxTrackString(::std::vector({"two"}), &sequence); @@ -1083,11 +1083,11 @@ TEST(MediaSequenceTest, ReconcileMetadataBoxAnnotationsUpdatesAllFeatures) { ASSERT_THAT(GetBBoxLabelIndexAt(sequence, 1), ::testing::ElementsAreArray({2})); ASSERT_THAT(GetBBoxLabelIndexAt(sequence, 2), - ::testing::ElementsAreArray(::std::vector())); + ::testing::ElementsAreArray(::std::vector())); ASSERT_THAT(GetBBoxLabelIndexAt(sequence, 3), - ::testing::ElementsAreArray(::std::vector())); + ::testing::ElementsAreArray(::std::vector())); ASSERT_THAT(GetBBoxLabelIndexAt(sequence, 4), - ::testing::ElementsAreArray(::std::vector())); + ::testing::ElementsAreArray(::std::vector())); ASSERT_THAT(GetBBoxLabelStringAt(sequence, 0), ::testing::ElementsAreArray({"one"})); @@ -1105,11 +1105,11 @@ TEST(MediaSequenceTest, ReconcileMetadataBoxAnnotationsUpdatesAllFeatures) { ASSERT_THAT(GetBBoxClassIndexAt(sequence, 1), ::testing::ElementsAreArray({2})); ASSERT_THAT(GetBBoxClassIndexAt(sequence, 2), - ::testing::ElementsAreArray(::std::vector())); + ::testing::ElementsAreArray(::std::vector())); ASSERT_THAT(GetBBoxClassIndexAt(sequence, 3), - ::testing::ElementsAreArray(::std::vector())); + ::testing::ElementsAreArray(::std::vector())); ASSERT_THAT(GetBBoxClassIndexAt(sequence, 4), - ::testing::ElementsAreArray(::std::vector())); + ::testing::ElementsAreArray(::std::vector())); ASSERT_THAT(GetBBoxClassStringAt(sequence, 0), ::testing::ElementsAreArray({"one"})); @@ -1127,11 +1127,11 @@ TEST(MediaSequenceTest, ReconcileMetadataBoxAnnotationsUpdatesAllFeatures) { ASSERT_THAT(GetBBoxTrackIndexAt(sequence, 1), ::testing::ElementsAreArray({2})); ASSERT_THAT(GetBBoxTrackIndexAt(sequence, 2), - ::testing::ElementsAreArray(::std::vector())); + ::testing::ElementsAreArray(::std::vector())); ASSERT_THAT(GetBBoxTrackIndexAt(sequence, 3), - ::testing::ElementsAreArray(::std::vector())); + ::testing::ElementsAreArray(::std::vector())); ASSERT_THAT(GetBBoxTrackIndexAt(sequence, 4), - ::testing::ElementsAreArray(::std::vector())); + ::testing::ElementsAreArray(::std::vector())); ASSERT_THAT(GetBBoxTrackStringAt(sequence, 0), ::testing::ElementsAreArray({"one"})); diff --git a/mediapipe/util/sequence/media_sequence_util_test.cc b/mediapipe/util/sequence/media_sequence_util_test.cc index 56d3b4868..8709165e3 100644 --- a/mediapipe/util/sequence/media_sequence_util_test.cc +++ b/mediapipe/util/sequence/media_sequence_util_test.cc @@ -253,7 +253,7 @@ TEST_F(MediaSequenceUtilTest, RoundTripFloatList) { TEST_F(MediaSequenceUtilTest, RoundTripInt64List) { tensorflow::SequenceExample sequence_example; std::string key = "key"; - std::vector expected_values{1, 3}; + std::vector expected_values{1, 3}; AddInt64Container(key, expected_values, &sequence_example); auto values = GetInt64sAt(sequence_example, key, 0); ASSERT_EQ(expected_values.size(), values.size()); @@ -302,7 +302,7 @@ TEST_F(MediaSequenceUtilTest, RoundTripContextFeatureList) { } // Test context in64 list. std::string clip_label_index_key = "clip_label_index"; - std::vector clip_label_indices{2, 0}; + std::vector clip_label_indices{2, 0}; SetContextInt64List(clip_label_index_key, clip_label_indices, &sequence_example); for (int i = 0; i < clip_label_indices.size(); ++i) { @@ -333,7 +333,7 @@ TEST_F(MediaSequenceUtilTest, ContextKeyMissing) { TEST_F(MediaSequenceUtilTest, RoundTripFeatureListsFeature) { tensorflow::SequenceExample sequence_example; std::string timestamp_key = "timestamp"; - int64 timestamp = 1000; + int64_t timestamp = 1000; MutableFeatureList(timestamp_key, &sequence_example) ->add_feature() ->mutable_int64_list() @@ -413,7 +413,7 @@ TEST_F(MediaSequenceUtilTest, StringFeature) { TEST_F(MediaSequenceUtilTest, Int64Feature) { tensorflow::SequenceExample example; - int64 test_value = 47; + int64_t test_value = 47; ASSERT_FALSE(HasInt64Feature(example)); SetInt64Feature(test_value, &example); @@ -426,7 +426,7 @@ TEST_F(MediaSequenceUtilTest, Int64Feature) { TEST_F(MediaSequenceUtilTest, FloatFeature) { tensorflow::SequenceExample example; - int64 test_value = 47.0f; + int64_t test_value = 47.0f; ASSERT_FALSE(HasFloatFeature(example)); SetFloatFeature(test_value, &example); @@ -464,7 +464,7 @@ TEST_F(MediaSequenceUtilTest, StringVectorFeature) { TEST_F(MediaSequenceUtilTest, Int64VectorFeature) { tensorflow::SequenceExample example; - ::std::vector test_value = {47, 42}; + ::std::vector test_value = {47, 42}; ASSERT_FALSE(HasInt64VectorFeature(example)); ASSERT_EQ(0, GetInt64VectorFeatureSize(example)); @@ -535,7 +535,7 @@ TEST_F(MediaSequenceUtilTest, StringFeatureList) { TEST_F(MediaSequenceUtilTest, Int64FeatureList) { tensorflow::SequenceExample example; - ::std::vector test_value = {47, 42}; + ::std::vector test_value = {47, 42}; ASSERT_FALSE(HasInt64FeatureList(example)); ASSERT_EQ(0, GetInt64FeatureListSize(example)); @@ -602,7 +602,7 @@ TEST_F(MediaSequenceUtilTest, VectorStringFeatureList) { TEST_F(MediaSequenceUtilTest, VectorInt64FeatureList) { tensorflow::SequenceExample example; - ::std::vector<::std::vector> test_value = {{47, 42}, {3, 5}}; + ::std::vector<::std::vector> test_value = {{47, 42}, {3, 5}}; ASSERT_FALSE(HasVectorInt64FeatureList(example)); ASSERT_EQ(0, GetVectorInt64FeatureListSize(example)); @@ -704,8 +704,8 @@ TEST_F(MediaSequenceUtilTest, VariablePrefixStringFeature) { TEST_F(MediaSequenceUtilTest, FixedPrefixInt64Feature) { tensorflow::SequenceExample example; - int64 test_value_1 = 47; - int64 test_value_2 = 49; + int64_t test_value_1 = 47; + int64_t test_value_2 = 49; ASSERT_FALSE(HasOneInt64Feature(example)); SetOneInt64Feature(test_value_1, &example); @@ -727,8 +727,8 @@ TEST_F(MediaSequenceUtilTest, FixedPrefixInt64Feature) { TEST_F(MediaSequenceUtilTest, FixedPrefixFloatFeature) { tensorflow::SequenceExample example; - int64 test_value_1 = 47.0f; - int64 test_value_2 = 49.0f; + int64_t test_value_1 = 47.0f; + int64_t test_value_2 = 49.0f; ASSERT_FALSE(HasOneFloatFeature(example)); SetOneFloatFeature(test_value_1, &example); @@ -795,8 +795,8 @@ TEST_F(MediaSequenceUtilTest, FixedPrefixStringVectorFeature) { TEST_F(MediaSequenceUtilTest, FixedPrefixInt64VectorFeature) { tensorflow::SequenceExample example; - ::std::vector test_value_1 = {47, 42}; - ::std::vector test_value_2 = {49, 47}; + ::std::vector test_value_1 = {47, 42}; + ::std::vector test_value_2 = {49, 47}; ASSERT_FALSE(HasOneInt64VectorFeature(example)); ASSERT_EQ(0, GetOneInt64VectorFeatureSize(example)); @@ -905,7 +905,7 @@ TEST_F(MediaSequenceUtilTest, FixedPrefixStringFeatureList) { TEST_F(MediaSequenceUtilTest, FixedPrefixInt64FeatureList) { tensorflow::SequenceExample example; - ::std::vector test_value = {47, 42}; + ::std::vector test_value = {47, 42}; ASSERT_FALSE(HasInt64FeatureList(example)); ASSERT_EQ(0, GetInt64FeatureListSize(example)); @@ -990,8 +990,8 @@ TEST_F(MediaSequenceUtilTest, FixedPrefixVectorStringFeatureList) { TEST_F(MediaSequenceUtilTest, FixedPrefixVectorInt64FeatureList) { tensorflow::SequenceExample example; - ::std::vector<::std::vector> test_value_1 = {{47, 42}, {3, 5}}; - ::std::vector<::std::vector> test_value_2 = {{49, 47}, {3, 5}}; + ::std::vector<::std::vector> test_value_1 = {{47, 42}, {3, 5}}; + ::std::vector<::std::vector> test_value_2 = {{49, 47}, {3, 5}}; ASSERT_FALSE(HasOneVectorInt64FeatureList(example)); ASSERT_EQ(0, GetOneVectorInt64FeatureListSize(example)); diff --git a/mediapipe/util/tensor_to_detection.cc b/mediapipe/util/tensor_to_detection.cc index 0b3d1f68a..1c19b3510 100644 --- a/mediapipe/util/tensor_to_detection.cc +++ b/mediapipe/util/tensor_to_detection.cc @@ -87,7 +87,7 @@ Status TensorsToDetections(const ::tensorflow::Tensor& num_detections, const auto& num_boxes_scalar = num_detections.scalar(); num_boxes = static_cast(num_boxes_scalar()); } else { - num_boxes = num_detections.scalar()(); + num_boxes = num_detections.scalar()(); } if (boxes.dim_size(0) < num_boxes) { return InvalidArgumentError( diff --git a/mediapipe/util/time_series_util.cc b/mediapipe/util/time_series_util.cc index 1e20daa59..87f69475a 100644 --- a/mediapipe/util/time_series_util.cc +++ b/mediapipe/util/time_series_util.cc @@ -29,7 +29,7 @@ namespace time_series_util { bool LogWarningIfTimestampIsInconsistent(const Timestamp& current_timestamp, const Timestamp& initial_timestamp, - int64 cumulative_samples, + int64_t cumulative_samples, double sample_rate) { // Ignore the "special" timestamp value Done(). if (current_timestamp == Timestamp::Done()) return true; @@ -122,11 +122,11 @@ absl::Status IsMatrixShapeConsistentWithHeader(const Matrix& matrix, return absl::OkStatus(); } -int64 SecondsToSamples(double time_in_seconds, double sample_rate) { +int64_t SecondsToSamples(double time_in_seconds, double sample_rate) { return round(time_in_seconds * sample_rate); } -double SamplesToSeconds(int64 num_samples, double sample_rate) { +double SamplesToSeconds(int64_t num_samples, double sample_rate) { DCHECK_NE(sample_rate, 0.0); return (num_samples / sample_rate); } diff --git a/mediapipe/util/time_series_util_test.cc b/mediapipe/util/time_series_util_test.cc index 807bc4f03..e8d47dbc6 100644 --- a/mediapipe/util/time_series_util_test.cc +++ b/mediapipe/util/time_series_util_test.cc @@ -186,7 +186,7 @@ TEST(TimeSeriesUtilTest, SecondsToSamples) { TEST(TimeSeriesUtilTest, SamplesToSeconds) { double sample_rate = 32.5; - int64 num_samples = 128; + int64_t num_samples = 128; EXPECT_EQ(num_samples / sample_rate, SamplesToSeconds(num_samples, sample_rate)); } diff --git a/mediapipe/util/tracking/BUILD b/mediapipe/util/tracking/BUILD index 816af2533..5a271ffac 100644 --- a/mediapipe/util/tracking/BUILD +++ b/mediapipe/util/tracking/BUILD @@ -13,24 +13,24 @@ # limitations under the License. # -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) package(default_visibility = ["//visibility:public"]) -proto_library( +mediapipe_proto_library( name = "tone_models_proto", srcs = ["tone_models.proto"], ) -proto_library( +mediapipe_proto_library( name = "tone_estimation_proto", srcs = ["tone_estimation.proto"], deps = [":tone_models_proto"], ) -proto_library( +mediapipe_proto_library( name = "region_flow_computation_proto", srcs = ["region_flow_computation.proto"], deps = [ @@ -38,17 +38,17 @@ proto_library( ], ) -proto_library( +mediapipe_proto_library( name = "motion_saliency_proto", srcs = ["motion_saliency.proto"], ) -proto_library( +mediapipe_proto_library( name = "motion_estimation_proto", srcs = ["motion_estimation.proto"], ) -proto_library( +mediapipe_proto_library( name = "motion_analysis_proto", srcs = ["motion_analysis.proto"], deps = [ @@ -58,33 +58,33 @@ proto_library( ], ) -proto_library( +mediapipe_proto_library( name = "region_flow_proto", srcs = ["region_flow.proto"], ) -proto_library( +mediapipe_proto_library( name = "motion_models_proto", srcs = ["motion_models.proto"], ) -proto_library( +mediapipe_proto_library( name = "camera_motion_proto", srcs = ["camera_motion.proto"], deps = [":motion_models_proto"], ) -proto_library( +mediapipe_proto_library( name = "push_pull_filtering_proto", srcs = ["push_pull_filtering.proto"], ) -proto_library( +mediapipe_proto_library( name = "frame_selection_solution_evaluator_proto", srcs = ["frame_selection_solution_evaluator.proto"], ) -proto_library( +mediapipe_proto_library( name = "frame_selection_proto", srcs = ["frame_selection.proto"], deps = [ @@ -94,7 +94,7 @@ proto_library( ], ) -proto_library( +mediapipe_proto_library( name = "flow_packager_proto", srcs = ["flow_packager.proto"], deps = [ @@ -103,7 +103,7 @@ proto_library( ], ) -proto_library( +mediapipe_proto_library( name = "tracking_proto", srcs = ["tracking.proto"], deps = [ @@ -111,18 +111,18 @@ proto_library( ], ) -proto_library( +mediapipe_proto_library( name = "box_tracker_proto", srcs = ["box_tracker.proto"], deps = [":tracking_proto"], ) -proto_library( +mediapipe_proto_library( name = "tracked_detection_manager_config_proto", srcs = ["tracked_detection_manager_config.proto"], ) -proto_library( +mediapipe_proto_library( name = "box_detector_proto", srcs = ["box_detector.proto"], deps = [ @@ -131,135 +131,6 @@ proto_library( ], ) -mediapipe_cc_proto_library( - name = "tone_models_cc_proto", - srcs = ["tone_models.proto"], - deps = [":tone_models_proto"], -) - -mediapipe_cc_proto_library( - name = "tone_estimation_cc_proto", - srcs = ["tone_estimation.proto"], - cc_deps = [":tone_models_cc_proto"], - deps = [":tone_estimation_proto"], -) - -mediapipe_cc_proto_library( - name = "region_flow_computation_cc_proto", - srcs = ["region_flow_computation.proto"], - cc_deps = [ - ":tone_estimation_cc_proto", - ":tone_models_cc_proto", - ], - deps = [":region_flow_computation_proto"], -) - -mediapipe_cc_proto_library( - name = "motion_saliency_cc_proto", - srcs = ["motion_saliency.proto"], - deps = [":motion_saliency_proto"], -) - -mediapipe_cc_proto_library( - name = "motion_estimation_cc_proto", - srcs = ["motion_estimation.proto"], - deps = [":motion_estimation_proto"], -) - -mediapipe_cc_proto_library( - name = "motion_analysis_cc_proto", - srcs = ["motion_analysis.proto"], - cc_deps = [ - ":motion_estimation_cc_proto", - ":motion_saliency_cc_proto", - ":region_flow_computation_cc_proto", - ], - deps = [":motion_analysis_proto"], -) - -mediapipe_cc_proto_library( - name = "region_flow_cc_proto", - srcs = ["region_flow.proto"], - cc_deps = [":motion_models_cc_proto"], - deps = [":region_flow_proto"], -) - -mediapipe_cc_proto_library( - name = "motion_models_cc_proto", - srcs = ["motion_models.proto"], - deps = [":motion_models_proto"], -) - -mediapipe_cc_proto_library( - name = "camera_motion_cc_proto", - srcs = ["camera_motion.proto"], - cc_deps = [":motion_models_cc_proto"], - deps = [":camera_motion_proto"], -) - -mediapipe_cc_proto_library( - name = "push_pull_filtering_cc_proto", - srcs = ["push_pull_filtering.proto"], - deps = [":push_pull_filtering_proto"], -) - -mediapipe_cc_proto_library( - name = "frame_selection_solution_evaluator_cc_proto", - srcs = ["frame_selection_solution_evaluator.proto"], - deps = [":frame_selection_solution_evaluator_proto"], -) - -mediapipe_cc_proto_library( - name = "frame_selection_cc_proto", - srcs = ["frame_selection.proto"], - cc_deps = [ - ":camera_motion_cc_proto", - ":frame_selection_solution_evaluator_cc_proto", - ":region_flow_cc_proto", - ], - deps = [":frame_selection_proto"], -) - -mediapipe_cc_proto_library( - name = "flow_packager_cc_proto", - srcs = ["flow_packager.proto"], - cc_deps = [ - ":motion_models_cc_proto", - ":region_flow_cc_proto", - ], - deps = [":flow_packager_proto"], -) - -mediapipe_cc_proto_library( - name = "tracking_cc_proto", - srcs = ["tracking.proto"], - cc_deps = [":motion_models_cc_proto"], - deps = [":tracking_proto"], -) - -mediapipe_cc_proto_library( - name = "box_tracker_cc_proto", - srcs = ["box_tracker.proto"], - cc_deps = [":tracking_cc_proto"], - deps = [":box_tracker_proto"], -) - -mediapipe_cc_proto_library( - name = "tracked_detection_manager_config_cc_proto", - srcs = ["tracked_detection_manager_config.proto"], - deps = [":tracked_detection_manager_config_proto"], -) - -mediapipe_cc_proto_library( - name = "box_detector_cc_proto", - srcs = ["box_detector.proto"], - cc_deps = [ - ":box_tracker_cc_proto", - ":region_flow_cc_proto", - ], - deps = [":box_detector_proto"], -) - cc_library( name = "motion_models", srcs = ["motion_models.cc"], diff --git a/requirements.txt b/requirements.txt index 326f21694..85d02d59a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,4 @@ matplotlib numpy opencv-contrib-python protobuf>=3.11,<4 -sounddevice +sounddevice>=0.4.4 diff --git a/setup.py b/setup.py index 992430cf1..6e7493caa 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ from setuptools.command import install __version__ = 'dev' IS_WINDOWS = (platform.system() == 'Windows') +IS_MAC = (platform.system() == 'Darwin') MP_ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) MP_DIR_INIT_PY = os.path.join(MP_ROOT_PATH, 'mediapipe/__init__.py') MP_THIRD_PARTY_BUILD = os.path.join(MP_ROOT_PATH, 'third_party/BUILD') @@ -343,11 +344,36 @@ class BuildExtension(build_ext.build_ext): def run(self): _check_bazel() - for ext in self.extensions: - self._build_binary(ext) + if IS_MAC: + for ext in self.extensions: + target_name = self.get_ext_fullpath(ext.name) + # Build x86 + self._build_binary(ext) + x86_name = self.get_ext_fullpath(ext.name) + # Build Arm64 + ext.name = ext.name + '.arm64' + self._build_binary( + ext, + ['--cpu=darwin_arm64', '--ios_multi_cpus=i386,x86_64,armv7,arm64'], + ) + arm64_name = self.get_ext_fullpath(ext.name) + # Merge architectures + lipo_command = [ + 'lipo', + '-create', + '-output', + target_name, + x86_name, + arm64_name, + ] + if subprocess.call(lipo_command) != 0: + sys.exit(-1) + else: + for ext in self.extensions: + self._build_binary(ext) build_ext.build_ext.run(self) - def _build_binary(self, ext): + def _build_binary(self, ext, extra_args=None): if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) bazel_command = [ @@ -359,6 +385,8 @@ class BuildExtension(build_ext.build_ext): '--action_env=PYTHON_BIN_PATH=' + _normalize_path(sys.executable), str(ext.bazel_target + '.so'), ] + if extra_args: + bazel_command += extra_args if not self.link_opencv and not IS_WINDOWS: bazel_command.append('--define=OPENCV=source') if subprocess.call(bazel_command) != 0: diff --git a/third_party/apple_frameworks/BUILD b/third_party/apple_frameworks/BUILD index 05f830e81..62f91b515 100644 --- a/third_party/apple_frameworks/BUILD +++ b/third_party/apple_frameworks/BUILD @@ -32,6 +32,11 @@ cc_library( linkopts = ["-framework Metal"], ) +cc_library( + name = "MetalKit", + linkopts = ["-framework MetalKit"], +) + cc_library( name = "MetalPerformanceShaders", linkopts = ["-framework MetalPerformanceShaders"], diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index e438afd98..d32c973a0 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -67,13 +67,7 @@ def external_files(): http_file( name = "com_google_mediapipe_BUILD", sha256 = "d2b2a8346202691d7f831887c84e9642e974f64ed67851d9a58cf15c94b1f6b3", - urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=166187566369397616783235763936531678737479599640"], - ) - - http_file( - name = "com_google_mediapipe_BUILD_orig", - sha256 = "d86b98b82e00dd87cd46bd1429bf5eaa007b500c1a24d9316b73309f2e6c8df8", - urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1679955080207504"], + urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=1661875663693976167832357639365316787374795996401679955080207504"], ) http_file( @@ -262,6 +256,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/dummy_gesture_recognizer.task?generation=1665707319890725"], ) + http_file( + name = "com_google_mediapipe_dynamic_input_classifier_tflite", + sha256 = "fb34b05e1cd4081f3c2bb882092f617efb19266b3353d51b3790a172cae09784", + urls = ["https://storage.googleapis.com/mediapipe-assets/dynamic_input_classifier.tflite?generation=1680543275416843"], + ) + http_file( name = "com_google_mediapipe_efficientdet_lite0_v1_json", sha256 = "7a9e1fb625a6130a251e612637fc546cfc8cfabfadc7dbdade44c87f1d8996ca", @@ -306,8 +306,8 @@ def external_files(): http_file( name = "com_google_mediapipe_expected_pose_landmarks_prototxt", - sha256 = "75dfd2825fc23f51e3906f3a0a050caa8ae9f502cc358af1e7c9fda7ea89c9a5", - urls = ["https://storage.googleapis.com/mediapipe-assets/expected_pose_landmarks.prototxt?generation=1679955083449778"], + sha256 = "0bb27e9d9729c4171419abf7edd746b4234cb91198d663f3a4363248a49dad1a", + urls = ["https://storage.googleapis.com/mediapipe-assets/expected_pose_landmarks.prototxt?generation=1680543279295598"], ) http_file( @@ -934,12 +934,24 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/pose_expected_detection.pbtxt?generation=1678737492211540"], ) + http_file( + name = "com_google_mediapipe_pose_expected_expanded_rect_pbtxt", + sha256 = "babb2a2d50077f6fa9ee15e30d81abb6e98a920e35acad7542bb3d27b5ce7ffd", + urls = ["https://storage.googleapis.com/mediapipe-assets/pose_expected_expanded_rect.pbtxt?generation=1680543294008098"], + ) + http_file( name = "com_google_mediapipe_pose_jpg", sha256 = "c8a830ed683c0276d713dd5aeda28f415f10cd6291972084a40d0d8b934ed62b", urls = ["https://storage.googleapis.com/mediapipe-assets/pose.jpg?generation=1678737494661975"], ) + http_file( + name = "com_google_mediapipe_pose_landmarker_task", + sha256 = "ca4137626f0dc04f87893ccf2ad01949a3b1d4b55fa85ba957dde44a29dd956e", + urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmarker.task?generation=1680543298177615"], + ) + http_file( name = "com_google_mediapipe_pose_landmark_full_tflite", sha256 = "e9a5c5cb17f736fafd4c2ec1da3b3d331d6edbe8a0d32395855aeb2cdfd64b9f", @@ -1258,12 +1270,6 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/vocab_with_index.txt?generation=1661875977280658"], ) - http_file( - name = "com_google_mediapipe_w_avg_npy", - sha256 = "a044e35609986d18a972532f2980e939832b5b7d559659959d11ecc752a58bbe", - urls = ["https://storage.googleapis.com/mediapipe-assets/w_avg.npy?generation=1679955100435717"], - ) - http_file( name = "com_google_mediapipe_yamnet_audio_classifier_with_metadata_tflite", sha256 = "10c95ea3eb9a7bb4cb8bddf6feb023250381008177ac162ce169694d05c317de", @@ -1276,60 +1282,6 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/yamnet_embedding_metadata.tflite?generation=1668295071595506"], ) - http_file( - name = "com_google_mediapipe_decoder_fingerprint_pb", - sha256 = "0bf6239c4855d78edb60f3349b46cdb2c6f83def64f1b31589b6e298e5cbec3c", - urls = ["https://storage.googleapis.com/mediapipe-assets/decoder/fingerprint.pb?generation=1679955102906559"], - ) - - http_file( - name = "com_google_mediapipe_decoder_keras_metadata_pb", - sha256 = "1631ee698455aea52d4467fe6118800718a86ec49c29f4f3c904785b72f425ff", - urls = ["https://storage.googleapis.com/mediapipe-assets/decoder/keras_metadata.pb?generation=1679955105294959"], - ) - - http_file( - name = "com_google_mediapipe_decoder_saved_model_pb", - sha256 = "b424d30c63548e93390b2944b9bd9dc29773a56197bb462d3bd9e7a0bd1270ff", - urls = ["https://storage.googleapis.com/mediapipe-assets/decoder/saved_model.pb?generation=1679955107808916"], - ) - - http_file( - name = "com_google_mediapipe_discriminator_fingerprint_pb", - sha256 = "1fa6201d253c9218f7054138b9ce273266ce431e00cbce2d74d557f6b97223fd", - urls = ["https://storage.googleapis.com/mediapipe-assets/discriminator/fingerprint.pb?generation=1679955110094297"], - ) - - http_file( - name = "com_google_mediapipe_discriminator_keras_metadata_pb", - sha256 = "59a8601790d615dd37ec24e788743ce737e9999ce6ea6593fcf1ee43f674987f", - urls = ["https://storage.googleapis.com/mediapipe-assets/discriminator/keras_metadata.pb?generation=1679955112486389"], - ) - - http_file( - name = "com_google_mediapipe_discriminator_saved_model_pb", - sha256 = "280d6097a9b3d4c3756e028c597fe3d3c76eb14f76f24d49a22ed7b6df1e3878", - urls = ["https://storage.googleapis.com/mediapipe-assets/discriminator/saved_model.pb?generation=1679955114873053"], - ) - - http_file( - name = "com_google_mediapipe_encoder_fingerprint_pb", - sha256 = "06cb4319f8178edf447a7a2442e89303a14a48cc4fc5ae27354eac2ba11ae120", - urls = ["https://storage.googleapis.com/mediapipe-assets/encoder/fingerprint.pb?generation=1679955117213209"], - ) - - http_file( - name = "com_google_mediapipe_encoder_keras_metadata_pb", - sha256 = "8b1429ee95c130fad0c78077b2b544cd03c9e288658aae93e81df4959b84009e", - urls = ["https://storage.googleapis.com/mediapipe-assets/encoder/keras_metadata.pb?generation=1679955119546778"], - ) - - http_file( - name = "com_google_mediapipe_encoder_saved_model_pb", - sha256 = "ce48392c71485ecd9b142b46e54442581a299df5560102337038b76a62e02a09", - urls = ["https://storage.googleapis.com/mediapipe-assets/encoder/saved_model.pb?generation=1679955122069362"], - ) - http_file( name = "com_google_mediapipe_gesture_embedder_keras_metadata_pb", sha256 = "c76b856101e2284293a5e5963b7c445e407a0b3e56ec63eb78f64d883e51e3aa", @@ -1342,24 +1294,6 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668550484904822"], ) - http_file( - name = "com_google_mediapipe_mapping_fingerprint_pb", - sha256 = "6320890f1b9a57e5f4e50e3b56d96fd39d815aa2de51dd1c9b635aa6107d982b", - urls = ["https://storage.googleapis.com/mediapipe-assets/mapping/fingerprint.pb?generation=1679955124430234"], - ) - - http_file( - name = "com_google_mediapipe_mapping_keras_metadata_pb", - sha256 = "22582a2ec1d4883b52f50e628c1a2d69a2610b38d72a48a0bd9939c26be304f6", - urls = ["https://storage.googleapis.com/mediapipe-assets/mapping/keras_metadata.pb?generation=1679955126858694"], - ) - - http_file( - name = "com_google_mediapipe_mapping_saved_model_pb", - sha256 = "6a79de45d00f49110304bf0a6746bc717c45f77824cad22690f700f2fbdc1470", - urls = ["https://storage.googleapis.com/mediapipe-assets/mapping/saved_model.pb?generation=1679955129259768"], - ) - http_file( name = "com_google_mediapipe_mobilebert_tiny_keras_metadata_pb", sha256 = "cef8131a414c602b9d4742ac57f4f90bc5d8a42baec36b65deece884e2d0cf0f", @@ -1408,42 +1342,6 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/saved_model.pb?generation=1661875999264354"], ) - http_file( - name = "com_google_mediapipe_decoder_variables_variables_data-00000-of-00001", - sha256 = "d720ddf354036f17fa210951f9ebfb009453b244913a493f494f1441cfc2eca3", - urls = ["https://storage.googleapis.com/mediapipe-assets/decoder/variables/variables.data-00000-of-00001?generation=1679955132326947"], - ) - - http_file( - name = "com_google_mediapipe_decoder_variables_variables_index", - sha256 = "245f69af6e53fb8b163059fe9936f57b68a7844e15d696393fcddf94c771dfcc", - urls = ["https://storage.googleapis.com/mediapipe-assets/decoder/variables/variables.index?generation=1679955134518344"], - ) - - http_file( - name = "com_google_mediapipe_discriminator_variables_variables_data-00000-of-00001", - sha256 = "50b00e1898a573588fb0d5d24d74346d99b7153b5d79441d0350c2c6ca89fb02", - urls = ["https://storage.googleapis.com/mediapipe-assets/discriminator/variables/variables.data-00000-of-00001?generation=1679955138489595"], - ) - - http_file( - name = "com_google_mediapipe_discriminator_variables_variables_index", - sha256 = "e5cb4be6442a5741504ce7da9487445637ad89b1f4b6a993bb9e762c7bd5621d", - urls = ["https://storage.googleapis.com/mediapipe-assets/discriminator/variables/variables.index?generation=1679955140891136"], - ) - - http_file( - name = "com_google_mediapipe_encoder_variables_variables_data-00000-of-00001", - sha256 = "09bcd1e2f1c6261bd1842af2da95651d54c9b4b9343eb9b8f0004a97f9bc84bf", - urls = ["https://storage.googleapis.com/mediapipe-assets/encoder/variables/variables.data-00000-of-00001?generation=1679955144875765"], - ) - - http_file( - name = "com_google_mediapipe_encoder_variables_variables_index", - sha256 = "964f5ac6ced7b19f76b7856d9dad47594a5b2fa89c52840f82996b809372aec9", - urls = ["https://storage.googleapis.com/mediapipe-assets/encoder/variables/variables.index?generation=1679955147123313"], - ) - http_file( name = "com_google_mediapipe_gesture_embedder_variables_variables_data-00000-of-00001", sha256 = "c156c9654c9ffb1091bb9f06c71080bd1e428586276d3f39c33fbab27fe0522d", @@ -1456,18 +1354,6 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668550490691823"], ) - http_file( - name = "com_google_mediapipe_mapping_variables_variables_data-00000-of-00001", - sha256 = "4187055e7f69fcc913ee2b11151a56149dda3017c75621d1e160596bde874c07", - urls = ["https://storage.googleapis.com/mediapipe-assets/mapping/variables/variables.data-00000-of-00001?generation=1679955149680908"], - ) - - http_file( - name = "com_google_mediapipe_mapping_variables_variables_index", - sha256 = "a04fcae7083715613f93ac89943f5fe1f5ba2e6efb9efd14eee7314f25502e4a", - urls = ["https://storage.googleapis.com/mediapipe-assets/mapping/variables/variables.index?generation=1679955152034297"], - ) - http_file( name = "com_google_mediapipe_mobilebert_tiny_assets_vocab_txt", sha256 = "07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3", diff --git a/third_party/halide.BUILD b/third_party/halide.BUILD index 02e701585..677fa9f38 100644 --- a/third_party/halide.BUILD +++ b/third_party/halide.BUILD @@ -43,8 +43,8 @@ cc_library( name = "lib_halide_static", srcs = select({ "@halide//:halide_config_windows_x86_64": [ - "lib/Release/Halide.lib", "bin/Release/Halide.dll", + "lib/Release/Halide.lib", ], "//conditions:default": [ "lib/libHalide.a",