diff --git a/.bazelrc b/.bazelrc index 724dd23fd..44bc3d0a1 100644 --- a/.bazelrc +++ b/.bazelrc @@ -87,6 +87,9 @@ build:ios_fat --config=ios build:ios_fat --ios_multi_cpus=armv7,arm64 build:ios_fat --watchos_cpus=armv7k +build:ios_sim_fat --config=ios +build:ios_sim_fat --ios_multi_cpus=x86_64,sim_arm64 + build:darwin_x86_64 --apple_platform_type=macos build:darwin_x86_64 --macos_minimum_os=10.12 build:darwin_x86_64 --cpu=darwin_x86_64 diff --git a/.github/ISSUE_TEMPLATE/task_issue_template.yaml b/.github/ISSUE_TEMPLATE/00-task-issue-template.yaml similarity index 96% rename from .github/ISSUE_TEMPLATE/task_issue_template.yaml rename to .github/ISSUE_TEMPLATE/00-task-issue-template.yaml index da3ec013e..d6130edb6 100644 --- a/.github/ISSUE_TEMPLATE/task_issue_template.yaml +++ b/.github/ISSUE_TEMPLATE/00-task-issue-template.yaml @@ -40,18 +40,16 @@ body: label: Programming Language and version (e.g. C++, Python, Java) validations: required: true - - type: textarea + - type: input id: current_model attributes: label: Describe the actual behavior - render: shell validations: required: true - - type: textarea + - type: input id: expected_model attributes: label: Describe the expected behaviour - render: shell validations: required: true - type: textarea diff --git a/.github/ISSUE_TEMPLATE/model_maker_issue_template.yaml b/.github/ISSUE_TEMPLATE/11-model-maker-issue-template.yaml similarity index 96% rename from .github/ISSUE_TEMPLATE/model_maker_issue_template.yaml rename to .github/ISSUE_TEMPLATE/11-model-maker-issue-template.yaml index 23fa08245..7a6d92152 100644 --- a/.github/ISSUE_TEMPLATE/model_maker_issue_template.yaml +++ b/.github/ISSUE_TEMPLATE/11-model-maker-issue-template.yaml @@ -41,18 +41,16 @@ body: label: Task name (e.g. Image classification, Gesture recognition etc.) validations: required: true - - type: textarea + - type: input id: current_model attributes: label: Describe the actual behavior - render: shell validations: required: true - - type: textarea + - type: input id: expected_model attributes: label: Describe the expected behaviour - render: shell validations: required: true - type: textarea diff --git a/.github/ISSUE_TEMPLATE/studio_issue_template.yaml b/.github/ISSUE_TEMPLATE/12-studio-issue-template.yaml similarity index 96% rename from .github/ISSUE_TEMPLATE/studio_issue_template.yaml rename to .github/ISSUE_TEMPLATE/12-studio-issue-template.yaml index 645a9e4e4..ffaa315f9 100644 --- a/.github/ISSUE_TEMPLATE/studio_issue_template.yaml +++ b/.github/ISSUE_TEMPLATE/12-studio-issue-template.yaml @@ -31,18 +31,16 @@ body: label: URL that shows the problem validations: required: false - - type: textarea + - type: input id: current_model attributes: label: Describe the actual behavior - render: shell validations: required: false - - type: textarea + - type: input id: expected_model attributes: label: Describe the expected behaviour - render: shell validations: required: false - type: textarea diff --git a/.github/ISSUE_TEMPLATE/feature_request_issue_template.yaml b/.github/ISSUE_TEMPLATE/14-feature-request-issue-template.yaml similarity index 89% rename from .github/ISSUE_TEMPLATE/feature_request_issue_template.yaml rename to .github/ISSUE_TEMPLATE/14-feature-request-issue-template.yaml index d707f09cc..c170891fb 100644 --- a/.github/ISSUE_TEMPLATE/feature_request_issue_template.yaml +++ b/.github/ISSUE_TEMPLATE/14-feature-request-issue-template.yaml @@ -28,37 +28,33 @@ body: - 'No' validations: required: false - - type: textarea + - type: input id: behaviour attributes: label: Describe the feature and the current behaviour/state - render: shell validations: required: true - - type: textarea + - type: input id: api_change attributes: label: Will this change the current API? How? - render: shell validations: required: false - - type: textarea + - type: input id: benifit attributes: label: Who will benefit with this feature? validations: required: false - - type: textarea + - type: input id: use_case attributes: label: Please specify the use cases for this feature - render: shell validations: required: true - - type: textarea + - type: input id: info_other attributes: label: Any Other info - render: shell validations: required: false diff --git a/.github/ISSUE_TEMPLATE/build.install_issue_template.yaml b/.github/ISSUE_TEMPLATE/15-build-install-issue-template.yaml similarity index 98% rename from .github/ISSUE_TEMPLATE/build.install_issue_template.yaml rename to .github/ISSUE_TEMPLATE/15-build-install-issue-template.yaml index 20cf9723d..ded9d09a6 100644 --- a/.github/ISSUE_TEMPLATE/build.install_issue_template.yaml +++ b/.github/ISSUE_TEMPLATE/15-build-install-issue-template.yaml @@ -87,14 +87,13 @@ body: placeholder: validations: required: false - - type: textarea + - type: input id: what-happened attributes: label: Describe the problem description: Provide the exact sequence of commands / steps that you executed before running into the [problem](https://google.github.io/mediapipe/getting_started/getting_started.html) placeholder: Tell us what you see! value: "A bug happened!" - render: shell validations: required: true - type: textarea diff --git a/.github/ISSUE_TEMPLATE/bug_issue_template.yaml b/.github/ISSUE_TEMPLATE/16-bug-issue-template.yaml similarity index 97% rename from .github/ISSUE_TEMPLATE/bug_issue_template.yaml rename to .github/ISSUE_TEMPLATE/16-bug-issue-template.yaml index e997958ae..efa925b44 100644 --- a/.github/ISSUE_TEMPLATE/bug_issue_template.yaml +++ b/.github/ISSUE_TEMPLATE/16-bug-issue-template.yaml @@ -80,18 +80,16 @@ body: label: Xcode & Tulsi version (if issue is related to building for iOS) validations: required: false - - type: textarea + - type: input id: current_model attributes: label: Describe the actual behavior - render: shell validations: required: true - - type: textarea + - type: input id: expected_model attributes: label: Describe the expected behaviour - render: shell validations: required: true - type: textarea diff --git a/.github/ISSUE_TEMPLATE/Documentation_issue_template.yaml b/.github/ISSUE_TEMPLATE/17-documentation-issue-template.yaml similarity index 100% rename from .github/ISSUE_TEMPLATE/Documentation_issue_template.yaml rename to .github/ISSUE_TEMPLATE/17-documentation-issue-template.yaml diff --git a/.github/ISSUE_TEMPLATE/Solution(Legacy_issue_template).yaml b/.github/ISSUE_TEMPLATE/18-solution-legacy-issue-template.yaml similarity index 97% rename from .github/ISSUE_TEMPLATE/Solution(Legacy_issue_template).yaml rename to .github/ISSUE_TEMPLATE/18-solution-legacy-issue-template.yaml index 26c59737b..acb0f5b89 100644 --- a/.github/ISSUE_TEMPLATE/Solution(Legacy_issue_template).yaml +++ b/.github/ISSUE_TEMPLATE/18-solution-legacy-issue-template.yaml @@ -48,18 +48,16 @@ body: placeholder: e.g. C++, Python, Java validations: required: false - - type: textarea + - type: input id: current_model attributes: label: Describe the actual behavior - render: shell validations: required: false - - type: textarea + - type: input id: expected_model attributes: label: Describe the expected behaviour - render: shell validations: required: false - type: textarea diff --git a/.github/ISSUE_TEMPLATE/50-other-issues.md b/.github/ISSUE_TEMPLATE/19-other-issues.md similarity index 100% rename from .github/ISSUE_TEMPLATE/50-other-issues.md rename to .github/ISSUE_TEMPLATE/19-other-issues.md diff --git a/.github/stale.yml b/.github/stale.yml deleted file mode 100644 index 03c67d0f6..000000000 --- a/.github/stale.yml +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2021 The MediaPipe Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -# -# This file was assembled from multiple pieces, whose use is documented -# throughout. Please refer to the TensorFlow dockerfiles documentation -# for more information. - -# Number of days of inactivity before an Issue or Pull Request becomes stale -daysUntilStale: 7 -# Number of days of inactivity before a stale Issue or Pull Request is closed -daysUntilClose: 7 -# Only issues or pull requests with all of these labels are checked if stale. Defaults to `[]` (disabled) -onlyLabels: - - stat:awaiting response -# Comment to post when marking as stale. Set to `false` to disable -markComment: > - This issue has been automatically marked as stale because it has not had - recent activity. It will be closed if no further activity occurs. Thank you. -# Comment to post when removing the stale label. Set to `false` to disable -unmarkComment: false -closeComment: > - Closing as stale. Please reopen if you'd like to work on this further. diff --git a/.github/workflows/stale.yaml b/.github/workflows/stale.yaml new file mode 100644 index 000000000..c5e15bce6 --- /dev/null +++ b/.github/workflows/stale.yaml @@ -0,0 +1,66 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# This workflow alerts and then closes the stale issues/PRs after specific time +# You can adjust the behavior by modifying this file. +# For more information, see: +# https://github.com/actions/stale + +name: 'Close stale issues and PRs' +"on": + schedule: + - cron: "30 1 * * *" +permissions: + contents: read + issues: write + pull-requests: write +jobs: + stale: + runs-on: ubuntu-latest + steps: + - uses: 'actions/stale@v7' + with: + # Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale. + exempt-issue-labels: 'override-stale' + # Comma separated list of labels that can be assigned to PRs to exclude them from being marked as stale. + exempt-pr-labels: "override-stale" + # Limit the No. of API calls in one run default value is 30. + operations-per-run: 500 + # Prevent to remove stale label when PRs or issues are updated. + remove-stale-when-updated: false + # comment on issue if not active for more then 7 days. + stale-issue-message: 'This issue has been marked stale because it has no recent activity since 7 days. It will be closed if no further activity occurs. Thank you.' + # comment on PR if not active for more then 14 days. + stale-pr-message: 'This PR has been marked stale because it has no recent activity since 14 days. It will be closed if no further activity occurs. Thank you.' + # comment on issue if stale for more then 7 days. + close-issue-message: This issue was closed due to lack of activity after being marked stale for past 7 days. + # comment on PR if stale for more then 14 days. + close-pr-message: This PR was closed due to lack of activity after being marked stale for past 14 days. + # Number of days of inactivity before an Issue Request becomes stale + days-before-issue-stale: 7 + # Number of days of inactivity before a stale Issue is closed + days-before-issue-close: 7 + # reason for closed the issue default value is not_planned + close-issue-reason: completed + # Number of days of inactivity before a stale PR is closed + days-before-pr-close: 14 + # Number of days of inactivity before an PR Request becomes stale + days-before-pr-stale: 14 + # Check for label to stale or close the issue/PR + any-of-labels: 'stat:awaiting response' + # override stale to stalled for PR + stale-pr-label: 'stale' + # override stale to stalled for Issue + stale-issue-label: "stale" diff --git a/README.md b/README.md index a82c88ab1..e4f5dd182 100644 --- a/README.md +++ b/README.md @@ -1,99 +1,121 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe title: Home nav_order: 1 --- -![MediaPipe](https://mediapipe.dev/images/mediapipe_small.png) - ---- -**Attention:** *Thanks for your interest in MediaPipe! We have moved to +**Attention:** *We have moved to [https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) as the primary developer documentation site for MediaPipe as of April 3, 2023.* -*This notice and web page will be removed on June 1, 2023.* +![MediaPipe](https://developers.google.com/static/mediapipe/images/home/hero_01_1920.png) ----- +**Attention**: MediaPipe Solutions Preview is an early release. [Learn +more](https://developers.google.com/mediapipe/solutions/about#notice). -









-









-









+**On-device machine learning for everyone** --------------------------------------------------------------------------------- +Delight your customers with innovative machine learning features. MediaPipe +contains everything that you need to customize and deploy to mobile (Android, +iOS), web, desktop, edge devices, and IoT, effortlessly. -## Live ML anywhere +* [See demos](https://goo.gle/mediapipe-studio) +* [Learn more](https://developers.google.com/mediapipe/solutions) -[MediaPipe](https://google.github.io/mediapipe/) offers cross-platform, customizable -ML solutions for live and streaming media. +## Get started -![accelerated.png](https://mediapipe.dev/images/accelerated_small.png) | ![cross_platform.png](https://mediapipe.dev/images/cross_platform_small.png) -:------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------: -***End-to-End acceleration***: *Built-in fast ML inference and processing accelerated even on common hardware* | ***Build once, deploy anywhere***: *Unified solution works across Android, iOS, desktop/cloud, web and IoT* -![ready_to_use.png](https://mediapipe.dev/images/ready_to_use_small.png) | ![open_source.png](https://mediapipe.dev/images/open_source_small.png) -***Ready-to-use solutions***: *Cutting-edge ML solutions demonstrating full power of the framework* | ***Free and open source***: *Framework and solutions both under Apache 2.0, fully extensible and customizable* +You can get started with MediaPipe Solutions by by checking out any of the +developer guides for +[vision](https://developers.google.com/mediapipe/solutions/vision/object_detector), +[text](https://developers.google.com/mediapipe/solutions/text/text_classifier), +and +[audio](https://developers.google.com/mediapipe/solutions/audio/audio_classifier) +tasks. If you need help setting up a development environment for use with +MediaPipe Tasks, check out the setup guides for +[Android](https://developers.google.com/mediapipe/solutions/setup_android), [web +apps](https://developers.google.com/mediapipe/solutions/setup_web), and +[Python](https://developers.google.com/mediapipe/solutions/setup_python). ----- +## Solutions -## ML solutions in MediaPipe +MediaPipe Solutions provides a suite of libraries and tools for you to quickly +apply artificial intelligence (AI) and machine learning (ML) techniques in your +applications. You can plug these solutions into your applications immediately, +customize them to your needs, and use them across multiple development +platforms. MediaPipe Solutions is part of the MediaPipe [open source +project](https://github.com/google/mediapipe), so you can further customize the +solutions code to meet your application needs. -Face Detection | Face Mesh | Iris | Hands | Pose | Holistic -:----------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :------: -[![face_detection](https://mediapipe.dev/images/mobile/face_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_detection) | [![face_mesh](https://mediapipe.dev/images/mobile/face_mesh_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_mesh) | [![iris](https://mediapipe.dev/images/mobile/iris_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/iris) | [![hand](https://mediapipe.dev/images/mobile/hand_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hands) | [![pose](https://mediapipe.dev/images/mobile/pose_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/pose) | [![hair_segmentation](https://mediapipe.dev/images/mobile/holistic_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/holistic) +These libraries and resources provide the core functionality for each MediaPipe +Solution: -Hair Segmentation | Object Detection | Box Tracking | Instant Motion Tracking | Objectron | KNIFT -:-------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------: | :---: -[![hair_segmentation](https://mediapipe.dev/images/mobile/hair_segmentation_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hair_segmentation) | [![object_detection](https://mediapipe.dev/images/mobile/object_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/object_detection) | [![box_tracking](https://mediapipe.dev/images/mobile/object_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/box_tracking) | [![instant_motion_tracking](https://mediapipe.dev/images/mobile/instant_motion_tracking_android_small.gif)](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | [![objectron](https://mediapipe.dev/images/mobile/objectron_chair_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/objectron) | [![knift](https://mediapipe.dev/images/mobile/template_matching_android_cpu_small.gif)](https://google.github.io/mediapipe/solutions/knift) +* **MediaPipe Tasks**: Cross-platform APIs and libraries for deploying + solutions. [Learn + more](https://developers.google.com/mediapipe/solutions/tasks). +* **MediaPipe models**: Pre-trained, ready-to-run models for use with each + solution. - - +These tools let you customize and evaluate solutions: -[]() | [Android](https://google.github.io/mediapipe/getting_started/android) | [iOS](https://google.github.io/mediapipe/getting_started/ios) | [C++](https://google.github.io/mediapipe/getting_started/cpp) | [Python](https://google.github.io/mediapipe/getting_started/python) | [JS](https://google.github.io/mediapipe/getting_started/javascript) | [Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/README.md) -:---------------------------------------------------------------------------------------- | :-------------------------------------------------------------: | :-----------------------------------------------------: | :-----------------------------------------------------: | :-----------------------------------------------------------: | :-----------------------------------------------------------: | :--------------------------------------------------------------------: -[Face Detection](https://google.github.io/mediapipe/solutions/face_detection) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ -[Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) | ✅ | ✅ | ✅ | ✅ | ✅ | -[Iris](https://google.github.io/mediapipe/solutions/iris) | ✅ | ✅ | ✅ | | | -[Hands](https://google.github.io/mediapipe/solutions/hands) | ✅ | ✅ | ✅ | ✅ | ✅ | -[Pose](https://google.github.io/mediapipe/solutions/pose) | ✅ | ✅ | ✅ | ✅ | ✅ | -[Holistic](https://google.github.io/mediapipe/solutions/holistic) | ✅ | ✅ | ✅ | ✅ | ✅ | -[Selfie Segmentation](https://google.github.io/mediapipe/solutions/selfie_segmentation) | ✅ | ✅ | ✅ | ✅ | ✅ | -[Hair Segmentation](https://google.github.io/mediapipe/solutions/hair_segmentation) | ✅ | | ✅ | | | -[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ -[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | -[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | -[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | ✅ | -[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | -[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | -[MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | -[YouTube 8M](https://google.github.io/mediapipe/solutions/youtube_8m) | | | ✅ | | | +* **MediaPipe Model Maker**: Customize models for solutions with your data. + [Learn more](https://developers.google.com/mediapipe/solutions/model_maker). +* **MediaPipe Studio**: Visualize, evaluate, and benchmark solutions in your + browser. [Learn + more](https://developers.google.com/mediapipe/solutions/studio). -See also -[MediaPipe Models and Model Cards](https://google.github.io/mediapipe/solutions/models) -for ML models released in MediaPipe. +### Legacy solutions -## Getting started +We have ended support for [these MediaPipe Legacy Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) +as of March 1, 2023. All other MediaPipe Legacy Solutions will be upgraded to +a new MediaPipe Solution. See the [Solutions guide](https://developers.google.com/mediapipe/solutions/guide#legacy) +for details. The [code repository](https://github.com/google/mediapipe/tree/master/mediapipe) +and prebuilt binaries for all MediaPipe Legacy Solutions will continue to be +provided on an as-is basis. -To start using MediaPipe -[solutions](https://google.github.io/mediapipe/solutions/solutions) with only a few -lines code, see example code and demos in -[MediaPipe in Python](https://google.github.io/mediapipe/getting_started/python) and -[MediaPipe in JavaScript](https://google.github.io/mediapipe/getting_started/javascript). +For more on the legacy solutions, see the [documentation](https://github.com/google/mediapipe/tree/master/docs/solutions). -To use MediaPipe in C++, Android and iOS, which allow further customization of -the [solutions](https://google.github.io/mediapipe/solutions/solutions) as well as -building your own, learn how to -[install](https://google.github.io/mediapipe/getting_started/install) MediaPipe and -start building example applications in -[C++](https://google.github.io/mediapipe/getting_started/cpp), -[Android](https://google.github.io/mediapipe/getting_started/android) and -[iOS](https://google.github.io/mediapipe/getting_started/ios). +## Framework -The source code is hosted in the -[MediaPipe Github repository](https://github.com/google/mediapipe), and you can -run code search using -[Google Open Source Code Search](https://cs.opensource.google/mediapipe/mediapipe). +To start using MediaPipe Framework, [install MediaPipe +Framework](https://developers.google.com/mediapipe/framework/getting_started/install) +and start building example applications in C++, Android, and iOS. -## Publications +[MediaPipe Framework](https://developers.google.com/mediapipe/framework) is the +low-level component used to build efficient on-device machine learning +pipelines, similar to the premade MediaPipe Solutions. + +Before using MediaPipe Framework, familiarize yourself with the following key +[Framework +concepts](https://developers.google.com/mediapipe/framework/framework_concepts/overview.md): + +* [Packets](https://developers.google.com/mediapipe/framework/framework_concepts/packets.md) +* [Graphs](https://developers.google.com/mediapipe/framework/framework_concepts/graphs.md) +* [Calculators](https://developers.google.com/mediapipe/framework/framework_concepts/calculators.md) + +## Community + +* [Slack community](https://mediapipe.page.link/joinslack) for MediaPipe + users. +* [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General + community discussion around MediaPipe. +* [Awesome MediaPipe](https://mediapipe.page.link/awesome-mediapipe) - A + curated list of awesome MediaPipe related frameworks, libraries and + software. + +## Contributing + +We welcome contributions. Please follow these +[guidelines](https://github.com/google/mediapipe/blob/master/CONTRIBUTING.md). + +We use GitHub issues for tracking requests and bugs. Please post questions to +the MediaPipe Stack Overflow with a `mediapipe` tag. + +## Resources + +### Publications * [Bringing artworks to life with AR](https://developers.googleblog.com/2021/07/bringing-artworks-to-life-with-ar.html) in Google Developers Blog @@ -102,7 +124,8 @@ run code search using * [SignAll SDK: Sign language interface using MediaPipe is now available for developers](https://developers.googleblog.com/2021/04/signall-sdk-sign-language-interface-using-mediapipe-now-available.html) in Google Developers Blog -* [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html) +* [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on + Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html) in Google AI Blog * [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html) in Google AI Blog @@ -130,43 +153,6 @@ run code search using in Google AI Blog * [MediaPipe: A Framework for Building Perception Pipelines](https://arxiv.org/abs/1906.08172) -## Videos +### Videos * [YouTube Channel](https://www.youtube.com/c/MediaPipe) - -## Events - -* [MediaPipe Seattle Meetup, Google Building Waterside, 13 Feb 2020](https://mediapipe.page.link/seattle2020) -* [AI Nextcon 2020, 12-16 Feb 2020, Seattle](http://aisea20.xnextcon.com/) -* [MediaPipe Madrid Meetup, 16 Dec 2019](https://www.meetup.com/Madrid-AI-Developers-Group/events/266329088/) -* [MediaPipe London Meetup, Google 123 Building, 12 Dec 2019](https://www.meetup.com/London-AI-Tech-Talk/events/266329038) -* [ML Conference, Berlin, 11 Dec 2019](https://mlconference.ai/machine-learning-advanced-development/mediapipe-building-real-time-cross-platform-mobile-web-edge-desktop-video-audio-ml-pipelines/) -* [MediaPipe Berlin Meetup, Google Berlin, 11 Dec 2019](https://www.meetup.com/Berlin-AI-Tech-Talk/events/266328794/) -* [The 3rd Workshop on YouTube-8M Large Scale Video Understanding Workshop, - Seoul, Korea ICCV - 2019](https://research.google.com/youtube8m/workshop2019/index.html) -* [AI DevWorld 2019, 10 Oct 2019, San Jose, CA](https://aidevworld.com) -* [Google Industry Workshop at ICIP 2019, 24 Sept 2019, Taipei, Taiwan](http://2019.ieeeicip.org/?action=page4&id=14#Google) - ([presentation](https://docs.google.com/presentation/d/e/2PACX-1vRIBBbO_LO9v2YmvbHHEt1cwyqH6EjDxiILjuT0foXy1E7g6uyh4CesB2DkkEwlRDO9_lWfuKMZx98T/pub?start=false&loop=false&delayms=3000&slide=id.g556cc1a659_0_5)) -* [Open sourced at CVPR 2019, 17~20 June, Long Beach, CA](https://sites.google.com/corp/view/perception-cv4arvr/mediapipe) - -## Community - -* [Awesome MediaPipe](https://mediapipe.page.link/awesome-mediapipe) - A - curated list of awesome MediaPipe related frameworks, libraries and software -* [Slack community](https://mediapipe.page.link/joinslack) for MediaPipe users -* [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General - community discussion around MediaPipe - -## Alpha disclaimer - -MediaPipe is currently in alpha at v0.7. We may be still making breaking API -changes and expect to get to stable APIs by v1.0. - -## Contributing - -We welcome contributions. Please follow these -[guidelines](https://github.com/google/mediapipe/blob/master/CONTRIBUTING.md). - -We use GitHub issues for tracking requests and bugs. Please post questions to -the MediaPipe Stack Overflow with a `mediapipe` tag. diff --git a/WORKSPACE b/WORKSPACE index 760898185..5341b094a 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -45,12 +45,13 @@ http_archive( ) http_archive( - name = "rules_foreign_cc", - strip_prefix = "rules_foreign_cc-0.1.0", - url = "https://github.com/bazelbuild/rules_foreign_cc/archive/0.1.0.zip", + name = "rules_foreign_cc", + sha256 = "2a4d07cd64b0719b39a7c12218a3e507672b82a97b98c6a89d38565894cf7c51", + strip_prefix = "rules_foreign_cc-0.9.0", + url = "https://github.com/bazelbuild/rules_foreign_cc/archive/refs/tags/0.9.0.tar.gz", ) -load("@rules_foreign_cc//:workspace_definitions.bzl", "rules_foreign_cc_dependencies") +load("@rules_foreign_cc//foreign_cc:repositories.bzl", "rules_foreign_cc_dependencies") rules_foreign_cc_dependencies() @@ -72,12 +73,9 @@ http_archive( http_archive( name = "zlib", build_file = "@//third_party:zlib.BUILD", - sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", - strip_prefix = "zlib-1.2.11", - urls = [ - "http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz", - "http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15 - ], + sha256 = "b3a24de97a8fdbc835b9833169501030b8977031bcb54b3b3ac13740f846ab30", + strip_prefix = "zlib-1.2.13", + url = "http://zlib.net/fossils/zlib-1.2.13.tar.gz", patches = [ "@//third_party:zlib.diff", ], @@ -156,22 +154,22 @@ http_archive( # 2020-08-21 http_archive( name = "com_github_glog_glog", - strip_prefix = "glog-0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6", - sha256 = "58c9b3b6aaa4dd8b836c0fd8f65d0f941441fb95e27212c5eeb9979cfd3592ab", + strip_prefix = "glog-3a0d4d22c5ae0b9a2216988411cfa6bf860cc372", + sha256 = "170d08f80210b82d95563f4723a15095eff1aad1863000e8eeb569c96a98fefb", urls = [ - "https://github.com/google/glog/archive/0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6.zip", + "https://github.com/google/glog/archive/3a0d4d22c5ae0b9a2216988411cfa6bf860cc372.zip", ], ) http_archive( name = "com_github_glog_glog_no_gflags", - strip_prefix = "glog-0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6", - sha256 = "58c9b3b6aaa4dd8b836c0fd8f65d0f941441fb95e27212c5eeb9979cfd3592ab", + strip_prefix = "glog-3a0d4d22c5ae0b9a2216988411cfa6bf860cc372", + sha256 = "170d08f80210b82d95563f4723a15095eff1aad1863000e8eeb569c96a98fefb", build_file = "@//third_party:glog_no_gflags.BUILD", urls = [ - "https://github.com/google/glog/archive/0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6.zip", + "https://github.com/google/glog/archive/3a0d4d22c5ae0b9a2216988411cfa6bf860cc372.zip", ], patches = [ - "@//third_party:com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff", + "@//third_party:com_github_glog_glog.diff", ], patch_args = [ "-p1", @@ -266,10 +264,10 @@ http_archive( http_archive( name = "com_googlesource_code_re2", - sha256 = "e06b718c129f4019d6e7aa8b7631bee38d3d450dd980246bfaf493eb7db67868", - strip_prefix = "re2-fe4a310131c37f9a7e7f7816fa6ce2a8b27d65a8", + sha256 = "ef516fb84824a597c4d5d0d6d330daedb18363b5a99eda87d027e6bdd9cba299", + strip_prefix = "re2-03da4fc0857c285e3a26782f6bc8931c4c950df4", urls = [ - "https://github.com/google/re2/archive/fe4a310131c37f9a7e7f7816fa6ce2a8b27d65a8.tar.gz", + "https://github.com/google/re2/archive/03da4fc0857c285e3a26782f6bc8931c4c950df4.tar.gz", ], ) @@ -375,6 +373,22 @@ http_archive( url = "https://github.com/opencv/opencv/releases/download/3.2.0/opencv-3.2.0-ios-framework.zip", ) +# Building an opencv.xcframework from the OpenCV 4.5.3 sources is necessary for +# MediaPipe iOS Task Libraries to be supported on arm64(M1) Macs. An +# `opencv.xcframework` archive has not been released and it is recommended to +# build the same from source using a script provided in OpenCV 4.5.0 upwards. +# OpenCV is fixed to version to 4.5.3 since swift support can only be disabled +# from 4.5.3 upwards. This is needed to avoid errors when the library is linked +# in Xcode. Swift support will be added in when the final binary MediaPipe iOS +# Task libraries are built. +http_archive( + name = "ios_opencv_source", + sha256 = "a61e7a4618d353140c857f25843f39b2abe5f451b018aab1604ef0bc34cd23d5", + build_file = "@//third_party:opencv_ios_source.BUILD", + type = "zip", + url = "https://github.com/opencv/opencv/archive/refs/tags/4.5.3.zip", +) + http_archive( name = "stblib", strip_prefix = "stb-b42009b3b9d4ca35bc703f5310eedc74f584be58", @@ -468,9 +482,10 @@ http_archive( ) # TensorFlow repo should always go after the other external dependencies. -# TF on 2023-04-12. -_TENSORFLOW_GIT_COMMIT = "d712c0c9e24519cc8cd3720279666720d1000eee" -_TENSORFLOW_SHA256 = "ba98de6ea5f720071246691a1536ecd5e1b1763033e8c82a1e721a06d3dfd4c1" +# TF on 2023-07-26. +_TENSORFLOW_GIT_COMMIT = "e92261fd4cec0b726692081c4d2966b75abf31dd" +# curl -L https://github.com/tensorflow/tensorflow/archive/.tar.gz | shasum -a 256 +_TENSORFLOW_SHA256 = "478a229bd4ec70a5b568ac23b5ea013d9fca46a47d6c43e30365a0412b9febf4" http_archive( name = "org_tensorflow", urls = [ @@ -478,6 +493,7 @@ http_archive( ], patches = [ "@//third_party:org_tensorflow_compatibility_fixes.diff", + "@//third_party:org_tensorflow_system_python.diff", # Diff is generated with a script, don't update it manually. "@//third_party:org_tensorflow_custom_ops.diff", ], diff --git a/docs/build_java_api_docs.py b/docs/build_java_api_docs.py index b13e8d1df..c30426557 100644 --- a/docs/build_java_api_docs.py +++ b/docs/build_java_api_docs.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ # ============================================================================== """Generate Java reference docs for MediaPipe.""" import pathlib +import shutil from absl import app from absl import flags @@ -41,7 +42,9 @@ def main(_) -> None: mp_root = pathlib.Path(__file__) while (mp_root := mp_root.parent).name != 'mediapipe': # Find the nearest `mediapipe` dir. - pass + if not mp_root.name: + # We've hit the filesystem root - abort. + raise FileNotFoundError('"mediapipe" root not found') # Find the root from which all packages are relative. root = mp_root.parent @@ -51,6 +54,14 @@ def main(_) -> None: if (mp_root / 'mediapipe').exists(): mp_root = mp_root / 'mediapipe' + # We need to copy this into the tasks dir to ensure we don't leave broken + # links in the generated docs. + old_api_dir = 'java/com/google/mediapipe/framework/image' + shutil.copytree( + mp_root / old_api_dir, + mp_root / 'tasks' / old_api_dir, + dirs_exist_ok=True) + gen_java.gen_java_docs( package='com.google.mediapipe', source_path=mp_root / 'tasks/java', diff --git a/docs/build_model_maker_api_docs.py b/docs/build_model_maker_api_docs.py index 7732b7d56..377536c33 100644 --- a/docs/build_model_maker_api_docs.py +++ b/docs/build_model_maker_api_docs.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/docs/build_py_api_docs.py b/docs/build_py_api_docs.py index 02eb04074..10b799320 100644 --- a/docs/build_py_api_docs.py +++ b/docs/build_py_api_docs.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/docs/getting_started/hello_world_cpp.md b/docs/getting_started/hello_world_cpp.md index 7c8f9be3e..f0c7ff0f9 100644 --- a/docs/getting_started/hello_world_cpp.md +++ b/docs/getting_started/hello_world_cpp.md @@ -50,7 +50,7 @@ as the primary developer documentation site for MediaPipe as of April 3, 2023.* 3. The [`hello world`] example uses a simple MediaPipe graph in the `PrintHelloWorld()` function, defined in a [`CalculatorGraphConfig`] proto. - ```C++ + ```c++ absl::Status PrintHelloWorld() { // Configures a simple graph, which concatenates 2 PassThroughCalculators. CalculatorGraphConfig config = ParseTextProtoOrDie(R"( @@ -126,7 +126,7 @@ as the primary developer documentation site for MediaPipe as of April 3, 2023.* ```c++ mediapipe::Packet packet; while (poller.Next(&packet)) { - LOG(INFO) << packet.Get(); + ABSL_LOG(INFO) << packet.Get(); } ``` diff --git a/docs/getting_started/hello_world_ios.md b/docs/getting_started/hello_world_ios.md index 4be097646..118b9a05b 100644 --- a/docs/getting_started/hello_world_ios.md +++ b/docs/getting_started/hello_world_ios.md @@ -138,7 +138,7 @@ Create a `BUILD` file in the `$APPLICATION_PATH` and add the following build rules: ``` -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" load( "@build_bazel_rules_apple//apple:ios.bzl", diff --git a/docs/index.md b/docs/index.md index a82c88ab1..e4f5dd182 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,99 +1,121 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe title: Home nav_order: 1 --- -![MediaPipe](https://mediapipe.dev/images/mediapipe_small.png) - ---- -**Attention:** *Thanks for your interest in MediaPipe! We have moved to +**Attention:** *We have moved to [https://developers.google.com/mediapipe](https://developers.google.com/mediapipe) as the primary developer documentation site for MediaPipe as of April 3, 2023.* -*This notice and web page will be removed on June 1, 2023.* +![MediaPipe](https://developers.google.com/static/mediapipe/images/home/hero_01_1920.png) ----- +**Attention**: MediaPipe Solutions Preview is an early release. [Learn +more](https://developers.google.com/mediapipe/solutions/about#notice). -









-









-









+**On-device machine learning for everyone** --------------------------------------------------------------------------------- +Delight your customers with innovative machine learning features. MediaPipe +contains everything that you need to customize and deploy to mobile (Android, +iOS), web, desktop, edge devices, and IoT, effortlessly. -## Live ML anywhere +* [See demos](https://goo.gle/mediapipe-studio) +* [Learn more](https://developers.google.com/mediapipe/solutions) -[MediaPipe](https://google.github.io/mediapipe/) offers cross-platform, customizable -ML solutions for live and streaming media. +## Get started -![accelerated.png](https://mediapipe.dev/images/accelerated_small.png) | ![cross_platform.png](https://mediapipe.dev/images/cross_platform_small.png) -:------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------: -***End-to-End acceleration***: *Built-in fast ML inference and processing accelerated even on common hardware* | ***Build once, deploy anywhere***: *Unified solution works across Android, iOS, desktop/cloud, web and IoT* -![ready_to_use.png](https://mediapipe.dev/images/ready_to_use_small.png) | ![open_source.png](https://mediapipe.dev/images/open_source_small.png) -***Ready-to-use solutions***: *Cutting-edge ML solutions demonstrating full power of the framework* | ***Free and open source***: *Framework and solutions both under Apache 2.0, fully extensible and customizable* +You can get started with MediaPipe Solutions by by checking out any of the +developer guides for +[vision](https://developers.google.com/mediapipe/solutions/vision/object_detector), +[text](https://developers.google.com/mediapipe/solutions/text/text_classifier), +and +[audio](https://developers.google.com/mediapipe/solutions/audio/audio_classifier) +tasks. If you need help setting up a development environment for use with +MediaPipe Tasks, check out the setup guides for +[Android](https://developers.google.com/mediapipe/solutions/setup_android), [web +apps](https://developers.google.com/mediapipe/solutions/setup_web), and +[Python](https://developers.google.com/mediapipe/solutions/setup_python). ----- +## Solutions -## ML solutions in MediaPipe +MediaPipe Solutions provides a suite of libraries and tools for you to quickly +apply artificial intelligence (AI) and machine learning (ML) techniques in your +applications. You can plug these solutions into your applications immediately, +customize them to your needs, and use them across multiple development +platforms. MediaPipe Solutions is part of the MediaPipe [open source +project](https://github.com/google/mediapipe), so you can further customize the +solutions code to meet your application needs. -Face Detection | Face Mesh | Iris | Hands | Pose | Holistic -:----------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :------: -[![face_detection](https://mediapipe.dev/images/mobile/face_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_detection) | [![face_mesh](https://mediapipe.dev/images/mobile/face_mesh_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_mesh) | [![iris](https://mediapipe.dev/images/mobile/iris_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/iris) | [![hand](https://mediapipe.dev/images/mobile/hand_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hands) | [![pose](https://mediapipe.dev/images/mobile/pose_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/pose) | [![hair_segmentation](https://mediapipe.dev/images/mobile/holistic_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/holistic) +These libraries and resources provide the core functionality for each MediaPipe +Solution: -Hair Segmentation | Object Detection | Box Tracking | Instant Motion Tracking | Objectron | KNIFT -:-------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------: | :---: -[![hair_segmentation](https://mediapipe.dev/images/mobile/hair_segmentation_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hair_segmentation) | [![object_detection](https://mediapipe.dev/images/mobile/object_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/object_detection) | [![box_tracking](https://mediapipe.dev/images/mobile/object_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/box_tracking) | [![instant_motion_tracking](https://mediapipe.dev/images/mobile/instant_motion_tracking_android_small.gif)](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | [![objectron](https://mediapipe.dev/images/mobile/objectron_chair_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/objectron) | [![knift](https://mediapipe.dev/images/mobile/template_matching_android_cpu_small.gif)](https://google.github.io/mediapipe/solutions/knift) +* **MediaPipe Tasks**: Cross-platform APIs and libraries for deploying + solutions. [Learn + more](https://developers.google.com/mediapipe/solutions/tasks). +* **MediaPipe models**: Pre-trained, ready-to-run models for use with each + solution. - - +These tools let you customize and evaluate solutions: -[]() | [Android](https://google.github.io/mediapipe/getting_started/android) | [iOS](https://google.github.io/mediapipe/getting_started/ios) | [C++](https://google.github.io/mediapipe/getting_started/cpp) | [Python](https://google.github.io/mediapipe/getting_started/python) | [JS](https://google.github.io/mediapipe/getting_started/javascript) | [Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/README.md) -:---------------------------------------------------------------------------------------- | :-------------------------------------------------------------: | :-----------------------------------------------------: | :-----------------------------------------------------: | :-----------------------------------------------------------: | :-----------------------------------------------------------: | :--------------------------------------------------------------------: -[Face Detection](https://google.github.io/mediapipe/solutions/face_detection) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ -[Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) | ✅ | ✅ | ✅ | ✅ | ✅ | -[Iris](https://google.github.io/mediapipe/solutions/iris) | ✅ | ✅ | ✅ | | | -[Hands](https://google.github.io/mediapipe/solutions/hands) | ✅ | ✅ | ✅ | ✅ | ✅ | -[Pose](https://google.github.io/mediapipe/solutions/pose) | ✅ | ✅ | ✅ | ✅ | ✅ | -[Holistic](https://google.github.io/mediapipe/solutions/holistic) | ✅ | ✅ | ✅ | ✅ | ✅ | -[Selfie Segmentation](https://google.github.io/mediapipe/solutions/selfie_segmentation) | ✅ | ✅ | ✅ | ✅ | ✅ | -[Hair Segmentation](https://google.github.io/mediapipe/solutions/hair_segmentation) | ✅ | | ✅ | | | -[Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ -[Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | -[Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | -[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | ✅ | ✅ | ✅ | -[KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | -[AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | -[MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | -[YouTube 8M](https://google.github.io/mediapipe/solutions/youtube_8m) | | | ✅ | | | +* **MediaPipe Model Maker**: Customize models for solutions with your data. + [Learn more](https://developers.google.com/mediapipe/solutions/model_maker). +* **MediaPipe Studio**: Visualize, evaluate, and benchmark solutions in your + browser. [Learn + more](https://developers.google.com/mediapipe/solutions/studio). -See also -[MediaPipe Models and Model Cards](https://google.github.io/mediapipe/solutions/models) -for ML models released in MediaPipe. +### Legacy solutions -## Getting started +We have ended support for [these MediaPipe Legacy Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) +as of March 1, 2023. All other MediaPipe Legacy Solutions will be upgraded to +a new MediaPipe Solution. See the [Solutions guide](https://developers.google.com/mediapipe/solutions/guide#legacy) +for details. The [code repository](https://github.com/google/mediapipe/tree/master/mediapipe) +and prebuilt binaries for all MediaPipe Legacy Solutions will continue to be +provided on an as-is basis. -To start using MediaPipe -[solutions](https://google.github.io/mediapipe/solutions/solutions) with only a few -lines code, see example code and demos in -[MediaPipe in Python](https://google.github.io/mediapipe/getting_started/python) and -[MediaPipe in JavaScript](https://google.github.io/mediapipe/getting_started/javascript). +For more on the legacy solutions, see the [documentation](https://github.com/google/mediapipe/tree/master/docs/solutions). -To use MediaPipe in C++, Android and iOS, which allow further customization of -the [solutions](https://google.github.io/mediapipe/solutions/solutions) as well as -building your own, learn how to -[install](https://google.github.io/mediapipe/getting_started/install) MediaPipe and -start building example applications in -[C++](https://google.github.io/mediapipe/getting_started/cpp), -[Android](https://google.github.io/mediapipe/getting_started/android) and -[iOS](https://google.github.io/mediapipe/getting_started/ios). +## Framework -The source code is hosted in the -[MediaPipe Github repository](https://github.com/google/mediapipe), and you can -run code search using -[Google Open Source Code Search](https://cs.opensource.google/mediapipe/mediapipe). +To start using MediaPipe Framework, [install MediaPipe +Framework](https://developers.google.com/mediapipe/framework/getting_started/install) +and start building example applications in C++, Android, and iOS. -## Publications +[MediaPipe Framework](https://developers.google.com/mediapipe/framework) is the +low-level component used to build efficient on-device machine learning +pipelines, similar to the premade MediaPipe Solutions. + +Before using MediaPipe Framework, familiarize yourself with the following key +[Framework +concepts](https://developers.google.com/mediapipe/framework/framework_concepts/overview.md): + +* [Packets](https://developers.google.com/mediapipe/framework/framework_concepts/packets.md) +* [Graphs](https://developers.google.com/mediapipe/framework/framework_concepts/graphs.md) +* [Calculators](https://developers.google.com/mediapipe/framework/framework_concepts/calculators.md) + +## Community + +* [Slack community](https://mediapipe.page.link/joinslack) for MediaPipe + users. +* [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General + community discussion around MediaPipe. +* [Awesome MediaPipe](https://mediapipe.page.link/awesome-mediapipe) - A + curated list of awesome MediaPipe related frameworks, libraries and + software. + +## Contributing + +We welcome contributions. Please follow these +[guidelines](https://github.com/google/mediapipe/blob/master/CONTRIBUTING.md). + +We use GitHub issues for tracking requests and bugs. Please post questions to +the MediaPipe Stack Overflow with a `mediapipe` tag. + +## Resources + +### Publications * [Bringing artworks to life with AR](https://developers.googleblog.com/2021/07/bringing-artworks-to-life-with-ar.html) in Google Developers Blog @@ -102,7 +124,8 @@ run code search using * [SignAll SDK: Sign language interface using MediaPipe is now available for developers](https://developers.googleblog.com/2021/04/signall-sdk-sign-language-interface-using-mediapipe-now-available.html) in Google Developers Blog -* [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html) +* [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on + Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html) in Google AI Blog * [Background Features in Google Meet, Powered by Web ML](https://ai.googleblog.com/2020/10/background-features-in-google-meet.html) in Google AI Blog @@ -130,43 +153,6 @@ run code search using in Google AI Blog * [MediaPipe: A Framework for Building Perception Pipelines](https://arxiv.org/abs/1906.08172) -## Videos +### Videos * [YouTube Channel](https://www.youtube.com/c/MediaPipe) - -## Events - -* [MediaPipe Seattle Meetup, Google Building Waterside, 13 Feb 2020](https://mediapipe.page.link/seattle2020) -* [AI Nextcon 2020, 12-16 Feb 2020, Seattle](http://aisea20.xnextcon.com/) -* [MediaPipe Madrid Meetup, 16 Dec 2019](https://www.meetup.com/Madrid-AI-Developers-Group/events/266329088/) -* [MediaPipe London Meetup, Google 123 Building, 12 Dec 2019](https://www.meetup.com/London-AI-Tech-Talk/events/266329038) -* [ML Conference, Berlin, 11 Dec 2019](https://mlconference.ai/machine-learning-advanced-development/mediapipe-building-real-time-cross-platform-mobile-web-edge-desktop-video-audio-ml-pipelines/) -* [MediaPipe Berlin Meetup, Google Berlin, 11 Dec 2019](https://www.meetup.com/Berlin-AI-Tech-Talk/events/266328794/) -* [The 3rd Workshop on YouTube-8M Large Scale Video Understanding Workshop, - Seoul, Korea ICCV - 2019](https://research.google.com/youtube8m/workshop2019/index.html) -* [AI DevWorld 2019, 10 Oct 2019, San Jose, CA](https://aidevworld.com) -* [Google Industry Workshop at ICIP 2019, 24 Sept 2019, Taipei, Taiwan](http://2019.ieeeicip.org/?action=page4&id=14#Google) - ([presentation](https://docs.google.com/presentation/d/e/2PACX-1vRIBBbO_LO9v2YmvbHHEt1cwyqH6EjDxiILjuT0foXy1E7g6uyh4CesB2DkkEwlRDO9_lWfuKMZx98T/pub?start=false&loop=false&delayms=3000&slide=id.g556cc1a659_0_5)) -* [Open sourced at CVPR 2019, 17~20 June, Long Beach, CA](https://sites.google.com/corp/view/perception-cv4arvr/mediapipe) - -## Community - -* [Awesome MediaPipe](https://mediapipe.page.link/awesome-mediapipe) - A - curated list of awesome MediaPipe related frameworks, libraries and software -* [Slack community](https://mediapipe.page.link/joinslack) for MediaPipe users -* [Discuss](https://groups.google.com/forum/#!forum/mediapipe) - General - community discussion around MediaPipe - -## Alpha disclaimer - -MediaPipe is currently in alpha at v0.7. We may be still making breaking API -changes and expect to get to stable APIs by v1.0. - -## Contributing - -We welcome contributions. Please follow these -[guidelines](https://github.com/google/mediapipe/blob/master/CONTRIBUTING.md). - -We use GitHub issues for tracking requests and bugs. Please post questions to -the MediaPipe Stack Overflow with a `mediapipe` tag. diff --git a/docs/solutions/face_detection.md b/docs/solutions/face_detection.md index f060d062c..93f239c37 100644 --- a/docs/solutions/face_detection.md +++ b/docs/solutions/face_detection.md @@ -20,9 +20,9 @@ nav_order: 1 --- **Attention:** *Thank you for your interest in MediaPipe Solutions. -As of March 1, 2023, this solution is planned to be upgraded to a new MediaPipe +As of May 10, 2023, this solution was upgraded to a new MediaPipe Solution. For more information, see the -[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) +[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/vision/face_detector) site.* ---- diff --git a/docs/solutions/face_mesh.md b/docs/solutions/face_mesh.md index ab34ba401..a859bafaa 100644 --- a/docs/solutions/face_mesh.md +++ b/docs/solutions/face_mesh.md @@ -20,9 +20,9 @@ nav_order: 2 --- **Attention:** *Thank you for your interest in MediaPipe Solutions. -As of March 1, 2023, this solution is planned to be upgraded to a new MediaPipe +As of May 10, 2023, this solution was upgraded to a new MediaPipe Solution. For more information, see the -[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) +[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/vision/face_landmarker) site.* ---- diff --git a/docs/solutions/iris.md b/docs/solutions/iris.md index eab3dedf6..c0af4342c 100644 --- a/docs/solutions/iris.md +++ b/docs/solutions/iris.md @@ -20,9 +20,9 @@ nav_order: 3 --- **Attention:** *Thank you for your interest in MediaPipe Solutions. -As of March 1, 2023, this solution is planned to be upgraded to a new MediaPipe +As of May 10, 2023, this solution was upgraded to a new MediaPipe Solution. For more information, see the -[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) +[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/vision/face_landmarker) site.* ---- diff --git a/docs/solutions/pose.md b/docs/solutions/pose.md index b6f9408ec..09c313b5e 100644 --- a/docs/solutions/pose.md +++ b/docs/solutions/pose.md @@ -22,9 +22,9 @@ nav_order: 5 --- **Attention:** *Thank you for your interest in MediaPipe Solutions. -As of March 1, 2023, this solution is planned to be upgraded to a new MediaPipe +As of May 10, 2023, this solution was upgraded to a new MediaPipe Solution. For more information, see the -[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/vision/pose_landmarker/) +[MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/vision/pose_landmarker) site.* ---- diff --git a/docs/solutions/pose_classification.md b/docs/solutions/pose_classification.md index 8420e2d7c..091b0c998 100644 --- a/docs/solutions/pose_classification.md +++ b/docs/solutions/pose_classification.md @@ -21,7 +21,7 @@ nav_order: 1 --- **Attention:** *Thank you for your interest in MediaPipe Solutions. -As of March 1, 2023, this solution is planned to be upgraded to a new MediaPipe +As of May 10, 2023, this solution was upgraded to a new MediaPipe Solution. For more information, see the [MediaPipe Solutions](https://developers.google.com/mediapipe/solutions/vision/pose_landmarker/) site.* diff --git a/docs/solutions/solutions.md b/docs/solutions/solutions.md index 7bc32d169..10551b7c9 100644 --- a/docs/solutions/solutions.md +++ b/docs/solutions/solutions.md @@ -1,5 +1,6 @@ --- -layout: default +layout: forward +target: https://developers.google.com/mediapipe/solutions/guide#legacy title: MediaPipe Legacy Solutions nav_order: 3 has_children: true @@ -13,8 +14,7 @@ has_toc: false {:toc} --- -**Attention:** *Thank you for your interest in MediaPipe Solutions. We have -ended support for +**Attention:** *We have ended support for [these MediaPipe Legacy Solutions](https://developers.google.com/mediapipe/solutions/guide#legacy) as of March 1, 2023. All other [MediaPipe Legacy Solutions will be upgraded](https://developers.google.com/mediapipe/solutions/guide#legacy) @@ -25,14 +25,6 @@ be provided on an as-is basis. We encourage you to check out the new MediaPipe Solutions at: [https://developers.google.com/mediapipe/solutions](https://developers.google.com/mediapipe/solutions)* -*This notice and web page will be removed on June 1, 2023.* - ----- - -









-









-









- ---- MediaPipe offers open source cross-platform, customizable ML solutions for live diff --git a/mediapipe/BUILD b/mediapipe/BUILD index 3187c0cf7..432ed18f6 100644 --- a/mediapipe/BUILD +++ b/mediapipe/BUILD @@ -14,81 +14,155 @@ licenses(["notice"]) # Apache 2.0 -# Note: yes, these need to use "//external:android/crosstool", not -# @androidndk//:default_crosstool. +load("@mediapipe//mediapipe:platforms.bzl", "config_setting_and_platform") +# Generic Android config_setting( name = "android", - values = {"crosstool_top": "//external:android/crosstool"}, + constraint_values = [ + "@platforms//os:android", + ], visibility = ["//visibility:public"], ) -config_setting( +# Android x86 32-bit. +config_setting_and_platform( name = "android_x86", - values = { - "crosstool_top": "//external:android/crosstool", - "cpu": "x86", - }, + constraint_values = [ + "@platforms//os:android", + "@platforms//cpu:x86_32", + ], visibility = ["//visibility:public"], ) -config_setting( +# Android x86 64-bit. +config_setting_and_platform( name = "android_x86_64", - values = { - "crosstool_top": "//external:android/crosstool", - "cpu": "x86_64", - }, + constraint_values = [ + "@platforms//os:android", + "@platforms//cpu:x86_64", + ], visibility = ["//visibility:public"], ) -config_setting( - name = "android_armeabi", - values = { - "crosstool_top": "//external:android/crosstool", - "cpu": "armeabi", - }, - visibility = ["//visibility:public"], -) - -config_setting( +# Android ARMv7. +config_setting_and_platform( name = "android_arm", - values = { - "crosstool_top": "//external:android/crosstool", - "cpu": "armeabi-v7a", - }, + constraint_values = [ + "@platforms//os:android", + "@platforms//cpu:armv7", + ], visibility = ["//visibility:public"], ) -config_setting( +# Android ARM64. +config_setting_and_platform( name = "android_arm64", - values = { - "crosstool_top": "//external:android/crosstool", - "cpu": "arm64-v8a", - }, + constraint_values = [ + "@platforms//os:android", + "@platforms//cpu:arm64", + ], visibility = ["//visibility:public"], ) -# Note: this cannot just match "apple_platform_type": "macos" because that option -# defaults to "macos" even when building on Linux! -alias( +# Generic MacOS. +config_setting( name = "macos", - actual = select({ - ":macos_i386": ":macos_i386", - ":macos_x86_64": ":macos_x86_64", - ":macos_arm64": ":macos_arm64", - "//conditions:default": ":macos_i386", # Arbitrarily chosen from above. - }), + constraint_values = [ + "@platforms//os:macos", + ], visibility = ["//visibility:public"], ) -# Note: this also matches on crosstool_top so that it does not produce ambiguous -# selectors when used together with "android". +# MacOS x86 64-bit. +config_setting_and_platform( + name = "macos_x86_64", + constraint_values = [ + "@platforms//os:macos", + "@platforms//cpu:x86_64", + ], + visibility = ["//visibility:public"], +) + +# MacOS ARM64. +config_setting_and_platform( + name = "macos_arm64", + constraint_values = [ + "@platforms//os:macos", + "@platforms//cpu:arm64", + ], + visibility = ["//visibility:public"], +) + +# Generic iOS. config_setting( name = "ios", - values = { - "crosstool_top": "@bazel_tools//tools/cpp:toolchain", - "apple_platform_type": "ios", - }, + constraint_values = [ + "@platforms//os:ios", + ], + visibility = ["//visibility:public"], +) + +# iOS device ARM32. +config_setting_and_platform( + name = "ios_armv7", + constraint_values = [ + "@platforms//os:ios", + "@platforms//cpu:arm", + ], + visibility = ["//visibility:public"], +) + +# iOS device ARM64. +config_setting_and_platform( + name = "ios_arm64", + constraint_values = [ + "@platforms//os:ios", + "@platforms//cpu:arm64", + ], + visibility = ["//visibility:public"], +) + +# iOS device ARM64E. +config_setting_and_platform( + name = "ios_arm64e", + constraint_values = [ + "@platforms//os:ios", + "@platforms//cpu:arm64e", + ], + visibility = ["//visibility:public"], +) + +# iOS simulator x86 32-bit. +config_setting_and_platform( + name = "ios_i386", + constraint_values = [ + "@platforms//os:ios", + "@platforms//cpu:x86_32", + "@build_bazel_apple_support//constraints:simulator", + ], + visibility = ["//visibility:public"], +) + +# iOS simulator x86 64-bit. +config_setting_and_platform( + name = "ios_x86_64", + constraint_values = [ + "@platforms//os:ios", + "@platforms//cpu:x86_64", + "@build_bazel_apple_support//constraints:simulator", + ], + visibility = ["//visibility:public"], +) + +# iOS simulator ARM64. +config_setting_and_platform( + name = "ios_sim_arm64", + constraint_values = [ + "@platforms//os:ios", + "@platforms//cpu:arm64", + "@build_bazel_apple_support//constraints:simulator", + ], visibility = ["//visibility:public"], ) @@ -102,51 +176,24 @@ alias( visibility = ["//visibility:public"], ) -config_setting( - name = "macos_i386", - values = { - "apple_platform_type": "macos", - "cpu": "darwin", - }, - visibility = ["//visibility:public"], -) - -config_setting( - name = "macos_x86_64", - values = { - "apple_platform_type": "macos", - "cpu": "darwin_x86_64", - }, - visibility = ["//visibility:public"], -) - -config_setting( - name = "macos_arm64", - values = { - "apple_platform_type": "macos", - "cpu": "darwin_arm64", - }, - visibility = ["//visibility:public"], -) - -[ - config_setting( - name = arch, - values = {"cpu": arch}, - visibility = ["//visibility:public"], - ) - for arch in [ - "ios_i386", - "ios_x86_64", - "ios_armv7", - "ios_arm64", - "ios_arm64e", - ] -] - -config_setting( +# Windows 64-bit. +config_setting_and_platform( name = "windows", - values = {"cpu": "x64_windows"}, + constraint_values = [ + "@platforms//os:windows", + "@platforms//cpu:x86_64", + ], + visibility = ["//visibility:public"], +) + +# Linux 64-bit. +config_setting_and_platform( + name = "linux", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:x86_64", + ], + visibility = ["//visibility:public"], ) exports_files( diff --git a/mediapipe/calculators/audio/BUILD b/mediapipe/calculators/audio/BUILD index 4a8f0f598..c12583e5b 100644 --- a/mediapipe/calculators/audio/BUILD +++ b/mediapipe/calculators/audio/BUILD @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Placeholder: load py_proto_library load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") licenses(["notice"]) @@ -145,6 +146,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", "//mediapipe/util:time_series_util", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_google_audio_tools//audio/dsp/mfcc", "@eigen_archive//:eigen3", @@ -163,8 +165,9 @@ cc_library( "//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:logging", "//mediapipe/util:time_series_util", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", "@com_google_audio_tools//audio/dsp:resampler", "@com_google_audio_tools//audio/dsp:resampler_q", @@ -185,6 +188,7 @@ cc_library( "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:status", "//mediapipe/util:time_series_util", + "@com_google_absl//absl/log:absl_check", ], alwayslink = 1, ) @@ -219,13 +223,12 @@ cc_library( deps = [ ":time_series_framer_calculator_cc_proto", "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:timestamp", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:time_series_header_cc_proto", - "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", - "//mediapipe/framework/port:status", "//mediapipe/util:time_series_util", + "@com_google_absl//absl/log:absl_check", "@com_google_audio_tools//audio/dsp:window_functions", "@eigen_archive//:eigen3", ], @@ -296,6 +299,7 @@ cc_test( "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:status", "//mediapipe/util:time_series_test_util", + "@com_google_absl//absl/log:absl_log", "@com_google_audio_tools//audio/dsp:number_util", "@eigen_archive//:eigen3", ], @@ -319,6 +323,21 @@ cc_test( ], ) +cc_binary( + name = "time_series_framer_calculator_benchmark", + srcs = ["time_series_framer_calculator_benchmark.cc"], + deps = [ + ":time_series_framer_calculator", + ":time_series_framer_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:packet", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:time_series_header_cc_proto", + "@com_google_absl//absl/log:absl_check", + "@com_google_benchmark//:benchmark", + ], +) + cc_test( name = "time_series_framer_calculator_test", srcs = ["time_series_framer_calculator_test.cc"], @@ -333,6 +352,7 @@ cc_test( "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:status", "//mediapipe/util:time_series_test_util", + "@com_google_absl//absl/log:absl_log", "@com_google_audio_tools//audio/dsp:window_functions", "@eigen_archive//:eigen3", ], diff --git a/mediapipe/calculators/audio/mfcc_mel_calculators.cc b/mediapipe/calculators/audio/mfcc_mel_calculators.cc index a63b9d6ea..ec936c844 100644 --- a/mediapipe/calculators/audio/mfcc_mel_calculators.cc +++ b/mediapipe/calculators/audio/mfcc_mel_calculators.cc @@ -23,6 +23,7 @@ #include #include "Eigen/Core" +#include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" @@ -138,7 +139,7 @@ absl::Status FramewiseTransformCalculatorBase::Process(CalculatorContext* cc) { TransformFrame(input_frame, &output_frame); // Copy output from vector to Eigen::Vector. - CHECK_EQ(output_frame.size(), num_output_channels_); + ABSL_CHECK_EQ(output_frame.size(), num_output_channels_); Eigen::Map output_frame_map(&output_frame[0], output_frame.size(), 1); output->col(frame) = output_frame_map.cast(); diff --git a/mediapipe/calculators/audio/rational_factor_resample_calculator.cc b/mediapipe/calculators/audio/rational_factor_resample_calculator.cc index 1a4210c30..e01bf5269 100644 --- a/mediapipe/calculators/audio/rational_factor_resample_calculator.cc +++ b/mediapipe/calculators/audio/rational_factor_resample_calculator.cc @@ -16,6 +16,8 @@ #include "mediapipe/calculators/audio/rational_factor_resample_calculator.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "audio/dsp/resampler_q.h" using audio_dsp::Resampler; @@ -45,9 +47,9 @@ void CopyVectorToChannel(const std::vector& vec, Matrix* matrix, if (matrix->cols() == 0) { matrix->resize(matrix->rows(), vec.size()); } else { - CHECK_EQ(vec.size(), matrix->cols()); + ABSL_CHECK_EQ(vec.size(), matrix->cols()); } - CHECK_LT(channel, matrix->rows()); + ABSL_CHECK_LT(channel, matrix->rows()); matrix->row(channel) = Eigen::Map(vec.data(), vec.size()); } @@ -77,7 +79,7 @@ absl::Status RationalFactorResampleCalculator::Open(CalculatorContext* cc) { r = ResamplerFromOptions(source_sample_rate_, target_sample_rate_, resample_options); if (!r) { - LOG(ERROR) << "Failed to initialize resampler."; + ABSL_LOG(ERROR) << "Failed to initialize resampler."; return absl::UnknownError("Failed to initialize resampler."); } } diff --git a/mediapipe/calculators/audio/rational_factor_resample_calculator.h b/mediapipe/calculators/audio/rational_factor_resample_calculator.h index 325886dc7..2c9df30b4 100644 --- a/mediapipe/calculators/audio/rational_factor_resample_calculator.h +++ b/mediapipe/calculators/audio/rational_factor_resample_calculator.h @@ -27,7 +27,6 @@ #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/time_series_header.pb.h" #include "mediapipe/framework/port/integral_types.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/util/time_series_util.h" namespace mediapipe { diff --git a/mediapipe/calculators/audio/spectrogram_calculator.cc b/mediapipe/calculators/audio/spectrogram_calculator.cc index 939e721ab..7f6528ec1 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator.cc +++ b/mediapipe/calculators/audio/spectrogram_calculator.cc @@ -210,6 +210,23 @@ REGISTER_CALCULATOR(SpectrogramCalculator); // Factor to convert ln(SQUARED_MAGNITUDE) to deciBels = 10.0/ln(10.0). const float SpectrogramCalculator::kLnSquaredMagnitudeToDb = 4.342944819032518; +namespace { +std::unique_ptr MakeWindowFun( + const SpectrogramCalculatorOptions::WindowType window_type) { + switch (window_type) { + // The cosine window and square root of Hann are equivalent. + case SpectrogramCalculatorOptions::COSINE: + case SpectrogramCalculatorOptions::SQRT_HANN: + return std::make_unique(); + case SpectrogramCalculatorOptions::HANN: + return std::make_unique(); + case SpectrogramCalculatorOptions::HAMMING: + return std::make_unique(); + } + return nullptr; +} +} // namespace + absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) { SpectrogramCalculatorOptions spectrogram_options = cc->Options(); @@ -266,28 +283,14 @@ absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) { output_scale_ = spectrogram_options.output_scale(); - std::vector window; - switch (spectrogram_options.window_type()) { - case SpectrogramCalculatorOptions::COSINE: - audio_dsp::CosineWindow().GetPeriodicSamples(frame_duration_samples_, - &window); - break; - case SpectrogramCalculatorOptions::HANN: - audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_, - &window); - break; - case SpectrogramCalculatorOptions::HAMMING: - audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples_, - &window); - break; - case SpectrogramCalculatorOptions::SQRT_HANN: { - audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_, - &window); - absl::c_transform(window, window.begin(), - [](double x) { return std::sqrt(x); }); - break; - } + auto window_fun = MakeWindowFun(spectrogram_options.window_type()); + if (window_fun == nullptr) { + return absl::Status(absl::StatusCode::kInvalidArgument, + absl::StrCat("Invalid window type ", + spectrogram_options.window_type())); } + std::vector window; + window_fun->GetPeriodicSamples(frame_duration_samples_, &window); // Propagate settings down to the actual Spectrogram object. spectrogram_generators_.clear(); @@ -433,9 +436,9 @@ absl::Status SpectrogramCalculator::ProcessVectorToOutput( absl::Status SpectrogramCalculator::ProcessVector(const Matrix& input_stream, CalculatorContext* cc) { switch (output_type_) { - // These blocks deliberately ignore clang-format to preserve the - // "silhouette" of the different cases. - // clang-format off + // These blocks deliberately ignore clang-format to preserve the + // "silhouette" of the different cases. + // clang-format off case SpectrogramCalculatorOptions::COMPLEX: { return ProcessVectorToOutput( input_stream, diff --git a/mediapipe/calculators/audio/spectrogram_calculator.proto b/mediapipe/calculators/audio/spectrogram_calculator.proto index ddfca1d1c..d8bca3f76 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator.proto +++ b/mediapipe/calculators/audio/spectrogram_calculator.proto @@ -68,7 +68,7 @@ message SpectrogramCalculatorOptions { HANN = 0; HAMMING = 1; COSINE = 2; - SQRT_HANN = 4; + SQRT_HANN = 4; // Alias of COSINE. } optional WindowType window_type = 6 [default = HANN]; diff --git a/mediapipe/calculators/audio/spectrogram_calculator_test.cc b/mediapipe/calculators/audio/spectrogram_calculator_test.cc index b35f30583..14cd74a3c 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator_test.cc +++ b/mediapipe/calculators/audio/spectrogram_calculator_test.cc @@ -22,6 +22,7 @@ #include #include "Eigen/Core" +#include "absl/log/absl_log.h" #include "audio/dsp/number_util.h" #include "mediapipe/calculators/audio/spectrogram_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -882,11 +883,11 @@ void BM_ProcessDC(benchmark::State& state) { const CalculatorRunner::StreamContents& output = runner.Outputs().Index(0); const Matrix& output_matrix = output.packets[0].Get(); - LOG(INFO) << "Output matrix=" << output_matrix.rows() << "x" - << output_matrix.cols(); - LOG(INFO) << "First values=" << output_matrix(0, 0) << ", " - << output_matrix(1, 0) << ", " << output_matrix(2, 0) << ", " - << output_matrix(3, 0); + ABSL_LOG(INFO) << "Output matrix=" << output_matrix.rows() << "x" + << output_matrix.cols(); + ABSL_LOG(INFO) << "First values=" << output_matrix(0, 0) << ", " + << output_matrix(1, 0) << ", " << output_matrix(2, 0) << ", " + << output_matrix(3, 0); } BENCHMARK(BM_ProcessDC); diff --git a/mediapipe/calculators/audio/stabilized_log_calculator.cc b/mediapipe/calculators/audio/stabilized_log_calculator.cc index 0c697a196..a7de6a37c 100644 --- a/mediapipe/calculators/audio/stabilized_log_calculator.cc +++ b/mediapipe/calculators/audio/stabilized_log_calculator.cc @@ -18,6 +18,7 @@ #include #include +#include "absl/log/absl_check.h" #include "mediapipe/calculators/audio/stabilized_log_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/matrix.h" @@ -59,7 +60,7 @@ class StabilizedLogCalculator : public CalculatorBase { output_scale_ = stabilized_log_calculator_options.output_scale(); check_nonnegativity_ = stabilized_log_calculator_options.check_nonnegativity(); - CHECK_GE(stabilizer_, 0.0) + ABSL_CHECK_GE(stabilizer_, 0.0) << "stabilizer must be >= 0.0, received a value of " << stabilizer_; // If the input packets have a header, propagate the header to the output. diff --git a/mediapipe/calculators/audio/time_series_framer_calculator.cc b/mediapipe/calculators/audio/time_series_framer_calculator.cc index a200b898a..d8cda5149 100644 --- a/mediapipe/calculators/audio/time_series_framer_calculator.cc +++ b/mediapipe/calculators/audio/time_series_framer_calculator.cc @@ -15,19 +15,17 @@ // Defines TimeSeriesFramerCalculator. #include -#include -#include -#include +#include #include "Eigen/Core" +#include "absl/log/absl_check.h" #include "audio/dsp/window_functions.h" #include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/time_series_header.pb.h" -#include "mediapipe/framework/port/integral_types.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/timestamp.h" #include "mediapipe/util/time_series_util.h" namespace mediapipe { @@ -88,11 +86,6 @@ class TimeSeriesFramerCalculator : public CalculatorBase { absl::Status Close(CalculatorContext* cc) override; private: - // Adds input data to the internal buffer. - void EnqueueInput(CalculatorContext* cc); - // Constructs and emits framed output packets. - void FrameOutput(CalculatorContext* cc); - Timestamp CurrentOutputTimestamp() { if (use_local_timestamp_) { return current_timestamp_; @@ -106,21 +99,13 @@ class TimeSeriesFramerCalculator : public CalculatorBase { Timestamp::kTimestampUnitsPerSecond); } - // Returns the timestamp of a sample on a base, which is usually the time - // stamp of a packet. - Timestamp CurrentSampleTimestamp(const Timestamp& timestamp_base, - int64_t number_of_samples) { - return timestamp_base + round(number_of_samples / sample_rate_ * - Timestamp::kTimestampUnitsPerSecond); - } - // The number of input samples to advance after the current output frame is // emitted. int next_frame_step_samples() const { // All numbers are in input samples. const int64_t current_output_frame_start = static_cast( round(cumulative_output_frames_ * average_frame_step_samples_)); - CHECK_EQ(current_output_frame_start, cumulative_completed_samples_); + ABSL_CHECK_EQ(current_output_frame_start, cumulative_completed_samples_); const int64_t next_output_frame_start = static_cast( round((cumulative_output_frames_ + 1) * average_frame_step_samples_)); return next_output_frame_start - current_output_frame_start; @@ -142,61 +127,174 @@ class TimeSeriesFramerCalculator : public CalculatorBase { Timestamp initial_input_timestamp_; // The current timestamp is updated along with the incoming packets. Timestamp current_timestamp_; - int num_channels_; - // Each entry in this deque consists of a single sample, i.e. a - // single column vector, and its timestamp. - std::deque> sample_buffer_; + // Samples are buffered in a vector of sample blocks. + class SampleBlockBuffer { + public: + // Initializes the buffer. + void Init(double sample_rate, int num_channels) { + ts_units_per_sample_ = Timestamp::kTimestampUnitsPerSecond / sample_rate; + num_channels_ = num_channels; + num_samples_ = 0; + first_block_offset_ = 0; + } + + // Number of channels, equal to the number of rows in each Matrix. + int num_channels() const { return num_channels_; } + // Total number of available samples over all blocks. + int num_samples() const { return num_samples_; } + + // Pushes a new block of samples on the back of the buffer with `timestamp` + // being the input timestamp of the packet containing the Matrix. + void Push(const Matrix& samples, Timestamp timestamp); + // Copies `count` samples from the front of the buffer. If there are fewer + // samples than this, the result is zero padded to have `count` samples. + // The timestamp of the last copied sample is written to *last_timestamp. + // This output is used below to update `current_timestamp_`, which is only + // used when `use_local_timestamp` is true. + Matrix CopySamples(int count, Timestamp* last_timestamp) const; + // Drops `count` samples from the front of the buffer. If `count` exceeds + // `num_samples()`, the buffer is emptied. Returns how many samples were + // dropped. + int DropSamples(int count); + + private: + struct Block { + // Matrix of num_channels rows by num_samples columns, a block of possibly + // multiple samples. + Matrix samples; + // Timestamp of the first sample in the Block. This comes from the input + // packet's timestamp that contains this Matrix. + Timestamp timestamp; + + Block() : timestamp(Timestamp::Unstarted()) {} + Block(const Matrix& samples, Timestamp timestamp) + : samples(samples), timestamp(timestamp) {} + int num_samples() const { return samples.cols(); } + }; + std::vector blocks_; + // Number of timestamp units per sample. Used to compute timestamps as + // nth sample timestamp = base_timestamp + round(ts_units_per_sample_ * n). + double ts_units_per_sample_; + // Number of rows in each Matrix. + int num_channels_; + // The total number of samples over all blocks, equal to + // (sum_i blocks_[i].num_samples()) - first_block_offset_. + int num_samples_; + // The number of samples in the first block that have been discarded. This + // way we can cheaply represent "partially discarding" a block. + int first_block_offset_; + } sample_buffer_; bool use_window_; - Matrix window_; + Eigen::RowVectorXf window_; bool use_local_timestamp_; }; REGISTER_CALCULATOR(TimeSeriesFramerCalculator); -void TimeSeriesFramerCalculator::EnqueueInput(CalculatorContext* cc) { - const Matrix& input_frame = cc->Inputs().Index(0).Get(); - - for (int i = 0; i < input_frame.cols(); ++i) { - sample_buffer_.emplace_back(std::make_pair( - input_frame.col(i), CurrentSampleTimestamp(cc->InputTimestamp(), i))); - } +void TimeSeriesFramerCalculator::SampleBlockBuffer::Push(const Matrix& samples, + Timestamp timestamp) { + num_samples_ += samples.cols(); + blocks_.emplace_back(samples, timestamp); } -void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) { - while (sample_buffer_.size() >= +Matrix TimeSeriesFramerCalculator::SampleBlockBuffer::CopySamples( + int count, Timestamp* last_timestamp) const { + Matrix copied(num_channels_, count); + + if (!blocks_.empty()) { + int num_copied = 0; + // First block has an offset for samples that have been discarded. + int offset = first_block_offset_; + int n; + Timestamp last_block_ts; + int last_sample_index; + + for (auto it = blocks_.begin(); it != blocks_.end() && count > 0; ++it) { + n = std::min(it->num_samples() - offset, count); + // Copy `n` samples from the next block. + copied.middleCols(num_copied, n) = it->samples.middleCols(offset, n); + count -= n; + num_copied += n; + last_block_ts = it->timestamp; + last_sample_index = offset + n - 1; + offset = 0; // No samples have been discarded in subsequent blocks. + } + + // Compute the timestamp of the last copied sample. + *last_timestamp = + last_block_ts + std::round(ts_units_per_sample_ * last_sample_index); + } + + if (count > 0) { + copied.rightCols(count).setZero(); // Zero pad if needed. + } + + return copied; +} + +int TimeSeriesFramerCalculator::SampleBlockBuffer::DropSamples(int count) { + if (blocks_.empty()) { + return 0; + } + + auto block_it = blocks_.begin(); + if (first_block_offset_ + count < block_it->num_samples()) { + // `count` is less than the remaining samples in the first block. + first_block_offset_ += count; + num_samples_ -= count; + return count; + } + + int num_samples_dropped = block_it->num_samples() - first_block_offset_; + count -= num_samples_dropped; + first_block_offset_ = 0; + + for (++block_it; block_it != blocks_.end(); ++block_it) { + if (block_it->num_samples() > count) { + break; + } + num_samples_dropped += block_it->num_samples(); + count -= block_it->num_samples(); + } + + blocks_.erase(blocks_.begin(), block_it); // Drop whole blocks. + if (!blocks_.empty()) { + first_block_offset_ = count; // Drop part of the next block. + num_samples_dropped += count; + } + + num_samples_ -= num_samples_dropped; + return num_samples_dropped; +} + +absl::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) { + if (initial_input_timestamp_ == Timestamp::Unstarted()) { + initial_input_timestamp_ = cc->InputTimestamp(); + current_timestamp_ = initial_input_timestamp_; + } + + // Add input data to the internal buffer. + sample_buffer_.Push(cc->Inputs().Index(0).Get(), + cc->InputTimestamp()); + + // Construct and emit framed output packets. + while (sample_buffer_.num_samples() >= frame_duration_samples_ + samples_still_to_drop_) { - while (samples_still_to_drop_ > 0) { - sample_buffer_.pop_front(); - --samples_still_to_drop_; - } + sample_buffer_.DropSamples(samples_still_to_drop_); + Matrix output_frame = sample_buffer_.CopySamples(frame_duration_samples_, + ¤t_timestamp_); const int frame_step_samples = next_frame_step_samples(); - std::unique_ptr output_frame( - new Matrix(num_channels_, frame_duration_samples_)); - for (int i = 0; i < std::min(frame_step_samples, frame_duration_samples_); - ++i) { - output_frame->col(i) = sample_buffer_.front().first; - current_timestamp_ = sample_buffer_.front().second; - sample_buffer_.pop_front(); - } - const int frame_overlap_samples = - frame_duration_samples_ - frame_step_samples; - if (frame_overlap_samples > 0) { - for (int i = 0; i < frame_overlap_samples; ++i) { - output_frame->col(i + frame_step_samples) = sample_buffer_[i].first; - current_timestamp_ = sample_buffer_[i].second; - } - } else { - samples_still_to_drop_ = -frame_overlap_samples; - } + samples_still_to_drop_ = frame_step_samples; if (use_window_) { - *output_frame = (output_frame->array() * window_.array()).matrix(); + // Apply the window to each row of output_frame. + output_frame.array().rowwise() *= window_.array(); } - cc->Outputs().Index(0).Add(output_frame.release(), - CurrentOutputTimestamp()); + cc->Outputs().Index(0).AddPacket(MakePacket(std::move(output_frame)) + .At(CurrentOutputTimestamp())); ++cumulative_output_frames_; cumulative_completed_samples_ += frame_step_samples; } @@ -206,35 +304,18 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) { // fact to enable packet queueing optimizations. cc->Outputs().Index(0).SetNextTimestampBound(CumulativeOutputTimestamp()); } -} - -absl::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) { - if (initial_input_timestamp_ == Timestamp::Unstarted()) { - initial_input_timestamp_ = cc->InputTimestamp(); - current_timestamp_ = initial_input_timestamp_; - } - - EnqueueInput(cc); - FrameOutput(cc); return absl::OkStatus(); } absl::Status TimeSeriesFramerCalculator::Close(CalculatorContext* cc) { - while (samples_still_to_drop_ > 0 && !sample_buffer_.empty()) { - sample_buffer_.pop_front(); - --samples_still_to_drop_; - } - if (!sample_buffer_.empty() && pad_final_packet_) { - std::unique_ptr output_frame(new Matrix); - output_frame->setZero(num_channels_, frame_duration_samples_); - for (int i = 0; i < sample_buffer_.size(); ++i) { - output_frame->col(i) = sample_buffer_[i].first; - current_timestamp_ = sample_buffer_[i].second; - } + sample_buffer_.DropSamples(samples_still_to_drop_); - cc->Outputs().Index(0).Add(output_frame.release(), - CurrentOutputTimestamp()); + if (sample_buffer_.num_samples() > 0 && pad_final_packet_) { + Matrix output_frame = sample_buffer_.CopySamples(frame_duration_samples_, + ¤t_timestamp_); + cc->Outputs().Index(0).AddPacket(MakePacket(std::move(output_frame)) + .At(CurrentOutputTimestamp())); } return absl::OkStatus(); @@ -258,7 +339,7 @@ absl::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) { cc->Inputs().Index(0).Header(), &input_header)); sample_rate_ = input_header.sample_rate(); - num_channels_ = input_header.num_channels(); + sample_buffer_.Init(sample_rate_, input_header.num_channels()); frame_duration_samples_ = time_series_util::SecondsToSamples( framer_options.frame_duration_seconds(), sample_rate_); RET_CHECK_GT(frame_duration_samples_, 0) @@ -312,9 +393,8 @@ absl::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) { } if (use_window_) { - window_ = Matrix::Ones(num_channels_, 1) * - Eigen::Map(window_vector.data(), 1, - frame_duration_samples_) + window_ = Eigen::Map(window_vector.data(), + frame_duration_samples_) .cast(); } use_local_timestamp_ = framer_options.use_local_timestamp(); diff --git a/mediapipe/calculators/audio/time_series_framer_calculator_benchmark.cc b/mediapipe/calculators/audio/time_series_framer_calculator_benchmark.cc new file mode 100644 index 000000000..6eada1ad3 --- /dev/null +++ b/mediapipe/calculators/audio/time_series_framer_calculator_benchmark.cc @@ -0,0 +1,93 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Benchmark for TimeSeriesFramerCalculator. +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "benchmark/benchmark.h" +#include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/packet.h" + +using ::mediapipe::Matrix; + +void BM_TimeSeriesFramerCalculator(benchmark::State& state) { + constexpr float kSampleRate = 32000.0; + constexpr int kNumChannels = 2; + constexpr int kFrameDurationSeconds = 5.0; + std::mt19937 rng(0 /*seed*/); + // Input around a half second's worth of samples at a time. + std::uniform_int_distribution input_size_dist(15000, 17000); + // Generate a pool of random blocks of samples up front. + std::vector sample_pool; + sample_pool.reserve(20); + for (int i = 0; i < 20; ++i) { + sample_pool.push_back(Matrix::Random(kNumChannels, input_size_dist(rng))); + } + std::uniform_int_distribution pool_index_dist(0, sample_pool.size() - 1); + + mediapipe::CalculatorGraphConfig config; + config.add_input_stream("input"); + config.add_output_stream("output"); + auto* node = config.add_node(); + node->set_calculator("TimeSeriesFramerCalculator"); + node->add_input_stream("input"); + node->add_output_stream("output"); + mediapipe::TimeSeriesFramerCalculatorOptions* options = + node->mutable_options()->MutableExtension( + mediapipe::TimeSeriesFramerCalculatorOptions::ext); + options->set_frame_duration_seconds(kFrameDurationSeconds); + + for (auto _ : state) { + state.PauseTiming(); // Pause benchmark timing. + + // Prepare input packets of random blocks of samples. + std::vector input_packets; + input_packets.reserve(32); + float t = 0; + for (int i = 0; i < 32; ++i) { + auto samples = + std::make_unique(sample_pool[pool_index_dist(rng)]); + const int num_samples = samples->cols(); + input_packets.push_back(mediapipe::Adopt(samples.release()) + .At(mediapipe::Timestamp::FromSeconds(t))); + t += num_samples / kSampleRate; + } + // Initialize graph. + mediapipe::CalculatorGraph graph; + ABSL_CHECK_OK(graph.Initialize(config)); + // Prepare input header. + auto header = std::make_unique(); + header->set_sample_rate(kSampleRate); + header->set_num_channels(kNumChannels); + + state.ResumeTiming(); // Resume benchmark timing. + + ABSL_CHECK_OK(graph.StartRun({}, {{"input", Adopt(header.release())}})); + for (auto& packet : input_packets) { + ABSL_CHECK_OK(graph.AddPacketToInputStream("input", packet)); + } + ABSL_CHECK(!graph.HasError()); + ABSL_CHECK_OK(graph.CloseAllInputStreams()); + ABSL_CHECK_OK(graph.WaitUntilIdle()); + } +} +BENCHMARK(BM_TimeSeriesFramerCalculator); + +BENCHMARK_MAIN(); diff --git a/mediapipe/calculators/audio/time_series_framer_calculator_test.cc b/mediapipe/calculators/audio/time_series_framer_calculator_test.cc index 72e9c88f7..fe42ecb12 100644 --- a/mediapipe/calculators/audio/time_series_framer_calculator_test.cc +++ b/mediapipe/calculators/audio/time_series_framer_calculator_test.cc @@ -19,6 +19,7 @@ #include #include "Eigen/Core" +#include "absl/log/absl_log.h" #include "audio/dsp/window_functions.h" #include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -186,11 +187,12 @@ class TimeSeriesFramerCalculatorTest const int num_unique_output_samples = round((output().packets.size() - 1) * frame_step_samples) + frame_duration_samples; - LOG(INFO) << "packets.size()=" << output().packets.size() - << " frame_duration_samples=" << frame_duration_samples - << " frame_step_samples=" << frame_step_samples - << " num_input_samples_=" << num_input_samples_ - << " num_unique_output_samples=" << num_unique_output_samples; + ABSL_LOG(INFO) << "packets.size()=" << output().packets.size() + << " frame_duration_samples=" << frame_duration_samples + << " frame_step_samples=" << frame_step_samples + << " num_input_samples_=" << num_input_samples_ + << " num_unique_output_samples=" + << num_unique_output_samples; const int num_padding_samples = num_unique_output_samples - num_input_samples_; if (options_.pad_final_packet()) { diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index d5ba6c74f..02efc84ea 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -117,6 +117,7 @@ mediapipe_proto_library( "//mediapipe/framework:calculator_proto", "//mediapipe/framework/formats:classification_proto", "//mediapipe/framework/formats:landmark_proto", + "//mediapipe/framework/formats:matrix_data_proto", "//mediapipe/framework/formats:time_series_header_proto", ], ) @@ -192,17 +193,19 @@ cc_library( "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_contract", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:collection_item_id", "//mediapipe/framework:packet", "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", - "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/gpu:gpu_buffer", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", ], alwayslink = 1, ) @@ -215,18 +218,18 @@ cc_library( "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_contract", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:collection_item_id", "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", - "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:ret_check", - "//mediapipe/framework/port:status", + "//mediapipe/gpu:gpu_buffer", "//mediapipe/util:render_data_cc_proto", + "@com_google_absl//absl/status", "@org_tensorflow//tensorflow/lite:framework", ], alwayslink = 1, @@ -287,6 +290,7 @@ cc_library( "//mediapipe/framework/api2:node", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:integral_types", @@ -295,8 +299,7 @@ cc_library( "//mediapipe/util:render_data_cc_proto", "@org_tensorflow//tensorflow/lite:framework", ] + select({ - "//mediapipe/gpu:disable_gpu": [], - "//mediapipe:ios": [], + ":ios_or_disable_gpu": [], "//conditions:default": [ "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", ], @@ -378,17 +381,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "clip_detection_vector_size_calculator", - srcs = ["clip_detection_vector_size_calculator.cc"], - deps = [ - ":clip_vector_size_calculator", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:detection_cc_proto", - ], - alwayslink = 1, -) - cc_test( name = "clip_vector_size_calculator_test", srcs = ["clip_vector_size_calculator_test.cc"], @@ -590,6 +582,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_util", + "@com_google_absl//absl/log:absl_check", ], alwayslink = 1, ) @@ -605,6 +598,7 @@ cc_test( "//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", ], ) @@ -637,6 +631,7 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_log", ], alwayslink = 1, ) @@ -784,10 +779,11 @@ cc_library( "//mediapipe/framework/deps:random", "//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_util", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -843,6 +839,7 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/tool:validate_type", + "@com_google_absl//absl/log:absl_check", "@eigen_archive//:eigen3", ], ) @@ -904,6 +901,7 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:rect_cc_proto", @@ -1029,6 +1027,7 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_log", ], alwayslink = 1, ) @@ -1067,6 +1066,7 @@ cc_test( "//mediapipe/framework:calculator_runner", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/log:absl_log", ], ) @@ -1113,6 +1113,7 @@ cc_library( "//mediapipe/framework/api2:node", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_log", ], alwayslink = 1, ) @@ -1136,6 +1137,7 @@ cc_library( deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", + "//mediapipe/framework/api2:node", "//mediapipe/framework/port:status", ], alwayslink = 1, @@ -1164,6 +1166,7 @@ cc_library( "//mediapipe/framework:collection_item_id", "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:ret_check", @@ -1238,6 +1241,7 @@ cc_library( "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", ], diff --git a/mediapipe/calculators/core/begin_loop_calculator.cc b/mediapipe/calculators/core/begin_loop_calculator.cc index 441c66937..d030bbbde 100644 --- a/mediapipe/calculators/core/begin_loop_calculator.cc +++ b/mediapipe/calculators/core/begin_loop_calculator.cc @@ -17,10 +17,13 @@ #include #include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/gpu/gpu_buffer.h" namespace mediapipe { @@ -60,4 +63,22 @@ REGISTER_CALCULATOR(BeginLoopUint64tCalculator); typedef BeginLoopCalculator> BeginLoopTensorCalculator; REGISTER_CALCULATOR(BeginLoopTensorCalculator); +// A calculator to process std::vector. +typedef BeginLoopCalculator> + BeginLoopImageFrameCalculator; +REGISTER_CALCULATOR(BeginLoopImageFrameCalculator); + +// A calculator to process std::vector. +typedef BeginLoopCalculator> + BeginLoopGpuBufferCalculator; +REGISTER_CALCULATOR(BeginLoopGpuBufferCalculator); + +// A calculator to process std::vector. +typedef BeginLoopCalculator> BeginLoopImageCalculator; +REGISTER_CALCULATOR(BeginLoopImageCalculator); + +// A calculator to process std::vector. +typedef BeginLoopCalculator> BeginLoopFloatCalculator; +REGISTER_CALCULATOR(BeginLoopFloatCalculator); + } // namespace mediapipe diff --git a/mediapipe/calculators/core/begin_loop_calculator.h b/mediapipe/calculators/core/begin_loop_calculator.h index 81fff39da..c0b3022d4 100644 --- a/mediapipe/calculators/core/begin_loop_calculator.h +++ b/mediapipe/calculators/core/begin_loop_calculator.h @@ -15,47 +15,57 @@ #ifndef MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_ #define MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_ +#include "absl/status/status.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_framework.h" -#include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/packet.h" -#include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/framework/port/status.h" -#include "mediapipe/framework/port/status_macros.h" namespace mediapipe { // Calculator for implementing loops on iterable collections inside a MediaPipe -// graph. +// graph. Assume InputIterT is an iterable for type InputT, and OutputIterT is +// an iterable for type OutputT, e.g. vector and vector. +// First, instantiate specializations in the loop calculators' implementations +// if missing: +// BeginLoopInputTCalculator = BeginLoopCalculator +// EndLoopOutputTCalculator = EndLoopCalculator +// Then, the following graph transforms an item of type InputIterT to an +// OutputIterT by applying InputToOutputConverter to every element: // -// It is designed to be used like: -// -// node { -// calculator: "BeginLoopWithIterableCalculator" -// input_stream: "ITERABLE:input_iterable" # IterableT @ext_ts -// output_stream: "ITEM:input_element" # ItemT @loop_internal_ts -// output_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts +// node { # Type @timestamp +// calculator: "BeginLoopInputTCalculator" +// input_stream: "ITERABLE:input_iterable" # InputIterT @iterable_ts +// input_stream: "CLONE:extra_input" # ExtraT @extra_ts +// output_stream: "ITEM:input_iterator" # InputT @loop_internal_ts +// output_stream: "CLONE:cloned_extra_input" # ExtraT @loop_internal_ts +// output_stream: "BATCH_END:iterable_ts" # Timestamp @loop_internal_ts // } // // node { -// calculator: "ElementToBlaConverterSubgraph" -// input_stream: "ITEM:input_to_loop_body" # ItemT @loop_internal_ts -// output_stream: "BLA:output_of_loop_body" # ItemU @loop_internal_ts +// calculator: "InputToOutputConverter" +// input_stream: "INPUT:input_iterator" # InputT @loop_internal_ts +// input_stream: "EXTRA:cloned_extra_input" # ExtraT @loop_internal_ts +// output_stream: "OUTPUT:output_iterator" # OutputT @loop_internal_ts // } // // node { -// calculator: "EndLoopWithOutputCalculator" -// input_stream: "ITEM:output_of_loop_body" # ItemU @loop_internal_ts -// input_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts -// output_stream: "ITERABLE:aggregated_result" # IterableU @ext_ts +// calculator: "EndLoopOutputTCalculator" +// input_stream: "ITEM:output_iterator" # OutputT @loop_internal_ts +// input_stream: "BATCH_END:iterable_ts" # Timestamp @loop_internal_ts +// output_stream: "ITERABLE:output_iterable" # OutputIterT @iterable_ts // } // +// The resulting 'output_iterable' has the same timestamp as 'input_iterable'. +// The output packets of this calculator are part of the loop body and have +// loop-internal timestamps that are unrelated to the input iterator timestamp. +// // Input streams tagged with "CLONE" are cloned to the corresponding output -// streams at loop timestamps. This ensures that a MediaPipe graph or sub-graph -// can run multiple times, once per element in the "ITERABLE" for each pakcet -// clone of the packets in the "CLONE" input streams. +// streams at loop-internal timestamps. This ensures that a MediaPipe graph or +// sub-graph can run multiple times, once per element in the "ITERABLE" for each +// packet clone of the packets in the "CLONE" input streams. Think of CLONEd +// inputs as loop-wide constants. template class BeginLoopCalculator : public CalculatorBase { using ItemT = typename IterableT::value_type; diff --git a/mediapipe/calculators/core/concatenate_vector_calculator.cc b/mediapipe/calculators/core/concatenate_vector_calculator.cc index 0079aa98d..53b3debf1 100644 --- a/mediapipe/calculators/core/concatenate_vector_calculator.cc +++ b/mediapipe/calculators/core/concatenate_vector_calculator.cc @@ -17,6 +17,7 @@ #include #include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/port/integral_types.h" @@ -55,6 +56,10 @@ MEDIAPIPE_REGISTER_NODE(ConcatenateUInt64VectorCalculator); typedef ConcatenateVectorCalculator ConcatenateBoolVectorCalculator; MEDIAPIPE_REGISTER_NODE(ConcatenateBoolVectorCalculator); +typedef ConcatenateVectorCalculator + ConcatenateStringVectorCalculator; +MEDIAPIPE_REGISTER_NODE(ConcatenateStringVectorCalculator); + // Example config: // node { // calculator: "ConcatenateTfLiteTensorVectorCalculator" @@ -100,4 +105,7 @@ typedef ConcatenateVectorCalculator ConcatenateRenderDataVectorCalculator; MEDIAPIPE_REGISTER_NODE(ConcatenateRenderDataVectorCalculator); +typedef ConcatenateVectorCalculator + ConcatenateImageVectorCalculator; +MEDIAPIPE_REGISTER_NODE(ConcatenateImageVectorCalculator); } // namespace mediapipe diff --git a/mediapipe/calculators/core/concatenate_vector_calculator_test.cc b/mediapipe/calculators/core/concatenate_vector_calculator_test.cc index 5510b98a3..3fccf58fd 100644 --- a/mediapipe/calculators/core/concatenate_vector_calculator_test.cc +++ b/mediapipe/calculators/core/concatenate_vector_calculator_test.cc @@ -30,13 +30,15 @@ namespace mediapipe { typedef ConcatenateVectorCalculator TestConcatenateIntVectorCalculator; MEDIAPIPE_REGISTER_NODE(TestConcatenateIntVectorCalculator); -void AddInputVector(int index, const std::vector& input, int64_t timestamp, +template +void AddInputVector(int index, const std::vector& input, int64_t timestamp, CalculatorRunner* runner) { runner->MutableInputs()->Index(index).packets.push_back( - MakePacket>(input).At(Timestamp(timestamp))); + MakePacket>(input).At(Timestamp(timestamp))); } -void AddInputVectors(const std::vector>& inputs, +template +void AddInputVectors(const std::vector>& inputs, int64_t timestamp, CalculatorRunner* runner) { for (int i = 0; i < inputs.size(); ++i) { AddInputVector(i, inputs[i], timestamp, runner); @@ -382,6 +384,23 @@ TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamNoOutput) { EXPECT_EQ(0, outputs.size()); } +TEST(ConcatenateStringVectorCalculatorTest, OneTimestamp) { + CalculatorRunner runner("ConcatenateStringVectorCalculator", + /*options_string=*/"", /*num_inputs=*/3, + /*num_outputs=*/1, /*num_side_packets=*/0); + + std::vector> inputs = { + {"a", "b"}, {"c"}, {"d", "e", "f"}}; + AddInputVectors(inputs, /*timestamp=*/1, &runner); + MP_ASSERT_OK(runner.Run()); + + const std::vector& outputs = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(Timestamp(1), outputs[0].Timestamp()); + std::vector expected_vector = {"a", "b", "c", "d", "e", "f"}; + EXPECT_EQ(expected_vector, outputs[0].Get>()); +} + typedef ConcatenateVectorCalculator> TestConcatenateUniqueIntPtrCalculator; MEDIAPIPE_REGISTER_NODE(TestConcatenateUniqueIntPtrCalculator); diff --git a/mediapipe/calculators/core/constant_side_packet_calculator.cc b/mediapipe/calculators/core/constant_side_packet_calculator.cc index 509f7e9dd..8762c9874 100644 --- a/mediapipe/calculators/core/constant_side_packet_calculator.cc +++ b/mediapipe/calculators/core/constant_side_packet_calculator.cc @@ -19,6 +19,7 @@ #include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/matrix_data.pb.h" #include "mediapipe/framework/formats/time_series_header.pb.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/integral_types.h" @@ -85,8 +86,12 @@ class ConstantSidePacketCalculator : public CalculatorBase { packet.Set(); } else if (packet_options.has_double_value()) { packet.Set(); + } else if (packet_options.has_matrix_data_value()) { + packet.Set(); } else if (packet_options.has_time_series_header_value()) { packet.Set(); + } else if (packet_options.has_int64_value()) { + packet.Set(); } else { return absl::InvalidArgumentError( "None of supported values were specified in options."); @@ -121,9 +126,13 @@ class ConstantSidePacketCalculator : public CalculatorBase { MakePacket(packet_options.landmark_list_value())); } else if (packet_options.has_double_value()) { packet.Set(MakePacket(packet_options.double_value())); + } else if (packet_options.has_matrix_data_value()) { + packet.Set(MakePacket(packet_options.matrix_data_value())); } else if (packet_options.has_time_series_header_value()) { packet.Set(MakePacket( packet_options.time_series_header_value())); + } else if (packet_options.has_int64_value()) { + packet.Set(MakePacket(packet_options.int64_value())); } else { return absl::InvalidArgumentError( "None of supported values were specified in options."); diff --git a/mediapipe/calculators/core/constant_side_packet_calculator.proto b/mediapipe/calculators/core/constant_side_packet_calculator.proto index 78a773a6c..0d53175fc 100644 --- a/mediapipe/calculators/core/constant_side_packet_calculator.proto +++ b/mediapipe/calculators/core/constant_side_packet_calculator.proto @@ -19,6 +19,7 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/formats/classification.proto"; import "mediapipe/framework/formats/landmark.proto"; +import "mediapipe/framework/formats/matrix_data.proto"; import "mediapipe/framework/formats/time_series_header.proto"; message ConstantSidePacketCalculatorOptions { @@ -29,14 +30,16 @@ message ConstantSidePacketCalculatorOptions { message ConstantSidePacket { oneof value { int32 int_value = 1; + uint64 uint64_value = 5; + int64 int64_value = 11; float float_value = 2; + double double_value = 9; bool bool_value = 3; string string_value = 4; - uint64 uint64_value = 5; ClassificationList classification_list_value = 6; LandmarkList landmark_list_value = 7; - double double_value = 9; TimeSeriesHeader time_series_header_value = 10; + MatrixData matrix_data_value = 12; } } diff --git a/mediapipe/calculators/core/constant_side_packet_calculator_test.cc b/mediapipe/calculators/core/constant_side_packet_calculator_test.cc index a7ff808f4..6e8c0ec33 100644 --- a/mediapipe/calculators/core/constant_side_packet_calculator_test.cc +++ b/mediapipe/calculators/core/constant_side_packet_calculator_test.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include "absl/strings/string_view.h" @@ -58,6 +59,7 @@ TEST(ConstantSidePacketCalculatorTest, EveryPossibleType) { DoTestSingleSidePacket("{ float_value: 6.5f }", 6.5f); DoTestSingleSidePacket("{ bool_value: true }", true); DoTestSingleSidePacket(R"({ string_value: "str" })", "str"); + DoTestSingleSidePacket("{ int64_value: 63 }", 63); } TEST(ConstantSidePacketCalculatorTest, MultiplePackets) { diff --git a/mediapipe/calculators/core/end_loop_calculator.cc b/mediapipe/calculators/core/end_loop_calculator.cc index b3b889ecd..94f7ee22e 100644 --- a/mediapipe/calculators/core/end_loop_calculator.cc +++ b/mediapipe/calculators/core/end_loop_calculator.cc @@ -14,15 +14,19 @@ #include "mediapipe/calculators/core/end_loop_calculator.h" +#include +#include #include #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/util/render_data.pb.h" #include "tensorflow/lite/interpreter.h" @@ -68,8 +72,22 @@ REGISTER_CALCULATOR(EndLoopMatrixCalculator); typedef EndLoopCalculator> EndLoopTensorCalculator; REGISTER_CALCULATOR(EndLoopTensorCalculator); +typedef EndLoopCalculator> EndLoopImageFrameCalculator; +REGISTER_CALCULATOR(EndLoopImageFrameCalculator); + +typedef EndLoopCalculator> EndLoopGpuBufferCalculator; +REGISTER_CALCULATOR(EndLoopGpuBufferCalculator); + typedef EndLoopCalculator> EndLoopImageCalculator; REGISTER_CALCULATOR(EndLoopImageCalculator); +typedef EndLoopCalculator>> + EndLoopAffineMatrixCalculator; +REGISTER_CALCULATOR(EndLoopAffineMatrixCalculator); + +typedef EndLoopCalculator>> + EndLoopImageSizeCalculator; +REGISTER_CALCULATOR(EndLoopImageSizeCalculator); + } // namespace mediapipe diff --git a/mediapipe/calculators/core/end_loop_calculator.h b/mediapipe/calculators/core/end_loop_calculator.h index 2598194e6..1e258f046 100644 --- a/mediapipe/calculators/core/end_loop_calculator.h +++ b/mediapipe/calculators/core/end_loop_calculator.h @@ -17,13 +17,11 @@ #include +#include "absl/status/status.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_framework.h" -#include "mediapipe/framework/collection_item_id.h" -#include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/framework/port/status.h" namespace mediapipe { @@ -33,27 +31,7 @@ namespace mediapipe { // from the "BATCH_END" tagged input stream, it emits the aggregated results // at the original timestamp contained in the "BATCH_END" input stream. // -// It is designed to be used like: -// -// node { -// calculator: "BeginLoopWithIterableCalculator" -// input_stream: "ITERABLE:input_iterable" # IterableT @ext_ts -// output_stream: "ITEM:input_element" # ItemT @loop_internal_ts -// output_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts -// } -// -// node { -// calculator: "ElementToBlaConverterSubgraph" -// input_stream: "ITEM:input_to_loop_body" # ItemT @loop_internal_ts -// output_stream: "BLA:output_of_loop_body" # ItemU @loop_internal_ts -// } -// -// node { -// calculator: "EndLoopWithOutputCalculator" -// input_stream: "ITEM:output_of_loop_body" # ItemU @loop_internal_ts -// input_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts -// output_stream: "ITERABLE:aggregated_result" # IterableU @ext_ts -// } +// See BeginLoopCalculator for a usage example. template class EndLoopCalculator : public CalculatorBase { using ItemT = typename IterableT::value_type; @@ -79,7 +57,7 @@ class EndLoopCalculator : public CalculatorBase { } // Try to consume the item and move it into the collection. If the items // are not consumable, then try to copy them instead. If the items are - // not copiable, then an error will be returned. + // not copyable, then an error will be returned. auto item_ptr_or = cc->Inputs().Tag("ITEM").Value().Consume(); if (item_ptr_or.ok()) { input_stream_collection_->push_back(std::move(*item_ptr_or.value())); diff --git a/mediapipe/calculators/core/flow_limiter_calculator.cc b/mediapipe/calculators/core/flow_limiter_calculator.cc index 5b08f3af5..46e5bf6a3 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator.cc @@ -42,7 +42,7 @@ constexpr char kOptionsTag[] = "OPTIONS"; // // Increasing `max_in_flight` to 2 or more can yield the better throughput // when the graph exhibits a high degree of pipeline parallelism. Decreasing -// `max_in_flight` to 0 can yield a better average latency, but at the cost of +// `max_in_queue` to 0 can yield a better average latency, but at the cost of // lower throughput (lower framerate) due to the time during which the graph // is idle awaiting the next input frame. // diff --git a/mediapipe/calculators/core/gate_calculator.cc b/mediapipe/calculators/core/gate_calculator.cc index 448329b88..e5e87b69b 100644 --- a/mediapipe/calculators/core/gate_calculator.cc +++ b/mediapipe/calculators/core/gate_calculator.cc @@ -26,19 +26,15 @@ constexpr char kStateChangeTag[] = "STATE_CHANGE"; constexpr char kDisallowTag[] = "DISALLOW"; constexpr char kAllowTag[] = "ALLOW"; -enum GateState { - GATE_UNINITIALIZED, - GATE_ALLOW, - GATE_DISALLOW, -}; - -std::string ToString(GateState state) { +std::string ToString(GateCalculatorOptions::GateState state) { switch (state) { - case GATE_UNINITIALIZED: + case GateCalculatorOptions::UNSPECIFIED: + return "UNSPECIFIED"; + case GateCalculatorOptions::GATE_UNINITIALIZED: return "UNINITIALIZED"; - case GATE_ALLOW: + case GateCalculatorOptions::GATE_ALLOW: return "ALLOW"; - case GATE_DISALLOW: + case GateCalculatorOptions::GATE_DISALLOW: return "DISALLOW"; } DLOG(FATAL) << "Unknown GateState"; @@ -153,10 +149,12 @@ class GateCalculator : public CalculatorBase { cc->SetOffset(TimestampDiff(0)); num_data_streams_ = cc->Inputs().NumEntries(""); - last_gate_state_ = GATE_UNINITIALIZED; - RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &cc->Outputs())); const auto& options = cc->Options<::mediapipe::GateCalculatorOptions>(); + last_gate_state_ = options.initial_gate_state(); + + RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &cc->Outputs())); + empty_packets_as_allow_ = options.empty_packets_as_allow(); if (!use_side_packet_for_allow_disallow_ && @@ -184,10 +182,12 @@ class GateCalculator : public CalculatorBase { allow = !cc->Inputs().Tag(kDisallowTag).Get(); } } - const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW; + const GateCalculatorOptions::GateState new_gate_state = + allow ? GateCalculatorOptions::GATE_ALLOW + : GateCalculatorOptions::GATE_DISALLOW; if (cc->Outputs().HasTag(kStateChangeTag)) { - if (last_gate_state_ != GATE_UNINITIALIZED && + if (last_gate_state_ != GateCalculatorOptions::GATE_UNINITIALIZED && last_gate_state_ != new_gate_state) { VLOG(2) << "State transition in " << cc->NodeName() << " @ " << cc->InputTimestamp().Value() << " from " @@ -223,7 +223,8 @@ class GateCalculator : public CalculatorBase { } private: - GateState last_gate_state_ = GATE_UNINITIALIZED; + GateCalculatorOptions::GateState last_gate_state_ = + GateCalculatorOptions::GATE_UNINITIALIZED; int num_data_streams_; bool empty_packets_as_allow_; bool use_side_packet_for_allow_disallow_ = false; diff --git a/mediapipe/calculators/core/gate_calculator.proto b/mediapipe/calculators/core/gate_calculator.proto index b7d597a63..4153d5f32 100644 --- a/mediapipe/calculators/core/gate_calculator.proto +++ b/mediapipe/calculators/core/gate_calculator.proto @@ -31,4 +31,13 @@ message GateCalculatorOptions { // Whether to allow or disallow the input streams to pass when no // ALLOW/DISALLOW input or side input is specified. optional bool allow = 2 [default = false]; + + enum GateState { + UNSPECIFIED = 0; + GATE_UNINITIALIZED = 1; + GATE_ALLOW = 2; + GATE_DISALLOW = 3; + } + + optional GateState initial_gate_state = 3 [default = GATE_UNINITIALIZED]; } diff --git a/mediapipe/calculators/core/gate_calculator_test.cc b/mediapipe/calculators/core/gate_calculator_test.cc index 192019820..0c49f1449 100644 --- a/mediapipe/calculators/core/gate_calculator_test.cc +++ b/mediapipe/calculators/core/gate_calculator_test.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/log/absl_log.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/port/gtest.h" @@ -356,18 +357,18 @@ TEST_F(GateCalculatorTest, AllowWithStateChangeNoDataStreams) { RunTimeStepWithoutDataStream(kTimestampValue2, "ALLOW", true); constexpr int64_t kTimestampValue3 = 45; RunTimeStepWithoutDataStream(kTimestampValue3, "ALLOW", false); - LOG(INFO) << "a"; + ABSL_LOG(INFO) << "a"; const std::vector& output = runner()->Outputs().Get("STATE_CHANGE", 0).packets; - LOG(INFO) << "s"; + ABSL_LOG(INFO) << "s"; ASSERT_EQ(2, output.size()); - LOG(INFO) << "d"; + ABSL_LOG(INFO) << "d"; EXPECT_EQ(kTimestampValue1, output[0].Timestamp().Value()); EXPECT_EQ(kTimestampValue3, output[1].Timestamp().Value()); - LOG(INFO) << "f"; + ABSL_LOG(INFO) << "f"; EXPECT_EQ(true, output[0].Get()); // Allow. EXPECT_EQ(false, output[1].Get()); // Disallow. - LOG(INFO) << "g"; + ABSL_LOG(INFO) << "g"; } TEST_F(GateCalculatorTest, DisallowWithStateChange) { @@ -458,5 +459,29 @@ TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) { ASSERT_EQ(0, output.size()); } +// Must detect allow value for first timestamp as a state change when the +// initial state is set to GATE_DISALLOW. +TEST_F(GateCalculatorTest, StateChangeTriggeredWithInitialGateStateOption) { + SetRunner(R"( + calculator: "GateCalculator" + input_stream: "test_input" + input_stream: "ALLOW:allow" + output_stream: "test_output" + output_stream: "STATE_CHANGE:state_change" + options: { + [mediapipe.GateCalculatorOptions.ext] { + initial_gate_state: GATE_DISALLOW + } + } + )"); + + constexpr int64_t kTimestampValue0 = 42; + RunTimeStep(kTimestampValue0, "ALLOW", true); + + const std::vector& output = + runner()->Outputs().Get("STATE_CHANGE", 0).packets; + ASSERT_EQ(1, output.size()); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/core/get_vector_item_calculator.cc b/mediapipe/calculators/core/get_vector_item_calculator.cc index 3306e4ff3..5dbda6d99 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.cc +++ b/mediapipe/calculators/core/get_vector_item_calculator.cc @@ -17,6 +17,7 @@ #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" namespace mediapipe { namespace api2 { @@ -37,5 +38,12 @@ using GetDetectionVectorItemCalculator = GetVectorItemCalculator; REGISTER_CALCULATOR(GetDetectionVectorItemCalculator); +using GetNormalizedRectVectorItemCalculator = + GetVectorItemCalculator; +REGISTER_CALCULATOR(GetNormalizedRectVectorItemCalculator); + +using GetRectVectorItemCalculator = GetVectorItemCalculator; +REGISTER_CALCULATOR(GetRectVectorItemCalculator); + } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/immediate_mux_calculator.cc b/mediapipe/calculators/core/immediate_mux_calculator.cc index 0e51cda5e..05de05e40 100644 --- a/mediapipe/calculators/core/immediate_mux_calculator.cc +++ b/mediapipe/calculators/core/immediate_mux_calculator.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/log/absl_log.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" @@ -78,7 +79,7 @@ absl::Status ImmediateMuxCalculator::Process(CalculatorContext* cc) { if (packet.Timestamp() >= cc->Outputs().Index(0).NextTimestampBound()) { cc->Outputs().Index(0).AddPacket(packet); } else { - LOG_FIRST_N(WARNING, 5) + ABSL_LOG_FIRST_N(WARNING, 5) << "Dropping a packet with timestamp " << packet.Timestamp(); } if (cc->Outputs().NumEntries() >= 2) { diff --git a/mediapipe/calculators/core/matrix_multiply_calculator_test.cc b/mediapipe/calculators/core/matrix_multiply_calculator_test.cc index e62ca8073..60976577a 100644 --- a/mediapipe/calculators/core/matrix_multiply_calculator_test.cc +++ b/mediapipe/calculators/core/matrix_multiply_calculator_test.cc @@ -16,6 +16,7 @@ #include #include "Eigen/Core" +#include "absl/log/absl_check.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/formats/matrix.h" @@ -209,7 +210,7 @@ TEST(MatrixMultiplyCalculatorTest, Multiply) { MatrixFromTextProto(kSamplesText, &samples); Matrix expected; MatrixFromTextProto(kExpectedText, &expected); - CHECK_EQ(samples.cols(), expected.cols()); + ABSL_CHECK_EQ(samples.cols(), expected.cols()); for (int i = 0; i < samples.cols(); ++i) { // Take a column from samples and produce a packet with just that diff --git a/mediapipe/calculators/core/matrix_to_vector_calculator_test.cc b/mediapipe/calculators/core/matrix_to_vector_calculator_test.cc index 1f994cbed..8b4254cbc 100644 --- a/mediapipe/calculators/core/matrix_to_vector_calculator_test.cc +++ b/mediapipe/calculators/core/matrix_to_vector_calculator_test.cc @@ -35,7 +35,7 @@ class MatrixToVectorCalculatorTest void SetUp() override { calculator_name_ = "MatrixToVectorCalculator"; } void AppendInput(const std::vector& column_major_data, - int64 timestamp) { + int64_t timestamp) { ASSERT_EQ(num_input_samples_ * num_input_channels_, column_major_data.size()); Eigen::Map data_map(&column_major_data[0], diff --git a/mediapipe/calculators/core/merge_calculator.cc b/mediapipe/calculators/core/merge_calculator.cc index a283842ae..43fc3b878 100644 --- a/mediapipe/calculators/core/merge_calculator.cc +++ b/mediapipe/calculators/core/merge_calculator.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/log/absl_log.h" #include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" @@ -53,7 +54,7 @@ class MergeCalculator : public Node { static absl::Status UpdateContract(CalculatorContract* cc) { RET_CHECK_GT(kIn(cc).Count(), 0) << "Needs at least one input stream"; if (kIn(cc).Count() == 1) { - LOG(WARNING) + ABSL_LOG(WARNING) << "MergeCalculator expects multiple input streams to merge but is " "receiving only one. Make sure the calculator is configured " "correctly or consider removing this calculator to reduce " @@ -72,8 +73,8 @@ class MergeCalculator : public Node { } } - LOG(WARNING) << "Empty input packets at timestamp " - << cc->InputTimestamp().Value(); + ABSL_LOG(WARNING) << "Empty input packets at timestamp " + << cc->InputTimestamp().Value(); return absl::OkStatus(); } diff --git a/mediapipe/calculators/core/merge_to_vector_calculator.cc b/mediapipe/calculators/core/merge_to_vector_calculator.cc index fd053ed2b..4bb3c8a40 100644 --- a/mediapipe/calculators/core/merge_to_vector_calculator.cc +++ b/mediapipe/calculators/core/merge_to_vector_calculator.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/calculators/core/merge_to_vector_calculator.h b/mediapipe/calculators/core/merge_to_vector_calculator.h index b4f7a37c2..4ec674c05 100644 --- a/mediapipe/calculators/core/merge_to_vector_calculator.h +++ b/mediapipe/calculators/core/merge_to_vector_calculator.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/calculators/core/packet_resampler_calculator.cc b/mediapipe/calculators/core/packet_resampler_calculator.cc index 60ec40537..81a68f03f 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator.cc +++ b/mediapipe/calculators/core/packet_resampler_calculator.cc @@ -16,6 +16,9 @@ #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" + namespace { // Reflect an integer against the lower and upper bound of an interval. int64_t ReflectBetween(int64_t ts, int64_t ts_min, int64_t ts_max) { @@ -177,7 +180,7 @@ PacketResamplerCalculator::GetSamplingStrategy( const PacketResamplerCalculatorOptions& options) { if (options.reproducible_sampling()) { if (!options.jitter_with_reflection()) { - LOG(WARNING) + ABSL_LOG(WARNING) << "reproducible_sampling enabled w/ jitter_with_reflection " "disabled. " << "reproducible_sampling always uses jitter with reflection, " @@ -200,15 +203,15 @@ PacketResamplerCalculator::GetSamplingStrategy( Timestamp PacketResamplerCalculator::PeriodIndexToTimestamp( int64_t index) const { - CHECK_EQ(jitter_, 0.0); - CHECK_NE(first_timestamp_, Timestamp::Unset()); + ABSL_CHECK_EQ(jitter_, 0.0); + ABSL_CHECK_NE(first_timestamp_, Timestamp::Unset()); return first_timestamp_ + TimestampDiffFromSeconds(index / frame_rate_); } int64_t PacketResamplerCalculator::TimestampToPeriodIndex( Timestamp timestamp) const { - CHECK_EQ(jitter_, 0.0); - CHECK_NE(first_timestamp_, Timestamp::Unset()); + ABSL_CHECK_EQ(jitter_, 0.0); + ABSL_CHECK_NE(first_timestamp_, Timestamp::Unset()); return MathUtil::SafeRound( (timestamp - first_timestamp_).Seconds() * frame_rate_); } @@ -229,13 +232,15 @@ absl::Status LegacyJitterWithReflectionStrategy::Open(CalculatorContext* cc) { if (resampler_options.output_header() != PacketResamplerCalculatorOptions::NONE) { - LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not " - "the actual value."; + ABSL_LOG(WARNING) + << "VideoHeader::frame_rate holds the target value and not " + "the actual value."; } if (calculator_->flush_last_packet_) { - LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is " - "ignored, because we are adding jitter."; + ABSL_LOG(WARNING) + << "PacketResamplerCalculatorOptions.flush_last_packet is " + "ignored, because we are adding jitter."; } const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get(); @@ -254,7 +259,7 @@ absl::Status LegacyJitterWithReflectionStrategy::Open(CalculatorContext* cc) { } absl::Status LegacyJitterWithReflectionStrategy::Close(CalculatorContext* cc) { if (!packet_reservoir_->IsEmpty()) { - LOG(INFO) << "Emitting pack from reservoir."; + ABSL_LOG(INFO) << "Emitting pack from reservoir."; calculator_->OutputWithinLimits(cc, packet_reservoir_->GetSample()); } return absl::OkStatus(); @@ -285,7 +290,7 @@ absl::Status LegacyJitterWithReflectionStrategy::Process( if (calculator_->frame_time_usec_ < (cc->InputTimestamp() - calculator_->last_packet_.Timestamp()).Value()) { - LOG_FIRST_N(WARNING, 2) + ABSL_LOG_FIRST_N(WARNING, 2) << "Adding jitter is not very useful when upsampling."; } @@ -340,8 +345,8 @@ void LegacyJitterWithReflectionStrategy::UpdateNextOutputTimestampWithJitter() { next_output_timestamp_ = Timestamp(ReflectBetween( next_output_timestamp_.Value(), next_output_timestamp_min_.Value(), next_output_timestamp_max_.Value())); - CHECK_GE(next_output_timestamp_, next_output_timestamp_min_); - CHECK_LT(next_output_timestamp_, next_output_timestamp_max_); + ABSL_CHECK_GE(next_output_timestamp_, next_output_timestamp_min_); + ABSL_CHECK_LT(next_output_timestamp_, next_output_timestamp_max_); } absl::Status ReproducibleJitterWithReflectionStrategy::Open( @@ -352,13 +357,15 @@ absl::Status ReproducibleJitterWithReflectionStrategy::Open( if (resampler_options.output_header() != PacketResamplerCalculatorOptions::NONE) { - LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not " - "the actual value."; + ABSL_LOG(WARNING) + << "VideoHeader::frame_rate holds the target value and not " + "the actual value."; } if (calculator_->flush_last_packet_) { - LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is " - "ignored, because we are adding jitter."; + ABSL_LOG(WARNING) + << "PacketResamplerCalculatorOptions.flush_last_packet is " + "ignored, because we are adding jitter."; } const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get(); @@ -411,7 +418,7 @@ absl::Status ReproducibleJitterWithReflectionStrategy::Process( // Note, if the stream is upsampling, this could lead to the same packet // being emitted twice. Upsampling and jitter doesn't make much sense // but does technically work. - LOG_FIRST_N(WARNING, 2) + ABSL_LOG_FIRST_N(WARNING, 2) << "Adding jitter is not very useful when upsampling."; } @@ -499,13 +506,15 @@ absl::Status JitterWithoutReflectionStrategy::Open(CalculatorContext* cc) { if (resampler_options.output_header() != PacketResamplerCalculatorOptions::NONE) { - LOG(WARNING) << "VideoHeader::frame_rate holds the target value and not " - "the actual value."; + ABSL_LOG(WARNING) + << "VideoHeader::frame_rate holds the target value and not " + "the actual value."; } if (calculator_->flush_last_packet_) { - LOG(WARNING) << "PacketResamplerCalculatorOptions.flush_last_packet is " - "ignored, because we are adding jitter."; + ABSL_LOG(WARNING) + << "PacketResamplerCalculatorOptions.flush_last_packet is " + "ignored, because we are adding jitter."; } const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get(); @@ -555,7 +564,7 @@ absl::Status JitterWithoutReflectionStrategy::Process(CalculatorContext* cc) { if (calculator_->frame_time_usec_ < (cc->InputTimestamp() - calculator_->last_packet_.Timestamp()).Value()) { - LOG_FIRST_N(WARNING, 2) + ABSL_LOG_FIRST_N(WARNING, 2) << "Adding jitter is not very useful when upsampling."; } diff --git a/mediapipe/calculators/core/packet_resampler_calculator.h b/mediapipe/calculators/core/packet_resampler_calculator.h index fbecdb0e7..f26dc2ca4 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator.h +++ b/mediapipe/calculators/core/packet_resampler_calculator.h @@ -13,7 +13,6 @@ #include "mediapipe/framework/deps/random_base.h" #include "mediapipe/framework/formats/video_stream_header.h" #include "mediapipe/framework/port/integral_types.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_macros.h" diff --git a/mediapipe/calculators/core/packet_resampler_calculator_test.cc b/mediapipe/calculators/core/packet_resampler_calculator_test.cc index f02da0d18..d80793da4 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator_test.cc +++ b/mediapipe/calculators/core/packet_resampler_calculator_test.cc @@ -51,9 +51,9 @@ class SimpleRunner : public CalculatorRunner { virtual ~SimpleRunner() {} - void SetInput(const std::vector& timestamp_list) { + void SetInput(const std::vector& timestamp_list) { MutableInputs()->Index(0).packets.clear(); - for (const int64 ts : timestamp_list) { + for (const int64_t ts : timestamp_list) { MutableInputs()->Index(0).packets.push_back( Adopt(new std::string(absl::StrCat("Frame #", ts))) .At(Timestamp(ts))); @@ -72,8 +72,8 @@ class SimpleRunner : public CalculatorRunner { } void CheckOutputTimestamps( - const std::vector& expected_frames, - const std::vector& expected_timestamps) const { + const std::vector& expected_frames, + const std::vector& expected_timestamps) const { EXPECT_EQ(expected_frames.size(), Outputs().Index(0).packets.size()); EXPECT_EQ(expected_timestamps.size(), Outputs().Index(0).packets.size()); int count = 0; @@ -112,7 +112,7 @@ MATCHER_P2(PacketAtTimestamp, payload, timestamp, *result_listener << "at incorrect timestamp = " << arg.Timestamp().Value(); return false; } - int64 actual_payload = arg.template Get(); + int64_t actual_payload = arg.template Get(); if (actual_payload != payload) { *result_listener << "with incorrect payload = " << actual_payload; return false; @@ -137,18 +137,18 @@ class ReproducibleJitterWithReflectionStrategyForTesting // // An EXPECT will fail if sequence is less than the number requested during // processing. - static std::vector random_sequence; + static std::vector random_sequence; protected: - virtual uint64 GetNextRandom(uint64 n) { + virtual uint64_t GetNextRandom(uint64_t n) { EXPECT_LT(sequence_index_, random_sequence.size()); return random_sequence[sequence_index_++] % n; } private: - int32 sequence_index_ = 0; + int32_t sequence_index_ = 0; }; -std::vector +std::vector ReproducibleJitterWithReflectionStrategyForTesting::random_sequence; // PacketResamplerCalculator child class which injects a specified stream @@ -469,7 +469,7 @@ TEST(PacketResamplerCalculatorTest, SetVideoHeader) { } )pb")); - for (const int64 ts : {0, 5000, 10010, 15001, 19990}) { + for (const int64_t ts : {0, 5000, 10010, 15001, 19990}) { runner.MutableInputs()->Tag(kDataTag).packets.push_back( Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts))); } diff --git a/mediapipe/calculators/core/packet_thinner_calculator.cc b/mediapipe/calculators/core/packet_thinner_calculator.cc index 35cd966ea..0bc5cc16d 100644 --- a/mediapipe/calculators/core/packet_thinner_calculator.cc +++ b/mediapipe/calculators/core/packet_thinner_calculator.cc @@ -17,6 +17,7 @@ #include // for ceil #include +#include "absl/log/absl_check.h" #include "mediapipe/calculators/core/packet_thinner_calculator.pb.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_framework.h" @@ -160,8 +161,8 @@ absl::Status PacketThinnerCalculator::Open(CalculatorContext* cc) { thinner_type_ = options.thinner_type(); // This check enables us to assume only two thinner types exist in Process() - CHECK(thinner_type_ == PacketThinnerCalculatorOptions::ASYNC || - thinner_type_ == PacketThinnerCalculatorOptions::SYNC) + ABSL_CHECK(thinner_type_ == PacketThinnerCalculatorOptions::ASYNC || + thinner_type_ == PacketThinnerCalculatorOptions::SYNC) << "Unsupported thinner type."; if (thinner_type_ == PacketThinnerCalculatorOptions::ASYNC) { @@ -177,7 +178,8 @@ absl::Status PacketThinnerCalculator::Open(CalculatorContext* cc) { } else { period_ = TimestampDiff(options.period()); } - CHECK_LT(TimestampDiff(0), period_) << "Specified period must be positive."; + ABSL_CHECK_LT(TimestampDiff(0), period_) + << "Specified period must be positive."; if (options.has_start_time()) { start_time_ = Timestamp(options.start_time()); @@ -189,7 +191,7 @@ absl::Status PacketThinnerCalculator::Open(CalculatorContext* cc) { end_time_ = options.has_end_time() ? Timestamp(options.end_time()) : Timestamp::Max(); - CHECK_LT(start_time_, end_time_) + ABSL_CHECK_LT(start_time_, end_time_) << "Invalid PacketThinner: start_time must be earlier than end_time"; sync_output_timestamps_ = options.sync_output_timestamps(); @@ -232,7 +234,7 @@ absl::Status PacketThinnerCalculator::Close(CalculatorContext* cc) { // Emit any saved packets before quitting. if (!saved_packet_.IsEmpty()) { // Only sync thinner should have saved packets. - CHECK_EQ(PacketThinnerCalculatorOptions::SYNC, thinner_type_); + ABSL_CHECK_EQ(PacketThinnerCalculatorOptions::SYNC, thinner_type_); if (sync_output_timestamps_) { cc->Outputs().Index(0).AddPacket( saved_packet_.At(NearestSyncTimestamp(saved_packet_.Timestamp()))); @@ -269,7 +271,7 @@ absl::Status PacketThinnerCalculator::SyncThinnerProcess( const Timestamp saved_sync = NearestSyncTimestamp(saved); const Timestamp now = cc->InputTimestamp(); const Timestamp now_sync = NearestSyncTimestamp(now); - CHECK_LE(saved_sync, now_sync); + ABSL_CHECK_LE(saved_sync, now_sync); if (saved_sync == now_sync) { // Saved Packet is in same interval as current packet. // Replace saved packet with current if it is at least as @@ -295,7 +297,7 @@ absl::Status PacketThinnerCalculator::SyncThinnerProcess( } Timestamp PacketThinnerCalculator::NearestSyncTimestamp(Timestamp now) const { - CHECK_NE(start_time_, Timestamp::Unset()) + ABSL_CHECK_NE(start_time_, Timestamp::Unset()) << "Method only valid for sync thinner calculator."; // Computation is done using int64 arithmetic. No easy way to avoid @@ -303,12 +305,12 @@ Timestamp PacketThinnerCalculator::NearestSyncTimestamp(Timestamp now) const { const int64_t now64 = now.Value(); const int64_t start64 = start_time_.Value(); const int64_t period64 = period_.Value(); - CHECK_LE(0, period64); + ABSL_CHECK_LE(0, period64); // Round now64 to its closest interval (units of period64). int64_t sync64 = (now64 - start64 + period64 / 2) / period64 * period64 + start64; - CHECK_LE(abs(now64 - sync64), period64 / 2) + ABSL_CHECK_LE(abs(now64 - sync64), period64 / 2) << "start64: " << start64 << "; now64: " << now64 << "; sync64: " << sync64; diff --git a/mediapipe/calculators/core/packet_thinner_calculator_test.cc b/mediapipe/calculators/core/packet_thinner_calculator_test.cc index 09de0ca70..69c008395 100644 --- a/mediapipe/calculators/core/packet_thinner_calculator_test.cc +++ b/mediapipe/calculators/core/packet_thinner_calculator_test.cc @@ -16,6 +16,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "mediapipe/calculators/core/packet_thinner_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -70,7 +71,7 @@ class SimpleRunner : public CalculatorRunner { } double GetFrameRate() const { - CHECK(!Outputs().Index(0).header.IsEmpty()); + ABSL_CHECK(!Outputs().Index(0).header.IsEmpty()); return Outputs().Index(0).header.Get().frame_rate; } }; diff --git a/mediapipe/calculators/core/previous_loopback_calculator.cc b/mediapipe/calculators/core/previous_loopback_calculator.cc index d67e6c061..36ee0f2d7 100644 --- a/mediapipe/calculators/core/previous_loopback_calculator.cc +++ b/mediapipe/calculators/core/previous_loopback_calculator.cc @@ -123,7 +123,10 @@ class PreviousLoopbackCalculator : public Node { // However, LOOP packet is empty. kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1); } else { - kPrevLoop(cc).Send(loop_candidate.At(main_spec.timestamp)); + // Avoids sending leftovers to a stream that's already closed. + if (!kPrevLoop(cc).IsClosed()) { + kPrevLoop(cc).Send(loop_candidate.At(main_spec.timestamp)); + } } loop_packets_.pop_front(); main_packet_specs_.pop_front(); diff --git a/mediapipe/calculators/core/previous_loopback_calculator_test.cc b/mediapipe/calculators/core/previous_loopback_calculator_test.cc index 563417669..d8c358909 100644 --- a/mediapipe/calculators/core/previous_loopback_calculator_test.cc +++ b/mediapipe/calculators/core/previous_loopback_calculator_test.cc @@ -43,8 +43,8 @@ constexpr char kDisallowTag[] = "DISALLOW"; // Returns the timestamp values for a vector of Packets. // TODO: puth this kind of test util in a common place. -std::vector TimestampValues(const std::vector& packets) { - std::vector result; +std::vector TimestampValues(const std::vector& packets) { + std::vector result; for (const Packet& packet : packets) { result.push_back(packet.Timestamp().Value()); } @@ -371,7 +371,7 @@ TEST(PreviousLoopbackCalculator, EmptyLoopForever) { for (int main_ts = 0; main_ts < 50; ++main_ts) { send_packet("in", main_ts); MP_EXPECT_OK(graph_.WaitUntilIdle()); - std::vector ts_values = TimestampValues(outputs); + std::vector ts_values = TimestampValues(outputs); EXPECT_EQ(ts_values.size(), main_ts + 1); for (int j = 0; j < main_ts + 1; ++j) { EXPECT_EQ(ts_values[j], j); diff --git a/mediapipe/calculators/core/sequence_shift_calculator.cc b/mediapipe/calculators/core/sequence_shift_calculator.cc index 026048b79..5b2a73fd3 100644 --- a/mediapipe/calculators/core/sequence_shift_calculator.cc +++ b/mediapipe/calculators/core/sequence_shift_calculator.cc @@ -14,6 +14,7 @@ #include +#include "absl/log/absl_log.h" #include "mediapipe/calculators/core/sequence_shift_calculator.pb.h" #include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" @@ -101,7 +102,7 @@ void SequenceShiftCalculator::ProcessPositiveOffset(CalculatorContext* cc) { kOut(cc).Send(packet_cache_.front().At(cc->InputTimestamp())); packet_cache_.pop_front(); } else if (emit_empty_packets_before_first_packet_) { - LOG(FATAL) << "Not supported yet"; + ABSL_LOG(FATAL) << "Not supported yet"; } // Store current packet for later output. packet_cache_.push_back(kIn(cc).packet()); diff --git a/mediapipe/calculators/core/side_packet_to_stream_calculator.cc b/mediapipe/calculators/core/side_packet_to_stream_calculator.cc index ed89889df..311f7d815 100644 --- a/mediapipe/calculators/core/side_packet_to_stream_calculator.cc +++ b/mediapipe/calculators/core/side_packet_to_stream_calculator.cc @@ -121,7 +121,7 @@ absl::Status SidePacketToStreamCalculator::GetContract(CalculatorContract* cc) { if (cc->Outputs().HasTag(kTagAtTimestamp)) { RET_CHECK_EQ(num_entries + 1, cc->InputSidePackets().NumEntries()) << "For AT_TIMESTAMP tag, 2 input side packets are required."; - cc->InputSidePackets().Tag(kTagSideInputTimestamp).Set(); + cc->InputSidePackets().Tag(kTagSideInputTimestamp).Set(); } else { RET_CHECK_EQ(num_entries, cc->InputSidePackets().NumEntries()) << "Same number of input side packets and output streams is required."; @@ -178,8 +178,8 @@ absl::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) { .AddPacket(cc->InputSidePackets().Index(i).At(timestamp)); } } else if (cc->Outputs().HasTag(kTagAtTimestamp)) { - int64 timestamp = - cc->InputSidePackets().Tag(kTagSideInputTimestamp).Get(); + int64_t timestamp = + cc->InputSidePackets().Tag(kTagSideInputTimestamp).Get(); for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) { cc->Outputs() .Get(output_tag_, i) diff --git a/mediapipe/calculators/core/split_vector_calculator.cc b/mediapipe/calculators/core/split_vector_calculator.cc index b76722de9..67fc38ce9 100644 --- a/mediapipe/calculators/core/split_vector_calculator.cc +++ b/mediapipe/calculators/core/split_vector_calculator.cc @@ -18,6 +18,7 @@ #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/rect.pb.h" @@ -86,4 +87,12 @@ REGISTER_CALCULATOR(SplitUint64tVectorCalculator); typedef SplitVectorCalculator SplitFloatVectorCalculator; REGISTER_CALCULATOR(SplitFloatVectorCalculator); +typedef SplitVectorCalculator + SplitImageVectorCalculator; +REGISTER_CALCULATOR(SplitImageVectorCalculator); + +typedef SplitVectorCalculator, false> + SplitAffineMatrixVectorCalculator; +REGISTER_CALCULATOR(SplitAffineMatrixVectorCalculator); + } // namespace mediapipe diff --git a/mediapipe/calculators/core/stream_to_side_packet_calculator.cc b/mediapipe/calculators/core/stream_to_side_packet_calculator.cc index 9dc25142a..72e812255 100644 --- a/mediapipe/calculators/core/stream_to_side_packet_calculator.cc +++ b/mediapipe/calculators/core/stream_to_side_packet_calculator.cc @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/timestamp.h" namespace mediapipe { +namespace api2 { // A calculator that takes a packet of an input stream and converts it to an // output side packet. This calculator only works under the assumption that the @@ -28,21 +30,21 @@ namespace mediapipe { // input_stream: "stream" // output_side_packet: "side_packet" // } -class StreamToSidePacketCalculator : public mediapipe::CalculatorBase { +class StreamToSidePacketCalculator : public Node { public: - static absl::Status GetContract(mediapipe::CalculatorContract* cc) { - cc->Inputs().Index(0).SetAny(); - cc->OutputSidePackets().Index(0).SetAny(); - return absl::OkStatus(); - } + static constexpr Input::Optional kIn{""}; + static constexpr SideOutput> kOut{""}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); absl::Status Process(mediapipe::CalculatorContext* cc) override { - mediapipe::Packet& packet = cc->Inputs().Index(0).Value(); - cc->OutputSidePackets().Index(0).Set( - packet.At(mediapipe::Timestamp::Unset())); + kOut(cc).Set( + kIn(cc).packet().As().At(mediapipe::Timestamp::Unset())); return absl::OkStatus(); } }; -REGISTER_CALCULATOR(StreamToSidePacketCalculator); +MEDIAPIPE_REGISTER_NODE(StreamToSidePacketCalculator); + +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 6de54189f..259012d64 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -97,6 +97,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:source_location", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", ], alwayslink = 1, ) @@ -125,6 +126,7 @@ cc_library( "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", ], alwayslink = 1, ) @@ -135,7 +137,6 @@ cc_library( deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame_opencv", - "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:status", ], @@ -152,11 +153,11 @@ cc_library( "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:status", "//mediapipe/framework/port:vector", + "@com_google_absl//absl/log:absl_log", ] + select({ "//mediapipe/gpu:disable_gpu": [], "//conditions:default": [ @@ -203,6 +204,7 @@ cc_library( "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:status", "//mediapipe/framework/port:vector", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", ] + select({ "//mediapipe/gpu:disable_gpu": [], @@ -262,9 +264,12 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/gpu:scale_mode_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", ] + select({ "//mediapipe/gpu:disable_gpu": [], "//conditions:default": [ + "//mediapipe/gpu:gl_base_hdr", "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:gl_simple_shaders", @@ -274,6 +279,36 @@ cc_library( alwayslink = 1, ) +cc_test( + name = "image_transformation_calculator_test", + srcs = ["image_transformation_calculator_test.cc"], + data = ["//mediapipe/calculators/image/testdata:test_images"], + tags = [ + "desktop_only_test", + ], + deps = [ + ":image_transformation_calculator", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/gpu:gpu_buffer_to_image_frame_calculator", + "//mediapipe/gpu:image_frame_to_gpu_buffer_calculator", + "//third_party:opencv", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "image_cropping_calculator", srcs = ["image_cropping_calculator.cc"], @@ -301,6 +336,7 @@ cc_library( "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_log", ] + select({ "//mediapipe/gpu:disable_gpu": [], "//conditions:default": [ @@ -317,6 +353,7 @@ cc_library( cc_test( name = "image_cropping_calculator_test", srcs = ["image_cropping_calculator_test.cc"], + tags = ["not_run:arm"], deps = [ ":image_cropping_calculator", ":image_cropping_calculator_cc_proto", @@ -396,6 +433,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", ], ) @@ -420,6 +458,8 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/util:image_frame_util", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", "@libyuv", ], @@ -625,9 +665,9 @@ cc_library( "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_frame", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", "//mediapipe/framework/port:vector", + "@com_google_absl//absl/log:absl_log", ] + select({ "//mediapipe/gpu:disable_gpu": [], "//conditions:default": [ @@ -650,6 +690,7 @@ cc_library( cc_test( name = "segmentation_smoothing_calculator_test", srcs = ["segmentation_smoothing_calculator_test.cc"], + tags = ["not_run:arm"], deps = [ ":image_clone_calculator", ":image_clone_calculator_cc_proto", @@ -664,6 +705,7 @@ cc_test( "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/log:absl_log", ], ) @@ -686,6 +728,7 @@ cc_library( "//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/gpu:shader_util", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -771,7 +814,10 @@ cc_test( "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero_interp_cubic.png", "//mediapipe/calculators/tensor:testdata/image_to_tensor/noop_except_range.png", ], - tags = ["desktop_only_test"], + tags = [ + "desktop_only_test", + "not_run:arm", + ], deps = [ ":affine_transformation", ":image_transformation_calculator", diff --git a/mediapipe/calculators/image/affine_transformation_runner_gl.cc b/mediapipe/calculators/image/affine_transformation_runner_gl.cc index 006416916..0fe2d2744 100644 --- a/mediapipe/calculators/image/affine_transformation_runner_gl.cc +++ b/mediapipe/calculators/image/affine_transformation_runner_gl.cc @@ -20,6 +20,7 @@ #include "Eigen/Core" #include "Eigen/Geometry" #include "Eigen/LU" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -53,6 +54,10 @@ bool IsMatrixVerticalFlipNeeded(GpuOrigin::Mode gpu_origin) { #endif // __APPLE__ case GpuOrigin::TOP_LEFT: return false; + default: + ABSL_LOG(ERROR) << "Incorrect GpuOrigin: " + << static_cast(gpu_origin); + return true; } } @@ -384,6 +389,8 @@ class GlTextureWarpAffineRunner glActiveTexture(GL_TEXTURE0); glBindTexture(GL_TEXTURE_2D, 0); + glFlush(); + return absl::OkStatus(); } diff --git a/mediapipe/calculators/image/bilateral_filter_calculator.cc b/mediapipe/calculators/image/bilateral_filter_calculator.cc index 6bb43dc00..3d364ad93 100644 --- a/mediapipe/calculators/image/bilateral_filter_calculator.cc +++ b/mediapipe/calculators/image/bilateral_filter_calculator.cc @@ -15,6 +15,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/strings/str_replace.h" #include "mediapipe/calculators/image/bilateral_filter_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -112,7 +113,7 @@ class BilateralFilterCalculator : public CalculatorBase { REGISTER_CALCULATOR(BilateralFilterCalculator); absl::Status BilateralFilterCalculator::GetContract(CalculatorContract* cc) { - CHECK_GE(cc->Inputs().NumEntries(), 1); + RET_CHECK_GE(cc->Inputs().NumEntries(), 1); if (cc->Inputs().HasTag(kInputFrameTag) && cc->Inputs().HasTag(kInputFrameTagGpu)) { @@ -183,8 +184,8 @@ absl::Status BilateralFilterCalculator::Open(CalculatorContext* cc) { sigma_color_ = options_.sigma_color(); sigma_space_ = options_.sigma_space(); - CHECK_GE(sigma_color_, 0.0); - CHECK_GE(sigma_space_, 0.0); + ABSL_CHECK_GE(sigma_color_, 0.0); + ABSL_CHECK_GE(sigma_space_, 0.0); if (!use_gpu_) sigma_color_ *= 255.0; if (use_gpu_) { diff --git a/mediapipe/calculators/image/color_convert_calculator.cc b/mediapipe/calculators/image/color_convert_calculator.cc index 4781f1ea1..f8f018363 100644 --- a/mediapipe/calculators/image/color_convert_calculator.cc +++ b/mediapipe/calculators/image/color_convert_calculator.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/log/absl_check.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" @@ -25,8 +26,8 @@ namespace mediapipe { namespace { void SetColorChannel(int channel, uint8 value, cv::Mat* mat) { - CHECK(mat->depth() == CV_8U); - CHECK(channel < mat->channels()); + ABSL_CHECK(mat->depth() == CV_8U); + ABSL_CHECK(channel < mat->channels()); const int step = mat->channels(); for (int r = 0; r < mat->rows; ++r) { uint8* row_ptr = mat->ptr(r); diff --git a/mediapipe/calculators/image/image_cropping_calculator.cc b/mediapipe/calculators/image/image_cropping_calculator.cc index 6776da7c8..9eb3e6808 100644 --- a/mediapipe/calculators/image/image_cropping_calculator.cc +++ b/mediapipe/calculators/image/image_cropping_calculator.cc @@ -16,6 +16,7 @@ #include +#include "absl/log/absl_log.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/rect.pb.h" @@ -202,8 +203,9 @@ absl::Status ImageCroppingCalculator::ValidateBorderModeForGPU( switch (options.border_mode()) { case mediapipe::ImageCroppingCalculatorOptions::BORDER_ZERO: - LOG(WARNING) << "BORDER_ZERO mode is not supported by GPU " - << "implementation and will fall back into BORDER_REPLICATE"; + ABSL_LOG(WARNING) + << "BORDER_ZERO mode is not supported by GPU " + << "implementation and will fall back into BORDER_REPLICATE"; break; case mediapipe::ImageCroppingCalculatorOptions::BORDER_REPLICATE: break; diff --git a/mediapipe/calculators/image/image_transformation_calculator.cc b/mediapipe/calculators/image/image_transformation_calculator.cc index dbf8f7337..8c6f715a0 100644 --- a/mediapipe/calculators/image/image_transformation_calculator.cc +++ b/mediapipe/calculators/image/image_transformation_calculator.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/status/status.h" #include "mediapipe/calculators/image/image_transformation_calculator.pb.h" #include "mediapipe/calculators/image/rotation_mode.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -27,6 +28,7 @@ #include "mediapipe/gpu/scale_mode.pb.h" #if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_quad_renderer.h" #include "mediapipe/gpu/gl_simple_shaders.h" @@ -60,42 +62,42 @@ constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM"; int RotationModeToDegrees(mediapipe::RotationMode_Mode rotation) { switch (rotation) { - case mediapipe::RotationMode_Mode_UNKNOWN: - case mediapipe::RotationMode_Mode_ROTATION_0: + case mediapipe::RotationMode::UNKNOWN: + case mediapipe::RotationMode::ROTATION_0: return 0; - case mediapipe::RotationMode_Mode_ROTATION_90: + case mediapipe::RotationMode::ROTATION_90: return 90; - case mediapipe::RotationMode_Mode_ROTATION_180: + case mediapipe::RotationMode::ROTATION_180: return 180; - case mediapipe::RotationMode_Mode_ROTATION_270: + case mediapipe::RotationMode::ROTATION_270: return 270; } } mediapipe::RotationMode_Mode DegreesToRotationMode(int degrees) { switch (degrees) { case 0: - return mediapipe::RotationMode_Mode_ROTATION_0; + return mediapipe::RotationMode::ROTATION_0; case 90: - return mediapipe::RotationMode_Mode_ROTATION_90; + return mediapipe::RotationMode::ROTATION_90; case 180: - return mediapipe::RotationMode_Mode_ROTATION_180; + return mediapipe::RotationMode::ROTATION_180; case 270: - return mediapipe::RotationMode_Mode_ROTATION_270; + return mediapipe::RotationMode::ROTATION_270; default: - return mediapipe::RotationMode_Mode_UNKNOWN; + return mediapipe::RotationMode::UNKNOWN; } } mediapipe::ScaleMode_Mode ParseScaleMode( mediapipe::ScaleMode_Mode scale_mode, mediapipe::ScaleMode_Mode default_mode) { switch (scale_mode) { - case mediapipe::ScaleMode_Mode_DEFAULT: + case mediapipe::ScaleMode::DEFAULT: return default_mode; - case mediapipe::ScaleMode_Mode_STRETCH: + case mediapipe::ScaleMode::STRETCH: return scale_mode; - case mediapipe::ScaleMode_Mode_FIT: + case mediapipe::ScaleMode::FIT: return scale_mode; - case mediapipe::ScaleMode_Mode_FILL_AND_CROP: + case mediapipe::ScaleMode::FILL_AND_CROP: return scale_mode; default: return default_mode; @@ -208,6 +210,8 @@ class ImageTransformationCalculator : public CalculatorBase { bool use_gpu_ = false; cv::Scalar padding_color_; + ImageTransformationCalculatorOptions::InterpolationMode interpolation_mode_; + #if !MEDIAPIPE_DISABLE_GPU GlCalculatorHelper gpu_helper_; std::unique_ptr rgb_renderer_; @@ -343,6 +347,11 @@ absl::Status ImageTransformationCalculator::Open(CalculatorContext* cc) { options_.padding_color().green(), options_.padding_color().blue()); + interpolation_mode_ = options_.interpolation_mode(); + if (options_.interpolation_mode() == + ImageTransformationCalculatorOptions::DEFAULT) { + interpolation_mode_ = ImageTransformationCalculatorOptions::LINEAR; + } if (use_gpu_) { #if !MEDIAPIPE_DISABLE_GPU // Let the helper access the GL context information. @@ -457,26 +466,48 @@ absl::Status ImageTransformationCalculator::RenderCpu(CalculatorContext* cc) { ComputeOutputDimensions(input_width, input_height, &output_width, &output_height); + int opencv_interpolation_mode = cv::INTER_LINEAR; if (output_width_ > 0 && output_height_ > 0) { cv::Mat scaled_mat; - if (scale_mode_ == mediapipe::ScaleMode_Mode_STRETCH) { - int scale_flag = - input_mat.cols > output_width_ && input_mat.rows > output_height_ - ? cv::INTER_AREA - : cv::INTER_LINEAR; + if (scale_mode_ == mediapipe::ScaleMode::STRETCH) { + if (interpolation_mode_ == ImageTransformationCalculatorOptions::LINEAR) { + // Use INTER_AREA for downscaling if interpolation mode is set to + // LINEAR. + if (input_mat.cols > output_width_ && input_mat.rows > output_height_) { + opencv_interpolation_mode = cv::INTER_AREA; + + } else { + opencv_interpolation_mode = cv::INTER_LINEAR; + } + } else { + opencv_interpolation_mode = cv::INTER_NEAREST; + } cv::resize(input_mat, scaled_mat, cv::Size(output_width_, output_height_), - 0, 0, scale_flag); + 0, 0, opencv_interpolation_mode); } else { const float scale = std::min(static_cast(output_width_) / input_width, static_cast(output_height_) / input_height); const int target_width = std::round(input_width * scale); const int target_height = std::round(input_height * scale); - int scale_flag = scale < 1.0f ? cv::INTER_AREA : cv::INTER_LINEAR; - if (scale_mode_ == mediapipe::ScaleMode_Mode_FIT) { + + if (interpolation_mode_ == ImageTransformationCalculatorOptions::LINEAR) { + // Use INTER_AREA for downscaling if interpolation mode is set to + // LINEAR. + if (scale < 1.0f) { + opencv_interpolation_mode = cv::INTER_AREA; + } else { + opencv_interpolation_mode = cv::INTER_LINEAR; + } + } else { + opencv_interpolation_mode = cv::INTER_NEAREST; + } + + if (scale_mode_ == mediapipe::ScaleMode::FIT) { cv::Mat intermediate_mat; cv::resize(input_mat, intermediate_mat, - cv::Size(target_width, target_height), 0, 0, scale_flag); + cv::Size(target_width, target_height), 0, 0, + opencv_interpolation_mode); const int top = (output_height_ - target_height) / 2; const int bottom = output_height_ - target_height - top; const int left = (output_width_ - target_width) / 2; @@ -488,7 +519,7 @@ absl::Status ImageTransformationCalculator::RenderCpu(CalculatorContext* cc) { padding_color_); } else { cv::resize(input_mat, scaled_mat, cv::Size(target_width, target_height), - 0, 0, scale_flag); + 0, 0, opencv_interpolation_mode); output_width = target_width; output_height = target_height; } @@ -514,17 +545,17 @@ absl::Status ImageTransformationCalculator::RenderCpu(CalculatorContext* cc) { cv::warpAffine(input_mat, rotated_mat, rotation_mat, rotated_size); } else { switch (rotation_) { - case mediapipe::RotationMode_Mode_UNKNOWN: - case mediapipe::RotationMode_Mode_ROTATION_0: + case mediapipe::RotationMode::UNKNOWN: + case mediapipe::RotationMode::ROTATION_0: rotated_mat = input_mat; break; - case mediapipe::RotationMode_Mode_ROTATION_90: + case mediapipe::RotationMode::ROTATION_90: cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_COUNTERCLOCKWISE); break; - case mediapipe::RotationMode_Mode_ROTATION_180: + case mediapipe::RotationMode::ROTATION_180: cv::rotate(input_mat, rotated_mat, cv::ROTATE_180); break; - case mediapipe::RotationMode_Mode_ROTATION_270: + case mediapipe::RotationMode::ROTATION_270: cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_CLOCKWISE); break; } @@ -561,7 +592,7 @@ absl::Status ImageTransformationCalculator::RenderGpu(CalculatorContext* cc) { ComputeOutputDimensions(input_width, input_height, &output_width, &output_height); - if (scale_mode_ == mediapipe::ScaleMode_Mode_FILL_AND_CROP) { + if (scale_mode_ == mediapipe::ScaleMode::FILL_AND_CROP) { const float scale = std::min(static_cast(output_width_) / input_width, static_cast(output_height_) / input_height); @@ -628,6 +659,12 @@ absl::Status ImageTransformationCalculator::RenderGpu(CalculatorContext* cc) { glActiveTexture(GL_TEXTURE1); glBindTexture(src1.target(), src1.name()); + if (interpolation_mode_ == ImageTransformationCalculatorOptions::NEAREST) { + // TODO: revert texture params. + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST); + } + MP_RETURN_IF_ERROR(renderer->GlRender( src1.width(), src1.height(), dst.width(), dst.height(), scale_mode, rotation, flip_horizontally_, flip_vertically_, @@ -652,8 +689,8 @@ void ImageTransformationCalculator::ComputeOutputDimensions( if (output_width_ > 0 && output_height_ > 0) { *output_width = output_width_; *output_height = output_height_; - } else if (rotation_ == mediapipe::RotationMode_Mode_ROTATION_90 || - rotation_ == mediapipe::RotationMode_Mode_ROTATION_270) { + } else if (rotation_ == mediapipe::RotationMode::ROTATION_90 || + rotation_ == mediapipe::RotationMode::ROTATION_270) { *output_width = input_height; *output_height = input_width; } else { @@ -666,9 +703,9 @@ void ImageTransformationCalculator::ComputeOutputLetterboxPadding( int input_width, int input_height, int output_width, int output_height, std::array* padding) { padding->fill(0.f); - if (scale_mode_ == mediapipe::ScaleMode_Mode_FIT) { - if (rotation_ == mediapipe::RotationMode_Mode_ROTATION_90 || - rotation_ == mediapipe::RotationMode_Mode_ROTATION_270) { + if (scale_mode_ == mediapipe::ScaleMode::FIT) { + if (rotation_ == mediapipe::RotationMode::ROTATION_90 || + rotation_ == mediapipe::RotationMode::ROTATION_270) { std::swap(input_width, input_height); } const float input_aspect_ratio = diff --git a/mediapipe/calculators/image/image_transformation_calculator.proto b/mediapipe/calculators/image/image_transformation_calculator.proto index 16f60fcbc..0e2453a46 100644 --- a/mediapipe/calculators/image/image_transformation_calculator.proto +++ b/mediapipe/calculators/image/image_transformation_calculator.proto @@ -54,4 +54,15 @@ message ImageTransformationCalculatorOptions { // The color for the padding. This option is only used when the scale mode is // FIT. Default is black. This is for CPU only. optional Color padding_color = 8; + + // Interpolation method to use. Note that on CPU when LINEAR is specified, + // INTER_LINEAR is used for upscaling and INTER_AREA is used for downscaling. + enum InterpolationMode { + DEFAULT = 0; + LINEAR = 1; + NEAREST = 2; + } + + // Mode DEFAULT will use LINEAR interpolation. + optional InterpolationMode interpolation_mode = 9; } diff --git a/mediapipe/calculators/image/image_transformation_calculator_test.cc b/mediapipe/calculators/image/image_transformation_calculator_test.cc new file mode 100644 index 000000000..48828cc70 --- /dev/null +++ b/mediapipe/calculators/image/image_transformation_calculator_test.cc @@ -0,0 +1,315 @@ +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/flags/flag.h" +#include "absl/strings/substitute.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "testing/base/public/gmock.h" +#include "testing/base/public/googletest.h" +#include "third_party/OpenCV/core.hpp" // IWYU pragma: keep +#include "third_party/OpenCV/core/mat.hpp" + +namespace mediapipe { + +namespace { + +absl::flat_hash_set computeUniqueValues(const cv::Mat& mat) { + // Compute the unique values in cv::Mat + absl::flat_hash_set unique_values; + for (int i = 0; i < mat.rows; i++) { + for (int j = 0; j < mat.cols; j++) { + unique_values.insert(mat.at(i, j)); + } + } + return unique_values; +} + +TEST(ImageTransformationCalculatorTest, NearestNeighborResizing) { + cv::Mat input_mat; + cv::cvtColor(cv::imread(file::JoinPath("./", + "/mediapipe/calculators/" + "image/testdata/binary_mask.png")), + input_mat, cv::COLOR_BGR2GRAY); + Packet input_image_packet = MakePacket( + ImageFormat::GRAY8, input_mat.size().width, input_mat.size().height); + input_mat.copyTo(formats::MatView(&(input_image_packet.Get()))); + + std::vector> output_dims{ + {256, 333}, {512, 512}, {1024, 1024}}; + + for (auto& output_dim : output_dims) { + Packet input_output_dim_packet = + MakePacket>(output_dim); + std::vector scale_modes{"FIT", "STRETCH"}; + for (const auto& scale_mode : scale_modes) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie( + absl::Substitute(R"( + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE:input_image" + input_stream: "OUTPUT_DIMENSIONS:image_size" + output_stream: "IMAGE:output_image" + options: { + [mediapipe.ImageTransformationCalculatorOptions.ext]: { + scale_mode: $0 + interpolation_mode: NEAREST + } + })", + scale_mode)); + + CalculatorRunner runner(node_config); + runner.MutableInputs()->Tag("IMAGE").packets.push_back( + input_image_packet.At(Timestamp(0))); + runner.MutableInputs() + ->Tag("OUTPUT_DIMENSIONS") + .packets.push_back(input_output_dim_packet.At(Timestamp(0))); + + MP_ASSERT_OK(runner.Run()); + const auto& outputs = runner.Outputs(); + ASSERT_EQ(outputs.NumEntries(), 1); + const std::vector& packets = outputs.Tag("IMAGE").packets; + ASSERT_EQ(packets.size(), 1); + const auto& result = packets[0].Get(); + ASSERT_EQ(output_dim.first, result.Width()); + ASSERT_EQ(output_dim.second, result.Height()); + + auto unique_input_values = computeUniqueValues(input_mat); + auto unique_output_values = + computeUniqueValues(formats::MatView(&result)); + EXPECT_THAT(unique_input_values, + ::testing::ContainerEq(unique_output_values)); + } + } +} + +TEST(ImageTransformationCalculatorTest, + NearestNeighborResizingWorksForFloatInput) { + cv::Mat input_mat; + cv::cvtColor(cv::imread(file::JoinPath("./", + "/mediapipe/calculators/" + "image/testdata/binary_mask.png")), + input_mat, cv::COLOR_BGR2GRAY); + Packet input_image_packet = MakePacket( + ImageFormat::VEC32F1, input_mat.size().width, input_mat.size().height); + cv::Mat packet_mat_view = + formats::MatView(&(input_image_packet.Get())); + input_mat.convertTo(packet_mat_view, CV_32FC1, 1 / 255.f); + + std::vector> output_dims{ + {256, 333}, {512, 512}, {1024, 1024}}; + + for (auto& output_dim : output_dims) { + Packet input_output_dim_packet = + MakePacket>(output_dim); + std::vector scale_modes{"FIT", "STRETCH"}; + for (const auto& scale_mode : scale_modes) { + CalculatorGraphConfig::Node node_config = + ParseTextProtoOrDie( + absl::Substitute(R"( + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE:input_image" + input_stream: "OUTPUT_DIMENSIONS:image_size" + output_stream: "IMAGE:output_image" + options: { + [mediapipe.ImageTransformationCalculatorOptions.ext]: { + scale_mode: $0 + interpolation_mode: NEAREST + } + })", + scale_mode)); + + CalculatorRunner runner(node_config); + runner.MutableInputs()->Tag("IMAGE").packets.push_back( + input_image_packet.At(Timestamp(0))); + runner.MutableInputs() + ->Tag("OUTPUT_DIMENSIONS") + .packets.push_back(input_output_dim_packet.At(Timestamp(0))); + + MP_ASSERT_OK(runner.Run()); + const auto& outputs = runner.Outputs(); + ASSERT_EQ(outputs.NumEntries(), 1); + const std::vector& packets = outputs.Tag("IMAGE").packets; + ASSERT_EQ(packets.size(), 1); + const auto& result = packets[0].Get(); + ASSERT_EQ(output_dim.first, result.Width()); + ASSERT_EQ(output_dim.second, result.Height()); + + auto unique_input_values = computeUniqueValues(packet_mat_view); + auto unique_output_values = + computeUniqueValues(formats::MatView(&result)); + EXPECT_THAT(unique_input_values, + ::testing::ContainerEq(unique_output_values)); + } + } +} + +TEST(ImageTransformationCalculatorTest, NearestNeighborResizingGpu) { + cv::Mat input_mat; + cv::cvtColor(cv::imread(file::JoinPath("./", + "/mediapipe/calculators/" + "image/testdata/binary_mask.png")), + input_mat, cv::COLOR_BGR2RGBA); + + std::vector> output_dims{ + {256, 333}, {512, 512}, {1024, 1024}}; + + for (auto& output_dim : output_dims) { + std::vector scale_modes{"FIT"}; //, "STRETCH"}; + for (const auto& scale_mode : scale_modes) { + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie( + absl::Substitute(R"( + input_stream: "input_image" + input_stream: "image_size" + output_stream: "output_image" + + node { + calculator: "ImageFrameToGpuBufferCalculator" + input_stream: "input_image" + output_stream: "input_image_gpu" + } + + node { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE_GPU:input_image_gpu" + input_stream: "OUTPUT_DIMENSIONS:image_size" + output_stream: "IMAGE_GPU:output_image_gpu" + options: { + [mediapipe.ImageTransformationCalculatorOptions.ext]: { + scale_mode: $0 + interpolation_mode: NEAREST + } + } + } + node { + calculator: "GpuBufferToImageFrameCalculator" + input_stream: "output_image_gpu" + output_stream: "output_image" + })", + scale_mode)); + ImageFrame input_image(ImageFormat::SRGBA, input_mat.size().width, + input_mat.size().height); + input_mat.copyTo(formats::MatView(&input_image)); + + std::vector output_image_packets; + tool::AddVectorSink("output_image", &graph_config, &output_image_packets); + + CalculatorGraph graph(graph_config); + MP_ASSERT_OK(graph.StartRun({})); + + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_image", + MakePacket(std::move(input_image)).At(Timestamp(0)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "image_size", + MakePacket>(output_dim).At(Timestamp(0)))); + + MP_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_THAT(output_image_packets, testing::SizeIs(1)); + + const auto& output_image = output_image_packets[0].Get(); + ASSERT_EQ(output_dim.first, output_image.Width()); + ASSERT_EQ(output_dim.second, output_image.Height()); + + auto unique_input_values = computeUniqueValues(input_mat); + auto unique_output_values = + computeUniqueValues(formats::MatView(&output_image)); + EXPECT_THAT(unique_input_values, + ::testing::ContainerEq(unique_output_values)); + } + } +} + +TEST(ImageTransformationCalculatorTest, + NearestNeighborResizingWorksForFloatTexture) { + cv::Mat input_mat; + cv::cvtColor(cv::imread(file::JoinPath("./", + "/mediapipe/calculators/" + "image/testdata/binary_mask.png")), + input_mat, cv::COLOR_BGR2GRAY); + Packet input_image_packet = MakePacket( + ImageFormat::VEC32F1, input_mat.size().width, input_mat.size().height); + cv::Mat packet_mat_view = + formats::MatView(&(input_image_packet.Get())); + input_mat.convertTo(packet_mat_view, CV_32FC1, 1 / 255.f); + + std::vector> output_dims{ + {256, 333}, {512, 512}, {1024, 1024}}; + + for (auto& output_dim : output_dims) { + std::vector scale_modes{"FIT"}; //, "STRETCH"}; + for (const auto& scale_mode : scale_modes) { + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie( + absl::Substitute(R"( + input_stream: "input_image" + input_stream: "image_size" + output_stream: "output_image" + + node { + calculator: "ImageFrameToGpuBufferCalculator" + input_stream: "input_image" + output_stream: "input_image_gpu" + } + + node { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE_GPU:input_image_gpu" + input_stream: "OUTPUT_DIMENSIONS:image_size" + output_stream: "IMAGE_GPU:output_image_gpu" + options: { + [mediapipe.ImageTransformationCalculatorOptions.ext]: { + scale_mode: $0 + interpolation_mode: NEAREST + } + } + } + node { + calculator: "GpuBufferToImageFrameCalculator" + input_stream: "output_image_gpu" + output_stream: "output_image" + })", + scale_mode)); + + std::vector output_image_packets; + tool::AddVectorSink("output_image", &graph_config, &output_image_packets); + + CalculatorGraph graph(graph_config); + MP_ASSERT_OK(graph.StartRun({})); + + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_image", input_image_packet.At(Timestamp(0)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "image_size", + MakePacket>(output_dim).At(Timestamp(0)))); + + MP_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_THAT(output_image_packets, testing::SizeIs(1)); + + const auto& output_image = output_image_packets[0].Get(); + ASSERT_EQ(output_dim.first, output_image.Width()); + ASSERT_EQ(output_dim.second, output_image.Height()); + + auto unique_input_values = computeUniqueValues(packet_mat_view); + auto unique_output_values = + computeUniqueValues(formats::MatView(&output_image)); + EXPECT_THAT(unique_input_values, + ::testing::ContainerEq(unique_output_values)); + } + } +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/image/opencv_image_encoder_calculator.cc b/mediapipe/calculators/image/opencv_image_encoder_calculator.cc index 93ec9435f..0308b9b8c 100644 --- a/mediapipe/calculators/image/opencv_image_encoder_calculator.cc +++ b/mediapipe/calculators/image/opencv_image_encoder_calculator.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/log/absl_check.h" #include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame_opencv.h" @@ -61,7 +62,7 @@ absl::Status OpenCvImageEncoderCalculator::Open(CalculatorContext* cc) { absl::Status OpenCvImageEncoderCalculator::Process(CalculatorContext* cc) { const ImageFrame& image_frame = cc->Inputs().Index(0).Get(); - CHECK_EQ(1, image_frame.ByteDepth()); + ABSL_CHECK_EQ(1, image_frame.ByteDepth()); std::unique_ptr encoded_result = absl::make_unique(); diff --git a/mediapipe/calculators/image/scale_image_calculator.cc b/mediapipe/calculators/image/scale_image_calculator.cc index d8a3cb93b..1d4f980fe 100644 --- a/mediapipe/calculators/image/scale_image_calculator.cc +++ b/mediapipe/calculators/image/scale_image_calculator.cc @@ -18,6 +18,8 @@ #include #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/strings/str_cat.h" #include "absl/strings/substitute.h" #include "libyuv/scale.h" @@ -293,7 +295,7 @@ absl::Status ScaleImageCalculator::InitializeFrameInfo(CalculatorContext* cc) { header->width = output_width_; header->height = output_height_; header->format = output_format_; - LOG(INFO) << "OUTPUTTING HEADER on stream"; + ABSL_LOG(INFO) << "OUTPUTTING HEADER on stream"; cc->Outputs() .Tag("VIDEO_HEADER") .Add(header.release(), Timestamp::PreStream()); @@ -393,10 +395,11 @@ absl::Status ScaleImageCalculator::Open(CalculatorContext* cc) { .SetHeader(Adopt(output_header.release())); has_header_ = true; } else { - LOG(WARNING) << "Stream had a VideoHeader which didn't have sufficient " - "information. " - "Dropping VideoHeader and trying to deduce needed " - "information."; + ABSL_LOG(WARNING) + << "Stream had a VideoHeader which didn't have sufficient " + "information. " + "Dropping VideoHeader and trying to deduce needed " + "information."; input_width_ = 0; input_height_ = 0; if (!options_.has_input_format()) { @@ -507,7 +510,7 @@ absl::Status ScaleImageCalculator::ValidateImageFrame( absl::Status ScaleImageCalculator::ValidateYUVImage(CalculatorContext* cc, const YUVImage& yuv_image) { - CHECK_EQ(input_format_, ImageFormat::YCBCR420P); + ABSL_CHECK_EQ(input_format_, ImageFormat::YCBCR420P); if (!has_header_) { if (input_width_ != yuv_image.width() || input_height_ != yuv_image.height()) { diff --git a/mediapipe/calculators/image/scale_image_utils.cc b/mediapipe/calculators/image/scale_image_utils.cc index 86a53ffc5..77b7c0ece 100644 --- a/mediapipe/calculators/image/scale_image_utils.cc +++ b/mediapipe/calculators/image/scale_image_utils.cc @@ -18,6 +18,7 @@ #include +#include "absl/log/absl_check.h" #include "absl/strings/str_split.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" @@ -40,10 +41,10 @@ absl::Status FindCropDimensions(int input_width, int input_height, // const std::string& max_aspect_ratio, // int* crop_width, int* crop_height, // int* col_start, int* row_start) { - CHECK(crop_width); - CHECK(crop_height); - CHECK(col_start); - CHECK(row_start); + ABSL_CHECK(crop_width); + ABSL_CHECK(crop_height); + ABSL_CHECK(col_start); + ABSL_CHECK(row_start); double min_aspect_ratio_q = 0.0; double max_aspect_ratio_q = 0.0; @@ -83,8 +84,8 @@ absl::Status FindCropDimensions(int input_width, int input_height, // } } - CHECK_LE(*crop_width, input_width); - CHECK_LE(*crop_height, input_height); + ABSL_CHECK_LE(*crop_width, input_width); + ABSL_CHECK_LE(*crop_height, input_height); return absl::OkStatus(); } @@ -96,8 +97,8 @@ absl::Status FindOutputDimensions(int input_width, // bool preserve_aspect_ratio, // int scale_to_multiple_of, // int* output_width, int* output_height) { - CHECK(output_width); - CHECK(output_height); + ABSL_CHECK(output_width); + ABSL_CHECK(output_height); if (target_max_area > 0 && input_width * input_height > target_max_area) { preserve_aspect_ratio = true; diff --git a/mediapipe/calculators/image/segmentation_smoothing_calculator.cc b/mediapipe/calculators/image/segmentation_smoothing_calculator.cc index 81732f904..1194412a6 100644 --- a/mediapipe/calculators/image/segmentation_smoothing_calculator.cc +++ b/mediapipe/calculators/image/segmentation_smoothing_calculator.cc @@ -15,13 +15,13 @@ #include #include +#include "absl/log/absl_log.h" #include "mediapipe/calculators/image/segmentation_smoothing_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_options.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/formats/image_frame.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/vector.h" @@ -110,7 +110,7 @@ REGISTER_CALCULATOR(SegmentationSmoothingCalculator); absl::Status SegmentationSmoothingCalculator::GetContract( CalculatorContract* cc) { - CHECK_GE(cc->Inputs().NumEntries(), 1); + RET_CHECK_GE(cc->Inputs().NumEntries(), 1); cc->Inputs().Tag(kCurrentMaskTag).Set(); cc->Inputs().Tag(kPreviousMaskTag).Set(); @@ -273,7 +273,7 @@ absl::Status SegmentationSmoothingCalculator::RenderGpu(CalculatorContext* cc) { const auto& previous_frame = cc->Inputs().Tag(kPreviousMaskTag).Get(); if (previous_frame.format() != current_frame.format()) { - LOG(ERROR) << "Warning: mixing input format types. "; + ABSL_LOG(ERROR) << "Warning: mixing input format types. "; } auto previous_texture = gpu_helper_.CreateSourceTexture(previous_frame); diff --git a/mediapipe/calculators/image/segmentation_smoothing_calculator_test.cc b/mediapipe/calculators/image/segmentation_smoothing_calculator_test.cc index eeb812cb7..0f5152fc3 100644 --- a/mediapipe/calculators/image/segmentation_smoothing_calculator_test.cc +++ b/mediapipe/calculators/image/segmentation_smoothing_calculator_test.cc @@ -14,6 +14,7 @@ #include +#include "absl/log/absl_log.h" #include "mediapipe/calculators/image/segmentation_smoothing_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" @@ -169,7 +170,7 @@ void RunTest(bool use_gpu, float mix_ratio, cv::Mat& test_result) { } } } else { - LOG(ERROR) << "invalid ratio"; + ABSL_LOG(ERROR) << "invalid ratio"; } } diff --git a/mediapipe/calculators/image/set_alpha_calculator.cc b/mediapipe/calculators/image/set_alpha_calculator.cc index e20621e8d..d451cd21c 100644 --- a/mediapipe/calculators/image/set_alpha_calculator.cc +++ b/mediapipe/calculators/image/set_alpha_calculator.cc @@ -14,13 +14,13 @@ #include +#include "absl/log/absl_log.h" #include "mediapipe/calculators/image/set_alpha_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_options.pb.h" #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status.h" @@ -142,7 +142,7 @@ class SetAlphaCalculator : public CalculatorBase { REGISTER_CALCULATOR(SetAlphaCalculator); absl::Status SetAlphaCalculator::GetContract(CalculatorContract* cc) { - CHECK_GE(cc->Inputs().NumEntries(), 1); + RET_CHECK_GE(cc->Inputs().NumEntries(), 1); bool use_gpu = false; @@ -268,7 +268,7 @@ absl::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) { const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get(); const cv::Mat input_mat = formats::MatView(&input_frame); if (!(input_mat.type() == CV_8UC3 || input_mat.type() == CV_8UC4)) { - LOG(ERROR) << "Only 3 or 4 channel 8-bit input image supported"; + ABSL_LOG(ERROR) << "Only 3 or 4 channel 8-bit input image supported"; } // Setup destination image @@ -328,7 +328,7 @@ absl::Status SetAlphaCalculator::RenderGpu(CalculatorContext* cc) { cc->Inputs().Tag(kInputFrameTagGpu).Get(); if (!(input_frame.format() == mediapipe::GpuBufferFormat::kBGRA32 || input_frame.format() == mediapipe::GpuBufferFormat::kRGB24)) { - LOG(ERROR) << "Only RGB or RGBA input image supported"; + ABSL_LOG(ERROR) << "Only RGB or RGBA input image supported"; } auto input_texture = gpu_helper_.CreateSourceTexture(input_frame); diff --git a/mediapipe/calculators/image/testdata/BUILD b/mediapipe/calculators/image/testdata/BUILD index da192b513..89df807c7 100644 --- a/mediapipe/calculators/image/testdata/BUILD +++ b/mediapipe/calculators/image/testdata/BUILD @@ -18,6 +18,7 @@ licenses(["notice"]) filegroup( name = "test_images", srcs = [ + "binary_mask.png", "dino.jpg", "dino_quality_50.jpg", "dino_quality_80.jpg", diff --git a/mediapipe/calculators/image/testdata/binary_mask.png b/mediapipe/calculators/image/testdata/binary_mask.png new file mode 100644 index 000000000..aa38e5d3f Binary files /dev/null and b/mediapipe/calculators/image/testdata/binary_mask.png differ diff --git a/mediapipe/calculators/image/yuv_to_image_calculator.cc b/mediapipe/calculators/image/yuv_to_image_calculator.cc index e84eee74e..6a82877c3 100644 --- a/mediapipe/calculators/image/yuv_to_image_calculator.cc +++ b/mediapipe/calculators/image/yuv_to_image_calculator.cc @@ -38,7 +38,7 @@ std::string FourCCToString(libyuv::FourCC fourcc) { buf[0] = (fourcc >> 24) & 0xff; buf[1] = (fourcc >> 16) & 0xff; buf[2] = (fourcc >> 8) & 0xff; - buf[3] = (fourcc)&0xff; + buf[3] = (fourcc) & 0xff; buf[4] = 0; return std::string(buf); } diff --git a/mediapipe/calculators/internal/BUILD b/mediapipe/calculators/internal/BUILD index a92a2f252..a5d82e134 100644 --- a/mediapipe/calculators/internal/BUILD +++ b/mediapipe/calculators/internal/BUILD @@ -31,12 +31,14 @@ mediapipe_proto_library( cc_library( name = "callback_packet_calculator", srcs = ["callback_packet_calculator.cc"], + hdrs = ["callback_packet_calculator.h"], visibility = ["//mediapipe/framework:__subpackages__"], deps = [ ":callback_packet_calculator_cc_proto", "//mediapipe/framework:calculator_base", "//mediapipe/framework:calculator_registry", "//mediapipe/framework:output_side_packet", + "@com_google_absl//absl/status", ], alwayslink = 1, ) diff --git a/mediapipe/calculators/internal/callback_packet_calculator.cc b/mediapipe/calculators/internal/callback_packet_calculator.cc index cc153483e..aa86c0617 100644 --- a/mediapipe/calculators/internal/callback_packet_calculator.cc +++ b/mediapipe/calculators/internal/callback_packet_calculator.cc @@ -11,10 +11,12 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/calculators/internal/callback_packet_calculator.h" #include #include +#include "absl/status/status.h" #include "mediapipe/calculators/internal/callback_packet_calculator.pb.h" // NOLINT #include "mediapipe/framework/calculator_base.h" #include "mediapipe/framework/calculator_registry.h" @@ -39,64 +41,55 @@ void DumpPostStreamPacket(Packet* post_stream_packet, const Packet& packet) { *post_stream_packet = packet; } } + } // namespace -// Creates a callback which takes a packet and stores it either in a -// vector of packets or stores only the packet at PostStream timestamp. -// The kind of callback is controlled by an option. The callback is -// a std::function and is directly usable by CallbackCalculator. -// Since the options for the packet generator include a serialized pointer -// value, the resulting callback is only valid on the original machine -// while that pointer is still alive. -class CallbackPacketCalculator : public CalculatorBase { - public: - static absl::Status GetContract(CalculatorContract* cc) { - const auto& options = cc->Options(); - switch (options.type()) { - case CallbackPacketCalculatorOptions::VECTOR_PACKET: - case CallbackPacketCalculatorOptions::POST_STREAM_PACKET: - cc->OutputSidePackets() - .Index(0) - .Set>(); - break; - default: - return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) - << "Invalid type of callback to produce."; - } - return absl::OkStatus(); - } - - absl::Status Open(CalculatorContext* cc) override { - const auto& options = cc->Options(); - void* ptr; - if (sscanf(options.pointer().c_str(), "%p", &ptr) != 1) { +absl::Status CallbackPacketCalculator::GetContract(CalculatorContract* cc) { + const auto& options = cc->Options(); + switch (options.type()) { + case CallbackPacketCalculatorOptions::VECTOR_PACKET: + case CallbackPacketCalculatorOptions::POST_STREAM_PACKET: + cc->OutputSidePackets() + .Index(0) + .Set>(); + break; + default: return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) - << "Stored pointer value in options is invalid."; - } - switch (options.type()) { - case CallbackPacketCalculatorOptions::VECTOR_PACKET: - cc->OutputSidePackets().Index(0).Set( - MakePacket>(std::bind( - &DumpToVector, reinterpret_cast*>(ptr), - std::placeholders::_1))); - break; - case CallbackPacketCalculatorOptions::POST_STREAM_PACKET: - cc->OutputSidePackets().Index(0).Set( - MakePacket>( - std::bind(&DumpPostStreamPacket, reinterpret_cast(ptr), - std::placeholders::_1))); - break; - default: - return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) - << "Invalid type to dump into."; - } - return absl::OkStatus(); + << "Invalid type of callback to produce."; } + return absl::OkStatus(); +} - absl::Status Process(CalculatorContext* cc) override { - return absl::OkStatus(); +absl::Status CallbackPacketCalculator::Open(CalculatorContext* cc) { + const auto& options = cc->Options(); + void* ptr; + if (sscanf(options.pointer().c_str(), "%p", &ptr) != 1) { + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Stored pointer value in options is invalid."; } -}; + switch (options.type()) { + case CallbackPacketCalculatorOptions::VECTOR_PACKET: + cc->OutputSidePackets().Index(0).Set( + MakePacket>(std::bind( + &DumpToVector, reinterpret_cast*>(ptr), + std::placeholders::_1))); + break; + case CallbackPacketCalculatorOptions::POST_STREAM_PACKET: + cc->OutputSidePackets().Index(0).Set( + MakePacket>( + std::bind(&DumpPostStreamPacket, reinterpret_cast(ptr), + std::placeholders::_1))); + break; + default: + return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) + << "Invalid type to dump into."; + } + return absl::OkStatus(); +} + +absl::Status CallbackPacketCalculator::Process(CalculatorContext* cc) { + return absl::OkStatus(); +} REGISTER_CALCULATOR(CallbackPacketCalculator); diff --git a/mediapipe/calculators/internal/callback_packet_calculator.h b/mediapipe/calculators/internal/callback_packet_calculator.h new file mode 100644 index 000000000..e0b170e36 --- /dev/null +++ b/mediapipe/calculators/internal/callback_packet_calculator.h @@ -0,0 +1,39 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_INTERNAL_CALLBACK_PACKET_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_INTERNAL_CALLBACK_PACKET_CALCULATOR_H_ + +#include "absl/status/status.h" +#include "mediapipe/framework/calculator_base.h" + +namespace mediapipe { + +// Creates a callback which takes a packet and stores it either in a +// vector of packets or stores only the packet at PostStream timestamp. +// The kind of callback is controlled by an option. The callback is +// a std::function and is directly usable by CallbackCalculator. +// Since the options for the packet generator include a serialized pointer +// value, the resulting callback is only valid on the original machine +// while that pointer is still alive. +class CallbackPacketCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_INTERNAL_CALLBACK_PACKET_CALCULATOR_H_ diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 5d52dda0f..e74ecffe6 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -87,6 +87,7 @@ cc_library( "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/util:time_series_util", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -181,6 +182,7 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:tensor", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", ], alwayslink = 1, @@ -198,6 +200,7 @@ cc_test( "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/log:absl_check", "@org_tensorflow//tensorflow/lite/c:common", ], ) @@ -228,7 +231,6 @@ cc_library( "//mediapipe/tasks/metadata:metadata_schema_cc", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -280,7 +282,6 @@ cc_library( "//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils", "//mediapipe/tasks/metadata:metadata_schema_cc", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", ], alwayslink = 1, ) @@ -394,7 +395,7 @@ mediapipe_proto_library( # If you want to have precise control of which implementations to include (e.g. for strict binary # size concerns), depend on those implementations directly, and do not depend on # :inference_calculator. -# In all cases, use "InferenceCalulator" in your graphs. +# In all cases, use "InferenceCalculator" in your graphs. cc_library_with_tflite( name = "inference_calculator_interface", srcs = ["inference_calculator.cc"], @@ -447,6 +448,7 @@ cc_library( "//mediapipe/framework/deps:file_path", "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/util/tflite:tflite_gpu_runner", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -476,6 +478,7 @@ cc_library( "//mediapipe/gpu:gpu_buffer", "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/util/tflite:config", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", @@ -622,6 +625,7 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", + "//mediapipe/gpu:gpu_origin_proto", ], ) @@ -651,10 +655,21 @@ cc_library( "//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + "//mediapipe/gpu:gpu_buffer_format", + "//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/util:resource_util", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings:str_format", ] + select({ "//mediapipe/gpu:disable_gpu": [], "//conditions:default": ["tensor_converter_calculator_gpu_deps"], + }) + select({ + "//mediapipe:apple": [ + "//third_party/apple_frameworks:MetalKit", + ], + "//conditions:default": [], }), alwayslink = 1, ) @@ -696,6 +711,7 @@ cc_test( "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:validate_type", "@com_google_absl//absl/memory", @@ -734,6 +750,8 @@ cc_library( "//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework/port:ret_check", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", ] + selects.with_or({ @@ -790,6 +808,7 @@ cc_library( "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:ret_check", + "@com_google_absl//absl/log:absl_check", ], alwayslink = 1, ) @@ -982,6 +1001,8 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", "//mediapipe/gpu:gpu_origin_cc_proto", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", ] + select({ "//mediapipe/gpu:disable_gpu": [], "//conditions:default": [":image_to_tensor_calculator_gpu_deps"], @@ -1052,6 +1073,7 @@ cc_test( "testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png", "testdata/image_to_tensor/noop_except_range.png", ], + tags = ["not_run:arm"], deps = [ ":image_to_tensor_calculator", ":image_to_tensor_converter", @@ -1073,6 +1095,7 @@ cc_test( "//mediapipe/framework/port:parse_text_proto", "//mediapipe/util:image_test_utils", "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -1200,6 +1223,7 @@ cc_library( "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_simple_shaders", "//mediapipe/gpu:shader_util", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", ], }), diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc index 47617b375..eaf593a69 100644 --- a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc @@ -20,6 +20,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -282,18 +283,23 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) { if (options.has_volume_gain_db()) { gain_ = pow(10, options.volume_gain_db() / 20.0); } - RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^ - !kAudioIn(cc).Header().IsEmpty()) - << "Must either specify the time series header of the \"AUDIO\" stream " - "or have the \"SAMPLE_RATE\" stream connected."; - if (!kAudioIn(cc).Header().IsEmpty()) { - mediapipe::TimeSeriesHeader input_header; - MP_RETURN_IF_ERROR(mediapipe::time_series_util::FillTimeSeriesHeaderIfValid( - kAudioIn(cc).Header(), &input_header)); - if (stream_mode_) { - MP_RETURN_IF_ERROR(SetupStreamingResampler(input_header.sample_rate())); - } else { - source_sample_rate_ = input_header.sample_rate(); + if (options.has_source_sample_rate()) { + source_sample_rate_ = options.source_sample_rate(); + } else { + RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^ + !kAudioIn(cc).Header().IsEmpty()) + << "Must either specify the time series header of the \"AUDIO\" stream " + "or have the \"SAMPLE_RATE\" stream connected."; + if (!kAudioIn(cc).Header().IsEmpty()) { + mediapipe::TimeSeriesHeader input_header; + MP_RETURN_IF_ERROR( + mediapipe::time_series_util::FillTimeSeriesHeaderIfValid( + kAudioIn(cc).Header(), &input_header)); + if (stream_mode_) { + MP_RETURN_IF_ERROR(SetupStreamingResampler(input_header.sample_rate())); + } else { + source_sample_rate_ = input_header.sample_rate(); + } } } AppendZerosToSampleBuffer(padding_samples_before_); @@ -343,7 +349,7 @@ absl::Status AudioToTensorCalculator::Process(CalculatorContext* cc) { return absl::InvalidArgumentError( "The audio data should be stored in column-major."); } - CHECK(channels_match || mono_output); + ABSL_CHECK(channels_match || mono_output); const Matrix& input = channels_match ? input_frame // Mono mixdown. : input_frame.colwise().mean(); @@ -452,7 +458,7 @@ absl::Status AudioToTensorCalculator::SetupStreamingResampler( } void AudioToTensorCalculator::AppendZerosToSampleBuffer(int num_samples) { - CHECK_GE(num_samples, 0); // Ensured by `UpdateContract`. + ABSL_CHECK_GE(num_samples, 0); // Ensured by `UpdateContract`. if (num_samples == 0) { return; } diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto b/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto index 5b7d61bcb..948c82a36 100644 --- a/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto @@ -85,4 +85,7 @@ message AudioToTensorCalculatorOptions { // The volume gain, measured in dB. // Scale the input audio amplitude by 10^(volume_gain_db/20). optional double volume_gain_db = 12; + + // The source number of samples per second (hertz) of the input audio buffers. + optional double source_sample_rate = 13; } diff --git a/mediapipe/calculators/tensor/bert_preprocessor_calculator.cc b/mediapipe/calculators/tensor/bert_preprocessor_calculator.cc index b56122805..12db1493c 100644 --- a/mediapipe/calculators/tensor/bert_preprocessor_calculator.cc +++ b/mediapipe/calculators/tensor/bert_preprocessor_calculator.cc @@ -22,7 +22,6 @@ #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" @@ -244,7 +243,8 @@ std::vector BertPreprocessorCalculator::GenerateInputTensors( input_tensors.reserve(kNumInputTensorsForBert); for (int i = 0; i < kNumInputTensorsForBert; ++i) { input_tensors.push_back( - {Tensor::ElementType::kInt32, Tensor::Shape({tensor_size})}); + {Tensor::ElementType::kInt32, + Tensor::Shape({1, tensor_size}, has_dynamic_input_tensors_)}); } std::memcpy(input_tensors[input_ids_tensor_index_] .GetCpuWriteView() diff --git a/mediapipe/calculators/tensor/bert_preprocessor_calculator.proto b/mediapipe/calculators/tensor/bert_preprocessor_calculator.proto index 5dc9815a1..b2dc5578f 100644 --- a/mediapipe/calculators/tensor/bert_preprocessor_calculator.proto +++ b/mediapipe/calculators/tensor/bert_preprocessor_calculator.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/calculators/tensor/feedback_tensors_calculator_test.cc b/mediapipe/calculators/tensor/feedback_tensors_calculator_test.cc index 5797cc31c..6c5e5cc4f 100644 --- a/mediapipe/calculators/tensor/feedback_tensors_calculator_test.cc +++ b/mediapipe/calculators/tensor/feedback_tensors_calculator_test.cc @@ -18,6 +18,7 @@ #include #include +#include "absl/log/absl_check.h" #include "mediapipe/calculators/tensor/feedback_tensors_calculator.pb.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -65,7 +66,7 @@ template Tensor MakeTensor(std::initializer_list shape, std::initializer_list values) { Tensor tensor(TensorElementType::value, shape); - CHECK_EQ(values.size(), tensor.shape().num_elements()) + ABSL_CHECK_EQ(values.size(), tensor.shape().num_elements()) << "The size of `values` is incompatible with `shape`"; absl::c_copy(values, tensor.GetCpuWriteView().buffer()); return tensor; diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc index 344e12da6..26fb1d868 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc @@ -16,6 +16,7 @@ #include #include +#include "absl/log/absl_log.h" #include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" #include "mediapipe/calculators/tensor/image_to_tensor_converter.h" #include "mediapipe/calculators/tensor/image_to_tensor_utils.h" @@ -284,9 +285,9 @@ class ImageToTensorCalculator : public Node { cc, GetBorderMode(options_.border_mode()), GetOutputTensorType(/*uses_gpu=*/false, params_))); #else - LOG(FATAL) << "Cannot create image to tensor CPU converter since " - "MEDIAPIPE_DISABLE_OPENCV is defined and " - "MEDIAPIPE_ENABLE_HALIDE is not defined."; + ABSL_LOG(FATAL) << "Cannot create image to tensor CPU converter since " + "MEDIAPIPE_DISABLE_OPENCV is defined and " + "MEDIAPIPE_ENABLE_HALIDE is not defined."; #endif // !MEDIAPIPE_DISABLE_HALIDE } } diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc index 409b8623c..7017c1e3a 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc @@ -18,6 +18,7 @@ #include #include "absl/flags/flag.h" +#include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/strings/str_format.h" #include "absl/strings/substitute.h" @@ -205,7 +206,7 @@ mediapipe::ImageFormat::Format GetImageFormat(int image_channels) { } else if (image_channels == 1) { return ImageFormat::GRAY8; } - CHECK(false) << "Unsupported input image channles: " << image_channels; + ABSL_CHECK(false) << "Unsupported input image channles: " << image_channels; } Packet MakeImageFramePacket(cv::Mat input) { diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_frame_buffer.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_frame_buffer.cc index 093f50d76..6f6f6f11c 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_frame_buffer.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_frame_buffer.cc @@ -95,7 +95,8 @@ absl::Status FrameBufferProcessor::Convert(const mediapipe::Image& input, static_cast(range_max) == 255); } - auto input_frame = input.GetGpuBuffer().GetReadView(); + auto input_frame = + input.GetGpuBuffer(/*upload_to_gpu=*/false).GetReadView(); const auto& output_shape = output_tensor.shape(); MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape)); FrameBuffer::Dimension output_dimension{/*width=*/output_shape.dims[2], diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc index 165df8970..465e7e0bc 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc @@ -22,6 +22,7 @@ #include #include +#include "absl/log/absl_log.h" #include "absl/strings/str_cat.h" #include "mediapipe/calculators/tensor/image_to_tensor_converter.h" #include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.h" @@ -259,7 +260,7 @@ class GlProcessor : public ImageToTensorConverter { // error. So in that case, we'll grab the transpose of our original matrix // and send that instead. const auto gl_context = mediapipe::GlContext::GetCurrent(); - LOG_IF(FATAL, !gl_context) << "GlContext is not bound to the thread."; + ABSL_LOG_IF(FATAL, !gl_context) << "GlContext is not bound to the thread."; if (gl_context->GetGlVersion() == mediapipe::GlVersion::kGLES2) { GetTransposedRotatedSubRectToRectTransformMatrix( sub_rect, texture.width(), texture.height(), flip_horizontaly, diff --git a/mediapipe/calculators/tensor/inference_calculator.proto b/mediapipe/calculators/tensor/inference_calculator.proto index 78a0039bc..82f4ec80a 100644 --- a/mediapipe/calculators/tensor/inference_calculator.proto +++ b/mediapipe/calculators/tensor/inference_calculator.proto @@ -88,6 +88,20 @@ message InferenceCalculatorOptions { // serialized model is invalid or missing. optional string serialized_model_dir = 7; + enum CacheWritingBehavior { + // Do not write any caches. + NO_WRITE = 0; + + // Try to write caches, log on failure. + TRY_WRITE = 1; + + // Write caches or return an error if write fails. + WRITE_OR_ERROR = 2; + } + // Specifies how GPU caches are written to disk. + optional CacheWritingBehavior cache_writing_behavior = 10 + [default = WRITE_OR_ERROR]; + // Unique token identifying the model. Used in conjunction with // "serialized_model_dir". It is the caller's responsibility to ensure // there is no clash of the tokens. diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc index 8aee46185..28ea45de0 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include @@ -26,6 +27,7 @@ #include "mediapipe/util/tflite/tflite_gpu_runner.h" #if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS) +#include "absl/log/absl_log.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/util/android/file/base/file.h" #include "mediapipe/util/android/file/base/filesystem.h" @@ -68,13 +70,21 @@ class InferenceCalculatorGlAdvancedImpl const mediapipe::InferenceCalculatorOptions::Delegate::Gpu& gpu_delegate_options); absl::Status ReadGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const; - absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const; + // Writes caches to disk based on |cache_writing_behavior_|. + absl::Status SaveGpuCachesBasedOnBehavior( + tflite::gpu::TFLiteGPURunner* gpu_runner) const; + bool UseSerializedModel() const { return use_serialized_model_; } private: + // Writes caches to disk, returns error on failure. + absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const; + bool use_kernel_caching_ = false; std::string cached_kernel_filename_; bool use_serialized_model_ = false; std::string serialized_model_path_; + mediapipe::InferenceCalculatorOptions::Delegate::Gpu::CacheWritingBehavior + cache_writing_behavior_; }; // Helper class that wraps everything related to GPU inference acceleration. @@ -150,8 +160,6 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process( } absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Close() { - MP_RETURN_IF_ERROR( - on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get())); return gpu_helper_.RunInGlContext([this]() -> absl::Status { tflite_gpu_runner_.reset(); return absl::OkStatus(); @@ -226,9 +234,15 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::InitTFLiteGPURunner( tflite_gpu_runner_->GetOutputShapes()[i].c}; } + if (on_disk_cache_helper_.UseSerializedModel()) { + tflite_gpu_runner_->ForceOpenCLInitFromSerializedModel(); + } + MP_RETURN_IF_ERROR( on_disk_cache_helper_.ReadGpuCaches(tflite_gpu_runner_.get())); - return tflite_gpu_runner_->Build(); + MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build()); + return on_disk_cache_helper_.SaveGpuCachesBasedOnBehavior( + tflite_gpu_runner_.get()); } #if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS) @@ -257,9 +271,36 @@ absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::Init( mediapipe::file::JoinPath(gpu_delegate_options.serialized_model_dir(), gpu_delegate_options.model_token()); } + cache_writing_behavior_ = gpu_delegate_options.has_cache_writing_behavior() + ? gpu_delegate_options.cache_writing_behavior() + : mediapipe::InferenceCalculatorOptions:: + Delegate::Gpu::WRITE_OR_ERROR; return absl::OkStatus(); } +absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper:: + SaveGpuCachesBasedOnBehavior( + tflite::gpu::TFLiteGPURunner* gpu_runner) const { + switch (cache_writing_behavior_) { + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::NO_WRITE: + return absl::OkStatus(); + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::TRY_WRITE: { + auto status = SaveGpuCaches(gpu_runner); + if (!status.ok()) { + ABSL_LOG_FIRST_N(WARNING, 1) << "Failed to save gpu caches: " << status; + } + return absl::OkStatus(); + } + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::WRITE_OR_ERROR: + return SaveGpuCaches(gpu_runner); + default: + ABSL_LOG_FIRST_N(ERROR, 1) + << "Unknown cache writing behavior: " + << static_cast(cache_writing_behavior_); + return absl::InvalidArgumentError("Unknown cache writing behavior."); + } +} + absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::SaveGpuCaches( tflite::gpu::TFLiteGPURunner* gpu_runner) const { @@ -314,6 +355,12 @@ absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::Init( return absl::OkStatus(); } +absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper:: + SaveGpuCachesBasedOnBehavior( + tflite::gpu::TFLiteGPURunner* gpu_runner) const { + return absl::OkStatus(); +} + absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::ReadGpuCaches( tflite::gpu::TFLiteGPURunner* gpu_runner) const { diff --git a/mediapipe/calculators/tensor/inference_calculator_metal.cc b/mediapipe/calculators/tensor/inference_calculator_metal.cc index fba18a81c..253091a8a 100644 --- a/mediapipe/calculators/tensor/inference_calculator_metal.cc +++ b/mediapipe/calculators/tensor/inference_calculator_metal.cc @@ -21,6 +21,7 @@ #include #include +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/strings/str_format.h" #include "mediapipe/calculators/tensor/inference_calculator.h" @@ -74,7 +75,7 @@ tflite::gpu::BHWC BhwcFromTensorShape(const Tensor::Shape& shape) { break; default: // Handles 0 and >4. - LOG(FATAL) + ABSL_LOG(FATAL) << "Dimensions size must be in range [1,4] for GPU inference, but " << shape.dims.size() << " is provided"; } diff --git a/mediapipe/calculators/tensor/inference_calculator_test.cc b/mediapipe/calculators/tensor/inference_calculator_test.cc index 3662af391..2e75bb976 100644 --- a/mediapipe/calculators/tensor/inference_calculator_test.cc +++ b/mediapipe/calculators/tensor/inference_calculator_test.cc @@ -16,7 +16,7 @@ #include #include -#include "absl/log/check.h" +#include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" diff --git a/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc b/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc index a2b8a9285..b727f179d 100644 --- a/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc +++ b/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc @@ -96,6 +96,19 @@ absl::StatusOr> InferenceInterpreterDelegateRunner::Run( CalculatorContext* cc, const std::vector& input_tensors) { // Read CPU input into tensors. RET_CHECK_EQ(interpreter_->inputs().size(), input_tensors.size()); + + // If the input tensors have dynamic shape, then the tensors need to be + // resized and reallocated before we can copy the tensor values. + bool resized_tensor_shapes = false; + for (int i = 0; i < input_tensors.size(); ++i) { + if (input_tensors[i].shape().is_dynamic) { + interpreter_->ResizeInputTensorStrict(i, input_tensors[i].shape().dims); + resized_tensor_shapes = true; + } + } + // Reallocation is needed for memory sanity. + if (resized_tensor_shapes) interpreter_->AllocateTensors(); + for (int i = 0; i < input_tensors.size(); ++i) { const TfLiteType input_tensor_type = interpreter_->tensor(interpreter_->inputs()[i])->type; diff --git a/mediapipe/calculators/tensor/regex_preprocessor_calculator.cc b/mediapipe/calculators/tensor/regex_preprocessor_calculator.cc index 92a5f0266..8276462ff 100644 --- a/mediapipe/calculators/tensor/regex_preprocessor_calculator.cc +++ b/mediapipe/calculators/tensor/regex_preprocessor_calculator.cc @@ -20,7 +20,6 @@ #include #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h" #include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/api2/port.h" @@ -161,7 +160,7 @@ absl::Status RegexPreprocessorCalculator::Process(CalculatorContext* cc) { // not found in the tokenizer vocab. std::vector result; result.push_back( - {Tensor::ElementType::kInt32, Tensor::Shape({max_seq_len_})}); + {Tensor::ElementType::kInt32, Tensor::Shape({1, max_seq_len_})}); std::memcpy(result[0].GetCpuWriteView().buffer(), input_tokens.data(), input_tokens.size() * sizeof(int32_t)); kTensorsOut(cc).Send(std::move(result)); diff --git a/mediapipe/calculators/tensor/regex_preprocessor_calculator.proto b/mediapipe/calculators/tensor/regex_preprocessor_calculator.proto index 793067a80..ef7ad0472 100644 --- a/mediapipe/calculators/tensor/regex_preprocessor_calculator.proto +++ b/mediapipe/calculators/tensor/regex_preprocessor_calculator.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/calculators/tensor/tensor_converter_calculator.cc b/mediapipe/calculators/tensor/tensor_converter_calculator.cc index c1bd92968..af5b71d3a 100644 --- a/mediapipe/calculators/tensor/tensor_converter_calculator.cc +++ b/mediapipe/calculators/tensor/tensor_converter_calculator.cc @@ -12,9 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "mediapipe/calculators/tensor/tensor_converter_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" @@ -22,7 +27,8 @@ #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/port.h" #include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/util/resource_util.h" +#include "mediapipe/gpu/gpu_buffer_format.h" +#include "mediapipe/gpu/gpu_origin.pb.h" #if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gpu_buffer.h" @@ -43,12 +49,36 @@ #endif // !MEDIAPIPE_DISABLE_GPU namespace { + constexpr int kWorkgroupSize = 8; // Block size for GPU shader. // Commonly used to compute the number of blocks to launch in a kernel. int NumGroups(const int size, const int group_size) { // NOLINT return (size + group_size - 1) / group_size; } +absl::StatusOr ShouldFlipVertically( + const mediapipe::TensorConverterCalculatorOptions& options) { + if (!options.has_gpu_origin()) { + return options.flip_vertically(); + } + + switch (options.gpu_origin()) { + case mediapipe::GpuOrigin::TOP_LEFT: + return false; + case mediapipe::GpuOrigin::DEFAULT: + case mediapipe::GpuOrigin::CONVENTIONAL: + // TOP_LEFT on Metal, BOTTOM_LEFT on OpenGL. +#ifdef __APPLE__ + return false; +#else + return true; +#endif + default: + return absl::InvalidArgumentError( + absl::StrFormat("Unhandled GPU origin %i", options.gpu_origin())); + } +} + typedef Eigen::Matrix RowMajorMatrixXf; typedef Eigen::Matrix @@ -58,6 +88,7 @@ constexpr char kImageFrameTag[] = "IMAGE"; constexpr char kGpuBufferTag[] = "IMAGE_GPU"; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kMatrixTag[] = "MATRIX"; + } // namespace namespace mediapipe { @@ -378,16 +409,27 @@ absl::Status TensorConverterCalculator::InitGpu(CalculatorContext* cc) { // Get input image sizes. const auto& input = cc->Inputs().Tag(kGpuBufferTag).Get(); - mediapipe::ImageFormat::Format format = - mediapipe::ImageFormatForGpuBufferFormat(input.format()); + mediapipe::GpuBufferFormat format = input.format(); const bool include_alpha = (max_num_channels_ == 4); const bool single_channel = (max_num_channels_ == 1); - if (!(format == mediapipe::ImageFormat::GRAY8 || - format == mediapipe::ImageFormat::SRGB || - format == mediapipe::ImageFormat::SRGBA)) - RET_CHECK_FAIL() << "Unsupported GPU input format."; - if (include_alpha && (format != mediapipe::ImageFormat::SRGBA)) - RET_CHECK_FAIL() << "Num input channels is less than desired output."; + + RET_CHECK(format == mediapipe::GpuBufferFormat::kBGRA32 || + format == mediapipe::GpuBufferFormat::kRGB24 || + format == mediapipe::GpuBufferFormat::kRGBA32 || + format == mediapipe::GpuBufferFormat::kRGBAFloat128 || + format == mediapipe::GpuBufferFormat::kRGBAHalf64 || + format == mediapipe::GpuBufferFormat::kGrayFloat32 || + format == mediapipe::GpuBufferFormat::kGrayHalf16 || + format == mediapipe::GpuBufferFormat::kOneComponent8) + << "Unsupported GPU input format: " << static_cast(format); + if (include_alpha) { + RET_CHECK(format == mediapipe::GpuBufferFormat::kBGRA32 || + format == mediapipe::GpuBufferFormat::kRGBA32 || + format == mediapipe::GpuBufferFormat::kRGBAFloat128 || + format == mediapipe::GpuBufferFormat::kRGBAHalf64) + << "Num input channels is less than desired output, input format: " + << static_cast(format); + } #if MEDIAPIPE_METAL_ENABLED id device = gpu_helper_.mtlDevice; @@ -582,7 +624,7 @@ absl::Status TensorConverterCalculator::LoadOptions(CalculatorContext* cc) { if (options.has_output_tensor_float_range()) { output_range_.emplace(options.output_tensor_float_range().min(), options.output_tensor_float_range().max()); - CHECK_GT(output_range_->second, output_range_->first); + ABSL_CHECK_GT(output_range_->second, output_range_->first); } // Custom div and sub values. @@ -593,16 +635,16 @@ absl::Status TensorConverterCalculator::LoadOptions(CalculatorContext* cc) { } // Get y-flip mode. - flip_vertically_ = options.flip_vertically(); + ASSIGN_OR_RETURN(flip_vertically_, ShouldFlipVertically(options)); // Get row_major_matrix mode. row_major_matrix_ = options.row_major_matrix(); // Get desired way to handle input channels. max_num_channels_ = options.max_num_channels(); - CHECK_GE(max_num_channels_, 1); - CHECK_LE(max_num_channels_, 4); - CHECK_NE(max_num_channels_, 2); + ABSL_CHECK_GE(max_num_channels_, 1); + ABSL_CHECK_LE(max_num_channels_, 4); + ABSL_CHECK_NE(max_num_channels_, 2); return absl::OkStatus(); } diff --git a/mediapipe/calculators/tensor/tensor_converter_calculator.proto b/mediapipe/calculators/tensor/tensor_converter_calculator.proto index 97c2154a0..194dd417e 100644 --- a/mediapipe/calculators/tensor/tensor_converter_calculator.proto +++ b/mediapipe/calculators/tensor/tensor_converter_calculator.proto @@ -3,6 +3,7 @@ syntax = "proto2"; package mediapipe; import "mediapipe/framework/calculator.proto"; +import "mediapipe/gpu/gpu_origin.proto"; // Full Example: // @@ -43,8 +44,14 @@ message TensorConverterCalculatorOptions { // with a coordinate system where the origin is at the bottom-left corner // (e.g., in OpenGL) whereas the ML model expects an image with a top-left // origin. + // Prefer gpu_origin over this field. optional bool flip_vertically = 2 [default = false]; + // Determines when the input image should be flipped vertically. + // See GpuOrigin.Mode for more information. + // If unset, falls back to flip_vertically for backwards compatibility. + optional GpuOrigin.Mode gpu_origin = 10; + // Controls how many channels of the input image get passed through to the // tensor. Valid values are 1,3,4 only. Ignored for iOS GPU. optional int32 max_num_channels = 3 [default = 3]; diff --git a/mediapipe/calculators/tensor/tensor_converter_calculator_test.cc b/mediapipe/calculators/tensor/tensor_converter_calculator_test.cc index 2cfbd3d1e..b3df01522 100644 --- a/mediapipe/calculators/tensor/tensor_converter_calculator_test.cc +++ b/mediapipe/calculators/tensor/tensor_converter_calculator_test.cc @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include +#include #include #include "absl/memory/memory.h" @@ -24,8 +27,10 @@ #include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" // NOLINT #include "mediapipe/framework/tool/validate_type.h" @@ -40,7 +45,6 @@ constexpr char kTransposeOptionsString[] = } // namespace using RandomEngine = std::mt19937_64; -using testing::Eq; const uint32_t kSeed = 1234; const int kNumSizes = 8; const int sizes[kNumSizes][2] = {{1, 1}, {12, 1}, {1, 9}, {2, 2}, @@ -110,12 +114,12 @@ TEST_F(TensorConverterCalculatorTest, RandomMatrixColMajor) { // Wait until the calculator done processing. MP_ASSERT_OK(graph_->WaitUntilIdle()); - EXPECT_EQ(1, output_packets.size()); + ASSERT_EQ(output_packets.size(), 1); // Get and process results. const std::vector& tensor_vec = output_packets[0].Get>(); - EXPECT_EQ(1, tensor_vec.size()); + ASSERT_EQ(tensor_vec.size(), 1); const Tensor* tensor = &tensor_vec[0]; EXPECT_EQ(Tensor::ElementType::kFloat32, tensor->element_type()); @@ -127,7 +131,7 @@ TEST_F(TensorConverterCalculatorTest, RandomMatrixColMajor) { auto tensor_buffer = view.buffer(); for (int i = 0; i < num_rows * num_columns; ++i) { const float expected = uniform_dist(random); - EXPECT_EQ(expected, tensor_buffer[i]) << "at i = " << i; + EXPECT_FLOAT_EQ(tensor_buffer[i], expected) << "at i = " << i; } // Fully close graph at end, otherwise calculator+tensors are destroyed @@ -172,12 +176,12 @@ TEST_F(TensorConverterCalculatorTest, RandomMatrixRowMajor) { // Wait until the calculator done processing. MP_ASSERT_OK(graph_->WaitUntilIdle()); - EXPECT_EQ(1, output_packets.size()); + ASSERT_EQ(output_packets.size(), 1); // Get and process results. const std::vector& tensor_vec = output_packets[0].Get>(); - EXPECT_EQ(1, tensor_vec.size()); + ASSERT_EQ(tensor_vec.size(), 1); const Tensor* tensor = &tensor_vec[0]; EXPECT_EQ(Tensor::ElementType::kFloat32, tensor->element_type()); @@ -189,7 +193,7 @@ TEST_F(TensorConverterCalculatorTest, RandomMatrixRowMajor) { auto tensor_buffer = view.buffer(); for (int i = 0; i < num_rows * num_columns; ++i) { const float expected = uniform_dist(random); - EXPECT_EQ(expected, tensor_buffer[i]) << "at i = " << i; + EXPECT_EQ(tensor_buffer[i], expected) << "at i = " << i; } // Fully close graph at end, otherwise calculator+tensors are destroyed @@ -239,12 +243,12 @@ TEST_F(TensorConverterCalculatorTest, CustomDivAndSub) { // Get and process results. const std::vector& tensor_vec = output_packets[0].Get>(); - EXPECT_EQ(1, tensor_vec.size()); + ASSERT_EQ(tensor_vec.size(), 1); const Tensor* tensor = &tensor_vec[0]; EXPECT_EQ(Tensor::ElementType::kFloat32, tensor->element_type()); auto view = tensor->GetCpuReadView(); - EXPECT_FLOAT_EQ(67.0f, *view.buffer()); + EXPECT_FLOAT_EQ(*view.buffer(), 67.0f); // Fully close graph at end, otherwise calculator+tensors are destroyed // after calling WaitUntilDone(). @@ -259,25 +263,22 @@ TEST_F(TensorConverterCalculatorTest, SetOutputRange) { for (std::pair range : range_values) { CalculatorGraph graph; CalculatorGraphConfig graph_config = - mediapipe::ParseTextProtoOrDie( - absl::Substitute(R"( - input_stream: "input_image" - node { - calculator: "TensorConverterCalculator" - input_stream: "IMAGE:input_image" - output_stream: "TENSORS:tensor" - options { - [mediapipe.TensorConverterCalculatorOptions.ext] { - output_tensor_float_range { - min: $0 - max: $1 + mediapipe::ParseTextProtoOrDie(absl::Substitute( + R"pb( + input_stream: "input_image" + node { + calculator: "TensorConverterCalculator" + input_stream: "IMAGE:input_image" + output_stream: "TENSORS:tensor" + options { + [mediapipe.TensorConverterCalculatorOptions.ext] { + output_tensor_float_range { min: $0 max: $1 } + } + } } - } - } - } - )", - /*$0=*/range.first, - /*$1=*/range.second)); + )pb", + /*$0=*/range.first, + /*$1=*/range.second)); std::vector output_packets; tool::AddVectorSink("tensor", &graph_config, &output_packets); @@ -292,26 +293,23 @@ TEST_F(TensorConverterCalculatorTest, SetOutputRange) { // Wait until the calculator finishes processing. MP_ASSERT_OK(graph.WaitUntilIdle()); - EXPECT_THAT(output_packets.size(), Eq(1)); + ASSERT_EQ(output_packets.size(), 1); // Get and process results. const std::vector& tensor_vec = output_packets[0].Get>(); - EXPECT_THAT(tensor_vec.size(), Eq(1)); + ASSERT_EQ(tensor_vec.size(), 1); const Tensor* tensor = &tensor_vec[0]; // Calculate the expected normalized value: - float normalized_value = + float expected_value = range.first + (200 * (range.second - range.first)) / 255.0; - EXPECT_THAT(tensor->element_type(), Eq(Tensor::ElementType::kFloat32)); + EXPECT_EQ(tensor->element_type(), Tensor::ElementType::kFloat32); auto view = tensor->GetCpuReadView(); - float dataf = *view.buffer(); - EXPECT_THAT( - normalized_value, - testing::FloatNear(dataf, 2.0f * std::abs(dataf) * - std::numeric_limits::epsilon())); + float actual_value = *view.buffer(); + EXPECT_FLOAT_EQ(actual_value, expected_value); // Fully close graph at end, otherwise calculator+tensors are destroyed // after calling WaitUntilDone(). @@ -320,4 +318,113 @@ TEST_F(TensorConverterCalculatorTest, SetOutputRange) { } } +TEST_F(TensorConverterCalculatorTest, FlipVertically) { + CalculatorGraph graph; + CalculatorGraphConfig graph_config = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "input_image" + node { + calculator: "TensorConverterCalculator" + input_stream: "IMAGE:input_image" + output_stream: "TENSORS:tensor" + options { + [mediapipe.TensorConverterCalculatorOptions.ext] { + flip_vertically: true + output_tensor_float_range { min: 0 max: 255 } + } + } + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("tensor", &graph_config, &output_packets); + + // Run the graph. + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + auto input_image = absl::make_unique(ImageFormat::GRAY8, 1, 2); + cv::Mat mat = mediapipe::formats::MatView(input_image.get()); + constexpr uint8_t kY0Value = 100; + constexpr uint8_t kY1Value = 200; + mat.at(0, 0) = kY0Value; + mat.at(1, 0) = kY1Value; // Note: y, x! + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_image", Adopt(input_image.release()).At(Timestamp(0)))); + + // Wait until the calculator finishes processing. + MP_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(output_packets.size(), 1); + + // Get and process results. + const std::vector& tensor_vec = + output_packets[0].Get>(); + ASSERT_EQ(tensor_vec.size(), 1); + + const Tensor* tensor = &tensor_vec[0]; + + EXPECT_EQ(tensor->element_type(), Tensor::ElementType::kFloat32); + const float* dataf = tensor->GetCpuReadView().buffer(); + EXPECT_EQ(static_cast(roundf(dataf[0])), kY1Value); // Y0, Y1 flipped! + EXPECT_EQ(static_cast(roundf(dataf[1])), kY0Value); + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + MP_ASSERT_OK(graph.CloseInputStream("input_image")); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +TEST_F(TensorConverterCalculatorTest, GpuOriginOverridesFlipVertically) { + CalculatorGraph graph; + CalculatorGraphConfig graph_config = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "input_image" + node { + calculator: "TensorConverterCalculator" + input_stream: "IMAGE:input_image" + output_stream: "TENSORS:tensor" + options { + [mediapipe.TensorConverterCalculatorOptions.ext] { + flip_vertically: true + gpu_origin: TOP_LEFT + output_tensor_float_range { min: 0 max: 255 } + } + } + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("tensor", &graph_config, &output_packets); + + // Run the graph. + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + auto input_image = absl::make_unique(ImageFormat::GRAY8, 1, 2); + cv::Mat mat = mediapipe::formats::MatView(input_image.get()); + constexpr uint8_t kY0Value = 100; + constexpr uint8_t kY1Value = 200; + mat.at(0, 0) = kY0Value; + mat.at(1, 0) = kY1Value; // Note: y, x! + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_image", Adopt(input_image.release()).At(Timestamp(0)))); + + // Wait until the calculator finishes processing. + MP_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(output_packets.size(), 1); + + // Get and process results. + const std::vector& tensor_vec = + output_packets[0].Get>(); + ASSERT_EQ(tensor_vec.size(), 1); + + const Tensor* tensor = &tensor_vec[0]; + + EXPECT_EQ(tensor->element_type(), Tensor::ElementType::kFloat32); + const float* dataf = tensor->GetCpuReadView().buffer(); + EXPECT_EQ(static_cast(roundf(dataf[0])), kY0Value); // Not flipped! + EXPECT_EQ(static_cast(roundf(dataf[1])), kY1Value); + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + MP_ASSERT_OK(graph.CloseInputStream("input_image")); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc index c8dd0e2a0..6d42226b9 100644 --- a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc @@ -15,6 +15,7 @@ #include #include +#include "absl/log/absl_log.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" @@ -83,7 +84,7 @@ void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes, void ConvertAnchorsToRawValues(const std::vector& anchors, int num_boxes, float* raw_anchors) { - CHECK_EQ(anchors.size(), num_boxes); + ABSL_CHECK_EQ(anchors.size(), num_boxes); int box = 0; for (const auto& anchor : anchors) { raw_anchors[box * kNumCoordsPerBox + 0] = anchor.y_center(); @@ -256,6 +257,7 @@ class TensorsToDetectionsCalculator : public Node { bool gpu_inited_ = false; bool gpu_input_ = false; + bool gpu_has_enough_work_groups_ = true; bool anchors_init_ = false; }; MEDIAPIPE_REGISTER_NODE(TensorsToDetectionsCalculator); @@ -291,7 +293,7 @@ absl::Status TensorsToDetectionsCalculator::Open(CalculatorContext* cc) { absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) { auto output_detections = absl::make_unique>(); bool gpu_processing = false; - if (CanUseGpu()) { + if (CanUseGpu() && gpu_has_enough_work_groups_) { // Use GPU processing only if at least one input tensor is already on GPU // (to avoid CPU->GPU overhead). for (const auto& tensor : *kInTensors(cc)) { @@ -321,11 +323,20 @@ absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) { RET_CHECK(!has_custom_box_indices_); } - if (gpu_processing) { - if (!gpu_inited_) { - MP_RETURN_IF_ERROR(GpuInit(cc)); + if (gpu_processing && !gpu_inited_) { + auto status = GpuInit(cc); + if (status.ok()) { gpu_inited_ = true; + } else if (status.code() == absl::StatusCode::kFailedPrecondition) { + // For initialization error because of hardware limitation, fallback to + // CPU processing. + ABSL_LOG(WARNING) << status.message(); + } else { + // For other error, let the error propagates. + return status; } + } + if (gpu_processing && gpu_inited_) { MP_RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get())); } else { MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get())); @@ -346,17 +357,41 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU( // TODO: Add flexible input tensor size handling. auto raw_box_tensor = &input_tensors[tensor_mapping_.detections_tensor_index()]; - RET_CHECK_EQ(raw_box_tensor->shape().dims.size(), 3); - RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1); RET_CHECK_GT(num_boxes_, 0) << "Please set num_boxes in calculator options"; - RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_); - RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_); + if (raw_box_tensor->shape().dims.size() == 3) { + // The tensors from CPU inference has dim 3. + RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1); + RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_); + RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_); + } else if (raw_box_tensor->shape().dims.size() == 4) { + // The tensors from GPU inference has dim 4. For gpu-cpu fallback support, + // we allow tensors with 4 dims. + RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1); + RET_CHECK_EQ(raw_box_tensor->shape().dims[1], 1); + RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_boxes_); + RET_CHECK_EQ(raw_box_tensor->shape().dims[3], num_coords_); + } else { + return absl::InvalidArgumentError( + "The dimensions of box Tensor must be 3 or 4."); + } auto raw_score_tensor = &input_tensors[tensor_mapping_.scores_tensor_index()]; - RET_CHECK_EQ(raw_score_tensor->shape().dims.size(), 3); - RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1); - RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_); - RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_classes_); + if (raw_score_tensor->shape().dims.size() == 3) { + // The tensors from CPU inference has dim 3. + RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1); + RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_); + RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_classes_); + } else if (raw_score_tensor->shape().dims.size() == 4) { + // The tensors from GPU inference has dim 4. For gpu-cpu fallback support, + // we allow tensors with 4 dims. + RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1); + RET_CHECK_EQ(raw_score_tensor->shape().dims[1], 1); + RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_boxes_); + RET_CHECK_EQ(raw_score_tensor->shape().dims[3], num_classes_); + } else { + return absl::InvalidArgumentError( + "The dimensions of score Tensor must be 3 or 4."); + } auto raw_box_view = raw_box_tensor->GetCpuReadView(); auto raw_boxes = raw_box_view.buffer(); auto raw_scores_view = raw_score_tensor->GetCpuReadView(); @@ -634,7 +669,7 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU( output_detections)); #else - LOG(ERROR) << "GPU input on non-Android not supported yet."; + ABSL_LOG(ERROR) << "GPU input on non-Android not supported yet."; #endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) return absl::OkStatus(); } @@ -669,18 +704,18 @@ absl::Status TensorsToDetectionsCalculator::LoadOptions(CalculatorContext* cc) { num_boxes_ = options_.num_boxes(); num_coords_ = options_.num_coords(); box_output_format_ = GetBoxFormat(options_); - CHECK_NE(options_.max_results(), 0) + ABSL_CHECK_NE(options_.max_results(), 0) << "The maximum number of the top-scored detection results must be " "non-zero."; max_results_ = options_.max_results(); // Currently only support 2D when num_values_per_keypoint equals to 2. - CHECK_EQ(options_.num_values_per_keypoint(), 2); + ABSL_CHECK_EQ(options_.num_values_per_keypoint(), 2); // Check if the output size is equal to the requested boxes and keypoints. - CHECK_EQ(options_.num_keypoints() * options_.num_values_per_keypoint() + - kNumCoordsPerBox, - num_coords_); + ABSL_CHECK_EQ(options_.num_keypoints() * options_.num_values_per_keypoint() + + kNumCoordsPerBox, + num_coords_); if (kSideInIgnoreClasses(cc).IsConnected()) { RET_CHECK(!kSideInIgnoreClasses(cc).IsEmpty()); @@ -1111,15 +1146,21 @@ void main() { int max_wg_size; // typically <= 1024 glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1, &max_wg_size); // y-dim - CHECK_LT(num_classes_, max_wg_size) - << "# classes must be < " << max_wg_size; + gpu_has_enough_work_groups_ = num_classes_ < max_wg_size; + if (!gpu_has_enough_work_groups_) { + return absl::FailedPreconditionError(absl::StrFormat( + "Hardware limitation: Processing will be done on CPU, because " + "num_classes %d exceeds the max work_group size %d.", + num_classes_, max_wg_size)); + } // TODO support better filtering. if (class_index_set_.is_allowlist) { - CHECK_EQ(class_index_set_.values.size(), - IsClassIndexAllowed(0) ? num_classes_ : num_classes_ - 1) + ABSL_CHECK_EQ(class_index_set_.values.size(), + IsClassIndexAllowed(0) ? num_classes_ : num_classes_ - 1) << "Only all classes >= class 0 or >= class 1"; } else { - CHECK_EQ(class_index_set_.values.size(), IsClassIndexAllowed(0) ? 0 : 1) + ABSL_CHECK_EQ(class_index_set_.values.size(), + IsClassIndexAllowed(0) ? 0 : 1) << "Only ignore class 0 is allowed"; } @@ -1340,11 +1381,12 @@ kernel void scoreKernel( // TODO support better filtering. if (class_index_set_.is_allowlist) { - CHECK_EQ(class_index_set_.values.size(), - IsClassIndexAllowed(0) ? num_classes_ : num_classes_ - 1) + ABSL_CHECK_EQ(class_index_set_.values.size(), + IsClassIndexAllowed(0) ? num_classes_ : num_classes_ - 1) << "Only all classes >= class 0 or >= class 1"; } else { - CHECK_EQ(class_index_set_.values.size(), IsClassIndexAllowed(0) ? 0 : 1) + ABSL_CHECK_EQ(class_index_set_.values.size(), + IsClassIndexAllowed(0) ? 0 : 1) << "Only ignore class 0 is allowed"; } @@ -1370,7 +1412,13 @@ kernel void scoreKernel( Tensor::ElementType::kFloat32, Tensor::Shape{1, num_boxes_ * 2}); // # filter classes supported is hardware dependent. int max_wg_size = score_program_.maxTotalThreadsPerThreadgroup; - CHECK_LT(num_classes_, max_wg_size) << "# classes must be <" << max_wg_size; + gpu_has_enough_work_groups_ = num_classes_ < max_wg_size; + if (!gpu_has_enough_work_groups_) { + return absl::FailedPreconditionError(absl::StrFormat( + "Hardware limitation: Processing will be done on CPU, because " + "num_classes %d exceeds the max work_group size %d.", + num_classes_, max_wg_size)); + } } #endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) diff --git a/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.cc b/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.cc index a1cc4e202..5942f234d 100644 --- a/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.cc @@ -142,7 +142,7 @@ absl::Status TensorsToLandmarksCalculator::Process(CalculatorContext* cc) { RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32); int num_values = input_tensors[0].shape().num_elements(); const int num_dimensions = num_values / num_landmarks_; - CHECK_GT(num_dimensions, 0); + ABSL_CHECK_GT(num_dimensions, 0); auto view = input_tensors[0].GetCpuReadView(); auto raw_landmarks = view.buffer(); diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index 0b30689eb..5f5f51657 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -13,6 +13,7 @@ # limitations under the License. # +# Placeholder: load py_proto_library load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library", "mediapipe_proto_library") licenses(["notice"]) @@ -314,6 +315,7 @@ cc_library( "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", ] + select({ "//conditions:default": [ "@org_tensorflow//tensorflow/core:framework", @@ -366,18 +368,18 @@ cc_library( name = "pack_media_sequence_calculator", srcs = ["pack_media_sequence_calculator.cc"], deps = [ + ":pack_media_sequence_calculator_cc_proto", "//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto", - "//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:ret_check", - "//mediapipe/framework/port:status", "//mediapipe/util/sequence:media_sequence", "//mediapipe/util/sequence:media_sequence_util", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/core:protos_all_cc", ], @@ -400,6 +402,21 @@ cc_library( # compile your binary with the flag TENSORFLOW_PROTOS=lite. cc_library( name = "tensorflow_inference_calculator_no_envelope_loader", + deps = [ + ":tensorflow_inference_calculator_for_boq", + ], + alwayslink = 1, +) + +# This dependency removed the following 3 targets because they failed Boq conformance test: +# +# tensorflow_jellyfish_deps +# jfprof_lib +# xprofilez_with_server +# +# If you need them plz consider tensorflow_inference_calculator_no_envelope_loader. +cc_library( + name = "tensorflow_inference_calculator_for_boq", srcs = ["tensorflow_inference_calculator.cc"], deps = [ ":tensorflow_inference_calculator_cc_proto", @@ -414,7 +431,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/tool:status_util", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -473,10 +490,10 @@ cc_library( "//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/deps:clock", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:status_util", + "@com_google_absl//absl/log:absl_log", "@org_tensorflow//tensorflow/core:protos_all_cc", ] + select({ "//conditions:default": [ @@ -504,10 +521,10 @@ cc_library( ":tensorflow_session_from_frozen_graph_generator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/deps:clock", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:status_util", + "@com_google_absl//absl/log:absl_log", "@org_tensorflow//tensorflow/core:protos_all_cc", ] + select({ "//conditions:default": [ @@ -540,6 +557,7 @@ cc_library( "//mediapipe/framework/deps:file_path", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/cc/saved_model:constants", "@org_tensorflow//tensorflow/cc/saved_model:loader_lite", @@ -585,6 +603,24 @@ cc_library( # See yaqs/1092546221614039040 cc_library( name = "tensorflow_session_from_saved_model_generator_no_envelope_loader", + defines = select({ + "//mediapipe:android": ["__ANDROID__"], + "//conditions:default": [], + }), + deps = [ + ":tensorflow_session_from_saved_model_generator_for_boq", + ] + select({ + "//conditions:default": [ + "//learning/brain/frameworks/uptc/public:uptc_session_no_envelope_loader", + ], + }), + alwayslink = 1, +) + +# Same library as tensorflow_session_from_saved_model_generator without uptc_session, +# envelop_loader and remote_session dependencies. +cc_library( + name = "tensorflow_session_from_saved_model_generator_for_boq", srcs = ["tensorflow_session_from_saved_model_generator.cc"], defines = select({ "//mediapipe:android": ["__ANDROID__"], @@ -599,6 +635,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:status_util", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/cc/saved_model:constants", @@ -620,6 +657,7 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_log", "@org_tensorflow//tensorflow/core:framework", ], alwayslink = 1, @@ -634,6 +672,7 @@ cc_library( "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", "@org_tensorflow//tensorflow/core:framework", ], alwayslink = 1, @@ -649,6 +688,7 @@ cc_library( "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", ] + select({ "//conditions:default": [ "@org_tensorflow//tensorflow/core:framework", @@ -745,6 +785,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/util:audio_decoder_cc_proto", "//mediapipe/util/sequence:media_sequence", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/core:protos_all_cc", ], @@ -759,6 +800,8 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@org_tensorflow//tensorflow/core:framework", ], alwayslink = 1, @@ -772,6 +815,7 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_log", "@org_tensorflow//tensorflow/core:framework", ], alwayslink = 1, @@ -785,6 +829,7 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_log", "@org_tensorflow//tensorflow/core:framework", ], alwayslink = 1, @@ -798,6 +843,8 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@org_tensorflow//tensorflow/core:protos_all_cc", ], alwayslink = 1, @@ -892,22 +939,24 @@ cc_test( srcs = ["pack_media_sequence_calculator_test.cc"], deps = [ ":pack_media_sequence_calculator", + ":pack_media_sequence_calculator_cc_proto", "//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto", - "//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:packet", "//mediapipe/framework:timestamp", "//mediapipe/framework/formats:detection_cc_proto", - "//mediapipe/framework/formats:image_frame", - "//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/util/sequence:media_sequence", - "@com_google_absl//absl/container:flat_hash_map", + "//mediapipe/util/sequence:media_sequence_util", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", "@org_tensorflow//tensorflow/core:protos_all_cc", ], ) @@ -1049,6 +1098,7 @@ cc_test( linkstatic = 1, deps = [ ":tensor_to_image_frame_calculator", + ":tensor_to_image_frame_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:image_frame", @@ -1134,6 +1184,7 @@ cc_test( "//mediapipe/framework/port:rectangle", "//mediapipe/util:audio_decoder_cc_proto", "//mediapipe/util/sequence:media_sequence", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/core:protos_all_cc", @@ -1215,6 +1266,8 @@ cc_test( "//mediapipe/framework/tool:sink", "//mediapipe/framework/tool:validate_type", "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", ] + select({ "//conditions:default": [ "@org_tensorflow//tensorflow/core:direct_session", diff --git a/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator.cc b/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator.cc index 32a0eb70b..bbd5cff3e 100644 --- a/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/log/absl_check.h" #include "mediapipe/calculators/tensorflow/matrix_to_tensor_calculator_options.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/matrix.h" @@ -28,7 +29,7 @@ namespace mediapipe { namespace { absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, TimeSeriesHeader* header) { - CHECK(header); + ABSL_CHECK(header); if (header_packet.IsEmpty()) { return absl::UnknownError("No header found."); } diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc index 34136440d..d87029143 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc @@ -12,21 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/strings/match.h" +#include "absl/strings/strip.h" #include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" #include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/location.h" #include "mediapipe/framework/formats/location_opencv.h" -#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/framework/port/status.h" #include "mediapipe/util/sequence/media_sequence.h" #include "mediapipe/util/sequence/media_sequence_util.h" #include "tensorflow/core/example/example.pb.h" @@ -36,7 +38,11 @@ namespace mediapipe { const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; const char kImageTag[] = "IMAGE"; +const char kImageLabelPrefixTag[] = "IMAGE_LABEL_"; +const char kClipLabelPrefixTag[] = "CLIP_LABEL_"; const char kFloatContextFeaturePrefixTag[] = "FLOAT_CONTEXT_FEATURE_"; +const char kIntsContextFeaturePrefixTag[] = "INTS_CONTEXT_FEATURE_"; +const char kBytesContextFeaturePrefixTag[] = "BYTES_CONTEXT_FEATURE_"; const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_"; const char kIntFeaturePrefixTag[] = "INT_FEATURE_"; const char kBytesFeaturePrefixTag[] = "BYTES_FEATURE_"; @@ -44,6 +50,7 @@ const char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED"; const char kBBoxTag[] = "BBOX"; const char kKeypointsTag[] = "KEYPOINTS"; const char kSegmentationMaskTag[] = "CLASS_SEGMENTATION"; +const char kClipMediaIdTag[] = "CLIP_MEDIA_ID"; namespace tf = ::tensorflow; namespace mpms = mediapipe::mediasequence; @@ -55,16 +62,23 @@ namespace mpms = mediapipe::mediasequence; // context features can be supplied verbatim in the calculator's options. The // SequenceExample will conform to the description in media_sequence.h. // -// The supported input stream tags are "IMAGE", which stores the encoded -// images from the OpenCVImageEncoderCalculator, "FORWARD_FLOW_ENCODED", which -// stores the encoded optical flow from the same calculator, "BBOX" which stores -// bounding boxes from vector, and streams with the -// "FLOAT_FEATURE_${NAME}" pattern, which stores the values from vector's -// associated with the name ${NAME}. "KEYPOINTS" stores a map of 2D keypoints -// from flat_hash_map>>. "IMAGE_${NAME}", -// "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store prefixed versions of -// each stream, which allows for multiple image streams to be included. However, -// the default names are suppored by more tools. +// The supported input stream tags are: +// * "IMAGE", which stores the encoded images from the +// OpenCVImageEncoderCalculator, +// * "IMAGE_LABEL", which stores whole image labels from Detection, +// * "FORWARD_FLOW_ENCODED", which stores the encoded optical flow from the same +// calculator, +// * "BBOX" which stores bounding boxes from vector, +// * streams with the "FLOAT_FEATURE_${NAME}" pattern, which stores the values +// from vector's associated with the name ${NAME}, +// * "KEYPOINTS" stores a map of 2D keypoints from flat_hash_map>>, +// * "CLIP_MEDIA_ID", which stores the clip's media ID as a string. +// * "CLIP_LABEL_${NAME}" which stores sparse feature labels, ID and scores in +// mediapipe::Detection. +// "IMAGE_${NAME}", "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store +// prefixed versions of each stream, which allows for multiple image streams to +// be included. However, the default names are suppored by more tools. // // Example config: // node { @@ -100,6 +114,9 @@ class PackMediaSequenceCalculator : public CalculatorBase { static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->InputSidePackets().HasTag(kSequenceExampleTag)); cc->InputSidePackets().Tag(kSequenceExampleTag).Set(); + if (cc->InputSidePackets().HasTag(kClipMediaIdTag)) { + cc->InputSidePackets().Tag(kClipMediaIdTag).Set(); + } if (cc->Inputs().HasTag(kForwardFlowEncodedTag)) { cc->Inputs() @@ -112,6 +129,10 @@ class PackMediaSequenceCalculator : public CalculatorBase { for (const auto& tag : cc->Inputs().GetTags()) { if (absl::StartsWith(tag, kImageTag)) { + if (absl::StartsWith(tag, kImageLabelPrefixTag)) { + cc->Inputs().Tag(tag).Set(); + continue; + } std::string key = ""; if (tag != kImageTag) { int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1; @@ -150,9 +171,18 @@ class PackMediaSequenceCalculator : public CalculatorBase { } cc->Inputs().Tag(tag).Set>(); } + if (absl::StartsWith(tag, kClipLabelPrefixTag)) { + cc->Inputs().Tag(tag).Set(); + } if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag)) { cc->Inputs().Tag(tag).Set>(); } + if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag)) { + cc->Inputs().Tag(tag).Set>(); + } + if (absl::StartsWith(tag, kBytesContextFeaturePrefixTag)) { + cc->Inputs().Tag(tag).Set>(); + } if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) { cc->Inputs().Tag(tag).Set>(); } @@ -164,8 +194,8 @@ class PackMediaSequenceCalculator : public CalculatorBase { } } - CHECK(cc->Outputs().HasTag(kSequenceExampleTag) || - cc->OutputSidePackets().HasTag(kSequenceExampleTag)) + RET_CHECK(cc->Outputs().HasTag(kSequenceExampleTag) || + cc->OutputSidePackets().HasTag(kSequenceExampleTag)) << "Neither the output stream nor the output side packet is set to " "output the sequence example."; if (cc->Outputs().HasTag(kSequenceExampleTag)) { @@ -184,6 +214,11 @@ class PackMediaSequenceCalculator : public CalculatorBase { cc->InputSidePackets() .Tag(kSequenceExampleTag) .Get()); + if (cc->InputSidePackets().HasTag(kClipMediaIdTag) && + !cc->InputSidePackets().Tag(kClipMediaIdTag).IsEmpty()) { + clip_media_id_ = + cc->InputSidePackets().Tag(kClipMediaIdTag).Get(); + } const auto& context_features = cc->Options().context_feature_map(); @@ -197,8 +232,19 @@ class PackMediaSequenceCalculator : public CalculatorBase { replace_keypoints_ = false; if (cc->Options() .replace_data_instead_of_append()) { + // Clear the existing values under the same key. for (const auto& tag : cc->Inputs().GetTags()) { if (absl::StartsWith(tag, kImageTag)) { + if (absl::StartsWith(tag, kImageLabelPrefixTag)) { + std::string key = + std::string(absl::StripPrefix(tag, kImageLabelPrefixTag)); + mpms::ClearImageLabelString(key, sequence_.get()); + mpms::ClearImageLabelConfidence(key, sequence_.get()); + if (!key.empty() || mpms::HasImageEncoded(*sequence_)) { + mpms::ClearImageTimestamp(key, sequence_.get()); + } + continue; + } std::string key = ""; if (tag != kImageTag) { int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1; @@ -227,12 +273,41 @@ class PackMediaSequenceCalculator : public CalculatorBase { mpms::ClearBBoxNumRegions(key, sequence_.get()); mpms::ClearBBoxLabelString(key, sequence_.get()); mpms::ClearBBoxLabelIndex(key, sequence_.get()); + mpms::ClearBBoxLabelConfidence(key, sequence_.get()); mpms::ClearBBoxClassString(key, sequence_.get()); mpms::ClearBBoxClassIndex(key, sequence_.get()); mpms::ClearBBoxTrackString(key, sequence_.get()); mpms::ClearBBoxTrackIndex(key, sequence_.get()); mpms::ClearUnmodifiedBBoxTimestamp(key, sequence_.get()); } + if (absl::StartsWith(tag, kClipLabelPrefixTag)) { + const std::string& key = tag.substr( + sizeof(kClipLabelPrefixTag) / sizeof(*kClipLabelPrefixTag) - 1); + mpms::ClearClipLabelIndex(key, sequence_.get()); + mpms::ClearClipLabelString(key, sequence_.get()); + mpms::ClearClipLabelConfidence(key, sequence_.get()); + } + if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag)) { + const std::string& key = + tag.substr(sizeof(kFloatContextFeaturePrefixTag) / + sizeof(*kFloatContextFeaturePrefixTag) - + 1); + mpms::ClearContextFeatureFloats(key, sequence_.get()); + } + if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag)) { + const std::string& key = + tag.substr(sizeof(kIntsContextFeaturePrefixTag) / + sizeof(*kIntsContextFeaturePrefixTag) - + 1); + mpms::ClearContextFeatureInts(key, sequence_.get()); + } + if (absl::StartsWith(tag, kBytesContextFeaturePrefixTag)) { + const std::string& key = + tag.substr(sizeof(kBytesContextFeaturePrefixTag) / + sizeof(*kBytesContextFeaturePrefixTag) - + 1); + mpms::ClearContextFeatureBytes(key, sequence_.get()); + } if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) { std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) / sizeof(*kFloatFeaturePrefixTag) - @@ -343,6 +418,34 @@ class PackMediaSequenceCalculator : public CalculatorBase { if (absl::StartsWith(tag, kImageTag) && !cc->Inputs().Tag(tag).IsEmpty()) { std::string key = ""; + if (absl::StartsWith(tag, kImageLabelPrefixTag)) { + std::string key = + std::string(absl::StripPrefix(tag, kImageLabelPrefixTag)); + const auto& detection = cc->Inputs().Tag(tag).Get(); + if (detection.label().empty()) continue; + RET_CHECK(detection.label_size() == detection.score_size()) + << "Wrong image label data format: " << detection.label_size() + << " vs " << detection.score_size(); + if (!detection.label_id().empty()) { + RET_CHECK(detection.label_id_size() == detection.label_size()) + << "Wrong image label ID format: " << detection.label_id_size() + << " vs " << detection.label_size(); + } + std::vector labels(detection.label().begin(), + detection.label().end()); + std::vector confidences(detection.score().begin(), + detection.score().end()); + std::vector ids(detection.label_id().begin(), + detection.label_id().end()); + if (!key.empty() || mpms::HasImageEncoded(*sequence_)) { + mpms::AddImageTimestamp(key, cc->InputTimestamp().Value(), + sequence_.get()); + } + mpms::AddImageLabelString(key, labels, sequence_.get()); + mpms::AddImageLabelConfidence(key, confidences, sequence_.get()); + if (!ids.empty()) mpms::AddImageLabelIndex(key, ids, sequence_.get()); + continue; + } if (tag != kImageTag) { int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1; if (tag[tag_length] == '_') { @@ -393,6 +496,7 @@ class PackMediaSequenceCalculator : public CalculatorBase { mpms::ClearBBoxNumRegions(prefix, sequence_.get()); mpms::ClearBBoxLabelString(prefix, sequence_.get()); mpms::ClearBBoxLabelIndex(prefix, sequence_.get()); + mpms::ClearBBoxLabelConfidence(prefix, sequence_.get()); mpms::ClearBBoxClassString(prefix, sequence_.get()); mpms::ClearBBoxClassIndex(prefix, sequence_.get()); mpms::ClearBBoxTrackString(prefix, sequence_.get()); @@ -405,6 +509,33 @@ class PackMediaSequenceCalculator : public CalculatorBase { } replace_keypoints_ = false; } + if (absl::StartsWith(tag, kClipLabelPrefixTag) && + !cc->Inputs().Tag(tag).IsEmpty()) { + const std::string& key = tag.substr( + sizeof(kClipLabelPrefixTag) / sizeof(*kClipLabelPrefixTag) - 1); + const Detection& detection = cc->Inputs().Tag(tag).Get(); + if (detection.label().size() != detection.score().size()) { + return absl::InvalidArgumentError( + "Different size of detection.label and detection.score"); + } + // Allow empty label_ids, but if label_ids is not empty, it should have + // the same size as the label and score fields. + if (!detection.label_id().empty()) { + if (detection.label_id().size() != detection.label().size()) { + return absl::InvalidArgumentError( + "Different size of detection.label_id and detection.label"); + } + } + for (int i = 0; i < detection.label().size(); ++i) { + if (!detection.label_id().empty()) { + mpms::AddClipLabelIndex(key, detection.label_id(i), + sequence_.get()); + } + mpms::AddClipLabelString(key, detection.label(i), sequence_.get()); + mpms::AddClipLabelConfidence(key, detection.score(i), + sequence_.get()); + } + } if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag) && !cc->Inputs().Tag(tag).IsEmpty()) { std::string key = @@ -412,9 +543,36 @@ class PackMediaSequenceCalculator : public CalculatorBase { sizeof(*kFloatContextFeaturePrefixTag) - 1); RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::PostStream()); - mpms::SetContextFeatureFloats( - key, cc->Inputs().Tag(tag).Get>(), - sequence_.get()); + for (const auto& value : + cc->Inputs().Tag(tag).Get>()) { + mpms::AddContextFeatureFloats(key, value, sequence_.get()); + } + } + if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag) && + !cc->Inputs().Tag(tag).IsEmpty()) { + const std::string& key = + tag.substr(sizeof(kIntsContextFeaturePrefixTag) / + sizeof(*kIntsContextFeaturePrefixTag) - + 1); + // To ensure only one packet is provided for this tag. + RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::PostStream()); + for (const auto& value : + cc->Inputs().Tag(tag).Get>()) { + mpms::AddContextFeatureInts(key, value, sequence_.get()); + } + } + if (absl::StartsWith(tag, kBytesContextFeaturePrefixTag) && + !cc->Inputs().Tag(tag).IsEmpty()) { + const std::string& key = + tag.substr(sizeof(kBytesContextFeaturePrefixTag) / + sizeof(*kBytesContextFeaturePrefixTag) - + 1); + // To ensure only one packet is provided for this tag. + RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::PostStream()); + for (const auto& value : + cc->Inputs().Tag(tag).Get>()) { + mpms::AddContextFeatureBytes(key, value, sequence_.get()); + } } if (absl::StartsWith(tag, kFloatFeaturePrefixTag) && !cc->Inputs().Tag(tag).IsEmpty()) { @@ -460,6 +618,7 @@ class PackMediaSequenceCalculator : public CalculatorBase { } std::vector predicted_locations; std::vector predicted_class_strings; + std::vector predicted_class_confidences; std::vector predicted_label_ids; for (auto& detection : cc->Inputs().Tag(tag).Get>()) { @@ -488,6 +647,9 @@ class PackMediaSequenceCalculator : public CalculatorBase { if (detection.label_id_size() > 0) { predicted_label_ids.push_back(detection.label_id(0)); } + if (detection.score_size() > 0) { + predicted_class_confidences.push_back(detection.score(0)); + } } } if (!predicted_locations.empty()) { @@ -501,6 +663,10 @@ class PackMediaSequenceCalculator : public CalculatorBase { if (!predicted_label_ids.empty()) { mpms::AddBBoxLabelIndex(key, predicted_label_ids, sequence_.get()); } + if (!predicted_class_confidences.empty()) { + mpms::AddBBoxLabelConfidence(key, predicted_class_confidences, + sequence_.get()); + } } } } @@ -548,10 +714,14 @@ class PackMediaSequenceCalculator : public CalculatorBase { } } } + if (clip_media_id_.has_value()) { + mpms::SetClipMediaId(*clip_media_id_, sequence_.get()); + } return absl::OkStatus(); } std::unique_ptr sequence_; + std::optional clip_media_id_ = std::nullopt; std::map features_present_; bool replace_keypoints_; }; diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc index 752db621e..d9dc56e9c 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc @@ -12,28 +12,32 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include +#include +#include +#include -#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" #include "absl/memory/memory.h" -#include "absl/strings/numbers.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" #include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/formats/detection.pb.h" -#include "mediapipe/framework/formats/image_frame.h" -#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/location.h" #include "mediapipe/framework/formats/location_opencv.h" -#include "mediapipe/framework/port/gmock.h" -#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/packet.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/timestamp.h" #include "mediapipe/util/sequence/media_sequence.h" +#include "mediapipe/util/sequence/media_sequence_util.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" +#include "testing/base/public/gmock.h" +#include "testing/base/public/gunit.h" namespace mediapipe { namespace { @@ -55,13 +59,22 @@ constexpr char kBytesFeatureTestTag[] = "BYTES_FEATURE_TEST"; constexpr char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED"; constexpr char kFloatContextFeatureOtherTag[] = "FLOAT_CONTEXT_FEATURE_OTHER"; constexpr char kFloatContextFeatureTestTag[] = "FLOAT_CONTEXT_FEATURE_TEST"; +constexpr char kIntsContextFeatureTestTag[] = "INTS_CONTEXT_FEATURE_TEST"; +constexpr char kIntsContextFeatureOtherTag[] = "INTS_CONTEXT_FEATURE_OTHER"; +constexpr char kBytesContextFeatureTestTag[] = "BYTES_CONTEXT_FEATURE_TEST"; +constexpr char kBytesContextFeatureOtherTag[] = "BYTES_CONTEXT_FEATURE_OTHER"; constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER"; constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST"; constexpr char kIntFeatureOtherTag[] = "INT_FEATURE_OTHER"; constexpr char kIntFeatureTestTag[] = "INT_FEATURE_TEST"; +constexpr char kImageLabelTestTag[] = "IMAGE_LABEL_TEST"; +constexpr char kImageLabelOtherTag[] = "IMAGE_LABEL_OTHER"; constexpr char kImagePrefixTag[] = "IMAGE_PREFIX"; constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; constexpr char kImageTag[] = "IMAGE"; +constexpr char kClipMediaIdTag[] = "CLIP_MEDIA_ID"; +constexpr char kClipLabelTestTag[] = "CLIP_LABEL_TEST"; +constexpr char kClipLabelOtherTag[] = "CLIP_LABEL_OTHER"; class PackMediaSequenceCalculatorTest : public ::testing::Test { protected: @@ -69,10 +82,14 @@ class PackMediaSequenceCalculatorTest : public ::testing::Test { const tf::Features& features, const bool output_only_if_all_present, const bool replace_instead_of_append, - const bool output_as_zero_timestamp = false) { + const bool output_as_zero_timestamp = false, + const std::vector& input_side_packets = { + "SEQUENCE_EXAMPLE:input_sequence"}) { CalculatorGraphConfig::Node config; config.set_calculator("PackMediaSequenceCalculator"); - config.add_input_side_packet("SEQUENCE_EXAMPLE:input_sequence"); + for (const std::string& side_packet : input_side_packets) { + config.add_input_side_packet(side_packet); + } config.add_output_stream("SEQUENCE_EXAMPLE:output_sequence"); for (const std::string& stream : input_streams) { config.add_input_stream(stream); @@ -96,7 +113,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImages) { mpms::SetClipMediaId(test_video_id, input_sequence.get()); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE( + cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1})); OpenCvImageEncoderCalculatorResults encoded_image; encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_width(2); @@ -139,7 +157,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoPrefixedImages) { mpms::SetClipMediaId(test_video_id, input_sequence.get()); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE( + cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1})); OpenCvImageEncoderCalculatorResults encoded_image; encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_width(2); @@ -312,6 +331,76 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBytesLists) { } } +TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImageLabels) { + SetUpCalculator( + {"IMAGE_LABEL_TEST:test_labels", "IMAGE_LABEL_OTHER:test_labels2"}, {}, + false, true); + auto input_sequence = ::absl::make_unique(); + + int num_timesteps = 2; + for (int i = 0; i < num_timesteps; ++i) { + Detection detection1; + detection1.add_label(absl::StrCat("foo", 2 << i)); + detection1.add_label_id(i); + detection1.add_score(0.1 * i); + detection1.add_label(absl::StrCat("foo", 2 << i)); + detection1.add_label_id(i); + detection1.add_score(0.1 * i); + auto label_ptr1 = ::absl::make_unique(detection1); + runner_->MutableInputs() + ->Tag(kImageLabelTestTag) + .packets.push_back(Adopt(label_ptr1.release()).At(Timestamp(i))); + Detection detection2; + detection2.add_label(absl::StrCat("bar", 2 << i)); + detection2.add_score(0.2 * i); + detection2.add_label(absl::StrCat("bar", 2 << i)); + detection2.add_score(0.2 * i); + auto label_ptr2 = ::absl::make_unique(detection2); + runner_->MutableInputs() + ->Tag(kImageLabelOtherTag) + .packets.push_back(Adopt(label_ptr2.release()).At(Timestamp(i))); + } + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_EQ(num_timesteps, + mpms::GetImageTimestampSize("TEST", output_sequence)); + ASSERT_EQ(num_timesteps, + mpms::GetImageLabelStringSize("TEST", output_sequence)); + ASSERT_EQ(num_timesteps, + mpms::GetImageLabelConfidenceSize("TEST", output_sequence)); + ASSERT_EQ(num_timesteps, + mpms::GetImageTimestampSize("OTHER", output_sequence)); + ASSERT_EQ(num_timesteps, + mpms::GetImageLabelStringSize("OTHER", output_sequence)); + ASSERT_EQ(num_timesteps, + mpms::GetImageLabelConfidenceSize("OTHER", output_sequence)); + for (int i = 0; i < num_timesteps; ++i) { + ASSERT_EQ(i, mpms::GetImageTimestampAt("TEST", output_sequence, i)); + ASSERT_THAT(mpms::GetImageLabelStringAt("TEST", output_sequence, i), + ::testing::ElementsAreArray( + std::vector(2, absl::StrCat("foo", 2 << i)))); + ASSERT_THAT(mpms::GetImageLabelIndexAt("TEST", output_sequence, i), + ::testing::ElementsAreArray(std::vector(2, i))); + ASSERT_THAT(mpms::GetImageLabelConfidenceAt("TEST", output_sequence, i), + ::testing::ElementsAreArray(std::vector(2, 0.1 * i))); + ASSERT_EQ(i, mpms::GetImageTimestampAt("OTHER", output_sequence, i)); + ASSERT_THAT(mpms::GetImageLabelStringAt("OTHER", output_sequence, i), + ::testing::ElementsAreArray( + std::vector(2, absl::StrCat("bar", 2 << i)))); + ASSERT_THAT(mpms::GetImageLabelConfidenceAt("OTHER", output_sequence, i), + ::testing::ElementsAreArray(std::vector(2, 0.2 * i))); + } +} + TEST_F(PackMediaSequenceCalculatorTest, OutputAsZeroTimestamp) { SetUpCalculator({"FLOAT_FEATURE_TEST:test"}, {}, false, true, true); auto input_sequence = ::absl::make_unique(); @@ -367,6 +456,315 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoContextFloatLists) { testing::ElementsAre(4, 4)); } +TEST_F(PackMediaSequenceCalculatorTest, ReplaceTwoContextFloatLists) { + SetUpCalculator( + /*input_streams=*/{"FLOAT_CONTEXT_FEATURE_TEST:test", + "FLOAT_CONTEXT_FEATURE_OTHER:test2"}, + /*features=*/{}, + /*output_only_if_all_present=*/false, /*replace_instead_of_append=*/true); + auto input_sequence = std::make_unique(); + mpms::SetContextFeatureFloats("TEST", {2, 3}, input_sequence.get()); + mpms::SetContextFeatureFloats("OTHER", {2, 4}, input_sequence.get()); + + const std::vector vf_1 = {5, 6}; + runner_->MutableInputs() + ->Tag(kFloatContextFeatureTestTag) + .packets.push_back( + MakePacket>(vf_1).At(Timestamp::PostStream())); + const std::vector vf_2 = {7, 8}; + runner_->MutableInputs() + ->Tag(kFloatContextFeatureOtherTag) + .packets.push_back( + MakePacket>(vf_2).At(Timestamp::PostStream())); + + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_THAT(mpms::GetContextFeatureFloats("TEST", output_sequence), + testing::ElementsAre(5, 6)); + ASSERT_THAT(mpms::GetContextFeatureFloats("OTHER", output_sequence), + testing::ElementsAre(7, 8)); +} + +TEST_F(PackMediaSequenceCalculatorTest, AppendTwoContextFloatLists) { + SetUpCalculator( + /*input_streams=*/{"FLOAT_CONTEXT_FEATURE_TEST:test", + "FLOAT_CONTEXT_FEATURE_OTHER:test2"}, + /*features=*/{}, + /*output_only_if_all_present=*/false, + /*replace_instead_of_append=*/false); + auto input_sequence = std::make_unique(); + mpms::SetContextFeatureFloats("TEST", {2, 3}, input_sequence.get()); + mpms::SetContextFeatureFloats("OTHER", {2, 4}, input_sequence.get()); + + const std::vector vf_1 = {5, 6}; + runner_->MutableInputs() + ->Tag(kFloatContextFeatureTestTag) + .packets.push_back( + MakePacket>(vf_1).At(Timestamp::PostStream())); + const std::vector vf_2 = {7, 8}; + runner_->MutableInputs() + ->Tag(kFloatContextFeatureOtherTag) + .packets.push_back( + MakePacket>(vf_2).At(Timestamp::PostStream())); + + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + EXPECT_THAT(mpms::GetContextFeatureFloats("TEST", output_sequence), + testing::ElementsAre(2, 3, 5, 6)); + EXPECT_THAT(mpms::GetContextFeatureFloats("OTHER", output_sequence), + testing::ElementsAre(2, 4, 7, 8)); +} + +TEST_F(PackMediaSequenceCalculatorTest, PackTwoContextIntLists) { + SetUpCalculator( + /*input_streams=*/{"INTS_CONTEXT_FEATURE_TEST:test", + "INTS_CONTEXT_FEATURE_OTHER:test2"}, + /*features=*/{}, + /*output_only_if_all_present=*/false, /*replace_instead_of_append=*/true); + auto input_sequence = absl::make_unique(); + + const std::vector vi_1 = {2, 3}; + runner_->MutableInputs() + ->Tag(kIntsContextFeatureTestTag) + .packets.push_back( + MakePacket>(vi_1).At(Timestamp::PostStream())); + const std::vector vi_2 = {2, 4}; + runner_->MutableInputs() + ->Tag(kIntsContextFeatureOtherTag) + .packets.push_back( + MakePacket>(vi_2).At(Timestamp::PostStream())); + + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_THAT(mpms::GetContextFeatureInts("TEST", output_sequence), + testing::ElementsAre(2, 3)); + ASSERT_THAT(mpms::GetContextFeatureInts("OTHER", output_sequence), + testing::ElementsAre(2, 4)); +} + +TEST_F(PackMediaSequenceCalculatorTest, ReplaceTwoContextIntLists) { + SetUpCalculator( + /*input_streams=*/{"INTS_CONTEXT_FEATURE_TEST:test", + "INTS_CONTEXT_FEATURE_OTHER:test2"}, + /*features=*/{}, + /*output_only_if_all_present=*/false, /*replace_instead_of_append=*/true); + auto input_sequence = absl::make_unique(); + mpms::SetContextFeatureInts("TEST", {2, 3}, input_sequence.get()); + mpms::SetContextFeatureInts("OTHER", {2, 4}, input_sequence.get()); + + const std::vector vi_1 = {5, 6}; + runner_->MutableInputs() + ->Tag(kIntsContextFeatureTestTag) + .packets.push_back( + MakePacket>(vi_1).At(Timestamp::PostStream())); + const std::vector vi_2 = {7, 8}; + runner_->MutableInputs() + ->Tag(kIntsContextFeatureOtherTag) + .packets.push_back( + MakePacket>(vi_2).At(Timestamp::PostStream())); + + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_THAT(mpms::GetContextFeatureInts("TEST", output_sequence), + testing::ElementsAre(5, 6)); + ASSERT_THAT(mpms::GetContextFeatureInts("OTHER", output_sequence), + testing::ElementsAre(7, 8)); +} + +TEST_F(PackMediaSequenceCalculatorTest, AppendTwoContextIntLists) { + SetUpCalculator( + /*input_streams=*/{"INTS_CONTEXT_FEATURE_TEST:test", + "INTS_CONTEXT_FEATURE_OTHER:test2"}, + /*features=*/{}, + /*output_only_if_all_present=*/false, + /*replace_instead_of_append=*/false); + auto input_sequence = absl::make_unique(); + mpms::SetContextFeatureInts("TEST", {2, 3}, input_sequence.get()); + mpms::SetContextFeatureInts("OTHER", {2, 4}, input_sequence.get()); + + const std::vector vi_1 = {5, 6}; + runner_->MutableInputs() + ->Tag(kIntsContextFeatureTestTag) + .packets.push_back( + MakePacket>(vi_1).At(Timestamp::PostStream())); + const std::vector vi_2 = {7, 8}; + runner_->MutableInputs() + ->Tag(kIntsContextFeatureOtherTag) + .packets.push_back( + MakePacket>(vi_2).At(Timestamp::PostStream())); + + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_THAT(mpms::GetContextFeatureInts("TEST", output_sequence), + testing::ElementsAre(2, 3, 5, 6)); + ASSERT_THAT(mpms::GetContextFeatureInts("OTHER", output_sequence), + testing::ElementsAre(2, 4, 7, 8)); +} + +TEST_F(PackMediaSequenceCalculatorTest, PackTwoContextByteLists) { + SetUpCalculator( + /*input_streams=*/{"BYTES_CONTEXT_FEATURE_TEST:test", + "BYTES_CONTEXT_FEATURE_OTHER:test2"}, + /*features=*/{}, + /*output_only_if_all_present=*/false, /*replace_instead_of_append=*/true); + auto input_sequence = absl::make_unique(); + + const std::vector vb_1 = {"value_1", "value_2"}; + runner_->MutableInputs() + ->Tag(kBytesContextFeatureTestTag) + .packets.push_back(MakePacket>(vb_1).At( + Timestamp::PostStream())); + const std::vector vb_2 = {"value_3", "value_4"}; + runner_->MutableInputs() + ->Tag(kBytesContextFeatureOtherTag) + .packets.push_back(MakePacket>(vb_2).At( + Timestamp::PostStream())); + + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_THAT(mpms::GetContextFeatureBytes("TEST", output_sequence), + testing::ElementsAre("value_1", "value_2")); + ASSERT_THAT(mpms::GetContextFeatureBytes("OTHER", output_sequence), + testing::ElementsAre("value_3", "value_4")); +} + +TEST_F(PackMediaSequenceCalculatorTest, ReplaceTwoContextByteLists) { + SetUpCalculator( + /*input_streams=*/{"BYTES_CONTEXT_FEATURE_TEST:test", + "BYTES_CONTEXT_FEATURE_OTHER:test2"}, + /*features=*/{}, + /*output_only_if_all_present=*/false, /*replace_instead_of_append=*/true); + auto input_sequence = absl::make_unique(); + mpms::SetContextFeatureBytes("TEST", {"existing_value_1", "existing_value_2"}, + input_sequence.get()); + mpms::SetContextFeatureBytes( + "OTHER", {"existing_value_3", "existing_value_4"}, input_sequence.get()); + + const std::vector vb_1 = {"value_1", "value_2"}; + runner_->MutableInputs() + ->Tag(kBytesContextFeatureTestTag) + .packets.push_back(MakePacket>(vb_1).At( + Timestamp::PostStream())); + const std::vector vb_2 = {"value_3", "value_4"}; + runner_->MutableInputs() + ->Tag(kBytesContextFeatureOtherTag) + .packets.push_back(MakePacket>(vb_2).At( + Timestamp::PostStream())); + + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_THAT(mpms::GetContextFeatureBytes("TEST", output_sequence), + testing::ElementsAre("value_1", "value_2")); + ASSERT_THAT(mpms::GetContextFeatureBytes("OTHER", output_sequence), + testing::ElementsAre("value_3", "value_4")); +} + +TEST_F(PackMediaSequenceCalculatorTest, AppendTwoContextByteLists) { + SetUpCalculator( + /*input_streams=*/{"BYTES_CONTEXT_FEATURE_TEST:test", + "BYTES_CONTEXT_FEATURE_OTHER:test2"}, + /*features=*/{}, + /*output_only_if_all_present=*/false, + /*replace_instead_of_append=*/false); + auto input_sequence = absl::make_unique(); + mpms::SetContextFeatureBytes("TEST", {"existing_value_1", "existing_value_2"}, + input_sequence.get()); + mpms::SetContextFeatureBytes( + "OTHER", {"existing_value_3", "existing_value_4"}, input_sequence.get()); + + const std::vector vb_1 = {"value_1", "value_2"}; + runner_->MutableInputs() + ->Tag(kBytesContextFeatureTestTag) + .packets.push_back(MakePacket>(vb_1).At( + Timestamp::PostStream())); + const std::vector vb_2 = {"value_3", "value_4"}; + runner_->MutableInputs() + ->Tag(kBytesContextFeatureOtherTag) + .packets.push_back(MakePacket>(vb_2).At( + Timestamp::PostStream())); + + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_THAT(mpms::GetContextFeatureBytes("TEST", output_sequence), + testing::ElementsAre("existing_value_1", "existing_value_2", + "value_1", "value_2")); + ASSERT_THAT(mpms::GetContextFeatureBytes("OTHER", output_sequence), + testing::ElementsAre("existing_value_3", "existing_value_4", + "value_3", "value_4")); +} + TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) { tf::Features context; (*context.mutable_feature())["TEST"].mutable_bytes_list()->add_value("YES"); @@ -378,7 +776,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) { Adopt(input_sequence.release()); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE( + cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1})); OpenCvImageEncoderCalculatorResults encoded_image; encoded_image.set_encoded_image(bytes.data(), bytes.size()); auto image_ptr = @@ -410,7 +809,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoForwardFlowEncodeds) { cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE( + cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1})); std::string test_flow_string(bytes.begin(), bytes.end()); OpenCvImageEncoderCalculatorResults encoded_flow; encoded_flow.set_encoded_image(test_flow_string); @@ -526,6 +926,10 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBBoxDetections) { auto class_indices = mpms::GetPredictedBBoxLabelIndexAt(output_sequence, i); ASSERT_EQ(0, class_indices[0]); ASSERT_EQ(1, class_indices[1]); + auto class_scores = + mpms::GetPredictedBBoxLabelConfidenceAt(output_sequence, i); + ASSERT_FLOAT_EQ(0.5, class_scores[0]); + ASSERT_FLOAT_EQ(0.75, class_scores[1]); } } @@ -618,7 +1022,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) { } cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE( + cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1})); OpenCvImageEncoderCalculatorResults encoded_image; encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_width(width); @@ -667,6 +1072,10 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) { auto class_indices = mpms::GetPredictedBBoxLabelIndexAt(output_sequence, i); ASSERT_EQ(0, class_indices[0]); ASSERT_EQ(1, class_indices[1]); + auto class_scores = + mpms::GetPredictedBBoxLabelConfidenceAt(output_sequence, i); + ASSERT_FLOAT_EQ(0.5, class_scores[0]); + ASSERT_FLOAT_EQ(0.75, class_scores[1]); } } @@ -757,6 +1166,365 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) { testing::ElementsAreArray(::std::vector({"mask"}))); } +TEST_F(PackMediaSequenceCalculatorTest, PackTwoClipLabels) { + SetUpCalculator( + /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"}, + /*features=*/{}, /*output_only_if_all_present=*/false, + /*replace_instead_of_append=*/true); + auto input_sequence = ::absl::make_unique(); + + Detection detection_1; + detection_1.add_label("label_1"); + detection_1.add_label("label_2"); + detection_1.add_label_id(1); + detection_1.add_label_id(2); + detection_1.add_score(0.1); + detection_1.add_score(0.2); + runner_->MutableInputs() + ->Tag(kClipLabelTestTag) + .packets.push_back(MakePacket(detection_1).At(Timestamp(1))); + // No label ID for detection_2. + Detection detection_2; + detection_2.add_label("label_3"); + detection_2.add_label("label_4"); + detection_2.add_score(0.3); + detection_2.add_score(0.4); + runner_->MutableInputs() + ->Tag(kClipLabelOtherTag) + .packets.push_back(MakePacket(detection_2).At(Timestamp(2))); + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_THAT(mpms::GetClipLabelString("TEST", output_sequence), + testing::ElementsAre("label_1", "label_2")); + ASSERT_THAT(mpms::GetClipLabelIndex("TEST", output_sequence), + testing::ElementsAre(1, 2)); + ASSERT_THAT(mpms::GetClipLabelConfidence("TEST", output_sequence), + testing::ElementsAre(0.1, 0.2)); + ASSERT_THAT(mpms::GetClipLabelString("OTHER", output_sequence), + testing::ElementsAre("label_3", "label_4")); + ASSERT_FALSE(mpms::HasClipLabelIndex("OTHER", output_sequence)); + ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence), + testing::ElementsAre(0.3, 0.4)); +} + +TEST_F(PackMediaSequenceCalculatorTest, + PackTwoClipLabels_DifferentLabelScoreSize) { + SetUpCalculator( + /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"}, + /*features=*/{}, /*output_only_if_all_present=*/false, + /*replace_instead_of_append=*/true); + auto input_sequence = ::absl::make_unique(); + + // 2 labels and 1 score in detection_1. + Detection detection_1; + detection_1.add_label("label_1"); + detection_1.add_label("label_2"); + detection_1.add_score(0.1); + runner_->MutableInputs() + ->Tag(kClipLabelTestTag) + .packets.push_back(MakePacket(detection_1).At(Timestamp(1))); + Detection detection_2; + detection_2.add_label("label_3"); + detection_2.add_label("label_4"); + detection_2.add_score(0.3); + detection_2.add_score(0.4); + runner_->MutableInputs() + ->Tag(kClipLabelOtherTag) + .packets.push_back(MakePacket(detection_2).At(Timestamp(2))); + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + ASSERT_THAT( + runner_->Run(), + testing::status::StatusIs( + absl::StatusCode::kInvalidArgument, + testing::HasSubstr( + "Different size of detection.label and detection.score"))); +} + +TEST_F(PackMediaSequenceCalculatorTest, + PackTwoClipLabels_DifferentLabelIdSize) { + SetUpCalculator( + /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"}, + /*features=*/{}, /*output_only_if_all_present=*/false, + /*replace_instead_of_append=*/true); + auto input_sequence = ::absl::make_unique(); + + // 2 labels and 1 label_id in detection_1. + Detection detection_1; + detection_1.add_label("label_1"); + detection_1.add_label("label_2"); + detection_1.add_label_id(1); + detection_1.add_score(0.1); + detection_1.add_score(0.2); + runner_->MutableInputs() + ->Tag(kClipLabelTestTag) + .packets.push_back(MakePacket(detection_1).At(Timestamp(1))); + Detection detection_2; + detection_2.add_label("label_3"); + detection_2.add_label("label_4"); + detection_2.add_score(0.3); + detection_2.add_score(0.4); + runner_->MutableInputs() + ->Tag(kClipLabelOtherTag) + .packets.push_back(MakePacket(detection_2).At(Timestamp(2))); + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + ASSERT_THAT( + runner_->Run(), + testing::status::StatusIs( + absl::StatusCode::kInvalidArgument, + testing::HasSubstr( + "Different size of detection.label_id and detection.label"))); +} + +TEST_F(PackMediaSequenceCalculatorTest, ReplaceTwoClipLabels) { + // Replace existing clip/label/string and clip/label/confidence values for + // the prefixes. + SetUpCalculator( + /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"}, + /*features=*/{}, /*output_only_if_all_present=*/false, + /*replace_instead_of_append=*/true); + auto input_sequence = ::absl::make_unique(); + mpms::SetClipLabelString("TEST", {"old_label_1", "old_label_2"}, + input_sequence.get()); + mpms::SetClipLabelConfidence("TEST", {0.1, 0.2}, input_sequence.get()); + mpms::SetClipLabelString("OTHER", {"old_label_3", "old_label_4"}, + input_sequence.get()); + mpms::SetClipLabelConfidence("OTHER", {0.3, 0.4}, input_sequence.get()); + + Detection detection_1; + detection_1.add_label("label_1"); + detection_1.add_label("label_2"); + detection_1.add_label_id(1); + detection_1.add_label_id(2); + detection_1.add_score(0.9); + detection_1.add_score(0.8); + runner_->MutableInputs() + ->Tag(kClipLabelTestTag) + .packets.push_back(MakePacket(detection_1).At(Timestamp(1))); + Detection detection_2; + detection_2.add_label("label_3"); + detection_2.add_label("label_4"); + detection_2.add_label_id(3); + detection_2.add_label_id(4); + detection_2.add_score(0.7); + detection_2.add_score(0.6); + runner_->MutableInputs() + ->Tag(kClipLabelOtherTag) + .packets.push_back(MakePacket(detection_2).At(Timestamp(2))); + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_THAT(mpms::GetClipLabelString("TEST", output_sequence), + testing::ElementsAre("label_1", "label_2")); + ASSERT_THAT(mpms::GetClipLabelIndex("TEST", output_sequence), + testing::ElementsAre(1, 2)); + ASSERT_THAT(mpms::GetClipLabelConfidence("TEST", output_sequence), + testing::ElementsAre(0.9, 0.8)); + ASSERT_THAT(mpms::GetClipLabelString("OTHER", output_sequence), + testing::ElementsAre("label_3", "label_4")); + ASSERT_THAT(mpms::GetClipLabelIndex("OTHER", output_sequence), + testing::ElementsAre(3, 4)); + ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence), + testing::ElementsAre(0.7, 0.6)); +} + +TEST_F(PackMediaSequenceCalculatorTest, AppendTwoClipLabels) { + // Append to the existing clip/label/string and clip/label/confidence values + // for the prefixes. + SetUpCalculator( + /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"}, + /*features=*/{}, /*output_only_if_all_present=*/false, + /*replace_instead_of_append=*/false); + auto input_sequence = ::absl::make_unique(); + mpms::SetClipLabelString("TEST", {"old_label_1", "old_label_2"}, + input_sequence.get()); + mpms::SetClipLabelIndex("TEST", {1, 2}, input_sequence.get()); + mpms::SetClipLabelConfidence("TEST", {0.1, 0.2}, input_sequence.get()); + mpms::SetClipLabelString("OTHER", {"old_label_3", "old_label_4"}, + input_sequence.get()); + mpms::SetClipLabelIndex("OTHER", {3, 4}, input_sequence.get()); + mpms::SetClipLabelConfidence("OTHER", {0.3, 0.4}, input_sequence.get()); + + Detection detection_1; + detection_1.add_label("label_1"); + detection_1.add_label("label_2"); + detection_1.add_label_id(9); + detection_1.add_label_id(8); + detection_1.add_score(0.9); + detection_1.add_score(0.8); + runner_->MutableInputs() + ->Tag(kClipLabelTestTag) + .packets.push_back(MakePacket(detection_1).At(Timestamp(1))); + Detection detection_2; + detection_2.add_label("label_3"); + detection_2.add_label("label_4"); + detection_2.add_label_id(7); + detection_2.add_label_id(6); + detection_2.add_score(0.7); + detection_2.add_score(0.6); + runner_->MutableInputs() + ->Tag(kClipLabelOtherTag) + .packets.push_back(MakePacket(detection_2).At(Timestamp(2))); + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_THAT( + mpms::GetClipLabelString("TEST", output_sequence), + testing::ElementsAre("old_label_1", "old_label_2", "label_1", "label_2")); + ASSERT_THAT(mpms::GetClipLabelIndex("TEST", output_sequence), + testing::ElementsAre(1, 2, 9, 8)); + ASSERT_THAT(mpms::GetClipLabelConfidence("TEST", output_sequence), + testing::ElementsAre(0.1, 0.2, 0.9, 0.8)); + ASSERT_THAT( + mpms::GetClipLabelString("OTHER", output_sequence), + testing::ElementsAre("old_label_3", "old_label_4", "label_3", "label_4")); + ASSERT_THAT(mpms::GetClipLabelIndex("OTHER", output_sequence), + testing::ElementsAre(3, 4, 7, 6)); + ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence), + testing::ElementsAre(0.3, 0.4, 0.7, 0.6)); +} + +TEST_F(PackMediaSequenceCalculatorTest, + DifferentClipLabelScoreAndConfidenceSize) { + SetUpCalculator( + /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"}, + /*features=*/{}, /*output_only_if_all_present=*/false, + /*replace_instead_of_append=*/true); + auto input_sequence = ::absl::make_unique(); + + Detection detection_1; + // 2 labels and 1 score. + detection_1.add_label("label_1"); + detection_1.add_label("label_2"); + detection_1.add_score(0.1); + runner_->MutableInputs() + ->Tag(kClipLabelTestTag) + .packets.push_back(MakePacket(detection_1).At(Timestamp(1))); + Detection detection_2; + detection_2.add_label("label_3"); + detection_2.add_label("label_4"); + detection_2.add_score(0.3); + detection_2.add_score(0.4); + runner_->MutableInputs() + ->Tag(kClipLabelOtherTag) + .packets.push_back(MakePacket(detection_2).At(Timestamp(2))); + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + ASSERT_THAT(runner_->Run(), + testing::status::StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(PackMediaSequenceCalculatorTest, AddClipMediaId) { + SetUpCalculator( + /*input_streams=*/{"FLOAT_FEATURE_TEST:test", + "FLOAT_FEATURE_OTHER:test2"}, + /*features=*/{}, + /*output_only_if_all_present=*/false, + /*replace_instead_of_append=*/true, + /*output_as_zero_timestamp=*/false, /*input_side_packets=*/ + {"SEQUENCE_EXAMPLE:input_sequence", "CLIP_MEDIA_ID:video_id"}); + auto input_sequence = absl::make_unique(); + const std::string test_video_id = "test_video_id"; + + int num_timesteps = 2; + for (int i = 0; i < num_timesteps; ++i) { + auto vf_ptr = ::absl::make_unique>(2, 2 << i); + runner_->MutableInputs() + ->Tag(kFloatFeatureTestTag) + .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i))); + vf_ptr = ::absl::make_unique>(2, 2 << i); + runner_->MutableInputs() + ->Tag(kFloatFeatureOtherTag) + .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i))); + } + + runner_->MutableSidePackets()->Tag(kClipMediaIdTag) = + MakePacket(test_video_id); + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence)); +} + +TEST_F(PackMediaSequenceCalculatorTest, ReplaceClipMediaId) { + SetUpCalculator( + /*input_streams=*/{"FLOAT_FEATURE_TEST:test", + "FLOAT_FEATURE_OTHER:test2"}, + /*features=*/{}, + /*output_only_if_all_present=*/false, + /*replace_instead_of_append=*/true, + /*output_as_zero_timestamp=*/false, /*input_side_packets=*/ + {"SEQUENCE_EXAMPLE:input_sequence", "CLIP_MEDIA_ID:video_id"}); + auto input_sequence = absl::make_unique(); + const std::string existing_video_id = "existing_video_id"; + mpms::SetClipMediaId(existing_video_id, input_sequence.get()); + const std::string test_video_id = "test_video_id"; + + int num_timesteps = 2; + for (int i = 0; i < num_timesteps; ++i) { + auto vf_ptr = ::absl::make_unique>(2, 2 << i); + runner_->MutableInputs() + ->Tag(kFloatFeatureTestTag) + .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i))); + vf_ptr = ::absl::make_unique>(2, 2 << i); + runner_->MutableInputs() + ->Tag(kFloatFeatureOtherTag) + .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i))); + } + + runner_->MutableSidePackets()->Tag(kClipMediaIdTag) = + MakePacket(test_video_id).At(Timestamp(0)); + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence)); +} + TEST_F(PackMediaSequenceCalculatorTest, MissingStreamOK) { SetUpCalculator( {"FORWARD_FLOW_ENCODED:flow", "FLOAT_FEATURE_I3D_FLOW:feature"}, {}, @@ -767,7 +1535,8 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamOK) { cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE( + cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1})); std::string test_flow_string(bytes.begin(), bytes.end()); OpenCvImageEncoderCalculatorResults encoded_flow; encoded_flow.set_encoded_image(test_flow_string); @@ -813,7 +1582,8 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamNotOK) { mpms::SetClipMediaId(test_video_id, input_sequence.get()); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE( + cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1})); std::string test_flow_string(bytes.begin(), bytes.end()); OpenCvImageEncoderCalculatorResults encoded_flow; encoded_flow.set_encoded_image(test_flow_string); @@ -970,7 +1740,8 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) { auto input_sequence = ::absl::make_unique(); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE( + cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1})); OpenCvImageEncoderCalculatorResults encoded_image; encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_width(2); @@ -1021,7 +1792,8 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) { auto input_sequence = ::absl::make_unique(); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE( + cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1})); OpenCvImageEncoderCalculatorResults encoded_image; encoded_image.set_encoded_image(bytes.data(), bytes.size()); int height = 2; @@ -1057,6 +1829,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) { mpms::AddBBoxNumRegions(-1, input_sequence.get()); mpms::AddBBoxLabelString({"anything"}, input_sequence.get()); mpms::AddBBoxLabelIndex({-1}, input_sequence.get()); + mpms::AddBBoxLabelConfidence({-1}, input_sequence.get()); mpms::AddBBoxClassString({"anything"}, input_sequence.get()); mpms::AddBBoxClassIndex({-1}, input_sequence.get()); mpms::AddBBoxTrackString({"anything"}, input_sequence.get()); diff --git a/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.cc b/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.cc index ad87297a9..8b938a868 100644 --- a/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/log/absl_log.h" #include "mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" @@ -99,10 +100,11 @@ class TensorSqueezeDimensionsCalculator : public CalculatorBase { } } if (remove_dims_.empty()) { - LOG(ERROR) << "TensorSqueezeDimensionsCalculator is squeezing input with " - "no single-dimensions. Calculator will be a no-op."; - LOG(ERROR) << "Input to TensorSqueezeDimensionsCalculator has shape " - << tensor_shape.DebugString(); + ABSL_LOG(ERROR) + << "TensorSqueezeDimensionsCalculator is squeezing input with " + "no single-dimensions. Calculator will be a no-op."; + ABSL_LOG(ERROR) << "Input to TensorSqueezeDimensionsCalculator has shape " + << tensor_shape.DebugString(); } } }; diff --git a/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc index 34e397b32..3b4d53813 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc @@ -14,6 +14,7 @@ #include +#include "absl/log/absl_check.h" #include "mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" @@ -65,6 +66,7 @@ class TensorToImageFrameCalculator : public CalculatorBase { private: float scale_factor_; + bool scale_per_frame_min_max_; }; REGISTER_CALCULATOR(TensorToImageFrameCalculator); @@ -88,6 +90,8 @@ absl::Status TensorToImageFrameCalculator::GetContract(CalculatorContract* cc) { absl::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) { scale_factor_ = cc->Options().scale_factor(); + scale_per_frame_min_max_ = cc->Options() + .scale_per_frame_min_max(); cc->SetOffset(TimestampDiff(0)); return absl::OkStatus(); } @@ -96,7 +100,7 @@ absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) { const tf::Tensor& input_tensor = cc->Inputs().Tag(kTensor).Get(); int32_t depth = 1; if (input_tensor.dims() != 2) { // Depth is 1 for 2D tensors. - CHECK(3 == input_tensor.dims()) + ABSL_CHECK(3 == input_tensor.dims()) << "Only 2 or 3-D Tensors can be converted to frames. Instead got: " << input_tensor.dims(); depth = input_tensor.dim_size(2); @@ -109,16 +113,38 @@ absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) { auto format = (depth == 3 ? ImageFormat::SRGB : ImageFormat::GRAY8); const int32_t total_size = height * width * depth; + if (scale_per_frame_min_max_) { + RET_CHECK_EQ(input_tensor.dtype(), tensorflow::DT_FLOAT) + << "Setting scale_per_frame_min_max requires FLOAT input tensors."; + } ::std::unique_ptr output; if (input_tensor.dtype() == tensorflow::DT_FLOAT) { // Allocate buffer with alignments. std::unique_ptr buffer( new (std::align_val_t(EIGEN_MAX_ALIGN_BYTES)) uint8_t[total_size]); auto data = input_tensor.flat().data(); + float min = 1e23; + float max = -1e23; + if (scale_per_frame_min_max_) { + for (int i = 0; i < total_size; ++i) { + float d = scale_factor_ * data[i]; + if (d < min) { + min = d; + } + if (d > max) { + max = d; + } + } + } for (int i = 0; i < total_size; ++i) { - float d = scale_factor_ * data[i]; - if (d < 0) d = 0; - if (d > 255) d = 255; + float d = data[i]; + if (scale_per_frame_min_max_) { + d = 255 * (d - min) / (max - min + 1e-9); + } else { + d = scale_factor_ * d; + if (d < 0) d = 0; + if (d > 255) d = 255; + } buffer[i] = d; } output = ::absl::make_unique( diff --git a/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.proto b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.proto index 3410068d0..c60448c16 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.proto +++ b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.proto @@ -26,4 +26,8 @@ message TensorToImageFrameCalculatorOptions { // Multiples floating point tensor outputs by this value before converting to // uint8. This is useful for converting from range [0, 1] to [0, 255] optional float scale_factor = 1 [default = 1.0]; + + // If true, scales any FLOAT tensor input of [min, max] to be between [0, 255] + // per frame. This overrides any explicit scale_factor. + optional bool scale_per_frame_min_max = 2 [default = false]; } diff --git a/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator_test.cc b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator_test.cc index aee9fee9b..13255ac4e 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator_test.cc @@ -11,7 +11,9 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include +#include "mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/formats/image_frame.h" @@ -32,11 +34,14 @@ constexpr char kImage[] = "IMAGE"; template class TensorToImageFrameCalculatorTest : public ::testing::Test { protected: - void SetUpRunner() { + void SetUpRunner(bool scale_per_frame_min_max = false) { CalculatorGraphConfig::Node config; config.set_calculator("TensorToImageFrameCalculator"); config.add_input_stream("TENSOR:input_tensor"); config.add_output_stream("IMAGE:output_image"); + config.mutable_options() + ->MutableExtension(mediapipe::TensorToImageFrameCalculatorOptions::ext) + ->set_scale_per_frame_min_max(scale_per_frame_min_max); runner_ = absl::make_unique(config); } @@ -157,4 +162,47 @@ TYPED_TEST(TensorToImageFrameCalculatorTest, } } +TYPED_TEST(TensorToImageFrameCalculatorTest, + Converts3DTensorToImageFrame2DGrayWithScaling) { + this->SetUpRunner(true); + auto& runner = this->runner_; + constexpr int kWidth = 16; + constexpr int kHeight = 8; + const tf::TensorShape tensor_shape{kHeight, kWidth}; + auto tensor = absl::make_unique( + tf::DataTypeToEnum::v(), tensor_shape); + auto tensor_vec = tensor->template flat().data(); + + // Writing sequence of integers as floats which we want normalized. + tensor_vec[0] = 255; + for (int i = 1; i < kWidth * kHeight; ++i) { + tensor_vec[i] = 200; + } + + const int64_t time = 1234; + runner->MutableInputs()->Tag(kTensor).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + + if (!std::is_same::value) { + EXPECT_FALSE(runner->Run().ok()); + return; // Short circuit because does not apply to other types. + } else { + EXPECT_TRUE(runner->Run().ok()); + const std::vector& output_packets = + runner->Outputs().Tag(kImage).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const ImageFrame& output_image = output_packets[0].Get(); + EXPECT_EQ(ImageFormat::GRAY8, output_image.Format()); + EXPECT_EQ(kWidth, output_image.Width()); + EXPECT_EQ(kHeight, output_image.Height()); + + EXPECT_EQ(255, output_image.PixelData()[0]); + for (int i = 1; i < kWidth * kHeight; ++i) { + const uint8_t pixel_value = output_image.PixelData()[i]; + ASSERT_EQ(0, pixel_value); + } + } +} + } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc index 081e0c83a..dc3d97844 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc @@ -15,6 +15,7 @@ // Calculator converts from one-dimensional Tensor of DT_FLOAT to Matrix // OR from (batched) two-dimensional Tensor of DT_FLOAT to Matrix. +#include "absl/log/absl_check.h" #include "mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/matrix.h" @@ -36,7 +37,7 @@ constexpr char kReference[] = "REFERENCE"; absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, TimeSeriesHeader* header) { - CHECK(header); + ABSL_CHECK(header); if (header_packet.IsEmpty()) { return absl::UnknownError("No header found."); } @@ -191,7 +192,7 @@ absl::Status TensorToMatrixCalculator::Process(CalculatorContext* cc) { << "Tensor stream packet does not contain a Tensor."; const tf::Tensor& input_tensor = cc->Inputs().Tag(kTensor).Get(); - CHECK(1 == input_tensor.dims() || 2 == input_tensor.dims()) + ABSL_CHECK(1 == input_tensor.dims() || 2 == input_tensor.dims()) << "Only 1-D or 2-D Tensors can be converted to matrices."; const int32_t length = input_tensor.dim_size(input_tensor.dims() - 1); const int32_t width = diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc index 4a47b7d7f..84c32fed6 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc @@ -20,6 +20,7 @@ #include #include "absl/base/thread_annotations.h" +#include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/strings/str_split.h" #include "absl/synchronization/mutex.h" @@ -61,12 +62,12 @@ constexpr char kSessionBundleTag[] = "SESSION_BUNDLE"; // overload GPU/TPU/... class SimpleSemaphore { public: - explicit SimpleSemaphore(uint32 initial_count) : count_(initial_count) {} + explicit SimpleSemaphore(uint32_t initial_count) : count_(initial_count) {} SimpleSemaphore(const SimpleSemaphore&) = delete; SimpleSemaphore(SimpleSemaphore&&) = delete; // Acquires the semaphore by certain amount. - void Acquire(uint32 amount) { + void Acquire(uint32_t amount) { mutex_.Lock(); while (count_ < amount) { cond_.Wait(&mutex_); @@ -76,7 +77,7 @@ class SimpleSemaphore { } // Releases the semaphore by certain amount. - void Release(uint32 amount) { + void Release(uint32_t amount) { mutex_.Lock(); count_ += amount; cond_.SignalAll(); @@ -84,7 +85,7 @@ class SimpleSemaphore { } private: - uint32 count_; + uint32_t count_; absl::Mutex mutex_; absl::CondVar cond_; }; @@ -488,7 +489,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase { // necessary. absl::Status OutputBatch(CalculatorContext* cc, std::unique_ptr inference_state) { - const int64 start_time = absl::ToUnixMicros(clock_->TimeNow()); + const int64_t start_time = absl::ToUnixMicros(clock_->TimeNow()); std::vector> input_tensors; for (auto& keyed_tensors : inference_state->input_tensor_batches_) { @@ -515,7 +516,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase { tf::Tensor concated; const tf::Status concat_status = tf::tensor::Concat(keyed_tensors.second, &concated); - CHECK(concat_status.ok()) << concat_status.ToString(); + ABSL_CHECK(concat_status.ok()) << concat_status.ToString(); input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first], concated); } @@ -544,7 +545,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase { get_session_run_throttle(options_.max_concurrent_session_runs()); session_run_throttle->Acquire(1); } - const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow()); + const int64_t run_start_time = absl::ToUnixMicros(clock_->TimeNow()); tf::Status tf_status; { #if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__) @@ -562,7 +563,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase { // informative error message. RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString(); - const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow()); + const int64_t run_end_time = absl::ToUnixMicros(clock_->TimeNow()); cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix) ->IncrementBy(run_end_time - run_start_time); cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment(); @@ -597,7 +598,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase { std::vector split_tensors; const tf::Status split_status = tf::tensor::Split(outputs[i], split_vector, &split_tensors); - CHECK(split_status.ok()) << split_status.ToString(); + ABSL_CHECK(split_status.ok()) << split_status.ToString(); // Loop over timestamps so that we don't copy the padding. for (int j = 0; j < inference_state->batch_timestamps_.size(); ++j) { tf::Tensor output_tensor(split_tensors[j]); @@ -611,7 +612,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase { } // Get end time and report. - const int64 end_time = absl::ToUnixMicros(clock_->TimeNow()); + const int64_t end_time = absl::ToUnixMicros(clock_->TimeNow()); cc->GetCounter(kTotalUsecsCounterSuffix) ->IncrementBy(end_time - start_time); cc->GetCounter(kTotalProcessedTimestampsCounterSuffix) @@ -650,7 +651,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase { // The static singleton semaphore to throttle concurrent session runs. static SimpleSemaphore* get_session_run_throttle( - int32 max_concurrent_session_runs) { + int32_t max_concurrent_session_runs) { static SimpleSemaphore* session_run_throttle = new SimpleSemaphore(max_concurrent_session_runs); return session_run_throttle; diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc index c93008373..708f1711e 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc @@ -17,6 +17,8 @@ #include #include "absl/flags/flag.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "mediapipe/calculators/tensorflow/tensorflow_inference_calculator.pb.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -118,7 +120,7 @@ class TensorflowInferenceCalculatorTest : public ::testing::Test { // Create tensor from Vector and add as a Packet to the provided tag as input. void AddVectorToInputsAsPacket(const std::vector& packets, const std::string& tag) { - CHECK(!packets.empty()) + ABSL_CHECK(!packets.empty()) << "Please specify at least some data in the packet"; auto packets_ptr = absl::make_unique>(packets); runner_->MutableInputs()->Tag(tag).packets.push_back( @@ -586,12 +588,12 @@ TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStates) { runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(2, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); - LOG(INFO) << "timestamp: " << 0; + ABSL_LOG(INFO) << "timestamp: " << 0; auto expected_tensor = tf::test::AsTensor({3, 8, 15}); tf::test::ExpectTensorEqual(tensor_mult, expected_tensor); const tf::Tensor& tensor_mult1 = output_packets_mult[1].Get(); auto expected_tensor1 = tf::test::AsTensor({9, 32, 75}); - LOG(INFO) << "timestamp: " << 1; + ABSL_LOG(INFO) << "timestamp: " << 1; tf::test::ExpectTensorEqual(tensor_mult1, expected_tensor1); EXPECT_EQ(2, runner_ @@ -627,12 +629,12 @@ TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStateOverride) { runner_->Outputs().Tag(kMultipliedTag).packets; ASSERT_EQ(2, output_packets_mult.size()); const tf::Tensor& tensor_mult = output_packets_mult[0].Get(); - LOG(INFO) << "timestamp: " << 0; + ABSL_LOG(INFO) << "timestamp: " << 0; auto expected_tensor = tf::test::AsTensor({3, 4, 5}); tf::test::ExpectTensorEqual(tensor_mult, expected_tensor); const tf::Tensor& tensor_mult1 = output_packets_mult[1].Get(); auto expected_tensor1 = tf::test::AsTensor({3, 4, 5}); - LOG(INFO) << "timestamp: " << 1; + ABSL_LOG(INFO) << "timestamp: " << 1; tf::test::ExpectTensorEqual(tensor_mult1, expected_tensor1); EXPECT_EQ(2, runner_ diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc index 1bb2c41fc..358b50cd3 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc @@ -23,12 +23,12 @@ #include +#include "absl/log/absl_log.h" #include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/deps/clock.h" #include "mediapipe/framework/deps/monotonic_clock.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/tool/status_util.h" @@ -156,8 +156,8 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase { cc->OutputSidePackets().Tag(kSessionTag).Set(Adopt(session.release())); const uint64_t end_time = absl::ToUnixMicros(clock->TimeNow()); - LOG(INFO) << "Loaded frozen model in: " << end_time - start_time - << " microseconds."; + ABSL_LOG(INFO) << "Loaded frozen model in: " << end_time - start_time + << " microseconds."; return absl::OkStatus(); } diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc index dc39458da..e340a098b 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc @@ -24,13 +24,13 @@ #include +#include "absl/log/absl_log.h" #include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/deps/clock.h" #include "mediapipe/framework/deps/monotonic_clock.h" #include "mediapipe/framework/port/file_helpers.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/tool/status_util.h" @@ -155,8 +155,8 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator { output_side_packets->Tag(kSessionTag) = Adopt(session.release()); const uint64_t end_time = absl::ToUnixMicros(clock->TimeNow()); - LOG(INFO) << "Loaded frozen model in: " << end_time - start_time - << " microseconds."; + ABSL_LOG(INFO) << "Loaded frozen model in: " << end_time - start_time + << " microseconds."; return absl::OkStatus(); } }; diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc index 18bddbbe3..4ca4cb8d6 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc @@ -17,6 +17,7 @@ #if !defined(__ANDROID__) #include "mediapipe/framework/port/file_helpers.h" #endif +#include "absl/log/absl_log.h" #include "absl/strings/str_replace.h" #include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.pb.h" @@ -69,7 +70,7 @@ const std::string MaybeConvertSignatureToTag( [](unsigned char c) { return std::toupper(c); }); output = absl::StrReplaceAll( output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}}); - LOG(INFO) << "Renamed TAG from: " << name << " to " << output; + ABSL_LOG(INFO) << "Renamed TAG from: " << name << " to " << output; return output; } else { return name; diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc index ee69ec56a..959622447 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc @@ -19,6 +19,7 @@ #if !defined(__ANDROID__) #include "mediapipe/framework/port/file_helpers.h" #endif +#include "absl/log/absl_log.h" #include "absl/strings/str_replace.h" #include "mediapipe/calculators/tensorflow/tensorflow_session.h" #include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.pb.h" @@ -75,7 +76,7 @@ const std::string MaybeConvertSignatureToTag( [](unsigned char c) { return std::toupper(c); }); output = absl::StrReplaceAll( output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}}); - LOG(INFO) << "Renamed TAG from: " << name << " to " << output; + ABSL_LOG(INFO) << "Renamed TAG from: " << name << " to " << output; return output; } else { return name; diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc index 0d1d4ca26..c77c0f3f8 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "absl/container/flat_hash_map.h" +#include "absl/log/absl_log.h" #include "absl/strings/match.h" #include "mediapipe/calculators/core/packet_resampler_calculator.pb.h" #include "mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.pb.h" @@ -197,15 +198,15 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { // timestamp and the associated feature. This information is used in process // to output batches of packets in order. timestamps_.clear(); - int64 last_timestamp_seen = Timestamp::PreStream().Value(); + int64_t last_timestamp_seen = Timestamp::PreStream().Value(); first_timestamp_seen_ = Timestamp::OneOverPostStream().Value(); for (const auto& map_kv : sequence_->feature_lists().feature_list()) { if (absl::StrContains(map_kv.first, "/timestamp")) { - LOG(INFO) << "Found feature timestamps: " << map_kv.first - << " with size: " << map_kv.second.feature_size(); - int64 recent_timestamp = Timestamp::PreStream().Value(); + ABSL_LOG(INFO) << "Found feature timestamps: " << map_kv.first + << " with size: " << map_kv.second.feature_size(); + int64_t recent_timestamp = Timestamp::PreStream().Value(); for (int i = 0; i < map_kv.second.feature_size(); ++i) { - int64 next_timestamp = + int64_t next_timestamp = mpms::GetInt64sAt(*sequence_, map_kv.first, i).Get(0); RET_CHECK_GT(next_timestamp, recent_timestamp) << "Timestamps must be sequential. If you're seeing this message " @@ -309,8 +310,8 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { audio_decoder_options->set_end_time( end_time + options.extra_padding_from_media_decoder()); } - LOG(INFO) << "Created AudioDecoderOptions:\n" - << audio_decoder_options->DebugString(); + ABSL_LOG(INFO) << "Created AudioDecoderOptions:\n" + << audio_decoder_options->DebugString(); cc->OutputSidePackets() .Tag(kAudioDecoderOptions) .Set(Adopt(audio_decoder_options.release())); @@ -331,8 +332,8 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { ->set_end_time(Timestamp::FromSeconds(end_time).Value()); } - LOG(INFO) << "Created PacketResamplerOptions:\n" - << resampler_options->DebugString(); + ABSL_LOG(INFO) << "Created PacketResamplerOptions:\n" + << resampler_options->DebugString(); cc->OutputSidePackets() .Tag(kPacketResamplerOptions) .Set(Adopt(resampler_options.release())); @@ -351,7 +352,8 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { absl::Status Process(CalculatorContext* cc) override { if (timestamps_.empty()) { // This occurs when we only have metadata to unpack. - LOG(INFO) << "only unpacking metadata because there are no timestamps."; + ABSL_LOG(INFO) + << "only unpacking metadata because there are no timestamps."; return tool::StatusStop(); } // In Process(), we loop through timestamps on a reference stream and emit @@ -361,8 +363,8 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { // any particular call to Process(). At the every end, we output the // poststream packets. If we only have poststream packets, // last_timestamp_key_ will be empty. - int64 start_timestamp = 0; - int64 end_timestamp = 0; + int64_t start_timestamp = 0; + int64_t end_timestamp = 0; if (last_timestamp_key_.empty() || process_poststream_) { process_poststream_ = true; start_timestamp = Timestamp::PostStream().Value(); @@ -481,14 +483,14 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { // Store a map from the keys for each stream to the timestamps for each // key. This allows us to identify which packets to output for each stream // for timestamps within a given time window. - std::map> timestamps_; + std::map> timestamps_; // Store the stream with the latest timestamp in the SequenceExample. std::string last_timestamp_key_; // Store the index of the current timestamp. Will be less than // timestamps_[last_timestamp_key_].size(). int current_timestamp_index_; // Store the very first timestamp, so we output everything on the first frame. - int64 first_timestamp_seen_; + int64_t first_timestamp_seen_; // List of keypoint names. std::vector keypoint_names_; // Default keypoint location when missing. diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc index addb4a27a..2fa70de39 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/strings/numbers.h" #include "mediapipe/calculators/core/packet_resampler_calculator.pb.h" @@ -81,7 +82,7 @@ class UnpackMediaSequenceCalculatorTest : public ::testing::Test { if (options != nullptr) { *config.mutable_options() = *options; } - LOG(INFO) << config.DebugString(); + ABSL_LOG(INFO) << config.DebugString(); runner_ = absl::make_unique(config); } diff --git a/mediapipe/calculators/tensorflow/unpack_yt8m_sequence_example_calculator.cc b/mediapipe/calculators/tensorflow/unpack_yt8m_sequence_example_calculator.cc index efb3037f8..12f2ade02 100644 --- a/mediapipe/calculators/tensorflow/unpack_yt8m_sequence_example_calculator.cc +++ b/mediapipe/calculators/tensorflow/unpack_yt8m_sequence_example_calculator.cc @@ -14,6 +14,8 @@ #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/packet.h" @@ -46,7 +48,7 @@ std::string GetQuantizedFeature( .Get(index) .bytes_list() .value(); - CHECK_EQ(1, bytes_list.size()); + ABSL_CHECK_EQ(1, bytes_list.size()); return bytes_list.Get(0); } } // namespace @@ -149,8 +151,9 @@ class UnpackYt8mSequenceExampleCalculator : public CalculatorBase { .Set(MakePacket(segment_size)); } } - LOG(INFO) << "Reading the sequence example that contains yt8m id: " - << yt8m_id << ". Feature list length: " << feature_list_length_; + ABSL_LOG(INFO) << "Reading the sequence example that contains yt8m id: " + << yt8m_id + << ". Feature list length: " << feature_list_length_; return absl::OkStatus(); } diff --git a/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator.cc b/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator.cc index 28184a8ca..dd0991cbf 100644 --- a/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator.cc @@ -14,6 +14,7 @@ // // Converts vector (or vector>) to 1D (or 2D) tf::Tensor. +#include "absl/log/absl_log.h" #include "mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator_options.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" @@ -68,7 +69,7 @@ absl::Status VectorFloatToTensorCalculator::GetContract( // Output vector. ); } else { - LOG(FATAL) << "input size not supported"; + ABSL_LOG(FATAL) << "input size not supported"; } RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) << "Only one output stream is supported."; @@ -125,7 +126,7 @@ absl::Status VectorFloatToTensorCalculator::Process(CalculatorContext* cc) { } cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); } else { - LOG(FATAL) << "input size not supported"; + ABSL_LOG(FATAL) << "input size not supported"; } return absl::OkStatus(); } diff --git a/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator_test.cc b/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator_test.cc index aadce3615..a4f98d2e9 100644 --- a/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator_test.cc @@ -54,7 +54,7 @@ class VectorToTensorFloatCalculatorTest : public ::testing::Test { } } - const int64 time = 1234; + const int64_t time = 1234; runner_->MutableInputs()->Index(0).packets.push_back( Adopt(input.release()).At(Timestamp(time))); @@ -91,7 +91,7 @@ TEST_F(VectorToTensorFloatCalculatorTest, ConvertsFromVectorFloat) { // 2^i can be represented exactly in floating point numbers if 'i' is small. input->at(i) = static_cast(1 << i); } - const int64 time = 1234; + const int64_t time = 1234; runner_->MutableInputs()->Index(0).packets.push_back( Adopt(input.release()).At(Timestamp(time))); diff --git a/mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator.cc b/mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator.cc index cb90276ae..f4a892027 100644 --- a/mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator.cc @@ -15,6 +15,8 @@ // Converts a single int or vector or vector> to 1D (or 2D) // tf::Tensor. +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator_options.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" @@ -86,7 +88,7 @@ absl::Status VectorIntToTensorCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kVectorInt).Set>(); } } else { - LOG(FATAL) << "input size not supported"; + ABSL_LOG(FATAL) << "input size not supported"; } RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) << "Only one output stream is supported."; @@ -113,11 +115,11 @@ absl::Status VectorIntToTensorCalculator::Process(CalculatorContext* cc) { .Get>>(); const int32_t rows = input.size(); - CHECK_GE(rows, 1); + ABSL_CHECK_GE(rows, 1); const int32_t cols = input[0].size(); - CHECK_GE(cols, 1); + ABSL_CHECK_GE(cols, 1); for (int i = 1; i < rows; ++i) { - CHECK_EQ(input[i].size(), cols); + ABSL_CHECK_EQ(input[i].size(), cols); } if (options_.transpose()) { tensor_shape = tf::TensorShape({cols, rows}); @@ -140,7 +142,7 @@ absl::Status VectorIntToTensorCalculator::Process(CalculatorContext* cc) { AssignMatrixValue(c, r, input[r][c], output.get()); break; default: - LOG(FATAL) << "tensor data type is not supported."; + ABSL_LOG(FATAL) << "tensor data type is not supported."; } } } @@ -158,7 +160,7 @@ absl::Status VectorIntToTensorCalculator::Process(CalculatorContext* cc) { AssignMatrixValue(r, c, input[r][c], output.get()); break; default: - LOG(FATAL) << "tensor data type is not supported."; + ABSL_LOG(FATAL) << "tensor data type is not supported."; } } } @@ -171,7 +173,7 @@ absl::Status VectorIntToTensorCalculator::Process(CalculatorContext* cc) { } else { input = cc->Inputs().Tag(kVectorInt).Value().Get>(); } - CHECK_GE(input.size(), 1); + ABSL_CHECK_GE(input.size(), 1); const int32_t length = input.size(); tensor_shape = tf::TensorShape({length}); auto output = ::absl::make_unique(options_.tensor_data_type(), @@ -188,12 +190,12 @@ absl::Status VectorIntToTensorCalculator::Process(CalculatorContext* cc) { output->tensor()(i) = input.at(i); break; default: - LOG(FATAL) << "tensor data type is not supported."; + ABSL_LOG(FATAL) << "tensor data type is not supported."; } } cc->Outputs().Tag(kTensorOut).Add(output.release(), cc->InputTimestamp()); } else { - LOG(FATAL) << "input size not supported"; + ABSL_LOG(FATAL) << "input size not supported"; } return absl::OkStatus(); } diff --git a/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator.cc b/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator.cc index 139511271..57ee553c5 100644 --- a/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator.cc @@ -15,6 +15,7 @@ // Converts vector (or vector>) to 1D (or 2D) // tf::Tensor. +#include "absl/log/absl_log.h" #include "mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator_options.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" @@ -69,7 +70,7 @@ absl::Status VectorStringToTensorCalculator::GetContract( // Input vector. ); } else { - LOG(FATAL) << "input size not supported"; + ABSL_LOG(FATAL) << "input size not supported"; } RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) << "Only one output stream is supported."; @@ -129,7 +130,7 @@ absl::Status VectorStringToTensorCalculator::Process(CalculatorContext* cc) { } cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); } else { - LOG(FATAL) << "input size not supported"; + ABSL_LOG(FATAL) << "input size not supported"; } return absl::OkStatus(); } diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 333de2069..ed9f47a8b 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -103,6 +103,8 @@ cc_library( "//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", ], alwayslink = 1, ) @@ -196,10 +198,13 @@ cc_library( deps = [ ":tflite_inference_calculator_cc_proto", "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", "//mediapipe/util/tflite:config", "//mediapipe/util/tflite:tflite_model_loader", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", @@ -275,6 +280,7 @@ cc_library( "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", "//mediapipe/util:resource_util", "//mediapipe/util/tflite:config", + "@com_google_absl//absl/log:absl_check", "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ] + selects.with_or({ @@ -392,6 +398,8 @@ cc_library( "//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/util/tflite:config", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@org_tensorflow//tensorflow/lite:framework", @@ -428,6 +436,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/util:resource_util", "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@org_tensorflow//tensorflow/lite:framework", @@ -456,6 +465,7 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:ret_check", + "@com_google_absl//absl/log:absl_check", "@org_tensorflow//tensorflow/lite:framework", ], alwayslink = 1, diff --git a/mediapipe/calculators/tflite/ssd_anchors_calculator.cc b/mediapipe/calculators/tflite/ssd_anchors_calculator.cc index 5ed5a95dc..d5303d65c 100644 --- a/mediapipe/calculators/tflite/ssd_anchors_calculator.cc +++ b/mediapipe/calculators/tflite/ssd_anchors_calculator.cc @@ -16,6 +16,8 @@ #include #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/object_detection/anchor.pb.h" @@ -272,13 +274,13 @@ absl::Status SsdAnchorsCalculator::GenerateAnchors( if (options.feature_map_height_size()) { if (options.strides_size()) { - LOG(ERROR) << "Found feature map shapes. Strides will be ignored."; + ABSL_LOG(ERROR) << "Found feature map shapes. Strides will be ignored."; } - CHECK_EQ(options.feature_map_height_size(), kNumLayers); - CHECK_EQ(options.feature_map_height_size(), - options.feature_map_width_size()); + ABSL_CHECK_EQ(options.feature_map_height_size(), kNumLayers); + ABSL_CHECK_EQ(options.feature_map_height_size(), + options.feature_map_width_size()); } else { - CHECK_EQ(options.strides_size(), kNumLayers); + ABSL_CHECK_EQ(options.strides_size(), kNumLayers); } if (options.multiscale_anchor_generation()) { diff --git a/mediapipe/calculators/tflite/tflite_converter_calculator.cc b/mediapipe/calculators/tflite/tflite_converter_calculator.cc index ff6b2ff91..7188cbc59 100644 --- a/mediapipe/calculators/tflite/tflite_converter_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_converter_calculator.cc @@ -15,6 +15,7 @@ #include #include +#include "absl/log/absl_check.h" #include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" @@ -643,7 +644,7 @@ absl::Status TfLiteConverterCalculator::LoadOptions(CalculatorContext* cc) { if (options.has_output_tensor_float_range()) { output_range_.emplace(options.output_tensor_float_range().min(), options.output_tensor_float_range().max()); - CHECK_GT(output_range_->second, output_range_->first); + ABSL_CHECK_GT(output_range_->second, output_range_->first); } // Custom div and sub values. @@ -661,9 +662,9 @@ absl::Status TfLiteConverterCalculator::LoadOptions(CalculatorContext* cc) { // Get desired way to handle input channels. max_num_channels_ = options.max_num_channels(); - CHECK_GE(max_num_channels_, 1); - CHECK_LE(max_num_channels_, 4); - CHECK_NE(max_num_channels_, 2); + ABSL_CHECK_GE(max_num_channels_, 1); + ABSL_CHECK_LE(max_num_channels_, 4); + ABSL_CHECK_NE(max_num_channels_, 2); #if defined(MEDIAPIPE_IOS) if (cc->Inputs().HasTag(kGpuBufferTag)) // Currently on iOS, tflite gpu input tensor must be 4 channels, diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.cc b/mediapipe/calculators/tflite/tflite_inference_calculator.cc index add9bb1a8..d875b6940 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.cc @@ -17,9 +17,12 @@ #include #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/util/tflite/config.h" @@ -109,8 +112,8 @@ std::unique_ptr BuildEdgeTpuInterpreter( edgetpu::EdgeTpuContext* edgetpu_context) { resolver->AddCustom(edgetpu::kCustomOp, edgetpu::RegisterCustomOp()); std::unique_ptr interpreter; - CHECK_EQ(tflite::InterpreterBuilder(model, *resolver)(&interpreter), - kTfLiteOk); + ABSL_CHECK_EQ(tflite::InterpreterBuilder(model, *resolver)(&interpreter), + kTfLiteOk); interpreter->SetExternalContext(kTfLiteEdgeTpuContext, edgetpu_context); return interpreter; } @@ -406,11 +409,12 @@ absl::Status TfLiteInferenceCalculator::Open(CalculatorContext* cc) { } if (use_advanced_gpu_api_ && !gpu_input_) { - LOG(WARNING) << "Cannot use advanced GPU APIs, input must be GPU buffers." - "Falling back to the default TFLite API."; + ABSL_LOG(WARNING) + << "Cannot use advanced GPU APIs, input must be GPU buffers." + "Falling back to the default TFLite API."; use_advanced_gpu_api_ = false; } - CHECK(!use_advanced_gpu_api_ || gpu_inference_); + ABSL_CHECK(!use_advanced_gpu_api_ || gpu_inference_); MP_RETURN_IF_ERROR(LoadModel(cc)); @@ -802,9 +806,10 @@ absl::Status TfLiteInferenceCalculator::InitTFLiteGPURunner( const int tensor_idx = interpreter_->inputs()[i]; interpreter_->SetTensorParametersReadWrite(tensor_idx, kTfLiteFloat32, "", shape, quant); - CHECK(interpreter_->ResizeInputTensor(tensor_idx, shape) == kTfLiteOk); + ABSL_CHECK(interpreter_->ResizeInputTensor(tensor_idx, shape) == + kTfLiteOk); } - CHECK(interpreter_->AllocateTensors() == kTfLiteOk); + ABSL_CHECK(interpreter_->AllocateTensors() == kTfLiteOk); } // Create and bind OpenGL buffers for outputs. @@ -1053,7 +1058,7 @@ absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) { gpu_data_in_[i]->shape.w * gpu_data_in_[i]->shape.c; // Input to model can be RGBA only. if (tensor->dims->data[3] != 4) { - LOG(WARNING) << "Please ensure input GPU tensor is 4 channels."; + ABSL_LOG(WARNING) << "Please ensure input GPU tensor is 4 channels."; } const std::string shader_source = absl::Substitute(R"(#include diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator.cc index 4d28b91e9..98ab4b1da 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator.cc @@ -17,6 +17,7 @@ #include #include "absl/container/node_hash_map.h" +#include "absl/log/absl_check.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator.pb.h" @@ -172,7 +173,7 @@ absl::Status TfLiteTensorsToClassificationCalculator::Process( // Note that partial_sort will raise error when top_k_ > // classification_list->classification_size(). - CHECK_GE(classification_list->classification_size(), top_k_); + ABSL_CHECK_GE(classification_list->classification_size(), top_k_); auto raw_classification_list = classification_list->mutable_classification(); if (top_k_ > 0 && classification_list->classification_size() >= top_k_) { std::partial_sort(raw_classification_list->begin(), diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc index 2ed62c46d..269661f73 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc @@ -15,6 +15,8 @@ #include #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.pb.h" @@ -93,7 +95,7 @@ void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes, void ConvertAnchorsToRawValues(const std::vector& anchors, int num_boxes, float* raw_anchors) { - CHECK_EQ(anchors.size(), num_boxes); + ABSL_CHECK_EQ(anchors.size(), num_boxes); int box = 0; for (const auto& anchor : anchors) { raw_anchors[box * kNumCoordsPerBox + 0] = anchor.y_center(); @@ -288,14 +290,14 @@ absl::Status TfLiteTensorsToDetectionsCalculator::ProcessCPU( const TfLiteTensor* raw_score_tensor = &input_tensors[1]; // TODO: Add flexible input tensor size handling. - CHECK_EQ(raw_box_tensor->dims->size, 3); - CHECK_EQ(raw_box_tensor->dims->data[0], 1); - CHECK_EQ(raw_box_tensor->dims->data[1], num_boxes_); - CHECK_EQ(raw_box_tensor->dims->data[2], num_coords_); - CHECK_EQ(raw_score_tensor->dims->size, 3); - CHECK_EQ(raw_score_tensor->dims->data[0], 1); - CHECK_EQ(raw_score_tensor->dims->data[1], num_boxes_); - CHECK_EQ(raw_score_tensor->dims->data[2], num_classes_); + ABSL_CHECK_EQ(raw_box_tensor->dims->size, 3); + ABSL_CHECK_EQ(raw_box_tensor->dims->data[0], 1); + ABSL_CHECK_EQ(raw_box_tensor->dims->data[1], num_boxes_); + ABSL_CHECK_EQ(raw_box_tensor->dims->data[2], num_coords_); + ABSL_CHECK_EQ(raw_score_tensor->dims->size, 3); + ABSL_CHECK_EQ(raw_score_tensor->dims->data[0], 1); + ABSL_CHECK_EQ(raw_score_tensor->dims->data[1], num_boxes_); + ABSL_CHECK_EQ(raw_score_tensor->dims->data[2], num_classes_); const float* raw_boxes = raw_box_tensor->data.f; const float* raw_scores = raw_score_tensor->data.f; @@ -303,13 +305,13 @@ absl::Status TfLiteTensorsToDetectionsCalculator::ProcessCPU( if (!anchors_init_) { if (input_tensors.size() == kNumInputTensorsWithAnchors) { const TfLiteTensor* anchor_tensor = &input_tensors[2]; - CHECK_EQ(anchor_tensor->dims->size, 2); - CHECK_EQ(anchor_tensor->dims->data[0], num_boxes_); - CHECK_EQ(anchor_tensor->dims->data[1], kNumCoordsPerBox); + ABSL_CHECK_EQ(anchor_tensor->dims->size, 2); + ABSL_CHECK_EQ(anchor_tensor->dims->data[0], num_boxes_); + ABSL_CHECK_EQ(anchor_tensor->dims->data[1], kNumCoordsPerBox); const float* raw_anchors = anchor_tensor->data.f; ConvertRawValuesToAnchors(raw_anchors, num_boxes_, &anchors_); } else if (side_packet_anchors_) { - CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty()); + ABSL_CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty()); anchors_ = cc->InputSidePackets().Tag("ANCHORS").Get>(); } else { @@ -409,7 +411,7 @@ absl::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU( CopyBuffer(input_tensors[1], gpu_data_->raw_scores_buffer)); if (!anchors_init_) { if (side_packet_anchors_) { - CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty()); + ABSL_CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty()); const auto& anchors = cc->InputSidePackets().Tag("ANCHORS").Get>(); std::vector raw_anchors(num_boxes_ * kNumCoordsPerBox); @@ -417,7 +419,7 @@ absl::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU( MP_RETURN_IF_ERROR(gpu_data_->raw_anchors_buffer.Write( absl::MakeSpan(raw_anchors))); } else { - CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors); + ABSL_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors); MP_RETURN_IF_ERROR( CopyBuffer(input_tensors[2], gpu_data_->raw_anchors_buffer)); } @@ -477,7 +479,7 @@ absl::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU( commandBuffer:[gpu_helper_ commandBuffer]]; if (!anchors_init_) { if (side_packet_anchors_) { - CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty()); + ABSL_CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty()); const auto& anchors = cc->InputSidePackets().Tag("ANCHORS").Get>(); std::vector raw_anchors(num_boxes_ * kNumCoordsPerBox); @@ -541,7 +543,7 @@ absl::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU( output_detections)); #else - LOG(ERROR) << "GPU input on non-Android not supported yet."; + ABSL_LOG(ERROR) << "GPU input on non-Android not supported yet."; #endif // MEDIAPIPE_TFLITE_GL_INFERENCE return absl::OkStatus(); } @@ -567,12 +569,12 @@ absl::Status TfLiteTensorsToDetectionsCalculator::LoadOptions( num_coords_ = options_.num_coords(); // Currently only support 2D when num_values_per_keypoint equals to 2. - CHECK_EQ(options_.num_values_per_keypoint(), 2); + ABSL_CHECK_EQ(options_.num_values_per_keypoint(), 2); // Check if the output size is equal to the requested boxes and keypoints. - CHECK_EQ(options_.num_keypoints() * options_.num_values_per_keypoint() + - kNumCoordsPerBox, - num_coords_); + ABSL_CHECK_EQ(options_.num_keypoints() * options_.num_values_per_keypoint() + + kNumCoordsPerBox, + num_coords_); for (int i = 0; i < options_.ignore_classes_size(); ++i) { ignore_classes_.insert(options_.ignore_classes(i)); @@ -897,10 +899,11 @@ void main() { int max_wg_size; // typically <= 1024 glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1, &max_wg_size); // y-dim - CHECK_LT(num_classes_, max_wg_size) + ABSL_CHECK_LT(num_classes_, max_wg_size) << "# classes must be < " << max_wg_size; // TODO support better filtering. - CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed"; + ABSL_CHECK_LE(ignore_classes_.size(), 1) + << "Only ignore class 0 is allowed"; // Shader program GlShader score_shader; @@ -1115,7 +1118,7 @@ kernel void scoreKernel( ignore_classes_.size() ? 1 : 0); // TODO support better filtering. - CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed"; + ABSL_CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed"; { // Shader program @@ -1147,7 +1150,8 @@ kernel void scoreKernel( options:MTLResourceStorageModeShared]; // # filter classes supported is hardware dependent. int max_wg_size = gpu_data_->score_program.maxTotalThreadsPerThreadgroup; - CHECK_LT(num_classes_, max_wg_size) << "# classes must be <" << max_wg_size; + ABSL_CHECK_LT(num_classes_, max_wg_size) + << "# classes must be <" << max_wg_size; } #endif // MEDIAPIPE_TFLITE_GL_INFERENCE diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc index 1be83bbe1..6740f0afa 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/log/absl_check.h" #include "mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/landmark.pb.h" @@ -199,7 +200,7 @@ absl::Status TfLiteTensorsToLandmarksCalculator::Process( num_values *= raw_tensor->dims->data[i]; } const int num_dimensions = num_values / num_landmarks_; - CHECK_GT(num_dimensions, 0); + ABSL_CHECK_GT(num_dimensions, 0); const float* raw_landmarks = raw_tensor->data.f; diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 22e6b0738..328b24fdb 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -183,9 +183,9 @@ cc_library( "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:timestamp", "//mediapipe/framework/deps:clock", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", ], @@ -248,11 +248,12 @@ cc_library( ":annotation_overlay_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:image_opencv", "//mediapipe/framework/formats:video_stream_header", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:status", @@ -260,6 +261,7 @@ cc_library( "//mediapipe/util:annotation_renderer", "//mediapipe/util:color_cc_proto", "//mediapipe/util:render_data_cc_proto", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", ] + select({ "//mediapipe/gpu:disable_gpu": [], @@ -267,6 +269,7 @@ cc_library( "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_simple_shaders", "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:gpu_buffer_format", "//mediapipe/gpu:shader_util", ], }), @@ -374,9 +377,10 @@ cc_library( "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:location", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:rectangle", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", ], alwayslink = 1, ) @@ -675,6 +679,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/util:color_cc_proto", "//mediapipe/util:render_data_cc_proto", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -731,6 +736,7 @@ cc_library( "//mediapipe/framework/port:statusor", "//mediapipe/util:color_cc_proto", "//mediapipe/util:render_data_cc_proto", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -746,6 +752,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/util:color_cc_proto", "//mediapipe/util:render_data_cc_proto", + "@com_google_absl//absl/log:absl_check", ], alwayslink = 1, ) @@ -899,16 +906,77 @@ mediapipe_proto_library( cc_library( name = "landmarks_smoothing_calculator", srcs = ["landmarks_smoothing_calculator.cc"], + hdrs = ["landmarks_smoothing_calculator.h"], deps = [ ":landmarks_smoothing_calculator_cc_proto", + ":landmarks_smoothing_calculator_utils", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:ret_check", + ], + alwayslink = 1, +) + +cc_library( + name = "landmarks_smoothing_calculator_utils", + srcs = ["landmarks_smoothing_calculator_utils.cc"], + hdrs = ["landmarks_smoothing_calculator_utils.h"], + deps = [ + ":landmarks_smoothing_calculator_cc_proto", + "//mediapipe/framework:calculator_context", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/util/filtering:one_euro_filter", "//mediapipe/util/filtering:relative_velocity_filter", - "@com_google_absl//absl/algorithm:container", + ], + alwayslink = 1, +) + +cc_test( + name = "landmarks_smoothing_calculator_utils_test", + size = "small", + srcs = ["landmarks_smoothing_calculator_utils_test.cc"], + deps = [ + ":landmarks_smoothing_calculator_utils", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/port:gtest_main", + ], +) + +cc_library( + name = "multi_landmarks_smoothing_calculator", + srcs = ["multi_landmarks_smoothing_calculator.cc"], + hdrs = ["multi_landmarks_smoothing_calculator.h"], + deps = [ + ":landmarks_smoothing_calculator_cc_proto", + ":landmarks_smoothing_calculator_utils", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:ret_check", + ], + alwayslink = 1, +) + +cc_library( + name = "multi_world_landmarks_smoothing_calculator", + srcs = ["multi_world_landmarks_smoothing_calculator.cc"], + hdrs = ["multi_world_landmarks_smoothing_calculator.h"], + deps = [ + ":landmarks_smoothing_calculator_cc_proto", + ":landmarks_smoothing_calculator_utils", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:ret_check", ], alwayslink = 1, ) @@ -1088,6 +1156,7 @@ cc_library( "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_log", ], alwayslink = 1, ) @@ -1148,6 +1217,7 @@ cc_library( "//mediapipe/framework/port:rectangle", "//mediapipe/framework/port:status", "//mediapipe/util:rectangle_util", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", ], alwayslink = 1, @@ -1285,12 +1355,14 @@ cc_library( srcs = ["flat_color_image_calculator.cc"], deps = [ ":flat_color_image_calculator_cc_proto", + "//mediapipe/framework:calculator_contract", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:ret_check", "//mediapipe/util:color_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -1417,6 +1489,7 @@ cc_library( "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:ret_check", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", ], alwayslink = 1, diff --git a/mediapipe/calculators/util/annotation_overlay_calculator.cc b/mediapipe/calculators/util/annotation_overlay_calculator.cc index 6e0dc769b..33490359c 100644 --- a/mediapipe/calculators/util/annotation_overlay_calculator.cc +++ b/mediapipe/calculators/util/annotation_overlay_calculator.cc @@ -14,15 +14,17 @@ #include +#include "absl/log/absl_log.h" #include "absl/strings/str_cat.h" #include "mediapipe/calculators/util/annotation_overlay_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_options.pb.h" +#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/image_opencv.h" #include "mediapipe/framework/formats/video_stream_header.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status.h" @@ -35,6 +37,7 @@ #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/shader_util.h" #endif // !MEDIAPIPE_DISABLE_GPU @@ -45,6 +48,7 @@ namespace { constexpr char kVectorTag[] = "VECTOR"; constexpr char kGpuBufferTag[] = "IMAGE_GPU"; constexpr char kImageFrameTag[] = "IMAGE"; +constexpr char kImageTag[] = "UIMAGE"; // Universal Image enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; @@ -57,13 +61,16 @@ size_t RoundUp(size_t n, size_t m) { return ((n + m - 1) / m) * m; } // NOLINT constexpr uchar kAnnotationBackgroundColor = 2; // Grayscale value. // Future Image type. -inline bool HasImageTag(mediapipe::CalculatorContext* cc) { return false; } +inline bool HasImageTag(mediapipe::CalculatorContext* cc) { + return cc->Inputs().HasTag(kImageTag); +} } // namespace // A calculator for rendering data on images. // // Inputs: // 1. IMAGE or IMAGE_GPU (optional): An ImageFrame (or GpuBuffer), +// or UIMAGE (an Image). // containing the input image. // If output is CPU, and input isn't provided, the renderer creates a // blank canvas with the width, height and color provided in the options. @@ -76,6 +83,7 @@ inline bool HasImageTag(mediapipe::CalculatorContext* cc) { return false; } // // Output: // 1. IMAGE or IMAGE_GPU: A rendered ImageFrame (or GpuBuffer), +// or UIMAGE (an Image). // Note: Output types should match their corresponding input stream type. // // For CPU input frames, only SRGBA, SRGB and GRAY8 format are supported. The @@ -135,6 +143,9 @@ class AnnotationOverlayCalculator : public CalculatorBase { absl::Status CreateRenderTargetCpu(CalculatorContext* cc, std::unique_ptr& image_mat, ImageFormat::Format* target_format); + absl::Status CreateRenderTargetCpuImage(CalculatorContext* cc, + std::unique_ptr& image_mat, + ImageFormat::Format* target_format); template absl::Status CreateRenderTargetGpu(CalculatorContext* cc, std::unique_ptr& image_mat); @@ -172,30 +183,38 @@ class AnnotationOverlayCalculator : public CalculatorBase { REGISTER_CALCULATOR(AnnotationOverlayCalculator); absl::Status AnnotationOverlayCalculator::GetContract(CalculatorContract* cc) { - CHECK_GE(cc->Inputs().NumEntries(), 1); + RET_CHECK_GE(cc->Inputs().NumEntries(), 1); bool use_gpu = false; - if (cc->Inputs().HasTag(kImageFrameTag) && - cc->Inputs().HasTag(kGpuBufferTag)) { - return absl::InternalError("Cannot have multiple input images."); - } - if (cc->Inputs().HasTag(kGpuBufferTag) != - cc->Outputs().HasTag(kGpuBufferTag)) { - return absl::InternalError("GPU output must have GPU input."); - } + RET_CHECK(cc->Inputs().HasTag(kImageFrameTag) + + cc->Inputs().HasTag(kGpuBufferTag) + + cc->Inputs().HasTag(kImageTag) <= + 1); + RET_CHECK(cc->Outputs().HasTag(kImageFrameTag) + + cc->Outputs().HasTag(kGpuBufferTag) + + cc->Outputs().HasTag(kImageTag) == + 1); // Input image to render onto copy of. Should be same type as output. #if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kGpuBufferTag)) { cc->Inputs().Tag(kGpuBufferTag).Set(); - CHECK(cc->Outputs().HasTag(kGpuBufferTag)); + RET_CHECK(cc->Outputs().HasTag(kGpuBufferTag)); use_gpu = true; } #endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kImageFrameTag)) { cc->Inputs().Tag(kImageFrameTag).Set(); - CHECK(cc->Outputs().HasTag(kImageFrameTag)); + RET_CHECK(cc->Outputs().HasTag(kImageFrameTag)); + } + + if (cc->Inputs().HasTag(kImageTag)) { + cc->Inputs().Tag(kImageTag).Set(); + RET_CHECK(cc->Outputs().HasTag(kImageTag)); +#if !MEDIAPIPE_DISABLE_GPU + use_gpu = true; // Prepare GPU resources because images can come in on GPU. +#endif } // Data streams to render. @@ -220,6 +239,9 @@ absl::Status AnnotationOverlayCalculator::GetContract(CalculatorContract* cc) { if (cc->Outputs().HasTag(kImageFrameTag)) { cc->Outputs().Tag(kImageFrameTag).Set(); } + if (cc->Outputs().HasTag(kImageTag)) { + cc->Outputs().Tag(kImageTag).Set(); + } if (use_gpu) { #if !MEDIAPIPE_DISABLE_GPU @@ -252,9 +274,14 @@ absl::Status AnnotationOverlayCalculator::Open(CalculatorContext* cc) { renderer_ = absl::make_unique(); renderer_->SetFlipTextVertically(options_.flip_text_vertically()); if (use_gpu_) renderer_->SetScaleFactor(options_.gpu_scale_factor()); + if (renderer_->GetScaleFactor() < 1.0 && HasImageTag(cc)) + ABSL_LOG(WARNING) + << "Annotation scale factor only supports GPU backed Image."; // Set the output header based on the input header (if present). - const char* tag = use_gpu_ ? kGpuBufferTag : kImageFrameTag; + const char* tag = HasImageTag(cc) ? kImageTag + : use_gpu_ ? kGpuBufferTag + : kImageFrameTag; if (image_frame_available_ && !cc->Inputs().Tag(tag).Header().IsEmpty()) { const auto& input_header = cc->Inputs().Tag(tag).Header().Get(); @@ -280,6 +307,12 @@ absl::Status AnnotationOverlayCalculator::Process(CalculatorContext* cc) { cc->Inputs().Tag(kImageFrameTag).IsEmpty()) { return absl::OkStatus(); } + if (cc->Inputs().HasTag(kImageTag) && cc->Inputs().Tag(kImageTag).IsEmpty()) { + return absl::OkStatus(); + } + if (HasImageTag(cc)) { + use_gpu_ = cc->Inputs().Tag(kImageTag).Get().UsesGpu(); + } // Initialize render target, drawn with OpenCV. std::unique_ptr image_mat; @@ -289,10 +322,17 @@ absl::Status AnnotationOverlayCalculator::Process(CalculatorContext* cc) { if (!gpu_initialized_) { MP_RETURN_IF_ERROR( gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { + if (HasImageTag(cc)) { + return GlSetup(cc); + } return GlSetup(cc); })); gpu_initialized_ = true; } + if (HasImageTag(cc)) { + MP_RETURN_IF_ERROR( + (CreateRenderTargetGpu(cc, image_mat))); + } if (cc->Inputs().HasTag(kGpuBufferTag)) { MP_RETURN_IF_ERROR( (CreateRenderTargetGpu( @@ -300,6 +340,10 @@ absl::Status AnnotationOverlayCalculator::Process(CalculatorContext* cc) { } #endif // !MEDIAPIPE_DISABLE_GPU } else { + if (cc->Outputs().HasTag(kImageTag)) { + MP_RETURN_IF_ERROR( + CreateRenderTargetCpuImage(cc, image_mat, &target_format)); + } if (cc->Outputs().HasTag(kImageFrameTag)) { MP_RETURN_IF_ERROR(CreateRenderTargetCpu(cc, image_mat, &target_format)); } @@ -339,6 +383,9 @@ absl::Status AnnotationOverlayCalculator::Process(CalculatorContext* cc) { uchar* image_mat_ptr = image_mat->data; MP_RETURN_IF_ERROR( gpu_helper_.RunInGlContext([this, cc, image_mat_ptr]() -> absl::Status { + if (HasImageTag(cc)) { + return RenderToGpu(cc, image_mat_ptr); + } return RenderToGpu( cc, image_mat_ptr); })); @@ -381,6 +428,10 @@ absl::Status AnnotationOverlayCalculator::RenderToCpu( ImageFrame::kDefaultAlignmentBoundary); #endif // !MEDIAPIPE_DISABLE_GPU + if (HasImageTag(cc)) { + auto out = std::make_unique(std::move(output_frame)); + cc->Outputs().Tag(kImageTag).Add(out.release(), cc->InputTimestamp()); + } if (cc->Outputs().HasTag(kImageFrameTag)) { cc->Outputs() .Tag(kImageFrameTag) @@ -399,7 +450,8 @@ absl::Status AnnotationOverlayCalculator::RenderToGpu(CalculatorContext* cc, auto input_texture = gpu_helper_.CreateSourceTexture(input_frame); auto output_texture = gpu_helper_.CreateDestinationTexture( - width_, height_, mediapipe::GpuBufferFormat::kBGRA32); + input_texture.width(), input_texture.height(), + mediapipe::GpuBufferFormat::kBGRA32); // Upload render target to GPU. { @@ -428,7 +480,7 @@ absl::Status AnnotationOverlayCalculator::RenderToGpu(CalculatorContext* cc, } // Send out blended image as GPU packet. - auto output_frame = output_texture.GetFrame(); + auto output_frame = output_texture.template GetFrame(); cc->Outputs().Tag(Tag).Add(output_frame.release(), cc->InputTimestamp()); // Cleanup @@ -471,7 +523,7 @@ absl::Status AnnotationOverlayCalculator::CreateRenderTargetCpu( auto input_mat = formats::MatView(&input_frame); if (input_frame.Format() == ImageFormat::GRAY8) { cv::Mat rgb_mat; - cv::cvtColor(input_mat, rgb_mat, CV_GRAY2RGB); + cv::cvtColor(input_mat, rgb_mat, cv::COLOR_GRAY2RGB); rgb_mat.copyTo(*image_mat); } else { input_mat.copyTo(*image_mat); @@ -487,6 +539,54 @@ absl::Status AnnotationOverlayCalculator::CreateRenderTargetCpu( return absl::OkStatus(); } +absl::Status AnnotationOverlayCalculator::CreateRenderTargetCpuImage( + CalculatorContext* cc, std::unique_ptr& image_mat, + ImageFormat::Format* target_format) { + if (image_frame_available_) { + const auto& input_frame = + cc->Inputs().Tag(kImageTag).Get(); + + int target_mat_type; + switch (input_frame.image_format()) { + case ImageFormat::SRGBA: + *target_format = ImageFormat::SRGBA; + target_mat_type = CV_8UC4; + break; + case ImageFormat::SRGB: + *target_format = ImageFormat::SRGB; + target_mat_type = CV_8UC3; + break; + case ImageFormat::GRAY8: + *target_format = ImageFormat::SRGB; + target_mat_type = CV_8UC3; + break; + default: + return absl::UnknownError("Unexpected image frame format."); + break; + } + + image_mat = absl::make_unique( + input_frame.height(), input_frame.width(), target_mat_type); + + auto input_mat = formats::MatView(&input_frame); + if (input_frame.image_format() == ImageFormat::GRAY8) { + cv::Mat rgb_mat; + cv::cvtColor(*input_mat, rgb_mat, cv::COLOR_GRAY2RGB); + rgb_mat.copyTo(*image_mat); + } else { + input_mat->copyTo(*image_mat); + } + } else { + image_mat = absl::make_unique( + options_.canvas_height_px(), options_.canvas_width_px(), CV_8UC3, + cv::Scalar(options_.canvas_color().r(), options_.canvas_color().g(), + options_.canvas_color().b())); + *target_format = ImageFormat::SRGB; + } + + return absl::OkStatus(); +} + template absl::Status AnnotationOverlayCalculator::CreateRenderTargetGpu( CalculatorContext* cc, std::unique_ptr& image_mat) { diff --git a/mediapipe/calculators/util/association_calculator.h b/mediapipe/calculators/util/association_calculator.h index 037ea838c..1cec63c80 100644 --- a/mediapipe/calculators/util/association_calculator.h +++ b/mediapipe/calculators/util/association_calculator.h @@ -18,6 +18,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "mediapipe/calculators/util/association_calculator.pb.h" #include "mediapipe/framework/calculator_context.h" @@ -72,7 +73,7 @@ class AssociationCalculator : public CalculatorBase { prev_input_stream_id_ = cc->Inputs().GetId("PREV", 0); } options_ = cc->Options<::mediapipe::AssociationCalculatorOptions>(); - CHECK_GE(options_.min_similarity_threshold(), 0); + ABSL_CHECK_GE(options_.min_similarity_threshold(), 0); return absl::OkStatus(); } diff --git a/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc b/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc index 0c1d6892e..44b7a210f 100644 --- a/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc +++ b/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc @@ -19,6 +19,7 @@ #include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/proto_ns.h" #include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/resource_util.h" @@ -85,7 +86,8 @@ absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) { ASSIGN_OR_RETURN(string_path, PathToResourceAsFile(options.label_map_path())); std::string label_map_string; - MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string)); + MP_RETURN_IF_ERROR( + mediapipe::GetResourceContents(string_path, &label_map_string)); std::istringstream stream(label_map_string); std::string line; diff --git a/mediapipe/calculators/util/detections_deduplicate_calculator.cc b/mediapipe/calculators/util/detections_deduplicate_calculator.cc index 2dfa09028..a31585b88 100644 --- a/mediapipe/calculators/util/detections_deduplicate_calculator.cc +++ b/mediapipe/calculators/util/detections_deduplicate_calculator.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/calculators/util/detections_to_render_data_calculator.cc b/mediapipe/calculators/util/detections_to_render_data_calculator.cc index 25d74ba68..73c2cb1d2 100644 --- a/mediapipe/calculators/util/detections_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/detections_to_render_data_calculator.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -233,13 +234,13 @@ void DetectionsToRenderDataCalculator::AddLabels( const Detection& detection, const DetectionsToRenderDataCalculatorOptions& options, float text_line_height, RenderData* render_data) { - CHECK(detection.label().empty() || detection.label_id().empty() || - detection.label_size() == detection.label_id_size()) + ABSL_CHECK(detection.label().empty() || detection.label_id().empty() || + detection.label_size() == detection.label_id_size()) << "String or integer labels should be of same size. Or only one of them " "is present."; const auto num_labels = std::max(detection.label_size(), detection.label_id_size()); - CHECK_EQ(detection.score_size(), num_labels) + ABSL_CHECK_EQ(detection.score_size(), num_labels) << "Number of scores and labels should match for detection."; // Extracts all "label(_id),score" for the detection. @@ -361,9 +362,9 @@ void DetectionsToRenderDataCalculator::AddDetectionToRenderData( const Detection& detection, const DetectionsToRenderDataCalculatorOptions& options, RenderData* render_data) { - CHECK(detection.location_data().format() == LocationData::BOUNDING_BOX || - detection.location_data().format() == - LocationData::RELATIVE_BOUNDING_BOX) + ABSL_CHECK(detection.location_data().format() == LocationData::BOUNDING_BOX || + detection.location_data().format() == + LocationData::RELATIVE_BOUNDING_BOX) << "Only Detection with formats of BOUNDING_BOX or RELATIVE_BOUNDING_BOX " "are supported."; double text_line_height; diff --git a/mediapipe/calculators/util/flat_color_image_calculator.cc b/mediapipe/calculators/util/flat_color_image_calculator.cc index 71d3582c5..f3b9c184c 100644 --- a/mediapipe/calculators/util/flat_color_image_calculator.cc +++ b/mediapipe/calculators/util/flat_color_image_calculator.cc @@ -15,14 +15,13 @@ #include #include "absl/status/status.h" -#include "absl/strings/str_cat.h" #include "mediapipe/calculators/util/flat_color_image_calculator.pb.h" #include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" -#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/util/color.pb.h" namespace mediapipe { @@ -32,6 +31,7 @@ namespace { using ::mediapipe::api2::Input; using ::mediapipe::api2::Node; using ::mediapipe::api2::Output; +using ::mediapipe::api2::SideOutput; } // namespace // A calculator for generating an image filled with a single color. @@ -45,7 +45,8 @@ using ::mediapipe::api2::Output; // // Outputs: // IMAGE (Image) -// Image filled with the requested color. +// Image filled with the requested color. Can be either an output_stream +// or an output_side_packet. // // Example useage: // node { @@ -68,9 +69,10 @@ class FlatColorImageCalculator : public Node { public: static constexpr Input::Optional kInImage{"IMAGE"}; static constexpr Input::Optional kInColor{"COLOR"}; - static constexpr Output kOutImage{"IMAGE"}; + static constexpr Output::Optional kOutImage{"IMAGE"}; + static constexpr SideOutput::Optional kOutSideImage{"IMAGE"}; - MEDIAPIPE_NODE_CONTRACT(kInImage, kInColor, kOutImage); + MEDIAPIPE_NODE_CONTRACT(kInImage, kInColor, kOutImage, kOutSideImage); static absl::Status UpdateContract(CalculatorContract* cc) { const auto& options = cc->Options(); @@ -81,6 +83,13 @@ class FlatColorImageCalculator : public Node { RET_CHECK(kInColor(cc).IsConnected() ^ options.has_color()) << "Either set COLOR input stream, or set through options"; + RET_CHECK(kOutImage(cc).IsConnected() ^ kOutSideImage(cc).IsConnected()) + << "Set IMAGE either as output stream, or as output side packet"; + + RET_CHECK(!kOutSideImage(cc).IsConnected() || + (options.has_output_height() && options.has_output_width())) + << "Set size through options, when setting IMAGE as output side packet"; + return absl::OkStatus(); } @@ -88,6 +97,9 @@ class FlatColorImageCalculator : public Node { absl::Status Process(CalculatorContext* cc) override; private: + std::optional> CreateOutputFrame( + CalculatorContext* cc); + bool use_dimension_from_option_ = false; bool use_color_from_option_ = false; }; @@ -96,10 +108,31 @@ MEDIAPIPE_REGISTER_NODE(FlatColorImageCalculator); absl::Status FlatColorImageCalculator::Open(CalculatorContext* cc) { use_dimension_from_option_ = !kInImage(cc).IsConnected(); use_color_from_option_ = !kInColor(cc).IsConnected(); + + if (!kOutImage(cc).IsConnected()) { + std::optional> output_frame = + CreateOutputFrame(cc); + if (output_frame.has_value()) { + kOutSideImage(cc).Set(Image(output_frame.value())); + } + } return absl::OkStatus(); } absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) { + if (kOutImage(cc).IsConnected()) { + std::optional> output_frame = + CreateOutputFrame(cc); + if (output_frame.has_value()) { + kOutImage(cc).Send(Image(output_frame.value())); + } + } + + return absl::OkStatus(); +} + +std::optional> +FlatColorImageCalculator::CreateOutputFrame(CalculatorContext* cc) { const auto& options = cc->Options(); int output_height = -1; @@ -112,7 +145,7 @@ absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) { output_height = input_image.height(); output_width = input_image.width(); } else { - return absl::OkStatus(); + return std::nullopt; } Color color; @@ -121,7 +154,7 @@ absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) { } else if (!kInColor(cc).IsEmpty()) { color = kInColor(cc).Get(); } else { - return absl::OkStatus(); + return std::nullopt; } auto output_frame = std::make_shared(ImageFormat::SRGB, @@ -130,9 +163,7 @@ absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) { output_mat.setTo(cv::Scalar(color.r(), color.g(), color.b())); - kOutImage(cc).Send(Image(output_frame)); - - return absl::OkStatus(); + return output_frame; } } // namespace mediapipe diff --git a/mediapipe/calculators/util/flat_color_image_calculator_test.cc b/mediapipe/calculators/util/flat_color_image_calculator_test.cc index 53c6de1b1..c09064bf2 100644 --- a/mediapipe/calculators/util/flat_color_image_calculator_test.cc +++ b/mediapipe/calculators/util/flat_color_image_calculator_test.cc @@ -113,6 +113,35 @@ TEST(FlatColorImageCalculatorTest, SpecifyDimensionThroughOptions) { } } +TEST(FlatColorImageCalculatorTest, ProducesOutputSidePacket) { + CalculatorRunner runner(R"pb( + calculator: "FlatColorImageCalculator" + output_side_packet: "IMAGE:out_packet" + options { + [mediapipe.FlatColorImageCalculatorOptions.ext] { + output_width: 1 + output_height: 1 + color: { + r: 100, + g: 200, + b: 255, + } + } + } + )pb"); + + MP_ASSERT_OK(runner.Run()); + + const auto& image = runner.OutputSidePackets().Tag(kImageTag).Get(); + EXPECT_EQ(image.width(), 1); + EXPECT_EQ(image.height(), 1); + auto image_frame = image.GetImageFrameSharedPtr(); + const uint8_t* pixel_data = image_frame->PixelData(); + EXPECT_EQ(pixel_data[0], 100); + EXPECT_EQ(pixel_data[1], 200); + EXPECT_EQ(pixel_data[2], 255); +} + TEST(FlatColorImageCalculatorTest, FailureMissingDimension) { CalculatorRunner runner(R"pb( calculator: "FlatColorImageCalculator" @@ -206,5 +235,56 @@ TEST(FlatColorImageCalculatorTest, FailureDuplicateColor) { HasSubstr("Either set COLOR input stream")); } +TEST(FlatColorImageCalculatorTest, FailureDuplicateOutputs) { + CalculatorRunner runner(R"pb( + calculator: "FlatColorImageCalculator" + output_stream: "IMAGE:out_image" + output_side_packet: "IMAGE:out_packet" + options { + [mediapipe.FlatColorImageCalculatorOptions.ext] { + output_width: 1 + output_height: 1 + color: { + r: 100, + g: 200, + b: 255, + } + } + } + )pb"); + + ASSERT_THAT( + runner.Run().message(), + HasSubstr("Set IMAGE either as output stream, or as output side packet")); +} + +TEST(FlatColorImageCalculatorTest, FailureSettingInputImageOnOutputSidePacket) { + CalculatorRunner runner(R"pb( + calculator: "FlatColorImageCalculator" + input_stream: "IMAGE:image" + output_side_packet: "IMAGE:out_packet" + options { + [mediapipe.FlatColorImageCalculatorOptions.ext] { + color: { + r: 100, + g: 200, + b: 255, + } + } + } + )pb"); + + auto image_frame = std::make_shared(ImageFormat::SRGB, + kImageWidth, kImageHeight); + + for (int ts = 0; ts < 3; ++ts) { + runner.MutableInputs()->Tag(kImageTag).packets.push_back( + MakePacket(image_frame).At(Timestamp(ts))); + } + ASSERT_THAT(runner.Run().message(), + HasSubstr("Set size through options, when setting IMAGE as " + "output side packet")); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/util/labels_to_render_data_calculator.cc b/mediapipe/calculators/util/labels_to_render_data_calculator.cc index dcd76d47b..314640ed7 100644 --- a/mediapipe/calculators/util/labels_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/labels_to_render_data_calculator.cc @@ -19,6 +19,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "mediapipe/calculators/util/labels_to_render_data_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -114,7 +115,8 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { video_height_ = video_header.height; return absl::OkStatus(); } else { - CHECK_EQ(options_.location(), LabelsToRenderDataCalculatorOptions::TOP_LEFT) + ABSL_CHECK_EQ(options_.location(), + LabelsToRenderDataCalculatorOptions::TOP_LEFT) << "Only TOP_LEFT is supported without VIDEO_PRESTREAM."; } @@ -144,7 +146,7 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().HasTag(kScoresTag)) { std::vector score_vector = cc->Inputs().Tag(kScoresTag).Get>(); - CHECK_EQ(label_vector.size(), score_vector.size()); + ABSL_CHECK_EQ(label_vector.size(), score_vector.size()); scores.resize(label_vector.size()); for (int i = 0; i < label_vector.size(); ++i) { scores[i] = score_vector[i]; diff --git a/mediapipe/calculators/util/landmarks_refinement_calculator.cc b/mediapipe/calculators/util/landmarks_refinement_calculator.cc index 8f734ac88..87394c6c5 100644 --- a/mediapipe/calculators/util/landmarks_refinement_calculator.cc +++ b/mediapipe/calculators/util/landmarks_refinement_calculator.cc @@ -18,6 +18,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "mediapipe/calculators/util/landmarks_refinement_calculator.pb.h" #include "mediapipe/framework/api2/node.h" @@ -102,7 +103,8 @@ void RefineZ( ->set_z(z_average); } } else { - CHECK(false) << "Z refinement is either not specified or not supported"; + ABSL_CHECK(false) + << "Z refinement is either not specified or not supported"; } } diff --git a/mediapipe/calculators/util/landmarks_smoothing_calculator.cc b/mediapipe/calculators/util/landmarks_smoothing_calculator.cc index 7a92cfb7e..bc7504485 100644 --- a/mediapipe/calculators/util/landmarks_smoothing_calculator.cc +++ b/mediapipe/calculators/util/landmarks_smoothing_calculator.cc @@ -1,4 +1,4 @@ -// Copyright 2020 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,471 +12,105 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/calculators/util/landmarks_smoothing_calculator.h" + #include -#include "absl/algorithm/container.h" #include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h" +#include "mediapipe/calculators/util/landmarks_smoothing_calculator_utils.h" +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" -#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/timestamp.h" -#include "mediapipe/util/filtering/one_euro_filter.h" -#include "mediapipe/util/filtering/relative_velocity_filter.h" namespace mediapipe { +namespace api2 { namespace { -constexpr char kNormalizedLandmarksTag[] = "NORM_LANDMARKS"; -constexpr char kLandmarksTag[] = "LANDMARKS"; -constexpr char kImageSizeTag[] = "IMAGE_SIZE"; -constexpr char kObjectScaleRoiTag[] = "OBJECT_SCALE_ROI"; -constexpr char kNormalizedFilteredLandmarksTag[] = "NORM_FILTERED_LANDMARKS"; -constexpr char kFilteredLandmarksTag[] = "FILTERED_LANDMARKS"; - using ::mediapipe::NormalizedRect; -using mediapipe::OneEuroFilter; using ::mediapipe::Rect; -using mediapipe::RelativeVelocityFilter; - -void NormalizedLandmarksToLandmarks( - const NormalizedLandmarkList& norm_landmarks, const int image_width, - const int image_height, LandmarkList* landmarks) { - for (int i = 0; i < norm_landmarks.landmark_size(); ++i) { - const auto& norm_landmark = norm_landmarks.landmark(i); - - auto* landmark = landmarks->add_landmark(); - landmark->set_x(norm_landmark.x() * image_width); - landmark->set_y(norm_landmark.y() * image_height); - // Scale Z the same way as X (using image width). - landmark->set_z(norm_landmark.z() * image_width); - landmark->set_visibility(norm_landmark.visibility()); - landmark->set_presence(norm_landmark.presence()); - } -} - -void LandmarksToNormalizedLandmarks(const LandmarkList& landmarks, - const int image_width, - const int image_height, - NormalizedLandmarkList* norm_landmarks) { - for (int i = 0; i < landmarks.landmark_size(); ++i) { - const auto& landmark = landmarks.landmark(i); - - auto* norm_landmark = norm_landmarks->add_landmark(); - norm_landmark->set_x(landmark.x() / image_width); - norm_landmark->set_y(landmark.y() / image_height); - // Scale Z the same way as X (using image width). - norm_landmark->set_z(landmark.z() / image_width); - norm_landmark->set_visibility(landmark.visibility()); - norm_landmark->set_presence(landmark.presence()); - } -} - -// Estimate object scale to use its inverse value as velocity scale for -// RelativeVelocityFilter. If value will be too small (less than -// `options_.min_allowed_object_scale`) smoothing will be disabled and -// landmarks will be returned as is. -// Object scale is calculated as average between bounding box width and height -// with sides parallel to axis. -float GetObjectScale(const LandmarkList& landmarks) { - const auto& lm_minmax_x = absl::c_minmax_element( - landmarks.landmark(), - [](const auto& a, const auto& b) { return a.x() < b.x(); }); - const float x_min = lm_minmax_x.first->x(); - const float x_max = lm_minmax_x.second->x(); - - const auto& lm_minmax_y = absl::c_minmax_element( - landmarks.landmark(), - [](const auto& a, const auto& b) { return a.y() < b.y(); }); - const float y_min = lm_minmax_y.first->y(); - const float y_max = lm_minmax_y.second->y(); - - const float object_width = x_max - x_min; - const float object_height = y_max - y_min; - - return (object_width + object_height) / 2.0f; -} - -float GetObjectScale(const NormalizedRect& roi, const int image_width, - const int image_height) { - const float object_width = roi.width() * image_width; - const float object_height = roi.height() * image_height; - - return (object_width + object_height) / 2.0f; -} - -float GetObjectScale(const Rect& roi) { - return (roi.width() + roi.height()) / 2.0f; -} - -// Abstract class for various landmarks filters. -class LandmarksFilter { - public: - virtual ~LandmarksFilter() = default; - - virtual absl::Status Reset() { return absl::OkStatus(); } - - virtual absl::Status Apply(const LandmarkList& in_landmarks, - const absl::Duration& timestamp, - const absl::optional object_scale_opt, - LandmarkList* out_landmarks) = 0; -}; - -// Returns landmarks as is without smoothing. -class NoFilter : public LandmarksFilter { - public: - absl::Status Apply(const LandmarkList& in_landmarks, - const absl::Duration& timestamp, - const absl::optional object_scale_opt, - LandmarkList* out_landmarks) override { - *out_landmarks = in_landmarks; - return absl::OkStatus(); - } -}; - -// Please check RelativeVelocityFilter documentation for details. -class VelocityFilter : public LandmarksFilter { - public: - VelocityFilter(int window_size, float velocity_scale, - float min_allowed_object_scale, bool disable_value_scaling) - : window_size_(window_size), - velocity_scale_(velocity_scale), - min_allowed_object_scale_(min_allowed_object_scale), - disable_value_scaling_(disable_value_scaling) {} - - absl::Status Reset() override { - x_filters_.clear(); - y_filters_.clear(); - z_filters_.clear(); - return absl::OkStatus(); - } - - absl::Status Apply(const LandmarkList& in_landmarks, - const absl::Duration& timestamp, - const absl::optional object_scale_opt, - LandmarkList* out_landmarks) override { - // Get value scale as inverse value of the object scale. - // If value is too small smoothing will be disabled and landmarks will be - // returned as is. - float value_scale = 1.0f; - if (!disable_value_scaling_) { - const float object_scale = - object_scale_opt ? *object_scale_opt : GetObjectScale(in_landmarks); - if (object_scale < min_allowed_object_scale_) { - *out_landmarks = in_landmarks; - return absl::OkStatus(); - } - value_scale = 1.0f / object_scale; - } - - // Initialize filters once. - MP_RETURN_IF_ERROR(InitializeFiltersIfEmpty(in_landmarks.landmark_size())); - - // Filter landmarks. Every axis of every landmark is filtered separately. - for (int i = 0; i < in_landmarks.landmark_size(); ++i) { - const auto& in_landmark = in_landmarks.landmark(i); - - auto* out_landmark = out_landmarks->add_landmark(); - *out_landmark = in_landmark; - out_landmark->set_x( - x_filters_[i].Apply(timestamp, value_scale, in_landmark.x())); - out_landmark->set_y( - y_filters_[i].Apply(timestamp, value_scale, in_landmark.y())); - out_landmark->set_z( - z_filters_[i].Apply(timestamp, value_scale, in_landmark.z())); - } - - return absl::OkStatus(); - } - - private: - // Initializes filters for the first time or after Reset. If initialized then - // check the size. - absl::Status InitializeFiltersIfEmpty(const int n_landmarks) { - if (!x_filters_.empty()) { - RET_CHECK_EQ(x_filters_.size(), n_landmarks); - RET_CHECK_EQ(y_filters_.size(), n_landmarks); - RET_CHECK_EQ(z_filters_.size(), n_landmarks); - return absl::OkStatus(); - } - - x_filters_.resize(n_landmarks, - RelativeVelocityFilter(window_size_, velocity_scale_)); - y_filters_.resize(n_landmarks, - RelativeVelocityFilter(window_size_, velocity_scale_)); - z_filters_.resize(n_landmarks, - RelativeVelocityFilter(window_size_, velocity_scale_)); - - return absl::OkStatus(); - } - - int window_size_; - float velocity_scale_; - float min_allowed_object_scale_; - bool disable_value_scaling_; - - std::vector x_filters_; - std::vector y_filters_; - std::vector z_filters_; -}; - -// Please check OneEuroFilter documentation for details. -class OneEuroFilterImpl : public LandmarksFilter { - public: - OneEuroFilterImpl(double frequency, double min_cutoff, double beta, - double derivate_cutoff, float min_allowed_object_scale, - bool disable_value_scaling) - : frequency_(frequency), - min_cutoff_(min_cutoff), - beta_(beta), - derivate_cutoff_(derivate_cutoff), - min_allowed_object_scale_(min_allowed_object_scale), - disable_value_scaling_(disable_value_scaling) {} - - absl::Status Reset() override { - x_filters_.clear(); - y_filters_.clear(); - z_filters_.clear(); - return absl::OkStatus(); - } - - absl::Status Apply(const LandmarkList& in_landmarks, - const absl::Duration& timestamp, - const absl::optional object_scale_opt, - LandmarkList* out_landmarks) override { - // Initialize filters once. - MP_RETURN_IF_ERROR(InitializeFiltersIfEmpty(in_landmarks.landmark_size())); - - // Get value scale as inverse value of the object scale. - // If value is too small smoothing will be disabled and landmarks will be - // returned as is. - float value_scale = 1.0f; - if (!disable_value_scaling_) { - const float object_scale = - object_scale_opt ? *object_scale_opt : GetObjectScale(in_landmarks); - if (object_scale < min_allowed_object_scale_) { - *out_landmarks = in_landmarks; - return absl::OkStatus(); - } - value_scale = 1.0f / object_scale; - } - - // Filter landmarks. Every axis of every landmark is filtered separately. - for (int i = 0; i < in_landmarks.landmark_size(); ++i) { - const auto& in_landmark = in_landmarks.landmark(i); - - auto* out_landmark = out_landmarks->add_landmark(); - *out_landmark = in_landmark; - out_landmark->set_x( - x_filters_[i].Apply(timestamp, value_scale, in_landmark.x())); - out_landmark->set_y( - y_filters_[i].Apply(timestamp, value_scale, in_landmark.y())); - out_landmark->set_z( - z_filters_[i].Apply(timestamp, value_scale, in_landmark.z())); - } - - return absl::OkStatus(); - } - - private: - // Initializes filters for the first time or after Reset. If initialized then - // check the size. - absl::Status InitializeFiltersIfEmpty(const int n_landmarks) { - if (!x_filters_.empty()) { - RET_CHECK_EQ(x_filters_.size(), n_landmarks); - RET_CHECK_EQ(y_filters_.size(), n_landmarks); - RET_CHECK_EQ(z_filters_.size(), n_landmarks); - return absl::OkStatus(); - } - - for (int i = 0; i < n_landmarks; ++i) { - x_filters_.push_back( - OneEuroFilter(frequency_, min_cutoff_, beta_, derivate_cutoff_)); - y_filters_.push_back( - OneEuroFilter(frequency_, min_cutoff_, beta_, derivate_cutoff_)); - z_filters_.push_back( - OneEuroFilter(frequency_, min_cutoff_, beta_, derivate_cutoff_)); - } - - return absl::OkStatus(); - } - - double frequency_; - double min_cutoff_; - double beta_; - double derivate_cutoff_; - double min_allowed_object_scale_; - bool disable_value_scaling_; - - std::vector x_filters_; - std::vector y_filters_; - std::vector z_filters_; -}; +using ::mediapipe::landmarks_smoothing::GetObjectScale; +using ::mediapipe::landmarks_smoothing::InitializeLandmarksFilter; +using ::mediapipe::landmarks_smoothing::LandmarksFilter; +using ::mediapipe::landmarks_smoothing::LandmarksToNormalizedLandmarks; +using ::mediapipe::landmarks_smoothing::NormalizedLandmarksToLandmarks; } // namespace -// A calculator to smooth landmarks over time. -// -// Inputs: -// NORM_LANDMARKS: A NormalizedLandmarkList of landmarks you want to smooth. -// IMAGE_SIZE: A std::pair represention of image width and height. -// Required to perform all computations in absolute coordinates to avoid any -// influence of normalized values. -// OBJECT_SCALE_ROI (optional): A NormRect or Rect (depending on the format of -// input landmarks) used to determine the object scale for some of the -// filters. If not provided - object scale will be calculated from -// landmarks. -// -// Outputs: -// NORM_FILTERED_LANDMARKS: A NormalizedLandmarkList of smoothed landmarks. -// -// Example config: -// node { -// calculator: "LandmarksSmoothingCalculator" -// input_stream: "NORM_LANDMARKS:pose_landmarks" -// input_stream: "IMAGE_SIZE:image_size" -// input_stream: "OBJECT_SCALE_ROI:roi" -// output_stream: "NORM_FILTERED_LANDMARKS:pose_landmarks_filtered" -// options: { -// [mediapipe.LandmarksSmoothingCalculatorOptions.ext] { -// velocity_filter: { -// window_size: 5 -// velocity_scale: 10.0 -// } -// } -// } -// } -// -class LandmarksSmoothingCalculator : public CalculatorBase { +class LandmarksSmoothingCalculatorImpl + : public NodeImpl { public: - static absl::Status GetContract(CalculatorContract* cc); - absl::Status Open(CalculatorContext* cc) override; - absl::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override { + ASSIGN_OR_RETURN(landmarks_filter_, + InitializeLandmarksFilter( + cc->Options())); + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + // Check that landmarks are not empty and reset the filter if so. + // Don't emit an empty packet for this timestamp. + if ((kInNormLandmarks(cc).IsConnected() && + kInNormLandmarks(cc).IsEmpty()) || + (kInLandmarks(cc).IsConnected() && kInLandmarks(cc).IsEmpty())) { + MP_RETURN_IF_ERROR(landmarks_filter_->Reset()); + return absl::OkStatus(); + } + + const auto& timestamp = + absl::Microseconds(cc->InputTimestamp().Microseconds()); + + if (kInNormLandmarks(cc).IsConnected()) { + const auto& in_norm_landmarks = kInNormLandmarks(cc).Get(); + + int image_width; + int image_height; + std::tie(image_width, image_height) = kImageSize(cc).Get(); + + absl::optional object_scale; + if (kObjectScaleRoi(cc).IsConnected() && !kObjectScaleRoi(cc).IsEmpty()) { + auto& roi = kObjectScaleRoi(cc).Get(); + object_scale = GetObjectScale(roi, image_width, image_height); + } + + auto in_landmarks = absl::make_unique(); + NormalizedLandmarksToLandmarks(in_norm_landmarks, image_width, + image_height, *in_landmarks.get()); + + auto out_landmarks = absl::make_unique(); + MP_RETURN_IF_ERROR(landmarks_filter_->Apply( + *in_landmarks, timestamp, object_scale, *out_landmarks)); + + auto out_norm_landmarks = absl::make_unique(); + LandmarksToNormalizedLandmarks(*out_landmarks, image_width, image_height, + *out_norm_landmarks.get()); + + kOutNormLandmarks(cc).Send(std::move(out_norm_landmarks)); + } else { + const auto& in_landmarks = kInLandmarks(cc).Get(); + + absl::optional object_scale; + if (kObjectScaleRoi(cc).IsConnected() && !kObjectScaleRoi(cc).IsEmpty()) { + auto& roi = kObjectScaleRoi(cc).Get(); + object_scale = GetObjectScale(roi); + } + + auto out_landmarks = absl::make_unique(); + MP_RETURN_IF_ERROR(landmarks_filter_->Apply( + in_landmarks, timestamp, object_scale, *out_landmarks)); + + kOutLandmarks(cc).Send(std::move(out_landmarks)); + } + + return absl::OkStatus(); + } private: std::unique_ptr landmarks_filter_; }; -REGISTER_CALCULATOR(LandmarksSmoothingCalculator); - -absl::Status LandmarksSmoothingCalculator::GetContract(CalculatorContract* cc) { - if (cc->Inputs().HasTag(kNormalizedLandmarksTag)) { - cc->Inputs().Tag(kNormalizedLandmarksTag).Set(); - cc->Inputs().Tag(kImageSizeTag).Set>(); - cc->Outputs() - .Tag(kNormalizedFilteredLandmarksTag) - .Set(); - - if (cc->Inputs().HasTag(kObjectScaleRoiTag)) { - cc->Inputs().Tag(kObjectScaleRoiTag).Set(); - } - } else { - cc->Inputs().Tag(kLandmarksTag).Set(); - cc->Outputs().Tag(kFilteredLandmarksTag).Set(); - - if (cc->Inputs().HasTag(kObjectScaleRoiTag)) { - cc->Inputs().Tag(kObjectScaleRoiTag).Set(); - } - } - - return absl::OkStatus(); -} - -absl::Status LandmarksSmoothingCalculator::Open(CalculatorContext* cc) { - cc->SetOffset(TimestampDiff(0)); - - // Pick landmarks filter. - const auto& options = cc->Options(); - if (options.has_no_filter()) { - landmarks_filter_ = absl::make_unique(); - } else if (options.has_velocity_filter()) { - landmarks_filter_ = absl::make_unique( - options.velocity_filter().window_size(), - options.velocity_filter().velocity_scale(), - options.velocity_filter().min_allowed_object_scale(), - options.velocity_filter().disable_value_scaling()); - } else if (options.has_one_euro_filter()) { - landmarks_filter_ = absl::make_unique( - options.one_euro_filter().frequency(), - options.one_euro_filter().min_cutoff(), - options.one_euro_filter().beta(), - options.one_euro_filter().derivate_cutoff(), - options.one_euro_filter().min_allowed_object_scale(), - options.one_euro_filter().disable_value_scaling()); - } else { - RET_CHECK_FAIL() - << "Landmarks filter is either not specified or not supported"; - } - - return absl::OkStatus(); -} - -absl::Status LandmarksSmoothingCalculator::Process(CalculatorContext* cc) { - // Check that landmarks are not empty and reset the filter if so. - // Don't emit an empty packet for this timestamp. - if ((cc->Inputs().HasTag(kNormalizedLandmarksTag) && - cc->Inputs().Tag(kNormalizedLandmarksTag).IsEmpty()) || - (cc->Inputs().HasTag(kLandmarksTag) && - cc->Inputs().Tag(kLandmarksTag).IsEmpty())) { - MP_RETURN_IF_ERROR(landmarks_filter_->Reset()); - return absl::OkStatus(); - } - - const auto& timestamp = - absl::Microseconds(cc->InputTimestamp().Microseconds()); - - if (cc->Inputs().HasTag(kNormalizedLandmarksTag)) { - const auto& in_norm_landmarks = - cc->Inputs().Tag(kNormalizedLandmarksTag).Get(); - - int image_width; - int image_height; - std::tie(image_width, image_height) = - cc->Inputs().Tag(kImageSizeTag).Get>(); - - absl::optional object_scale; - if (cc->Inputs().HasTag(kObjectScaleRoiTag) && - !cc->Inputs().Tag(kObjectScaleRoiTag).IsEmpty()) { - auto& roi = cc->Inputs().Tag(kObjectScaleRoiTag).Get(); - object_scale = GetObjectScale(roi, image_width, image_height); - } - - auto in_landmarks = absl::make_unique(); - NormalizedLandmarksToLandmarks(in_norm_landmarks, image_width, image_height, - in_landmarks.get()); - - auto out_landmarks = absl::make_unique(); - MP_RETURN_IF_ERROR(landmarks_filter_->Apply( - *in_landmarks, timestamp, object_scale, out_landmarks.get())); - - auto out_norm_landmarks = absl::make_unique(); - LandmarksToNormalizedLandmarks(*out_landmarks, image_width, image_height, - out_norm_landmarks.get()); - - cc->Outputs() - .Tag(kNormalizedFilteredLandmarksTag) - .Add(out_norm_landmarks.release(), cc->InputTimestamp()); - } else { - const auto& in_landmarks = - cc->Inputs().Tag(kLandmarksTag).Get(); - - absl::optional object_scale; - if (cc->Inputs().HasTag(kObjectScaleRoiTag) && - !cc->Inputs().Tag(kObjectScaleRoiTag).IsEmpty()) { - auto& roi = cc->Inputs().Tag(kObjectScaleRoiTag).Get(); - object_scale = GetObjectScale(roi); - } - - auto out_landmarks = absl::make_unique(); - MP_RETURN_IF_ERROR(landmarks_filter_->Apply( - in_landmarks, timestamp, object_scale, out_landmarks.get())); - - cc->Outputs() - .Tag(kFilteredLandmarksTag) - .Add(out_landmarks.release(), cc->InputTimestamp()); - } - - return absl::OkStatus(); -} +MEDIAPIPE_NODE_IMPLEMENTATION(LandmarksSmoothingCalculatorImpl); +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/util/landmarks_smoothing_calculator.h b/mediapipe/calculators/util/landmarks_smoothing_calculator.h new file mode 100644 index 000000000..a64286c15 --- /dev/null +++ b/mediapipe/calculators/util/landmarks_smoothing_calculator.h @@ -0,0 +1,106 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_UTIL_LANDMARKS_SMOOTHING_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_UTIL_LANDMARKS_SMOOTHING_CALCULATOR_H_ + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { +namespace api2 { + +// A calculator to smooth landmarks over time. +// +// Inputs: +// NORM_LANDMARKS (optional): A NormalizedLandmarkList of landmarks you want +// to smooth. +// LANDMARKS (optional): A LandmarkList of landmarks you want to smooth. +// IMAGE_SIZE (optional): A std::pair represention of image width +// and height. Required to perform all computations in absolute coordinates +// when smoothing NORM_LANDMARKS to avoid any influence of normalized +// values. +// OBJECT_SCALE_ROI (optional): A NormRect or Rect (depending on the format of +// input landmarks) used to determine the object scale for some of the +// filters. If not provided - object scale will be calculated from +// landmarks. +// +// Outputs: +// NORM_FILTERED_LANDMARKS (optional): A NormalizedLandmarkList of smoothed +// landmarks. +// FILTERED_LANDMARKS (optional): A LandmarkList of smoothed landmarks. +// +// Example config: +// node { +// calculator: "LandmarksSmoothingCalculator" +// input_stream: "NORM_LANDMARKS:landmarks" +// input_stream: "IMAGE_SIZE:image_size" +// input_stream: "OBJECT_SCALE_ROI:roi" +// output_stream: "NORM_FILTERED_LANDMARKS:landmarks_filtered" +// options: { +// [mediapipe.LandmarksSmoothingCalculatorOptions.ext] { +// velocity_filter: { +// window_size: 5 +// velocity_scale: 10.0 +// } +// } +// } +// } +// +class LandmarksSmoothingCalculator : public NodeIntf { + public: + static constexpr Input::Optional + kInNormLandmarks{"NORM_LANDMARKS"}; + static constexpr Input::Optional kInLandmarks{ + "LANDMARKS"}; + static constexpr Input>::Optional kImageSize{ + "IMAGE_SIZE"}; + static constexpr Input>::Optional kObjectScaleRoi{ + "OBJECT_SCALE_ROI"}; + static constexpr Output::Optional + kOutNormLandmarks{"NORM_FILTERED_LANDMARKS"}; + static constexpr Output::Optional kOutLandmarks{ + "FILTERED_LANDMARKS"}; + MEDIAPIPE_NODE_INTERFACE(LandmarksSmoothingCalculator, kInNormLandmarks, + kInLandmarks, kImageSize, kObjectScaleRoi, + kOutNormLandmarks, kOutLandmarks); + + static absl::Status UpdateContract(CalculatorContract* cc) { + RET_CHECK(kInNormLandmarks(cc).IsConnected() ^ + kInLandmarks(cc).IsConnected()) + << "One and only one of NORM_LANDMARKS and LANDMARKS input is allowed"; + + // TODO: Verify scale ROI is of the same type as landmarks + // that are being smoothed. + + if (kInNormLandmarks(cc).IsConnected()) { + RET_CHECK(kImageSize(cc).IsConnected()); + RET_CHECK(kOutNormLandmarks(cc).IsConnected()); + RET_CHECK(!kOutLandmarks(cc).IsConnected()); + } else { + RET_CHECK(!kImageSize(cc).IsConnected()); + RET_CHECK(kOutLandmarks(cc).IsConnected()); + RET_CHECK(!kOutNormLandmarks(cc).IsConnected()); + } + + return absl::OkStatus(); + } +}; + +} // namespace api2 +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_UTIL_LANDMARKS_SMOOTHING_CALCULATOR_H_ diff --git a/mediapipe/calculators/util/landmarks_smoothing_calculator_utils.cc b/mediapipe/calculators/util/landmarks_smoothing_calculator_utils.cc new file mode 100644 index 000000000..32e282150 --- /dev/null +++ b/mediapipe/calculators/util/landmarks_smoothing_calculator_utils.cc @@ -0,0 +1,375 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/util/landmarks_smoothing_calculator_utils.h" + +#include + +#include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/util/filtering/one_euro_filter.h" +#include "mediapipe/util/filtering/relative_velocity_filter.h" + +namespace mediapipe { +namespace landmarks_smoothing { + +namespace { + +using ::mediapipe::NormalizedRect; +using ::mediapipe::OneEuroFilter; +using ::mediapipe::Rect; +using ::mediapipe::RelativeVelocityFilter; + +// Estimate object scale to use its inverse value as velocity scale for +// RelativeVelocityFilter. If value will be too small (less than +// `options_.min_allowed_object_scale`) smoothing will be disabled and +// landmarks will be returned as is. +// Object scale is calculated as average between bounding box width and height +// with sides parallel to axis. +float GetObjectScale(const LandmarkList& landmarks) { + const auto& lm_minmax_x = absl::c_minmax_element( + landmarks.landmark(), + [](const auto& a, const auto& b) { return a.x() < b.x(); }); + const float x_min = lm_minmax_x.first->x(); + const float x_max = lm_minmax_x.second->x(); + + const auto& lm_minmax_y = absl::c_minmax_element( + landmarks.landmark(), + [](const auto& a, const auto& b) { return a.y() < b.y(); }); + const float y_min = lm_minmax_y.first->y(); + const float y_max = lm_minmax_y.second->y(); + + const float object_width = x_max - x_min; + const float object_height = y_max - y_min; + + return (object_width + object_height) / 2.0f; +} + +// Returns landmarks as is without smoothing. +class NoFilter : public LandmarksFilter { + public: + absl::Status Apply(const LandmarkList& in_landmarks, + const absl::Duration& timestamp, + const absl::optional object_scale_opt, + LandmarkList& out_landmarks) override { + out_landmarks = in_landmarks; + return absl::OkStatus(); + } +}; + +// Please check RelativeVelocityFilter documentation for details. +class VelocityFilter : public LandmarksFilter { + public: + VelocityFilter(int window_size, float velocity_scale, + float min_allowed_object_scale, bool disable_value_scaling) + : window_size_(window_size), + velocity_scale_(velocity_scale), + min_allowed_object_scale_(min_allowed_object_scale), + disable_value_scaling_(disable_value_scaling) {} + + absl::Status Reset() override { + x_filters_.clear(); + y_filters_.clear(); + z_filters_.clear(); + return absl::OkStatus(); + } + + absl::Status Apply(const LandmarkList& in_landmarks, + const absl::Duration& timestamp, + const absl::optional object_scale_opt, + LandmarkList& out_landmarks) override { + // Get value scale as inverse value of the object scale. + // If value is too small smoothing will be disabled and landmarks will be + // returned as is. + float value_scale = 1.0f; + if (!disable_value_scaling_) { + const float object_scale = + object_scale_opt ? *object_scale_opt : GetObjectScale(in_landmarks); + if (object_scale < min_allowed_object_scale_) { + out_landmarks = in_landmarks; + return absl::OkStatus(); + } + value_scale = 1.0f / object_scale; + } + + // Initialize filters once. + MP_RETURN_IF_ERROR(InitializeFiltersIfEmpty(in_landmarks.landmark_size())); + + // Filter landmarks. Every axis of every landmark is filtered separately. + for (int i = 0; i < in_landmarks.landmark_size(); ++i) { + const auto& in_landmark = in_landmarks.landmark(i); + + auto* out_landmark = out_landmarks.add_landmark(); + *out_landmark = in_landmark; + out_landmark->set_x( + x_filters_[i].Apply(timestamp, value_scale, in_landmark.x())); + out_landmark->set_y( + y_filters_[i].Apply(timestamp, value_scale, in_landmark.y())); + out_landmark->set_z( + z_filters_[i].Apply(timestamp, value_scale, in_landmark.z())); + } + + return absl::OkStatus(); + } + + private: + // Initializes filters for the first time or after Reset. If initialized then + // check the size. + absl::Status InitializeFiltersIfEmpty(const int n_landmarks) { + if (!x_filters_.empty()) { + RET_CHECK_EQ(x_filters_.size(), n_landmarks); + RET_CHECK_EQ(y_filters_.size(), n_landmarks); + RET_CHECK_EQ(z_filters_.size(), n_landmarks); + return absl::OkStatus(); + } + + x_filters_.resize(n_landmarks, + RelativeVelocityFilter(window_size_, velocity_scale_)); + y_filters_.resize(n_landmarks, + RelativeVelocityFilter(window_size_, velocity_scale_)); + z_filters_.resize(n_landmarks, + RelativeVelocityFilter(window_size_, velocity_scale_)); + + return absl::OkStatus(); + } + + int window_size_; + float velocity_scale_; + float min_allowed_object_scale_; + bool disable_value_scaling_; + + std::vector x_filters_; + std::vector y_filters_; + std::vector z_filters_; +}; + +// Please check OneEuroFilter documentation for details. +class OneEuroFilterImpl : public LandmarksFilter { + public: + OneEuroFilterImpl(double frequency, double min_cutoff, double beta, + double derivate_cutoff, float min_allowed_object_scale, + bool disable_value_scaling) + : frequency_(frequency), + min_cutoff_(min_cutoff), + beta_(beta), + derivate_cutoff_(derivate_cutoff), + min_allowed_object_scale_(min_allowed_object_scale), + disable_value_scaling_(disable_value_scaling) {} + + absl::Status Reset() override { + x_filters_.clear(); + y_filters_.clear(); + z_filters_.clear(); + return absl::OkStatus(); + } + + absl::Status Apply(const LandmarkList& in_landmarks, + const absl::Duration& timestamp, + const absl::optional object_scale_opt, + LandmarkList& out_landmarks) override { + // Initialize filters once. + MP_RETURN_IF_ERROR(InitializeFiltersIfEmpty(in_landmarks.landmark_size())); + + // Get value scale as inverse value of the object scale. + // If value is too small smoothing will be disabled and landmarks will be + // returned as is. + float value_scale = 1.0f; + if (!disable_value_scaling_) { + const float object_scale = + object_scale_opt ? *object_scale_opt : GetObjectScale(in_landmarks); + if (object_scale < min_allowed_object_scale_) { + out_landmarks = in_landmarks; + return absl::OkStatus(); + } + value_scale = 1.0f / object_scale; + } + + // Filter landmarks. Every axis of every landmark is filtered separately. + for (int i = 0; i < in_landmarks.landmark_size(); ++i) { + const auto& in_landmark = in_landmarks.landmark(i); + + auto* out_landmark = out_landmarks.add_landmark(); + *out_landmark = in_landmark; + out_landmark->set_x( + x_filters_[i].Apply(timestamp, value_scale, in_landmark.x())); + out_landmark->set_y( + y_filters_[i].Apply(timestamp, value_scale, in_landmark.y())); + out_landmark->set_z( + z_filters_[i].Apply(timestamp, value_scale, in_landmark.z())); + } + + return absl::OkStatus(); + } + + private: + // Initializes filters for the first time or after Reset. If initialized then + // check the size. + absl::Status InitializeFiltersIfEmpty(const int n_landmarks) { + if (!x_filters_.empty()) { + RET_CHECK_EQ(x_filters_.size(), n_landmarks); + RET_CHECK_EQ(y_filters_.size(), n_landmarks); + RET_CHECK_EQ(z_filters_.size(), n_landmarks); + return absl::OkStatus(); + } + + for (int i = 0; i < n_landmarks; ++i) { + x_filters_.push_back( + OneEuroFilter(frequency_, min_cutoff_, beta_, derivate_cutoff_)); + y_filters_.push_back( + OneEuroFilter(frequency_, min_cutoff_, beta_, derivate_cutoff_)); + z_filters_.push_back( + OneEuroFilter(frequency_, min_cutoff_, beta_, derivate_cutoff_)); + } + + return absl::OkStatus(); + } + + double frequency_; + double min_cutoff_; + double beta_; + double derivate_cutoff_; + double min_allowed_object_scale_; + bool disable_value_scaling_; + + std::vector x_filters_; + std::vector y_filters_; + std::vector z_filters_; +}; + +} // namespace + +void NormalizedLandmarksToLandmarks( + const NormalizedLandmarkList& norm_landmarks, const int image_width, + const int image_height, LandmarkList& landmarks) { + for (int i = 0; i < norm_landmarks.landmark_size(); ++i) { + const auto& norm_landmark = norm_landmarks.landmark(i); + + auto* landmark = landmarks.add_landmark(); + landmark->set_x(norm_landmark.x() * image_width); + landmark->set_y(norm_landmark.y() * image_height); + // Scale Z the same way as X (using image width). + landmark->set_z(norm_landmark.z() * image_width); + + if (norm_landmark.has_visibility()) { + landmark->set_visibility(norm_landmark.visibility()); + } else { + landmark->clear_visibility(); + } + + if (norm_landmark.has_presence()) { + landmark->set_presence(norm_landmark.presence()); + } else { + landmark->clear_presence(); + } + } +} + +void LandmarksToNormalizedLandmarks(const LandmarkList& landmarks, + const int image_width, + const int image_height, + NormalizedLandmarkList& norm_landmarks) { + for (int i = 0; i < landmarks.landmark_size(); ++i) { + const auto& landmark = landmarks.landmark(i); + + auto* norm_landmark = norm_landmarks.add_landmark(); + norm_landmark->set_x(landmark.x() / image_width); + norm_landmark->set_y(landmark.y() / image_height); + // Scale Z the same way as X (using image width). + norm_landmark->set_z(landmark.z() / image_width); + + if (landmark.has_visibility()) { + norm_landmark->set_visibility(landmark.visibility()); + } else { + norm_landmark->clear_visibility(); + } + + if (landmark.has_presence()) { + norm_landmark->set_presence(landmark.presence()); + } else { + norm_landmark->clear_presence(); + } + } +} + +float GetObjectScale(const NormalizedRect& roi, const int image_width, + const int image_height) { + const float object_width = roi.width() * image_width; + const float object_height = roi.height() * image_height; + + return (object_width + object_height) / 2.0f; +} + +float GetObjectScale(const Rect& roi) { + return (roi.width() + roi.height()) / 2.0f; +} + +absl::StatusOr> InitializeLandmarksFilter( + const LandmarksSmoothingCalculatorOptions& options) { + if (options.has_no_filter()) { + return absl::make_unique(); + } else if (options.has_velocity_filter()) { + return absl::make_unique( + options.velocity_filter().window_size(), + options.velocity_filter().velocity_scale(), + options.velocity_filter().min_allowed_object_scale(), + options.velocity_filter().disable_value_scaling()); + } else if (options.has_one_euro_filter()) { + return absl::make_unique( + options.one_euro_filter().frequency(), + options.one_euro_filter().min_cutoff(), + options.one_euro_filter().beta(), + options.one_euro_filter().derivate_cutoff(), + options.one_euro_filter().min_allowed_object_scale(), + options.one_euro_filter().disable_value_scaling()); + } else { + RET_CHECK_FAIL() + << "Landmarks filter is either not specified or not supported"; + } +} + +absl::StatusOr MultiLandmarkFilters::GetOrCreate( + const int64_t tracking_id, + const mediapipe::LandmarksSmoothingCalculatorOptions& options) { + const auto it = filters_.find(tracking_id); + if (it != filters_.end()) { + return it->second.get(); + } + + ASSIGN_OR_RETURN(auto landmarks_filter, InitializeLandmarksFilter(options)); + filters_[tracking_id] = std::move(landmarks_filter); + return filters_[tracking_id].get(); +} + +void MultiLandmarkFilters::ClearUnused( + const std::vector& tracking_ids) { + std::vector unused_tracking_ids; + for (const auto& it : filters_) { + bool unused = true; + for (int64_t tracking_id : tracking_ids) { + if (tracking_id == it.first) unused = false; + } + if (unused) unused_tracking_ids.push_back(it.first); + } + + for (int64_t tracking_id : unused_tracking_ids) { + filters_.erase(tracking_id); + } +} + +void MultiLandmarkFilters::Clear() { filters_.clear(); } + +} // namespace landmarks_smoothing +} // namespace mediapipe diff --git a/mediapipe/calculators/util/landmarks_smoothing_calculator_utils.h b/mediapipe/calculators/util/landmarks_smoothing_calculator_utils.h new file mode 100644 index 000000000..c63db926a --- /dev/null +++ b/mediapipe/calculators/util/landmarks_smoothing_calculator_utils.h @@ -0,0 +1,77 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_UTIL_LANDMARKS_SMOOTHING_CALCULATOR_UTILS_H_ +#define MEDIAPIPE_CALCULATORS_UTIL_LANDMARKS_SMOOTHING_CALCULATOR_UTILS_H_ + +#include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h" +#include "mediapipe/framework/calculator_context.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/util/filtering/one_euro_filter.h" +#include "mediapipe/util/filtering/relative_velocity_filter.h" + +namespace mediapipe { +namespace landmarks_smoothing { + +void NormalizedLandmarksToLandmarks( + const mediapipe::NormalizedLandmarkList& norm_landmarks, + const int image_width, const int image_height, + mediapipe::LandmarkList& landmarks); + +void LandmarksToNormalizedLandmarks( + const mediapipe::LandmarkList& landmarks, const int image_width, + const int image_height, mediapipe::NormalizedLandmarkList& norm_landmarks); + +float GetObjectScale(const NormalizedRect& roi, const int image_width, + const int image_height); + +float GetObjectScale(const Rect& roi); + +// Abstract class for various landmarks filters. +class LandmarksFilter { + public: + virtual ~LandmarksFilter() = default; + + virtual absl::Status Reset() { return absl::OkStatus(); } + + virtual absl::Status Apply(const mediapipe::LandmarkList& in_landmarks, + const absl::Duration& timestamp, + const absl::optional object_scale_opt, + mediapipe::LandmarkList& out_landmarks) = 0; +}; + +absl::StatusOr> InitializeLandmarksFilter( + const mediapipe::LandmarksSmoothingCalculatorOptions& options); + +class MultiLandmarkFilters { + public: + virtual ~MultiLandmarkFilters() = default; + + virtual absl::StatusOr GetOrCreate( + const int64_t tracking_id, + const mediapipe::LandmarksSmoothingCalculatorOptions& options); + + virtual void ClearUnused(const std::vector& tracking_ids); + + virtual void Clear(); + + private: + std::map> filters_; +}; + +} // namespace landmarks_smoothing +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_UTIL_LANDMARKS_SMOOTHING_CALCULATOR_UTILS_H_ diff --git a/mediapipe/calculators/util/landmarks_smoothing_calculator_utils_test.cc b/mediapipe/calculators/util/landmarks_smoothing_calculator_utils_test.cc new file mode 100644 index 000000000..2404beb1a --- /dev/null +++ b/mediapipe/calculators/util/landmarks_smoothing_calculator_utils_test.cc @@ -0,0 +1,118 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/calculators/util/landmarks_smoothing_calculator_utils.h" + +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" + +namespace mediapipe { +namespace landmarks_smoothing { +namespace { + +TEST(LandmarksSmoothingCalculatorUtilsTest, NormalizedLandmarksToLandmarks) { + NormalizedLandmarkList norm_landmarks; + NormalizedLandmark* norm_landmark = norm_landmarks.add_landmark(); + norm_landmark->set_x(0.1); + norm_landmark->set_y(0.2); + norm_landmark->set_z(0.3); + norm_landmark->set_visibility(0.4); + norm_landmark->set_presence(0.5); + + LandmarkList landmarks; + NormalizedLandmarksToLandmarks(norm_landmarks, /*image_width=*/10, + /*image_height=*/10, landmarks); + + EXPECT_EQ(landmarks.landmark_size(), 1); + Landmark landmark = landmarks.landmark(0); + EXPECT_NEAR(landmark.x(), 1.0, 1e-6); + EXPECT_NEAR(landmark.y(), 2.0, 1e-6); + EXPECT_NEAR(landmark.z(), 3.0, 1e-6); + EXPECT_NEAR(landmark.visibility(), 0.4, 1e-6); + EXPECT_NEAR(landmark.presence(), 0.5, 1e-6); +} + +TEST(LandmarksSmoothingCalculatorUtilsTest, + NormalizedLandmarksToLandmarks_EmptyVisibilityAndPresence) { + NormalizedLandmarkList norm_landmarks; + NormalizedLandmark* norm_landmark = norm_landmarks.add_landmark(); + norm_landmark->set_x(0.1); + norm_landmark->set_y(0.2); + norm_landmark->set_z(0.3); + norm_landmark->clear_visibility(); + norm_landmark->clear_presence(); + + LandmarkList landmarks; + NormalizedLandmarksToLandmarks(norm_landmarks, /*image_width=*/10, + /*image_height=*/10, landmarks); + + EXPECT_EQ(landmarks.landmark_size(), 1); + Landmark landmark = landmarks.landmark(0); + EXPECT_NEAR(landmark.x(), 1.0, 1e-6); + EXPECT_NEAR(landmark.y(), 2.0, 1e-6); + EXPECT_NEAR(landmark.z(), 3.0, 1e-6); + EXPECT_FALSE(landmark.has_visibility()); + EXPECT_FALSE(landmark.has_presence()); +} + +TEST(LandmarksSmoothingCalculatorUtilsTest, LandmarksToNormalizedLandmarks) { + LandmarkList landmarks; + Landmark* landmark = landmarks.add_landmark(); + landmark->set_x(1.0); + landmark->set_y(2.0); + landmark->set_z(3.0); + landmark->set_visibility(0.4); + landmark->set_presence(0.5); + + NormalizedLandmarkList norm_landmarks; + LandmarksToNormalizedLandmarks(landmarks, /*image_width=*/10, + /*image_height=*/10, norm_landmarks); + + EXPECT_EQ(norm_landmarks.landmark_size(), 1); + NormalizedLandmark norm_landmark = norm_landmarks.landmark(0); + EXPECT_NEAR(norm_landmark.x(), 0.1, 1e-6); + EXPECT_NEAR(norm_landmark.y(), 0.2, 1e-6); + EXPECT_NEAR(norm_landmark.z(), 0.3, 1e-6); + EXPECT_NEAR(norm_landmark.visibility(), 0.4, 1e-6); + EXPECT_NEAR(norm_landmark.presence(), 0.5, 1e-6); +} + +TEST(LandmarksSmoothingCalculatorUtilsTest, + LandmarksToNormalizedLandmarks_EmptyVisibilityAndPresence) { + LandmarkList landmarks; + Landmark* landmark = landmarks.add_landmark(); + landmark->set_x(1.0); + landmark->set_y(2.0); + landmark->set_z(3.0); + landmark->clear_visibility(); + landmark->clear_presence(); + + NormalizedLandmarkList norm_landmarks; + LandmarksToNormalizedLandmarks(landmarks, /*image_width=*/10, + /*image_height=*/10, norm_landmarks); + + EXPECT_EQ(norm_landmarks.landmark_size(), 1); + NormalizedLandmark norm_landmark = norm_landmarks.landmark(0); + EXPECT_NEAR(norm_landmark.x(), 0.1, 1e-6); + EXPECT_NEAR(norm_landmark.y(), 0.2, 1e-6); + EXPECT_NEAR(norm_landmark.z(), 0.3, 1e-6); + EXPECT_FALSE(norm_landmark.has_visibility()); + EXPECT_FALSE(norm_landmark.has_presence()); +} + +} // namespace +} // namespace landmarks_smoothing +} // namespace mediapipe diff --git a/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc b/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc index 263ef85c6..b0d4f4175 100644 --- a/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc @@ -322,27 +322,30 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) { options_.presence_threshold(), options_.connection_color(), thickness, /*normalized=*/false, render_data.get()); } - for (int i = 0; i < landmarks.landmark_size(); ++i) { - const Landmark& landmark = landmarks.landmark(i); + if (options_.render_landmarks()) { + for (int i = 0; i < landmarks.landmark_size(); ++i) { + const Landmark& landmark = landmarks.landmark(i); - if (!IsLandmarkVisibleAndPresent( - landmark, options_.utilize_visibility(), - options_.visibility_threshold(), options_.utilize_presence(), - options_.presence_threshold())) { - continue; - } + if (!IsLandmarkVisibleAndPresent( + landmark, options_.utilize_visibility(), + options_.visibility_threshold(), options_.utilize_presence(), + options_.presence_threshold())) { + continue; + } - auto* landmark_data_render = AddPointRenderData( - options_.landmark_color(), thickness, render_data.get()); - if (visualize_depth) { - SetColorSizeValueFromZ(landmark.z(), z_min, z_max, landmark_data_render, - options_.min_depth_circle_thickness(), - options_.max_depth_circle_thickness()); + auto* landmark_data_render = AddPointRenderData( + options_.landmark_color(), thickness, render_data.get()); + if (visualize_depth) { + SetColorSizeValueFromZ(landmark.z(), z_min, z_max, + landmark_data_render, + options_.min_depth_circle_thickness(), + options_.max_depth_circle_thickness()); + } + auto* landmark_data = landmark_data_render->mutable_point(); + landmark_data->set_normalized(false); + landmark_data->set_x(landmark.x()); + landmark_data->set_y(landmark.y()); } - auto* landmark_data = landmark_data_render->mutable_point(); - landmark_data->set_normalized(false); - landmark_data->set_x(landmark.x()); - landmark_data->set_y(landmark.y()); } } @@ -368,27 +371,30 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) { options_.presence_threshold(), options_.connection_color(), thickness, /*normalized=*/true, render_data.get()); } - for (int i = 0; i < landmarks.landmark_size(); ++i) { - const NormalizedLandmark& landmark = landmarks.landmark(i); + if (options_.render_landmarks()) { + for (int i = 0; i < landmarks.landmark_size(); ++i) { + const NormalizedLandmark& landmark = landmarks.landmark(i); - if (!IsLandmarkVisibleAndPresent( - landmark, options_.utilize_visibility(), - options_.visibility_threshold(), options_.utilize_presence(), - options_.presence_threshold())) { - continue; - } + if (!IsLandmarkVisibleAndPresent( + landmark, options_.utilize_visibility(), + options_.visibility_threshold(), options_.utilize_presence(), + options_.presence_threshold())) { + continue; + } - auto* landmark_data_render = AddPointRenderData( - options_.landmark_color(), thickness, render_data.get()); - if (visualize_depth) { - SetColorSizeValueFromZ(landmark.z(), z_min, z_max, landmark_data_render, - options_.min_depth_circle_thickness(), - options_.max_depth_circle_thickness()); + auto* landmark_data_render = AddPointRenderData( + options_.landmark_color(), thickness, render_data.get()); + if (visualize_depth) { + SetColorSizeValueFromZ(landmark.z(), z_min, z_max, + landmark_data_render, + options_.min_depth_circle_thickness(), + options_.max_depth_circle_thickness()); + } + auto* landmark_data = landmark_data_render->mutable_point(); + landmark_data->set_normalized(true); + landmark_data->set_x(landmark.x()); + landmark_data->set_y(landmark.y()); } - auto* landmark_data = landmark_data_render->mutable_point(); - landmark_data->set_normalized(true); - landmark_data->set_x(landmark.x()); - landmark_data->set_y(landmark.y()); } } diff --git a/mediapipe/calculators/util/landmarks_to_render_data_calculator.proto b/mediapipe/calculators/util/landmarks_to_render_data_calculator.proto index 990919540..67dca84ad 100644 --- a/mediapipe/calculators/util/landmarks_to_render_data_calculator.proto +++ b/mediapipe/calculators/util/landmarks_to_render_data_calculator.proto @@ -32,6 +32,10 @@ message LandmarksToRenderDataCalculatorOptions { // Color of the landmarks. optional Color landmark_color = 2; + + // Whether to render landmarks as points. + optional bool render_landmarks = 14 [default = true]; + // Color of the connections. optional Color connection_color = 3; diff --git a/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc b/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc index a9bc51f66..d83ff67c0 100644 --- a/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc +++ b/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc @@ -15,6 +15,7 @@ #include #include +#include "absl/log/absl_log.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/status.h" @@ -58,7 +59,7 @@ class LocalFilePatternContentsCalculator : public CalculatorBase { absl::Status Process(CalculatorContext* cc) override { if (current_output_ < filenames_.size()) { auto contents = absl::make_unique(); - LOG(INFO) << filenames_[current_output_]; + ABSL_LOG(INFO) << filenames_[current_output_]; MP_RETURN_IF_ERROR(mediapipe::file::GetContents( filenames_[current_output_], contents.get())); ++current_output_; diff --git a/mediapipe/calculators/util/logic_calculator.proto b/mediapipe/calculators/util/logic_calculator.proto index fe00a2d9b..8f9f6d856 100644 --- a/mediapipe/calculators/util/logic_calculator.proto +++ b/mediapipe/calculators/util/logic_calculator.proto @@ -18,6 +18,9 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; +option java_package = "com.google.mediapipe.calculator.proto"; +option java_outer_classname = "LogicCalculatorOptionsProto"; + message LogicCalculatorOptions { extend CalculatorOptions { optional LogicCalculatorOptions ext = 338731246; diff --git a/mediapipe/calculators/util/multi_landmarks_smoothing_calculator.cc b/mediapipe/calculators/util/multi_landmarks_smoothing_calculator.cc new file mode 100644 index 000000000..40098935c --- /dev/null +++ b/mediapipe/calculators/util/multi_landmarks_smoothing_calculator.cc @@ -0,0 +1,113 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/util/multi_landmarks_smoothing_calculator.h" + +#include +#include +#include +#include + +#include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h" +#include "mediapipe/calculators/util/landmarks_smoothing_calculator_utils.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/timestamp.h" + +namespace mediapipe { +namespace api2 { + +namespace { + +using ::mediapipe::NormalizedRect; +using ::mediapipe::landmarks_smoothing::GetObjectScale; +using ::mediapipe::landmarks_smoothing::LandmarksToNormalizedLandmarks; +using ::mediapipe::landmarks_smoothing::MultiLandmarkFilters; +using ::mediapipe::landmarks_smoothing::NormalizedLandmarksToLandmarks; + +} // namespace + +class MultiLandmarksSmoothingCalculatorImpl + : public NodeImpl { + public: + absl::Status Process(CalculatorContext* cc) override { + // Check that landmarks are not empty and reset the filter if so. + // Don't emit an empty packet for this timestamp. + if (kInNormLandmarks(cc).IsEmpty()) { + multi_filters_.Clear(); + return absl::OkStatus(); + } + + const auto& timestamp = + absl::Microseconds(cc->InputTimestamp().Microseconds()); + + const auto& tracking_ids = kTrackingIds(cc).Get(); + multi_filters_.ClearUnused(tracking_ids); + + const auto& in_norm_landmarks_vec = kInNormLandmarks(cc).Get(); + RET_CHECK_EQ(in_norm_landmarks_vec.size(), tracking_ids.size()); + + int image_width; + int image_height; + std::tie(image_width, image_height) = kImageSize(cc).Get(); + + std::optional> object_scale_roi_vec; + if (kObjectScaleRoi(cc).IsConnected() && !kObjectScaleRoi(cc).IsEmpty()) { + object_scale_roi_vec = kObjectScaleRoi(cc).Get(); + RET_CHECK_EQ(object_scale_roi_vec.value().size(), tracking_ids.size()); + } + + std::vector out_norm_landmarks_vec; + for (int i = 0; i < tracking_ids.size(); ++i) { + LandmarkList in_landmarks; + NormalizedLandmarksToLandmarks(in_norm_landmarks_vec[i], image_width, + image_height, in_landmarks); + + std::optional object_scale; + if (object_scale_roi_vec) { + object_scale = GetObjectScale(object_scale_roi_vec.value()[i], + image_width, image_height); + } + + ASSIGN_OR_RETURN(auto* landmarks_filter, + multi_filters_.GetOrCreate( + tracking_ids[i], + cc->Options())); + + LandmarkList out_landmarks; + MP_RETURN_IF_ERROR(landmarks_filter->Apply(in_landmarks, timestamp, + object_scale, out_landmarks)); + + NormalizedLandmarkList out_norm_landmarks; + LandmarksToNormalizedLandmarks(out_landmarks, image_width, image_height, + out_norm_landmarks); + + out_norm_landmarks_vec.push_back(std::move(out_norm_landmarks)); + } + + kOutNormLandmarks(cc).Send(std::move(out_norm_landmarks_vec)); + + return absl::OkStatus(); + } + + private: + MultiLandmarkFilters multi_filters_; +}; +MEDIAPIPE_NODE_IMPLEMENTATION(MultiLandmarksSmoothingCalculatorImpl); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/util/multi_landmarks_smoothing_calculator.h b/mediapipe/calculators/util/multi_landmarks_smoothing_calculator.h new file mode 100644 index 000000000..6c834ef56 --- /dev/null +++ b/mediapipe/calculators/util/multi_landmarks_smoothing_calculator.h @@ -0,0 +1,81 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_UTIL_MULTI_LANDMARKS_SMOOTHING_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_UTIL_MULTI_LANDMARKS_SMOOTHING_CALCULATOR_H_ + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" + +namespace mediapipe { +namespace api2 { + +// A calculator to smooth landmarks over time. +// +// Inputs: +// NORM_LANDMARKS: A std::vector of landmarks you want +// to smooth. +// TRACKING_IDS: A std vector of tracking IDs used to associate +// landmarks over time. When new ID arrives - calculator will initialize new +// filter. When tracking ID is no longer provided - calculator will forget +// smoothing state. +// IMAGE_SIZE: A std::pair represention of image width and height. +// Required to perform all computations in absolute coordinates to avoid any +// influence of normalized values. +// OBJECT_SCALE_ROI (optional): A std::vector used to determine the +// object scale for some of the filters. If not provided - object scale will +// be calculated from landmarks. +// +// Outputs: +// NORM_FILTERED_LANDMARKS: A std::vector of smoothed +// landmarks. +// +// Example config: +// node { +// calculator: "MultiLandmarksSmoothingCalculator" +// input_stream: "NORM_LANDMARKS:pose_landmarks" +// input_stream: "IMAGE_SIZE:image_size" +// input_stream: "OBJECT_SCALE_ROI:roi" +// output_stream: "NORM_FILTERED_LANDMARKS:pose_landmarks_filtered" +// options: { +// [mediapipe.LandmarksSmoothingCalculatorOptions.ext] { +// velocity_filter: { +// window_size: 5 +// velocity_scale: 10.0 +// } +// } +// } +// } +// +class MultiLandmarksSmoothingCalculator : public NodeIntf { + public: + static constexpr Input> + kInNormLandmarks{"NORM_LANDMARKS"}; + static constexpr Input> kTrackingIds{"TRACKING_IDS"}; + static constexpr Input> kImageSize{"IMAGE_SIZE"}; + static constexpr Input>::Optional kObjectScaleRoi{ + "OBJECT_SCALE_ROI"}; + static constexpr Output> + kOutNormLandmarks{"NORM_FILTERED_LANDMARKS"}; + + MEDIAPIPE_NODE_INTERFACE(MultiLandmarksSmoothingCalculator, kInNormLandmarks, + kTrackingIds, kImageSize, kObjectScaleRoi, + kOutNormLandmarks); +}; + +} // namespace api2 +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_UTIL_MULTI_LANDMARKS_SMOOTHING_CALCULATOR_H_ diff --git a/mediapipe/calculators/util/multi_world_landmarks_smoothing_calculator.cc b/mediapipe/calculators/util/multi_world_landmarks_smoothing_calculator.cc new file mode 100644 index 000000000..ddc16d296 --- /dev/null +++ b/mediapipe/calculators/util/multi_world_landmarks_smoothing_calculator.cc @@ -0,0 +1,100 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/util/multi_world_landmarks_smoothing_calculator.h" + +#include +#include +#include +#include + +#include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h" +#include "mediapipe/calculators/util/landmarks_smoothing_calculator_utils.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/timestamp.h" + +namespace mediapipe { +namespace api2 { + +namespace { + +using ::mediapipe::Rect; +using ::mediapipe::landmarks_smoothing::GetObjectScale; +using ::mediapipe::landmarks_smoothing::MultiLandmarkFilters; + +} // namespace + +class MultiWorldLandmarksSmoothingCalculatorImpl + : public NodeImpl { + public: + absl::Status Process(CalculatorContext* cc) override { + // Check that landmarks are not empty and reset the filter if so. + // Don't emit an empty packet for this timestamp. + if (kInLandmarks(cc).IsEmpty()) { + multi_filters_.Clear(); + return absl::OkStatus(); + } + + const auto& timestamp = + absl::Microseconds(cc->InputTimestamp().Microseconds()); + + const auto& tracking_ids = kTrackingIds(cc).Get(); + multi_filters_.ClearUnused(tracking_ids); + + const auto& in_landmarks_vec = kInLandmarks(cc).Get(); + RET_CHECK_EQ(in_landmarks_vec.size(), tracking_ids.size()); + + std::optional> object_scale_roi_vec; + if (kObjectScaleRoi(cc).IsConnected() && !kObjectScaleRoi(cc).IsEmpty()) { + object_scale_roi_vec = kObjectScaleRoi(cc).Get(); + RET_CHECK_EQ(object_scale_roi_vec.value().size(), tracking_ids.size()); + } + + std::vector out_landmarks_vec; + for (int i = 0; i < tracking_ids.size(); ++i) { + const auto& in_landmarks = in_landmarks_vec[i]; + + std::optional object_scale; + if (object_scale_roi_vec) { + object_scale = GetObjectScale(object_scale_roi_vec.value()[i]); + } + + ASSIGN_OR_RETURN(auto* landmarks_filter, + multi_filters_.GetOrCreate( + tracking_ids[i], + cc->Options())); + + LandmarkList out_landmarks; + MP_RETURN_IF_ERROR(landmarks_filter->Apply(in_landmarks, timestamp, + object_scale, out_landmarks)); + + out_landmarks_vec.push_back(std::move(out_landmarks)); + } + + kOutLandmarks(cc).Send(std::move(out_landmarks_vec)); + + return absl::OkStatus(); + } + + private: + MultiLandmarkFilters multi_filters_; +}; +MEDIAPIPE_NODE_IMPLEMENTATION(MultiWorldLandmarksSmoothingCalculatorImpl); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/util/multi_world_landmarks_smoothing_calculator.h b/mediapipe/calculators/util/multi_world_landmarks_smoothing_calculator.h new file mode 100644 index 000000000..2c54ae53c --- /dev/null +++ b/mediapipe/calculators/util/multi_world_landmarks_smoothing_calculator.h @@ -0,0 +1,74 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_UTIL_MULTI_WORLD_LANDMARKS_SMOOTHING_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_UTIL_MULTI_WORLD_LANDMARKS_SMOOTHING_CALCULATOR_H_ + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" + +namespace mediapipe { +namespace api2 { + +// A calculator to smooth landmarks over time. +// +// Inputs: +// LANDMARKS: A std::vector of landmarks you want to +// smooth. +// TRACKING_IDS: A std vector of tracking IDs used to associate +// landmarks over time. When new ID arrives - calculator will initialize new +// filter. When tracking ID is no longer provided - calculator will forget +// smoothing state. +// OBJECT_SCALE_ROI (optional): A std::vector used to determine the +// object scale for some of the filters. If not provided - object scale will +// be calculated from landmarks. +// +// Outputs: +// FILTERED_LANDMARKS: A std::vector of smoothed landmarks. +// +// Example config: +// node { +// calculator: "MultiWorldLandmarksSmoothingCalculator" +// input_stream: "LANDMARKS:landmarks" +// input_stream: "OBJECT_SCALE_ROI:roi" +// output_stream: "FILTERED_LANDMARKS:landmarks_filtered" +// options: { +// [mediapipe.LandmarksSmoothingCalculatorOptions.ext] { +// velocity_filter: { +// window_size: 5 +// velocity_scale: 10.0 +// } +// } +// } +// } +// +class MultiWorldLandmarksSmoothingCalculator : public NodeIntf { + public: + static constexpr Input> kInLandmarks{ + "LANDMARKS"}; + static constexpr Input> kTrackingIds{"TRACKING_IDS"}; + static constexpr Input>::Optional kObjectScaleRoi{ + "OBJECT_SCALE_ROI"}; + static constexpr Output> kOutLandmarks{ + "FILTERED_LANDMARKS"}; + + MEDIAPIPE_NODE_INTERFACE(MultiWorldLandmarksSmoothingCalculator, kInLandmarks, + kTrackingIds, kObjectScaleRoi, kOutLandmarks); +}; + +} // namespace api2 +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_UTIL_MULTI_WORLD_LANDMARKS_SMOOTHING_CALCULATOR_H_ diff --git a/mediapipe/calculators/util/non_max_suppression_calculator.cc b/mediapipe/calculators/util/non_max_suppression_calculator.cc index 535e2a719..be3a8da73 100644 --- a/mediapipe/calculators/util/non_max_suppression_calculator.cc +++ b/mediapipe/calculators/util/non_max_suppression_calculator.cc @@ -18,12 +18,13 @@ #include #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/location.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/rectangle.h" #include "mediapipe/framework/port/status.h" @@ -47,8 +48,8 @@ bool RetainMaxScoringLabelOnly(Detection* detection) { if (detection->label_id_size() == 0 && detection->label_size() == 0) { return false; } - CHECK(detection->label_id_size() == detection->score_size() || - detection->label_size() == detection->score_size()) + ABSL_CHECK(detection->label_id_size() == detection->score_size() || + detection->label_size() == detection->score_size()) << "Number of scores must be equal to number of detections."; std::vector> indexed_scores; @@ -92,7 +93,7 @@ float OverlapSimilarity( normalization = rect1.Area() + rect2.Area() - intersection_area; break; default: - LOG(FATAL) << "Unrecognized overlap type: " << overlap_type; + ABSL_LOG(FATAL) << "Unrecognized overlap type: " << overlap_type; } return normalization > 0.0f ? intersection_area / normalization : 0.0f; } @@ -171,9 +172,9 @@ class NonMaxSuppressionCalculator : public CalculatorBase { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); - CHECK_GT(options_.num_detection_streams(), 0) + ABSL_CHECK_GT(options_.num_detection_streams(), 0) << "At least one detection stream need to be specified."; - CHECK_NE(options_.max_num_detections(), 0) + ABSL_CHECK_NE(options_.max_num_detections(), 0) << "max_num_detections=0 is not a valid value. Please choose a " << "positive number of you want to limit the number of output " << "detections, or set -1 if you do not want any limit."; diff --git a/mediapipe/calculators/util/packet_latency_calculator.cc b/mediapipe/calculators/util/packet_latency_calculator.cc index 6509f016f..39c98bdd0 100644 --- a/mediapipe/calculators/util/packet_latency_calculator.cc +++ b/mediapipe/calculators/util/packet_latency_calculator.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/log/absl_log.h" #include "absl/strings/str_cat.h" #include "absl/time/time.h" #include "mediapipe/calculators/util/latency.pb.h" @@ -20,7 +21,6 @@ #include "mediapipe/framework/calculator_options.pb.h" #include "mediapipe/framework/deps/clock.h" #include "mediapipe/framework/deps/monotonic_clock.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/timestamp.h" @@ -237,7 +237,7 @@ absl::Status PacketLatencyCalculator::Process(CalculatorContext* cc) { } if (first_process_time_usec_ < 0) { - LOG(WARNING) << "No reference packet received."; + ABSL_LOG(WARNING) << "No reference packet received."; return absl::OkStatus(); } diff --git a/mediapipe/calculators/util/rect_to_render_data_calculator.cc b/mediapipe/calculators/util/rect_to_render_data_calculator.cc index bbc08255e..002471cab 100644 --- a/mediapipe/calculators/util/rect_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/rect_to_render_data_calculator.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/log/absl_check.h" #include "mediapipe/calculators/util/rect_to_render_data_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/rect.pb.h" @@ -41,8 +42,8 @@ RenderAnnotation::Rectangle* NewRect( annotation->set_thickness(options.thickness()); if (options.has_top_left_thickness()) { - CHECK(!options.oval()); - CHECK(!options.filled()); + ABSL_CHECK(!options.oval()); + ABSL_CHECK(!options.filled()); annotation->mutable_rectangle()->set_top_left_thickness( options.top_left_thickness()); } diff --git a/mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.cc b/mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.cc index 59b21d574..30dc11dbe 100644 --- a/mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.cc +++ b/mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.cc @@ -124,7 +124,7 @@ absl::StatusOr RefineLandmarksFromHeatMap( int center_row = out_lms.landmark(lm_index).y() * hm_height; // Point is outside of the image let's keep it intact. if (center_col < 0 || center_col >= hm_width || center_row < 0 || - center_col >= hm_height) { + center_row >= hm_height) { continue; } diff --git a/mediapipe/calculators/video/BUILD b/mediapipe/calculators/video/BUILD index 7245b13c2..f17747d28 100644 --- a/mediapipe/calculators/video/BUILD +++ b/mediapipe/calculators/video/BUILD @@ -130,9 +130,9 @@ cc_library( "//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_video", - "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:status_util", + "@com_google_absl//absl/log:absl_log", ], alwayslink = 1, ) @@ -154,6 +154,7 @@ cc_library( "//mediapipe/framework/port:source_location", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:status_util", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -169,6 +170,7 @@ cc_library( "//mediapipe/framework/formats/motion:optical_flow_field", "//mediapipe/framework/port:opencv_video", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/synchronization", ], alwayslink = 1, @@ -194,6 +196,8 @@ cc_library( "//mediapipe/util/tracking:motion_estimation", "//mediapipe/util/tracking:motion_models", "//mediapipe/util/tracking:region_flow_cc_proto", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -206,10 +210,11 @@ cc_library( ":flow_packager_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:logging", "//mediapipe/util/tracking:camera_motion_cc_proto", "//mediapipe/util/tracking:flow_packager", "//mediapipe/util/tracking:region_flow_cc_proto", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], @@ -226,7 +231,6 @@ cc_library( "//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:video_stream_header", # fixdeps: keep -- required for exobazel build. "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -237,6 +241,8 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -252,7 +258,6 @@ cc_library( "//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:video_stream_header", # fixdeps: keep -- required for exobazel build. "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_features2d", "//mediapipe/framework/port:ret_check", @@ -264,6 +269,8 @@ cc_library( "//mediapipe/util/tracking:box_tracker_cc_proto", "//mediapipe/util/tracking:flow_packager_cc_proto", "//mediapipe/util/tracking:tracking_visualization_utilities", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ] + select({ @@ -341,7 +348,6 @@ cc_test( "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:test_util", - "@com_google_absl//absl/flags:flag", ], ) @@ -361,13 +367,12 @@ cc_test( "//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_highgui", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_video", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:test_util", - "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log:absl_log", ], ) @@ -451,7 +456,8 @@ cc_test( "//mediapipe/framework/tool:test_util", "//mediapipe/util/tracking:box_tracker_cc_proto", "//mediapipe/util/tracking:tracking_cc_proto", - "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", ], ) diff --git a/mediapipe/calculators/video/box_detector_calculator.cc b/mediapipe/calculators/video/box_detector_calculator.cc index 14ac12e5e..51f57b7eb 100644 --- a/mediapipe/calculators/video/box_detector_calculator.cc +++ b/mediapipe/calculators/video/box_detector_calculator.cc @@ -17,6 +17,8 @@ #include #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/strings/numbers.h" #include "mediapipe/calculators/video/box_detector_calculator.pb.h" @@ -25,7 +27,6 @@ #include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/video_stream_header.h" #include "mediapipe/framework/port/integral_types.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_features2d_inc.h" #include "mediapipe/framework/port/ret_check.h" @@ -198,7 +199,8 @@ absl::Status BoxDetectorCalculator::Open(CalculatorContext* cc) { if (!predefined_index.ParseFromString(cc->InputSidePackets() .Tag(kIndexProtoStringTag) .Get())) { - LOG(FATAL) << "failed to parse BoxDetectorIndex from INDEX_PROTO_STRING"; + ABSL_LOG(FATAL) + << "failed to parse BoxDetectorIndex from INDEX_PROTO_STRING"; } box_detector_->AddBoxDetectorIndex(predefined_index); } @@ -210,7 +212,7 @@ absl::Status BoxDetectorCalculator::Open(CalculatorContext* cc) { MP_RETURN_IF_ERROR(file::GetContents(string_path, &index_string)); BoxDetectorIndex predefined_index; if (!predefined_index.ParseFromString(index_string)) { - LOG(FATAL) + ABSL_LOG(FATAL) << "failed to parse BoxDetectorIndex from index_proto_filename"; } box_detector_->AddBoxDetectorIndex(predefined_index); @@ -248,7 +250,7 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { BoxDetectorIndex predefined_index; if (!predefined_index.ParseFromString( add_index_stream->Get())) { - LOG(FATAL) << "failed to parse BoxDetectorIndex from ADD_INDEX"; + ABSL_LOG(FATAL) << "failed to parse BoxDetectorIndex from ADD_INDEX"; } box_detector_->AddBoxDetectorIndex(predefined_index); } @@ -276,8 +278,8 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { ? &(cc->Inputs().Tag(kDescriptorsTag)) : nullptr; - CHECK(track_stream != nullptr || video_stream != nullptr || - (feature_stream != nullptr && descriptor_stream != nullptr)) + ABSL_CHECK(track_stream != nullptr || video_stream != nullptr || + (feature_stream != nullptr && descriptor_stream != nullptr)) << "One and only one of {tracking_data, input image frame, " "feature/descriptor} need to be valid."; @@ -295,7 +297,7 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { const TrackingData& tracking_data = track_stream->Get(); - CHECK(tracked_boxes_stream != nullptr) << "tracked_boxes needed."; + ABSL_CHECK(tracked_boxes_stream != nullptr) << "tracked_boxes needed."; const TimedBoxProtoList tracked_boxes = tracked_boxes_stream->Get(); @@ -359,7 +361,7 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { const auto& descriptors = descriptor_stream->Get>(); const int dims = options_.detector_options().descriptor_dims(); - CHECK_GE(descriptors.size(), feature_size * dims); + ABSL_CHECK_GE(descriptors.size(), feature_size * dims); cv::Mat descriptors_mat(feature_size, dims, CV_32F); for (int j = 0; j < feature_size; ++j) { features_vec[j].Set(features[j].pt.x * inv_scale, diff --git a/mediapipe/calculators/video/box_tracker_calculator.cc b/mediapipe/calculators/video/box_tracker_calculator.cc index b5f3b5b0b..4a8f4543d 100644 --- a/mediapipe/calculators/video/box_tracker_calculator.cc +++ b/mediapipe/calculators/video/box_tracker_calculator.cc @@ -22,6 +22,8 @@ #include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_map.h" #include "absl/container/node_hash_set.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/strings/numbers.h" #include "mediapipe/calculators/video/box_tracker_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -29,7 +31,6 @@ #include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/video_stream_header.h" #include "mediapipe/framework/port/integral_types.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" @@ -315,16 +316,16 @@ void ConvertCoordinateForRotation(float in_top, float in_left, float in_bottom, float in_right, int rotation, float* out_top, float* out_left, float* out_bottom, float* out_right) { - CHECK(out_top != nullptr); - CHECK(out_left != nullptr); - CHECK(out_bottom != nullptr); - CHECK(out_right != nullptr); + ABSL_CHECK(out_top != nullptr); + ABSL_CHECK(out_left != nullptr); + ABSL_CHECK(out_bottom != nullptr); + ABSL_CHECK(out_right != nullptr); const float in_center_x = (in_left + in_right) * 0.5f; const float in_center_y = (in_top + in_bottom) * 0.5f; const float in_width = in_right - in_left; const float in_height = in_bottom - in_top; - CHECK_GT(in_width, 0); - CHECK_GT(in_height, 0); + ABSL_CHECK_GT(in_width, 0); + ABSL_CHECK_GT(in_height, 0); float out_center_x; float out_center_y; float out_width; @@ -358,7 +359,7 @@ void ConvertCoordinateForRotation(float in_top, float in_left, float in_bottom, out_height = in_width; break; default: - LOG(ERROR) << "invalid rotation " << rotation; + ABSL_LOG(ERROR) << "invalid rotation " << rotation; out_center_x = in_center_x; out_center_y = in_center_y; out_width = in_width; @@ -373,7 +374,7 @@ void ConvertCoordinateForRotation(float in_top, float in_left, float in_bottom, void AddStateToPath(const MotionBoxState& state, int64_t time_msec, PathSegment* path) { - CHECK(path); + ABSL_CHECK(path); TimedBox result; TimedBoxFromMotionBoxState(state, &result); result.time_msec = time_msec; @@ -384,7 +385,8 @@ void AddStateToPath(const MotionBoxState& state, int64_t time_msec, path->insert(insert_pos, InternalTimedBox(result, new MotionBoxState(state))); } else { - LOG(ERROR) << "Box at time " << time_msec << " already present; ignoring"; + ABSL_LOG(ERROR) << "Box at time " << time_msec + << " already present; ignoring"; } } @@ -486,8 +488,9 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) { #if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(__EMSCRIPTEN__) if (cc->InputSidePackets().HasTag(kInitialPosTag)) { - LOG(INFO) << "Parsing: " - << cc->InputSidePackets().Tag(kInitialPosTag).Get(); + ABSL_LOG(INFO) + << "Parsing: " + << cc->InputSidePackets().Tag(kInitialPosTag).Get(); initial_pos_ = ParseTextProtoOrDie( cc->InputSidePackets().Tag(kInitialPosTag).Get()); } @@ -624,7 +627,7 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { if (cancel_object_id_stream && !cancel_object_id_stream->IsEmpty()) { const int cancel_object_id = cancel_object_id_stream->Get(); if (streaming_motion_boxes_.erase(cancel_object_id) == 0) { - LOG(WARNING) << "box id " << cancel_object_id << " does not exist."; + ABSL_LOG(WARNING) << "box id " << cancel_object_id << " does not exist."; } } @@ -649,7 +652,7 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { // present at this frame. TimedBoxProtoList box_track_list; - CHECK(box_tracker_ || track_stream) + ABSL_CHECK(box_tracker_ || track_stream) << "Expected either batch or streaming mode"; // Corresponding list of box states for rendering. For each id present at @@ -944,7 +947,7 @@ void BoxTrackerCalculator::OutputRandomAccessTrack( const bool forward_track = start.time_msec() < end_time_msec; if (track_timestamps_.empty()) { - LOG(WARNING) << "No tracking data cached yet."; + ABSL_LOG(WARNING) << "No tracking data cached yet."; continue; } @@ -954,27 +957,27 @@ void BoxTrackerCalculator::OutputRandomAccessTrack( const int64_t tracking_end_timestamp_msec = track_timestamps_.back().Microseconds() / 1000; if (start.time_msec() < tracking_start_timestamp_msec) { - LOG(WARNING) << "Request start timestamp " << start.time_msec() - << " too old. First frame in the window: " - << tracking_start_timestamp_msec; + ABSL_LOG(WARNING) << "Request start timestamp " << start.time_msec() + << " too old. First frame in the window: " + << tracking_start_timestamp_msec; continue; } if (start.time_msec() > tracking_end_timestamp_msec) { - LOG(WARNING) << "Request start timestamp " << start.time_msec() - << " too new. Last frame in the window: " - << tracking_end_timestamp_msec; + ABSL_LOG(WARNING) << "Request start timestamp " << start.time_msec() + << " too new. Last frame in the window: " + << tracking_end_timestamp_msec; continue; } if (end_time_msec < tracking_start_timestamp_msec) { - LOG(WARNING) << "Request end timestamp " << end_time_msec - << " too old. First frame in the window: " - << tracking_start_timestamp_msec; + ABSL_LOG(WARNING) << "Request end timestamp " << end_time_msec + << " too old. First frame in the window: " + << tracking_start_timestamp_msec; continue; } if (end_time_msec > tracking_end_timestamp_msec) { - LOG(WARNING) << "Request end timestamp " << end_time_msec - << " too new. Last frame in the window: " - << tracking_end_timestamp_msec; + ABSL_LOG(WARNING) << "Request end timestamp " << end_time_msec + << " too new. Last frame in the window: " + << tracking_end_timestamp_msec; continue; } @@ -982,7 +985,7 @@ void BoxTrackerCalculator::OutputRandomAccessTrack( GetRandomAccessTimestampPos(start, forward_track); if (timestamp_pos == track_timestamps_.end()) { - LOG(ERROR) << "Random access outside cached range"; + ABSL_LOG(ERROR) << "Random access outside cached range"; continue; } @@ -993,13 +996,13 @@ void BoxTrackerCalculator::OutputRandomAccessTrack( // TODO: Interpolate random access tracking start_data instead // of dropping the request in the case of missing processed frame. if (start_data == tracking_data_cache_.end()) { - LOG(ERROR) << "Random access starts at unprocessed frame."; + ABSL_LOG(ERROR) << "Random access starts at unprocessed frame."; continue; } const int init_frame = timestamp_pos - track_timestamps_.begin() + track_timestamps_base_index_; - CHECK_GE(init_frame, 0); + ABSL_CHECK_GE(init_frame, 0); MotionBoxMap single_map = PrepareRandomAccessTrack(start, init_frame, forward_track, start_data); @@ -1010,7 +1013,7 @@ void BoxTrackerCalculator::OutputRandomAccessTrack( &single_map, end_time_msec); if (track_error) { - LOG(ERROR) << "Could not track box."; + ABSL_LOG(ERROR) << "Could not track box."; continue; } @@ -1166,8 +1169,8 @@ void BoxTrackerCalculator::StreamTrack(const TrackingData& data, int64_t duration_ms, bool forward, MotionBoxMap* box_map, std::vector* failed_ids) { - CHECK(box_map); - CHECK(failed_ids); + ABSL_CHECK(box_map); + ABSL_CHECK(failed_ids); // Cache the actively discarded tracked ids from the new tracking data. for (const int discarded_id : @@ -1197,7 +1200,7 @@ void BoxTrackerCalculator::StreamTrack(const TrackingData& data, if (!motion_box.second.box.TrackStep(from_frame, // from frame. mvf, forward)) { failed_ids->push_back(motion_box.first); - LOG(INFO) << "lost track. pushed failed id: " << motion_box.first; + ABSL_LOG(INFO) << "lost track. pushed failed id: " << motion_box.first; } else { // Store result. PathSegment& path = motion_box.second.path; @@ -1224,8 +1227,8 @@ void BoxTrackerCalculator::FastForwardStartPos( track_timestamps_.end(), timestamp); if (timestamp_pos == track_timestamps_.end()) { - LOG(WARNING) << "Received start pos beyond current timestamp, " - << "Starting to track once frame arrives."; + ABSL_LOG(WARNING) << "Received start pos beyond current timestamp, " + << "Starting to track once frame arrives."; *initial_pos_.add_box() = start_pos; continue; } @@ -1233,7 +1236,7 @@ void BoxTrackerCalculator::FastForwardStartPos( // Start at previous frame. const int init_frame = timestamp_pos - track_timestamps_.begin() + track_timestamps_base_index_; - CHECK_GE(init_frame, 0); + ABSL_CHECK_GE(init_frame, 0); // Locate corresponding tracking data. auto start_data = std::find_if( @@ -1242,8 +1245,9 @@ void BoxTrackerCalculator::FastForwardStartPos( -> bool { return item.first == timestamp_pos[0]; }); if (start_data == tracking_data_cache_.end()) { - LOG(ERROR) << "Box to fast forward outside tracking data cache. Ignoring." - << " To avoid this error consider increasing the cache size."; + ABSL_LOG(ERROR) + << "Box to fast forward outside tracking data cache. Ignoring." + << " To avoid this error consider increasing the cache size."; continue; } @@ -1281,7 +1285,8 @@ void BoxTrackerCalculator::FastForwardStartPos( true, // forward &single_map, &failed_box); if (!failed_box.empty()) { - LOG(WARNING) << "Unable to fast forward box at frame " << curr_frame; + ABSL_LOG(WARNING) << "Unable to fast forward box at frame " + << curr_frame; track_error = true; break; } diff --git a/mediapipe/calculators/video/flow_packager_calculator.cc b/mediapipe/calculators/video/flow_packager_calculator.cc index 2965cd8e6..b04534999 100644 --- a/mediapipe/calculators/video/flow_packager_calculator.cc +++ b/mediapipe/calculators/video/flow_packager_calculator.cc @@ -17,12 +17,13 @@ #include #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "mediapipe/calculators/video/flow_packager_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/integral_types.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/util/tracking/camera_motion.pb.h" #include "mediapipe/util/tracking/flow_packager.h" #include "mediapipe/util/tracking/region_flow.pb.h" @@ -160,7 +161,7 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) { timestamp.Value() / 1000 / options_.caching_chunk_size_msec(); tracking_chunk_.set_first_chunk(true); } - CHECK_GE(chunk_idx_, 0); + ABSL_CHECK_GE(chunk_idx_, 0); TrackingDataChunk::Item* item = tracking_chunk_.add_item(); item->set_frame_idx(frame_idx_); @@ -227,10 +228,11 @@ absl::Status FlowPackagerCalculator::Close(CalculatorContext* cc) { void FlowPackagerCalculator::WriteChunk(const TrackingDataChunk& chunk) const { if (chunk.item_size() == 0) { - LOG(ERROR) << "Write chunk called with empty tracking data." - << "This can only occur if the spacing between frames " - << "is larger than the requested chunk size. Try increasing " - << "the chunk size"; + ABSL_LOG(ERROR) + << "Write chunk called with empty tracking data." + << "This can only occur if the spacing between frames " + << "is larger than the requested chunk size. Try increasing " + << "the chunk size"; return; } @@ -242,7 +244,7 @@ void FlowPackagerCalculator::WriteChunk(const TrackingDataChunk& chunk) const { chunk_file = cache_dir_ + "/" + absl::StrFormat(*format_runtime, chunk_idx_); } else { - LOG(ERROR) << "chache_file_format wrong. fall back to chunk_%04d."; + ABSL_LOG(ERROR) << "chache_file_format wrong. fall back to chunk_%04d."; chunk_file = cache_dir_ + "/" + absl::StrFormat("chunk_%04d", chunk_idx_); } @@ -252,23 +254,23 @@ void FlowPackagerCalculator::WriteChunk(const TrackingDataChunk& chunk) const { const char* temp_filename = tempnam(cache_dir_.c_str(), nullptr); std::ofstream out_file(temp_filename); if (!out_file) { - LOG(ERROR) << "Could not open " << temp_filename; + ABSL_LOG(ERROR) << "Could not open " << temp_filename; } else { out_file.write(data.data(), data.size()); } if (rename(temp_filename, chunk_file.c_str()) != 0) { - LOG(ERROR) << "Failed to rename to " << chunk_file; + ABSL_LOG(ERROR) << "Failed to rename to " << chunk_file; } - LOG(INFO) << "Wrote chunk : " << chunk_file; + ABSL_LOG(INFO) << "Wrote chunk : " << chunk_file; } void FlowPackagerCalculator::PrepareCurrentForNextChunk( TrackingDataChunk* chunk) { - CHECK(chunk); + ABSL_CHECK(chunk); if (chunk->item_size() == 0) { - LOG(ERROR) << "Called with empty chunk. Unexpected."; + ABSL_LOG(ERROR) << "Called with empty chunk. Unexpected."; return; } diff --git a/mediapipe/calculators/video/motion_analysis_calculator.cc b/mediapipe/calculators/video/motion_analysis_calculator.cc index 544439ae8..601b8b045 100644 --- a/mediapipe/calculators/video/motion_analysis_calculator.cc +++ b/mediapipe/calculators/video/motion_analysis_calculator.cc @@ -17,6 +17,8 @@ #include #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" @@ -348,8 +350,8 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) { video_header = &(cc->Inputs().Tag(kSelectionTag).Header().Get()); } else { - LOG(WARNING) << "No input video header found. Downstream calculators " - "expecting video headers are likely to fail."; + ABSL_LOG(WARNING) << "No input video header found. Downstream calculators " + "expecting video headers are likely to fail."; } with_saliency_ = options_.analysis_options().compute_motion_saliency(); @@ -357,9 +359,9 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) { if (cc->Outputs().HasTag(kSaliencyTag)) { with_saliency_ = true; if (!options_.analysis_options().compute_motion_saliency()) { - LOG(WARNING) << "Enable saliency computation. Set " - << "compute_motion_saliency to true to silence this " - << "warning."; + ABSL_LOG(WARNING) << "Enable saliency computation. Set " + << "compute_motion_saliency to true to silence this " + << "warning."; options_.mutable_analysis_options()->set_compute_motion_saliency(true); } } @@ -428,7 +430,7 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) { selection_input_ ? &(cc->Inputs().Tag(kSelectionTag)) : nullptr; // Checked on Open. - CHECK(video_stream || selection_stream); + ABSL_CHECK(video_stream || selection_stream); // Lazy init. if (frame_width_ < 0 || frame_height_ < 0) { @@ -472,7 +474,7 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) { // Always use frame if selection is not activated. bool use_frame = !selection_input_; if (selection_input_) { - CHECK(selection_stream); + ABSL_CHECK(selection_stream); // Fill in timestamps we process. if (!selection_stream->Value().IsEmpty()) { @@ -603,8 +605,8 @@ absl::Status MotionAnalysisCalculator::Close(CalculatorContext* cc) { } if (csv_file_input_) { if (!meta_motions_.empty()) { - LOG(ERROR) << "More motions than frames. Unexpected! Remainder: " - << meta_motions_.size(); + ABSL_LOG(ERROR) << "More motions than frames. Unexpected! Remainder: " + << meta_motions_.size(); } } return absl::OkStatus(); @@ -620,7 +622,7 @@ void MotionAnalysisCalculator::OutputMotionAnalyzedFrames( const int num_results = motion_analysis_->GetResults( flush, &features, &camera_motions, with_saliency_ ? &saliency : nullptr); - CHECK_LE(num_results, buffer_size); + ABSL_CHECK_LE(num_results, buffer_size); if (num_results == 0) { return; @@ -695,7 +697,7 @@ void MotionAnalysisCalculator::OutputMotionAnalyzedFrames( if (hybrid_meta_analysis_) { hybrid_meta_offset_ -= num_results; - CHECK_GE(hybrid_meta_offset_, 0); + ABSL_CHECK_GE(hybrid_meta_offset_, 0); } timestamp_buffer_.erase(timestamp_buffer_.begin(), @@ -741,8 +743,8 @@ absl::Status MotionAnalysisCalculator::InitOnProcess( } if (region_options->image_format() != image_format && region_options->image_format() != image_format2) { - LOG(WARNING) << "Requested image format in RegionFlowComputation " - << "does not match video stream format. Overriding."; + ABSL_LOG(WARNING) << "Requested image format in RegionFlowComputation " + << "does not match video stream format. Overriding."; region_options->set_image_format(image_format); } @@ -761,12 +763,12 @@ absl::Status MotionAnalysisCalculator::InitOnProcess( frame_width_ = camera_motion.frame_width(); frame_height_ = camera_motion.frame_height(); } else { - LOG(FATAL) << "Either VIDEO or SELECTION stream need to be specified."; + ABSL_LOG(FATAL) << "Either VIDEO or SELECTION stream need to be specified."; } // Filled by CSV file parsing. if (!meta_homographies_.empty()) { - CHECK(csv_file_input_); + ABSL_CHECK(csv_file_input_); AppendCameraMotionsFromHomographies(meta_homographies_, true, // append identity. &meta_motions_, &meta_features_); @@ -800,7 +802,7 @@ bool MotionAnalysisCalculator::ParseModelCSV( for (const auto& value : values) { double value_64f; if (!absl::SimpleAtod(value, &value_64f)) { - LOG(ERROR) << "Not a double, expected!"; + ABSL_LOG(ERROR) << "Not a double, expected!"; return false; } @@ -813,12 +815,12 @@ bool MotionAnalysisCalculator::ParseModelCSV( bool MotionAnalysisCalculator::HomographiesFromValues( const std::vector& homog_values, std::deque* homographies) { - CHECK(homographies); + ABSL_CHECK(homographies); // Obvious constants are obvious :D constexpr int kHomographyValues = 9; if (homog_values.size() % kHomographyValues != 0) { - LOG(ERROR) << "Contents not a multiple of " << kHomographyValues; + ABSL_LOG(ERROR) << "Contents not a multiple of " << kHomographyValues; return false; } @@ -830,7 +832,7 @@ bool MotionAnalysisCalculator::HomographiesFromValues( // Normalize last entry to 1. if (h_vals[kHomographyValues - 1] == 0) { - LOG(ERROR) << "Degenerate homography, last entry is zero"; + ABSL_LOG(ERROR) << "Degenerate homography, last entry is zero"; return false; } @@ -844,8 +846,8 @@ bool MotionAnalysisCalculator::HomographiesFromValues( } if (homographies->size() % options_.meta_models_per_frame() != 0) { - LOG(ERROR) << "Total homographies not a multiple of specified models " - << "per frame."; + ABSL_LOG(ERROR) << "Total homographies not a multiple of specified models " + << "per frame."; return false; } @@ -855,7 +857,7 @@ bool MotionAnalysisCalculator::HomographiesFromValues( void MotionAnalysisCalculator::SubtractMetaMotion( const CameraMotion& meta_motion, RegionFlowFeatureList* features) { if (meta_motion.mixture_homography().model_size() > 0) { - CHECK(row_weights_ != nullptr); + ABSL_CHECK(row_weights_ != nullptr); RegionFlowFeatureListViaTransform(meta_motion.mixture_homography(), features, -1.0f, 1.0f, // subtract transformed. @@ -901,7 +903,7 @@ void MotionAnalysisCalculator::AddMetaMotion( const CameraMotion& meta_motion, const RegionFlowFeatureList& meta_features, RegionFlowFeatureList* features, CameraMotion* motion) { // Restore old feature location. - CHECK_EQ(meta_features.feature_size(), features->feature_size()); + ABSL_CHECK_EQ(meta_features.feature_size(), features->feature_size()); for (int k = 0; k < meta_features.feature_size(); ++k) { auto feature = features->mutable_feature(k); const auto& meta_feature = meta_features.feature(k); @@ -922,8 +924,8 @@ void MotionAnalysisCalculator::AppendCameraMotionsFromHomographies( const std::deque& homographies, bool append_identity, std::deque* camera_motions, std::deque* features) { - CHECK(camera_motions); - CHECK(features); + ABSL_CHECK(camera_motions); + ABSL_CHECK(features); CameraMotion identity; identity.set_frame_width(frame_width_); @@ -947,8 +949,9 @@ void MotionAnalysisCalculator::AppendCameraMotionsFromHomographies( } const int models_per_frame = options_.meta_models_per_frame(); - CHECK_GT(models_per_frame, 0) << "At least one model per frame is needed"; - CHECK_EQ(0, homographies.size() % models_per_frame); + ABSL_CHECK_GT(models_per_frame, 0) + << "At least one model per frame is needed"; + ABSL_CHECK_EQ(0, homographies.size() % models_per_frame); const int num_frames = homographies.size() / models_per_frame; // Heuristic sigma, similar to what we use for rolling shutter removal. diff --git a/mediapipe/calculators/video/opencv_video_decoder_calculator.cc b/mediapipe/calculators/video/opencv_video_decoder_calculator.cc index 9e04f33cb..cda7085da 100644 --- a/mediapipe/calculators/video/opencv_video_decoder_calculator.cc +++ b/mediapipe/calculators/video/opencv_video_decoder_calculator.cc @@ -14,6 +14,7 @@ #include +#include "absl/log/absl_log.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/formats/image_frame.h" @@ -168,9 +169,10 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase { .Tag(kSavedAudioPathTag) .Set(MakePacket(saved_audio_path)); } else { - LOG(WARNING) << "FFmpeg can't extract audio from " << input_file_path - << " by executing the following command: " - << ffmpeg_command; + ABSL_LOG(WARNING) << "FFmpeg can't extract audio from " + << input_file_path + << " by executing the following command: " + << ffmpeg_command; cc->OutputSidePackets() .Tag(kSavedAudioPathTag) .Set(MakePacket(std::string())); @@ -227,9 +229,9 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase { cap_->release(); } if (decoded_frames_ != frame_count_) { - LOG(WARNING) << "Not all the frames are decoded (total frames: " - << frame_count_ << " vs decoded frames: " << decoded_frames_ - << ")."; + ABSL_LOG(WARNING) << "Not all the frames are decoded (total frames: " + << frame_count_ + << " vs decoded frames: " << decoded_frames_ << ")."; } return absl::OkStatus(); } diff --git a/mediapipe/calculators/video/opencv_video_encoder_calculator.cc b/mediapipe/calculators/video/opencv_video_encoder_calculator.cc index 4af8c5955..5979d57b0 100644 --- a/mediapipe/calculators/video/opencv_video_encoder_calculator.cc +++ b/mediapipe/calculators/video/opencv_video_encoder_calculator.cc @@ -18,6 +18,7 @@ #include #include +#include "absl/log/absl_log.h" #include "absl/strings/str_split.h" #include "mediapipe/calculators/video/opencv_video_encoder_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -187,9 +188,10 @@ absl::Status OpenCvVideoEncoderCalculator::Close(CalculatorContext* cc) { const std::string& audio_file_path = cc->InputSidePackets().Tag(kAudioFilePathTag).Get(); if (audio_file_path.empty()) { - LOG(WARNING) << "OpenCvVideoEncoderCalculator isn't able to attach the " - "audio tracks to the generated video because the audio " - "file path is not specified."; + ABSL_LOG(WARNING) + << "OpenCvVideoEncoderCalculator isn't able to attach the " + "audio tracks to the generated video because the audio " + "file path is not specified."; } else { // A temp output file is needed because FFmpeg can't do in-place editing. const std::string temp_file_path = std::tmpnam(nullptr); diff --git a/mediapipe/calculators/video/tool/BUILD b/mediapipe/calculators/video/tool/BUILD index 408461d2f..2a32c680c 100644 --- a/mediapipe/calculators/video/tool/BUILD +++ b/mediapipe/calculators/video/tool/BUILD @@ -44,6 +44,7 @@ cc_library( "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/tool:status_util", + "@com_google_absl//absl/log:absl_check", ], alwayslink = 1, ) diff --git a/mediapipe/calculators/video/tool/flow_quantizer_model.cc b/mediapipe/calculators/video/tool/flow_quantizer_model.cc index f0b00063f..146dc4a70 100644 --- a/mediapipe/calculators/video/tool/flow_quantizer_model.cc +++ b/mediapipe/calculators/video/tool/flow_quantizer_model.cc @@ -14,6 +14,7 @@ #include "mediapipe/calculators/video/tool/flow_quantizer_model.h" +#include "absl/log/absl_check.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/type_map.h" @@ -21,7 +22,7 @@ namespace mediapipe { // Uniform normalization to 0-255. uint8_t FlowQuantizerModel::Apply(const float val, const int channel) const { - CHECK_LT(channel, model_.min_value_size()); + ABSL_CHECK_LT(channel, model_.min_value_size()); const auto& min_value = model_.min_value(channel); const auto& max_value = model_.max_value(channel); QCHECK_GT(max_value, min_value); @@ -51,7 +52,7 @@ const QuantizerModelData& FlowQuantizerModel::GetModelData() const { // TODO: Taking the min and max over all training flow fields might be // sensitive to noise. We should use more robust statistics. void FlowQuantizerModel::AddSampleFlowField(const OpticalFlowField& flow) { - CHECK_EQ(model_.min_value_size(), 2); + ABSL_CHECK_EQ(model_.min_value_size(), 2); const cv::Mat_& flow_mat = flow.flow_data(); for (int i = 0; i != flow.width(); ++i) { for (int j = 0; j != flow.height(); ++j) { diff --git a/mediapipe/calculators/video/tracking_graph_test.cc b/mediapipe/calculators/video/tracking_graph_test.cc index 8fd8806b7..1ccc61214 100644 --- a/mediapipe/calculators/video/tracking_graph_test.cc +++ b/mediapipe/calculators/video/tracking_graph_test.cc @@ -19,6 +19,8 @@ #include #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "mediapipe/calculators/video/box_tracker_calculator.pb.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -52,7 +54,7 @@ bool LoadBinaryTestGraph(const std::string& graph_path, bool success = config->ParseFromZeroCopyStream(&in_stream); ifs.close(); if (!success) { - LOG(ERROR) << "could not parse test graph: " << graph_path; + ABSL_LOG(ERROR) << "could not parse test graph: " << graph_path; } return success; } @@ -297,7 +299,7 @@ std::unique_ptr TrackingGraphTest::CreateRandomAccessTrackingBoxList( const std::vector& start_timestamps, const std::vector& end_timestamps) const { - CHECK_EQ(start_timestamps.size(), end_timestamps.size()); + ABSL_CHECK_EQ(start_timestamps.size(), end_timestamps.size()); auto ra_boxes = absl::make_unique(); for (int i = 0; i < start_timestamps.size(); ++i) { auto start_box_list = @@ -620,7 +622,7 @@ TEST_F(TrackingGraphTest, TestTransitionFramesForReacquisition) { // Add TRACK_TIME stream queries in between 2 frames. if (j > 0) { Timestamp track_time = Timestamp((j - 0.5f) * kFrameIntervalUs); - LOG(INFO) << track_time.Value(); + ABSL_LOG(INFO) << track_time.Value(); Packet track_time_packet = Adopt(new Timestamp).At(track_time); MP_EXPECT_OK( graph_.AddPacketToInputStream("track_time", track_time_packet)); diff --git a/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc b/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc index 56f3253e2..e60df0280 100644 --- a/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc +++ b/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "absl/base/macros.h" +#include "absl/log/absl_check.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" @@ -158,7 +159,7 @@ absl::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) { absl::Status Tvl1OpticalFlowCalculator::CalculateOpticalFlow( const ImageFrame& current_frame, const ImageFrame& next_frame, OpticalFlowField* flow) { - CHECK(flow); + ABSL_CHECK(flow); if (!ImageSizesMatch(current_frame, next_frame)) { return tool::StatusInvalid("Images are different sizes."); } @@ -182,7 +183,7 @@ absl::Status Tvl1OpticalFlowCalculator::CalculateOpticalFlow( flow->Allocate(first.cols, first.rows); cv::Mat cv_flow(flow->mutable_flow_data()); tvl1_computer->calc(first, second, cv_flow); - CHECK_EQ(flow->mutable_flow_data().data, cv_flow.data); + ABSL_CHECK_EQ(flow->mutable_flow_data().data, cv_flow.data); // Inserts the idle DenseOpticalFlow object back to the cache for reuse. { absl::MutexLock lock(&mutex_); diff --git a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar index 943f0cbfa..afba10928 100644 Binary files a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar and b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar differ diff --git a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties index 508322917..4e86b9270 100644 --- a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties +++ b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,6 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.1-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.2-bin.zip networkTimeout=10000 zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/mediapipe/examples/coral/BUILD b/mediapipe/examples/coral/BUILD index 68244d579..0c7c6b113 100644 --- a/mediapipe/examples/coral/BUILD +++ b/mediapipe/examples/coral/BUILD @@ -35,6 +35,7 @@ cc_library( "//mediapipe/framework/port:status", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/log:absl_log", ], ) diff --git a/mediapipe/examples/coral/demo_run_graph_main.cc b/mediapipe/examples/coral/demo_run_graph_main.cc index 6f1c56268..692f26008 100644 --- a/mediapipe/examples/coral/demo_run_graph_main.cc +++ b/mediapipe/examples/coral/demo_run_graph_main.cc @@ -17,6 +17,7 @@ #include "absl/flags/flag.h" #include "absl/flags/parse.h" +#include "absl/log/absl_log.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" @@ -45,17 +46,17 @@ absl::Status RunMPPGraph() { MP_RETURN_IF_ERROR(mediapipe::file::GetContents( absl::GetFlag(FLAGS_calculator_graph_config_file), &calculator_graph_config_contents)); - LOG(INFO) << "Get calculator graph config contents: " - << calculator_graph_config_contents; + ABSL_LOG(INFO) << "Get calculator graph config contents: " + << calculator_graph_config_contents; mediapipe::CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie( calculator_graph_config_contents); - LOG(INFO) << "Initialize the calculator graph."; + ABSL_LOG(INFO) << "Initialize the calculator graph."; mediapipe::CalculatorGraph graph; MP_RETURN_IF_ERROR(graph.Initialize(config)); - LOG(INFO) << "Initialize the camera or load the video."; + ABSL_LOG(INFO) << "Initialize the camera or load the video."; cv::VideoCapture capture; const bool load_video = !absl::GetFlag(FLAGS_input_video_path).empty(); if (load_video) { @@ -68,7 +69,7 @@ absl::Status RunMPPGraph() { cv::VideoWriter writer; const bool save_video = !absl::GetFlag(FLAGS_output_video_path).empty(); if (save_video) { - LOG(INFO) << "Prepare video writer."; + ABSL_LOG(INFO) << "Prepare video writer."; cv::Mat test_frame; capture.read(test_frame); // Consume first frame. capture.set(cv::CAP_PROP_POS_AVI_RATIO, 0); // Rewind to beginning. @@ -85,12 +86,12 @@ absl::Status RunMPPGraph() { capture.set(cv::CAP_PROP_FPS, 30); } - LOG(INFO) << "Start running the calculator graph."; + ABSL_LOG(INFO) << "Start running the calculator graph."; ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller poller, graph.AddOutputStreamPoller(kOutputStream)); MP_RETURN_IF_ERROR(graph.StartRun({})); - LOG(INFO) << "Start grabbing and processing frames."; + ABSL_LOG(INFO) << "Start grabbing and processing frames."; bool grab_frames = true; while (grab_frames) { // Capture opencv camera or video frame. @@ -135,7 +136,7 @@ absl::Status RunMPPGraph() { } } - LOG(INFO) << "Shutting down."; + ABSL_LOG(INFO) << "Shutting down."; if (writer.isOpened()) writer.release(); MP_RETURN_IF_ERROR(graph.CloseInputStream(kInputStream)); return graph.WaitUntilDone(); @@ -146,10 +147,10 @@ int main(int argc, char** argv) { absl::ParseCommandLine(argc, argv); absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { - LOG(ERROR) << "Failed to run the graph: " << run_status.message(); + ABSL_LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; } else { - LOG(INFO) << "Success!"; + ABSL_LOG(INFO) << "Success!"; } return EXIT_SUCCESS; } diff --git a/mediapipe/examples/desktop/BUILD b/mediapipe/examples/desktop/BUILD index 80cb7ad81..3d59c059a 100644 --- a/mediapipe/examples/desktop/BUILD +++ b/mediapipe/examples/desktop/BUILD @@ -31,6 +31,7 @@ cc_library( "//mediapipe/framework/port:statusor", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", ], ) @@ -48,8 +49,10 @@ cc_library( "//mediapipe/framework/port:opencv_video", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", + "//mediapipe/util:resource_util", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/log:absl_log", ], ) @@ -73,7 +76,9 @@ cc_library( "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_shared_data_internal", + "//mediapipe/util:resource_util", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/log:absl_log", ], ) diff --git a/mediapipe/examples/desktop/autoflip/calculators/BUILD b/mediapipe/examples/desktop/autoflip/calculators/BUILD index a3b2ace2a..4ae45ac8f 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/BUILD +++ b/mediapipe/examples/desktop/autoflip/calculators/BUILD @@ -306,6 +306,7 @@ cc_library( "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_log", ], alwayslink = 1, ) diff --git a/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc index 238bcf8be..c3c920bcf 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc @@ -28,11 +28,8 @@ #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" -using mediapipe::Adopt; -using mediapipe::CalculatorBase; using mediapipe::ImageFrame; using mediapipe::PacketTypeSet; -using mediapipe::autoflip::Border; constexpr char kDetectedBorders[] = "DETECTED_BORDERS"; constexpr int kMinBorderDistance = 5; diff --git a/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator_test.cc index e72d54e55..431e3d161 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator_test.cc @@ -28,16 +28,12 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_matchers.h" -using mediapipe::Adopt; using mediapipe::CalculatorGraphConfig; using mediapipe::CalculatorRunner; using mediapipe::ImageFormat; using mediapipe::ImageFrame; using mediapipe::Packet; using mediapipe::PacketTypeSet; -using mediapipe::ParseTextProtoOrDie; -using mediapipe::Timestamp; -using mediapipe::autoflip::Border; namespace mediapipe { namespace autoflip { diff --git a/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator.cc index 299f60b10..da655cb65 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator.cc @@ -18,6 +18,7 @@ #include #include +#include "absl/log/absl_log.h" #include "mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" @@ -112,8 +113,8 @@ void ShotBoundaryCalculator::Transmit(mediapipe::CalculatorContext* cc, is_shot_change = false; } if (is_shot_change) { - LOG(INFO) << "Shot change at: " << cc->InputTimestamp().Seconds() - << " seconds."; + ABSL_LOG(INFO) << "Shot change at: " << cc->InputTimestamp().Seconds() + << " seconds."; cc->Outputs() .Tag(kShotChangeTag) .AddPacket(Adopt(std::make_unique(true).release()) diff --git a/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator_test.cc index 9ea79ba44..a45a171a4 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator_test.cc @@ -31,14 +31,11 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_matchers.h" -using mediapipe::Adopt; using mediapipe::CalculatorGraphConfig; using mediapipe::CalculatorRunner; using mediapipe::ImageFormat; using mediapipe::ImageFrame; using mediapipe::PacketTypeSet; -using mediapipe::ParseTextProtoOrDie; -using mediapipe::Timestamp; namespace mediapipe { namespace autoflip { diff --git a/mediapipe/examples/desktop/autoflip/calculators/signal_fusing_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/signal_fusing_calculator.cc index 85b2d96f8..ba5064a28 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/signal_fusing_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/signal_fusing_calculator.cc @@ -28,8 +28,6 @@ using mediapipe::Packet; using mediapipe::PacketTypeSet; using mediapipe::autoflip::DetectionSet; -using mediapipe::autoflip::SalientRegion; -using mediapipe::autoflip::SignalType; constexpr char kIsShotBoundaryTag[] = "IS_SHOT_BOUNDARY"; constexpr char kSignalInputsTag[] = "SIGNAL"; diff --git a/mediapipe/examples/desktop/autoflip/quality/BUILD b/mediapipe/examples/desktop/autoflip/quality/BUILD index 20e286107..0aeeffaa4 100644 --- a/mediapipe/examples/desktop/autoflip/quality/BUILD +++ b/mediapipe/examples/desktop/autoflip/quality/BUILD @@ -53,6 +53,7 @@ cc_library( "//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_log", ], ) @@ -67,6 +68,7 @@ cc_library( hdrs = ["piecewise_linear_function.h"], deps = [ "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", ], ) @@ -192,6 +194,7 @@ cc_library( "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", ], ) @@ -234,6 +237,7 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:status", "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", ], ) @@ -281,6 +285,7 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", ], ) @@ -327,6 +332,7 @@ cc_library( "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_log", ], ) diff --git a/mediapipe/examples/desktop/autoflip/quality/frame_crop_region_computer.cc b/mediapipe/examples/desktop/autoflip/quality/frame_crop_region_computer.cc index 5916d1829..947676cd2 100644 --- a/mediapipe/examples/desktop/autoflip/quality/frame_crop_region_computer.cc +++ b/mediapipe/examples/desktop/autoflip/quality/frame_crop_region_computer.cc @@ -16,6 +16,7 @@ #include +#include "absl/log/absl_log.h" #include "mediapipe/examples/desktop/autoflip/quality/utils.h" #include "mediapipe/framework/port/ret_check.h" @@ -137,7 +138,7 @@ void FrameCropRegionComputer::UpdateCropRegionScore( const float feature_score, const bool is_required, float* crop_region_score) { if (feature_score < 0.0) { - LOG(WARNING) << "Ignoring negative score"; + ABSL_LOG(WARNING) << "Ignoring negative score"; return; } @@ -161,7 +162,8 @@ void FrameCropRegionComputer::UpdateCropRegionScore( break; } default: { - LOG(WARNING) << "Unknown CropRegionScoreType " << score_aggregation_type; + ABSL_LOG(WARNING) << "Unknown CropRegionScoreType " + << score_aggregation_type; break; } } diff --git a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc index 84b229d80..4c9e96b88 100644 --- a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc +++ b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc @@ -190,14 +190,16 @@ TEST(PaddingEffectGeneratorTest, ScaleToMultipleOfTwo) { double target_aspect_ratio = 0.5; int expect_width = 14; int expect_height = input_height; - auto test_frame = absl::make_unique(/*format=*/ImageFormat::SRGB, - input_width, input_height); + ImageFrame test_frame(/*format=*/ImageFormat::SRGB, input_width, + input_height); + cv::Mat mat = formats::MatView(&test_frame); + mat = cv::Scalar(0, 0, 0); - PaddingEffectGenerator generator(test_frame->Width(), test_frame->Height(), + PaddingEffectGenerator generator(test_frame.Width(), test_frame.Height(), target_aspect_ratio, /*scale_to_multiple_of_two=*/true); ImageFrame result_frame; - MP_ASSERT_OK(generator.Process(*test_frame, 0.3, 40, 0.0, &result_frame)); + MP_ASSERT_OK(generator.Process(test_frame, 0.3, 40, 0.0, &result_frame)); EXPECT_EQ(result_frame.Width(), expect_width); EXPECT_EQ(result_frame.Height(), expect_height); } diff --git a/mediapipe/examples/desktop/autoflip/quality/piecewise_linear_function.cc b/mediapipe/examples/desktop/autoflip/quality/piecewise_linear_function.cc index fb8f44f11..6e1fc99e5 100644 --- a/mediapipe/examples/desktop/autoflip/quality/piecewise_linear_function.cc +++ b/mediapipe/examples/desktop/autoflip/quality/piecewise_linear_function.cc @@ -20,6 +20,7 @@ #include #include +#include "absl/log/absl_check.h" #include "mediapipe/framework/port/status.h" namespace mediapipe { @@ -27,7 +28,7 @@ namespace autoflip { void PiecewiseLinearFunction::AddPoint(double x, double y) { if (!points_.empty()) { - CHECK_GE(x, points_.back().x) + ABSL_CHECK_GE(x, points_.back().x) << "Points must be provided in non-decreasing x order."; } points_.push_back(PiecewiseLinearFunction::Point(x, y)); @@ -45,8 +46,8 @@ PiecewiseLinearFunction::GetIntervalIterator(double input) const { double PiecewiseLinearFunction::Interpolate( const PiecewiseLinearFunction::Point& p1, const PiecewiseLinearFunction::Point& p2, double input) const { - CHECK_LT(p1.x, input); - CHECK_GE(p2.x, input); + ABSL_CHECK_LT(p1.x, input); + ABSL_CHECK_GE(p2.x, input); return p2.y - (p2.x - input) / (p2.x - p1.x) * (p2.y - p1.y); } diff --git a/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver_test.cc b/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver_test.cc index c21245cde..7870fb434 100644 --- a/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver_test.cc +++ b/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver_test.cc @@ -14,6 +14,7 @@ #include "mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.h" +#include "absl/log/absl_check.h" #include "mediapipe/examples/desktop/autoflip/quality/focus_point.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -145,8 +146,8 @@ void GenerateDataPointsFromRealVideo( const int prior_focus_point_frames_length, std::vector* focus_point_frames, std::vector* prior_focus_point_frames) { - CHECK(focus_point_frames_length + prior_focus_point_frames_length <= - kNumObservations); + ABSL_CHECK(focus_point_frames_length + prior_focus_point_frames_length <= + kNumObservations); for (int i = 0; i < prior_focus_point_frames_length; i++) { FocusPoint sp; sp.set_norm_point_x(data[i]); diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.h b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.h index d7f06a021..a1528a7d7 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.h +++ b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.h @@ -43,7 +43,7 @@ namespace autoflip { // SceneCameraMotionAnalyzer analyzer(options); // SceneKeyFrameCropSummary scene_summary; // std::vector focus_point_frames; -// CHECK_OK(analyzer.AnalyzeScenePopulateFocusPointFrames( +// ABSL_CHECK_OK(analyzer.AnalyzeScenePopulateFocusPointFrames( // key_frame_crop_infos, key_frame_crop_options, key_frame_crop_results, // scene_frame_width, scene_frame_height, scene_frame_timestamps, // &scene_summary, &focus_point_frames)); diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer_test.cc b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer_test.cc index aa3ba5c6e..3b286e000 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer_test.cc +++ b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer_test.cc @@ -20,6 +20,7 @@ #include #include "absl/flags/flag.h" +#include "absl/log/absl_check.h" #include "absl/strings/str_split.h" #include "mediapipe/examples/desktop/autoflip/autoflip_messages.pb.h" #include "mediapipe/examples/desktop/autoflip/quality/focus_point.pb.h" @@ -744,7 +745,7 @@ TEST(SceneCameraMotionAnalyzerTest, std::vector r = absl::StrSplit(line, ','); records.insert(records.end(), r.begin(), r.end()); } - CHECK_EQ(records.size(), kNumSceneFrames * 3 + 1); + ABSL_CHECK_EQ(records.size(), kNumSceneFrames * 3 + 1); std::vector focus_point_frames; MP_EXPECT_OK(analyzer.PopulateFocusPointFrames( diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_cropper.h b/mediapipe/examples/desktop/autoflip/quality/scene_cropper.h index 0e5c332db..c3c8a35cb 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_cropper.h +++ b/mediapipe/examples/desktop/autoflip/quality/scene_cropper.h @@ -41,7 +41,7 @@ namespace autoflip { // SceneCropperOptions scene_cropper_options; // SceneCropper scene_cropper(scene_cropper_options); // std::vector cropped_frames; -// CHECK_OK(scene_cropper.CropFrames( +// ABSL_CHECK_OK(scene_cropper.CropFrames( // scene_summary, scene_frames, focus_point_frames, // prior_focus_point_frames, &cropped_frames)); class SceneCropper { diff --git a/mediapipe/examples/desktop/autoflip/quality/utils.cc b/mediapipe/examples/desktop/autoflip/quality/utils.cc index 919459263..0695ff759 100644 --- a/mediapipe/examples/desktop/autoflip/quality/utils.cc +++ b/mediapipe/examples/desktop/autoflip/quality/utils.cc @@ -19,6 +19,7 @@ #include #include +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "mediapipe/examples/desktop/autoflip/quality/math_utils.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" @@ -121,12 +122,12 @@ absl::Status PackKeyFrameInfo(const int64_t frame_timestamp_ms, ScaleRect(original_detection.location(), scale_x, scale_y, &location); } else { has_valid_location = false; - LOG(ERROR) << "Detection missing a bounding box, skipped."; + ABSL_LOG(ERROR) << "Detection missing a bounding box, skipped."; } if (has_valid_location) { if (!ClampRect(original_frame_width, original_frame_height, &location) .ok()) { - LOG(ERROR) << "Invalid detection bounding box, skipped."; + ABSL_LOG(ERROR) << "Invalid detection bounding box, skipped."; continue; } auto* detection = processed_detections->add_detections(); diff --git a/mediapipe/examples/desktop/autoflip/quality/visual_scorer.cc b/mediapipe/examples/desktop/autoflip/quality/visual_scorer.cc index 9ae612004..661922fd9 100644 --- a/mediapipe/examples/desktop/autoflip/quality/visual_scorer.cc +++ b/mediapipe/examples/desktop/autoflip/quality/visual_scorer.cc @@ -21,6 +21,7 @@ #include #include +#include "absl/log/absl_log.h" #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/ret_check.h" @@ -106,7 +107,7 @@ absl::Status VisualScorer::CalculateScore(const cv::Mat& image, *score = (area_score + sharpness_score + colorfulness_score) / weight_sum; if (*score > 1.0f || *score < 0.0f) { - LOG(WARNING) << "Score of region outside expected range: " << *score; + ABSL_LOG(WARNING) << "Score of region outside expected range: " << *score; } return absl::OkStatus(); } diff --git a/mediapipe/examples/desktop/demo_run_graph_main.cc b/mediapipe/examples/desktop/demo_run_graph_main.cc index 0d26aa0d3..ba36ba6c9 100644 --- a/mediapipe/examples/desktop/demo_run_graph_main.cc +++ b/mediapipe/examples/desktop/demo_run_graph_main.cc @@ -17,6 +17,7 @@ #include "absl/flags/flag.h" #include "absl/flags/parse.h" +#include "absl/log/absl_log.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" @@ -26,6 +27,7 @@ #include "mediapipe/framework/port/opencv_video_inc.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status.h" +#include "mediapipe/util/resource_util.h" constexpr char kInputStream[] = "input_video"; constexpr char kOutputStream[] = "output_video"; @@ -45,17 +47,17 @@ absl::Status RunMPPGraph() { MP_RETURN_IF_ERROR(mediapipe::file::GetContents( absl::GetFlag(FLAGS_calculator_graph_config_file), &calculator_graph_config_contents)); - LOG(INFO) << "Get calculator graph config contents: " - << calculator_graph_config_contents; + ABSL_LOG(INFO) << "Get calculator graph config contents: " + << calculator_graph_config_contents; mediapipe::CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie( calculator_graph_config_contents); - LOG(INFO) << "Initialize the calculator graph."; + ABSL_LOG(INFO) << "Initialize the calculator graph."; mediapipe::CalculatorGraph graph; MP_RETURN_IF_ERROR(graph.Initialize(config)); - LOG(INFO) << "Initialize the camera or load the video."; + ABSL_LOG(INFO) << "Initialize the camera or load the video."; cv::VideoCapture capture; const bool load_video = !absl::GetFlag(FLAGS_input_video_path).empty(); if (load_video) { @@ -76,12 +78,12 @@ absl::Status RunMPPGraph() { #endif } - LOG(INFO) << "Start running the calculator graph."; + ABSL_LOG(INFO) << "Start running the calculator graph."; ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller poller, graph.AddOutputStreamPoller(kOutputStream)); MP_RETURN_IF_ERROR(graph.StartRun({})); - LOG(INFO) << "Start grabbing and processing frames."; + ABSL_LOG(INFO) << "Start grabbing and processing frames."; bool grab_frames = true; while (grab_frames) { // Capture opencv camera or video frame. @@ -89,10 +91,10 @@ absl::Status RunMPPGraph() { capture >> camera_frame_raw; if (camera_frame_raw.empty()) { if (!load_video) { - LOG(INFO) << "Ignore empty frames from camera."; + ABSL_LOG(INFO) << "Ignore empty frames from camera."; continue; } - LOG(INFO) << "Empty frame, end of video reached."; + ABSL_LOG(INFO) << "Empty frame, end of video reached."; break; } cv::Mat camera_frame; @@ -125,7 +127,7 @@ absl::Status RunMPPGraph() { cv::cvtColor(output_frame_mat, output_frame_mat, cv::COLOR_RGB2BGR); if (save_video) { if (!writer.isOpened()) { - LOG(INFO) << "Prepare video writer."; + ABSL_LOG(INFO) << "Prepare video writer."; writer.open(absl::GetFlag(FLAGS_output_video_path), mediapipe::fourcc('a', 'v', 'c', '1'), // .mp4 capture.get(cv::CAP_PROP_FPS), output_frame_mat.size()); @@ -140,7 +142,7 @@ absl::Status RunMPPGraph() { } } - LOG(INFO) << "Shutting down."; + ABSL_LOG(INFO) << "Shutting down."; if (writer.isOpened()) writer.release(); MP_RETURN_IF_ERROR(graph.CloseInputStream(kInputStream)); return graph.WaitUntilDone(); @@ -151,10 +153,10 @@ int main(int argc, char** argv) { absl::ParseCommandLine(argc, argv); absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { - LOG(ERROR) << "Failed to run the graph: " << run_status.message(); + ABSL_LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; } else { - LOG(INFO) << "Success!"; + ABSL_LOG(INFO) << "Success!"; } return EXIT_SUCCESS; } diff --git a/mediapipe/examples/desktop/demo_run_graph_main_gpu.cc b/mediapipe/examples/desktop/demo_run_graph_main_gpu.cc index 586565db4..5702bca72 100644 --- a/mediapipe/examples/desktop/demo_run_graph_main_gpu.cc +++ b/mediapipe/examples/desktop/demo_run_graph_main_gpu.cc @@ -18,6 +18,7 @@ #include "absl/flags/flag.h" #include "absl/flags/parse.h" +#include "absl/log/absl_log.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" @@ -30,6 +31,7 @@ #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_shared_data_internal.h" +#include "mediapipe/util/resource_util.h" constexpr char kInputStream[] = "input_video"; constexpr char kOutputStream[] = "output_video"; @@ -49,23 +51,23 @@ absl::Status RunMPPGraph() { MP_RETURN_IF_ERROR(mediapipe::file::GetContents( absl::GetFlag(FLAGS_calculator_graph_config_file), &calculator_graph_config_contents)); - LOG(INFO) << "Get calculator graph config contents: " - << calculator_graph_config_contents; + ABSL_LOG(INFO) << "Get calculator graph config contents: " + << calculator_graph_config_contents; mediapipe::CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie( calculator_graph_config_contents); - LOG(INFO) << "Initialize the calculator graph."; + ABSL_LOG(INFO) << "Initialize the calculator graph."; mediapipe::CalculatorGraph graph; MP_RETURN_IF_ERROR(graph.Initialize(config)); - LOG(INFO) << "Initialize the GPU."; + ABSL_LOG(INFO) << "Initialize the GPU."; ASSIGN_OR_RETURN(auto gpu_resources, mediapipe::GpuResources::Create()); MP_RETURN_IF_ERROR(graph.SetGpuResources(std::move(gpu_resources))); mediapipe::GlCalculatorHelper gpu_helper; gpu_helper.InitializeForTest(graph.GetGpuResources().get()); - LOG(INFO) << "Initialize the camera or load the video."; + ABSL_LOG(INFO) << "Initialize the camera or load the video."; cv::VideoCapture capture; const bool load_video = !absl::GetFlag(FLAGS_input_video_path).empty(); if (load_video) { @@ -86,12 +88,12 @@ absl::Status RunMPPGraph() { #endif } - LOG(INFO) << "Start running the calculator graph."; + ABSL_LOG(INFO) << "Start running the calculator graph."; ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller poller, graph.AddOutputStreamPoller(kOutputStream)); MP_RETURN_IF_ERROR(graph.StartRun({})); - LOG(INFO) << "Start grabbing and processing frames."; + ABSL_LOG(INFO) << "Start grabbing and processing frames."; bool grab_frames = true; while (grab_frames) { // Capture opencv camera or video frame. @@ -99,10 +101,10 @@ absl::Status RunMPPGraph() { capture >> camera_frame_raw; if (camera_frame_raw.empty()) { if (!load_video) { - LOG(INFO) << "Ignore empty frames from camera."; + ABSL_LOG(INFO) << "Ignore empty frames from camera."; continue; } - LOG(INFO) << "Empty frame, end of video reached."; + ABSL_LOG(INFO) << "Empty frame, end of video reached."; break; } cv::Mat camera_frame; @@ -168,7 +170,7 @@ absl::Status RunMPPGraph() { cv::cvtColor(output_frame_mat, output_frame_mat, cv::COLOR_RGB2BGR); if (save_video) { if (!writer.isOpened()) { - LOG(INFO) << "Prepare video writer."; + ABSL_LOG(INFO) << "Prepare video writer."; writer.open(absl::GetFlag(FLAGS_output_video_path), mediapipe::fourcc('a', 'v', 'c', '1'), // .mp4 capture.get(cv::CAP_PROP_FPS), output_frame_mat.size()); @@ -183,7 +185,7 @@ absl::Status RunMPPGraph() { } } - LOG(INFO) << "Shutting down."; + ABSL_LOG(INFO) << "Shutting down."; if (writer.isOpened()) writer.release(); MP_RETURN_IF_ERROR(graph.CloseInputStream(kInputStream)); return graph.WaitUntilDone(); @@ -194,10 +196,10 @@ int main(int argc, char** argv) { absl::ParseCommandLine(argc, argv); absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { - LOG(ERROR) << "Failed to run the graph: " << run_status.message(); + ABSL_LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; } else { - LOG(INFO) << "Success!"; + ABSL_LOG(INFO) << "Success!"; } return EXIT_SUCCESS; } diff --git a/mediapipe/examples/desktop/hello_world/BUILD b/mediapipe/examples/desktop/hello_world/BUILD index 27aa088e7..14eff2dbd 100644 --- a/mediapipe/examples/desktop/hello_world/BUILD +++ b/mediapipe/examples/desktop/hello_world/BUILD @@ -22,8 +22,9 @@ cc_binary( deps = [ "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_graph", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", ], ) diff --git a/mediapipe/examples/desktop/hello_world/hello_world.cc b/mediapipe/examples/desktop/hello_world/hello_world.cc index fde821b51..85cf6c32a 100644 --- a/mediapipe/examples/desktop/hello_world/hello_world.cc +++ b/mediapipe/examples/desktop/hello_world/hello_world.cc @@ -14,8 +14,9 @@ // // A simple example to print out "Hello World!" from a MediaPipe graph. +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "mediapipe/framework/calculator_graph.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status.h" @@ -54,7 +55,7 @@ absl::Status PrintHelloWorld() { mediapipe::Packet packet; // Get the output packets string. while (poller.Next(&packet)) { - LOG(INFO) << packet.Get(); + ABSL_LOG(INFO) << packet.Get(); } return graph.WaitUntilDone(); } @@ -62,6 +63,6 @@ absl::Status PrintHelloWorld() { int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); - CHECK(mediapipe::PrintHelloWorld().ok()); + ABSL_CHECK(mediapipe::PrintHelloWorld().ok()); return 0; } diff --git a/mediapipe/examples/desktop/iris_tracking/BUILD b/mediapipe/examples/desktop/iris_tracking/BUILD index b9f3f6f4e..147a0ac25 100644 --- a/mediapipe/examples/desktop/iris_tracking/BUILD +++ b/mediapipe/examples/desktop/iris_tracking/BUILD @@ -33,6 +33,7 @@ cc_binary( "//mediapipe/graphs/iris_tracking:iris_depth_cpu_deps", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/log:absl_log", ], ) diff --git a/mediapipe/examples/desktop/iris_tracking/iris_depth_from_image_desktop.cc b/mediapipe/examples/desktop/iris_tracking/iris_depth_from_image_desktop.cc index 928ebb207..37476b2b3 100644 --- a/mediapipe/examples/desktop/iris_tracking/iris_depth_from_image_desktop.cc +++ b/mediapipe/examples/desktop/iris_tracking/iris_depth_from_image_desktop.cc @@ -19,6 +19,7 @@ #include "absl/flags/flag.h" #include "absl/flags/parse.h" +#include "absl/log/absl_log.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" @@ -55,11 +56,11 @@ absl::StatusOr ReadFileToString(const std::string& file_path) { } absl::Status ProcessImage(std::unique_ptr graph) { - LOG(INFO) << "Load the image."; + ABSL_LOG(INFO) << "Load the image."; ASSIGN_OR_RETURN(const std::string raw_image, ReadFileToString(absl::GetFlag(FLAGS_input_image_path))); - LOG(INFO) << "Start running the calculator graph."; + ABSL_LOG(INFO) << "Start running the calculator graph."; ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller output_image_poller, graph->AddOutputStreamPoller(kOutputImageStream)); ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller left_iris_depth_poller, @@ -108,7 +109,7 @@ absl::Status ProcessImage(std::unique_ptr graph) { cv::cvtColor(output_frame_mat, output_frame_mat, cv::COLOR_RGB2BGR); const bool save_image = !absl::GetFlag(FLAGS_output_image_path).empty(); if (save_image) { - LOG(INFO) << "Saving image to file..."; + ABSL_LOG(INFO) << "Saving image to file..."; cv::imwrite(absl::GetFlag(FLAGS_output_image_path), output_frame_mat); } else { cv::namedWindow(kWindowName, /*flags=WINDOW_AUTOSIZE*/ 1); @@ -117,7 +118,7 @@ absl::Status ProcessImage(std::unique_ptr graph) { cv::waitKey(0); } - LOG(INFO) << "Shutting down."; + ABSL_LOG(INFO) << "Shutting down."; MP_RETURN_IF_ERROR(graph->CloseInputStream(kInputStream)); return graph->WaitUntilDone(); } @@ -126,13 +127,13 @@ absl::Status RunMPPGraph() { std::string calculator_graph_config_contents; MP_RETURN_IF_ERROR(mediapipe::file::GetContents( kCalculatorGraphConfigFile, &calculator_graph_config_contents)); - LOG(INFO) << "Get calculator graph config contents: " - << calculator_graph_config_contents; + ABSL_LOG(INFO) << "Get calculator graph config contents: " + << calculator_graph_config_contents; mediapipe::CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie( calculator_graph_config_contents); - LOG(INFO) << "Initialize the calculator graph."; + ABSL_LOG(INFO) << "Initialize the calculator graph."; std::unique_ptr graph = absl::make_unique(); MP_RETURN_IF_ERROR(graph->Initialize(config)); @@ -152,10 +153,10 @@ int main(int argc, char** argv) { absl::ParseCommandLine(argc, argv); absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { - LOG(ERROR) << "Failed to run the graph: " << run_status.message(); + ABSL_LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; } else { - LOG(INFO) << "Success!"; + ABSL_LOG(INFO) << "Success!"; } return EXIT_SUCCESS; } diff --git a/mediapipe/examples/desktop/media_sequence/BUILD b/mediapipe/examples/desktop/media_sequence/BUILD index 1a88aa109..53f932948 100644 --- a/mediapipe/examples/desktop/media_sequence/BUILD +++ b/mediapipe/examples/desktop/media_sequence/BUILD @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Placeholder: load py_library +# Placeholder: load py_binary + licenses(["notice"]) package(default_visibility = ["//mediapipe/examples:__subpackages__"]) @@ -27,6 +30,7 @@ cc_library( "//mediapipe/framework/port:status", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc b/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc index 06212b013..a14c7734d 100644 --- a/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc +++ b/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc @@ -19,6 +19,7 @@ #include "absl/flags/flag.h" #include "absl/flags/parse.h" +#include "absl/log/absl_log.h" #include "absl/strings/str_split.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/file_helpers.h" @@ -43,8 +44,8 @@ absl::Status RunMPPGraph() { MP_RETURN_IF_ERROR(mediapipe::file::GetContents( absl::GetFlag(FLAGS_calculator_graph_config_file), &calculator_graph_config_contents)); - LOG(INFO) << "Get calculator graph config contents: " - << calculator_graph_config_contents; + ABSL_LOG(INFO) << "Get calculator graph config contents: " + << calculator_graph_config_contents; mediapipe::CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie( calculator_graph_config_contents); @@ -61,12 +62,12 @@ absl::Status RunMPPGraph() { input_side_packets[name_and_value[0]] = mediapipe::MakePacket(input_side_packet_contents); } - LOG(INFO) << "Initialize the calculator graph."; + ABSL_LOG(INFO) << "Initialize the calculator graph."; mediapipe::CalculatorGraph graph; MP_RETURN_IF_ERROR(graph.Initialize(config, input_side_packets)); - LOG(INFO) << "Start running the calculator graph."; + ABSL_LOG(INFO) << "Start running the calculator graph."; MP_RETURN_IF_ERROR(graph.Run()); - LOG(INFO) << "Gathering output side packets."; + ABSL_LOG(INFO) << "Gathering output side packets."; kv_pairs = absl::StrSplit(absl::GetFlag(FLAGS_output_side_packets), ','); for (const std::string& kv_pair : kv_pairs) { std::vector name_and_value = absl::StrSplit(kv_pair, '='); @@ -88,10 +89,10 @@ int main(int argc, char** argv) { absl::ParseCommandLine(argc, argv); absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { - LOG(ERROR) << "Failed to run the graph: " << run_status.message(); + ABSL_LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; } else { - LOG(INFO) << "Success!"; + ABSL_LOG(INFO) << "Success!"; } return EXIT_SUCCESS; } diff --git a/mediapipe/examples/desktop/simple_run_graph_main.cc b/mediapipe/examples/desktop/simple_run_graph_main.cc index 96d9839a8..e794902d8 100644 --- a/mediapipe/examples/desktop/simple_run_graph_main.cc +++ b/mediapipe/examples/desktop/simple_run_graph_main.cc @@ -22,6 +22,7 @@ #include "absl/flags/flag.h" #include "absl/flags/parse.h" +#include "absl/log/absl_log.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" @@ -102,8 +103,8 @@ absl::Status RunMPPGraph() { MP_RETURN_IF_ERROR(mediapipe::file::GetContents( absl::GetFlag(FLAGS_calculator_graph_config_file), &calculator_graph_config_contents)); - LOG(INFO) << "Get calculator graph config contents: " - << calculator_graph_config_contents; + ABSL_LOG(INFO) << "Get calculator graph config contents: " + << calculator_graph_config_contents; mediapipe::CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie( calculator_graph_config_contents); @@ -119,14 +120,14 @@ absl::Status RunMPPGraph() { mediapipe::MakePacket(name_and_value[1]); } } - LOG(INFO) << "Initialize the calculator graph."; + ABSL_LOG(INFO) << "Initialize the calculator graph."; mediapipe::CalculatorGraph graph; MP_RETURN_IF_ERROR(graph.Initialize(config, input_side_packets)); if (!absl::GetFlag(FLAGS_output_stream).empty() && !absl::GetFlag(FLAGS_output_stream_file).empty()) { ASSIGN_OR_RETURN(auto poller, graph.AddOutputStreamPoller( absl::GetFlag(FLAGS_output_stream))); - LOG(INFO) << "Start running the calculator graph."; + ABSL_LOG(INFO) << "Start running the calculator graph."; MP_RETURN_IF_ERROR(graph.StartRun({})); MP_RETURN_IF_ERROR(OutputStreamToLocalFile(poller)); } else { @@ -134,7 +135,7 @@ absl::Status RunMPPGraph() { absl::GetFlag(FLAGS_output_stream_file).empty()) << "--output_stream and --output_stream_file should be specified in " "pair."; - LOG(INFO) << "Start running the calculator graph."; + ABSL_LOG(INFO) << "Start running the calculator graph."; MP_RETURN_IF_ERROR(graph.StartRun({})); } MP_RETURN_IF_ERROR(graph.WaitUntilDone()); @@ -146,10 +147,10 @@ int main(int argc, char** argv) { absl::ParseCommandLine(argc, argv); absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { - LOG(ERROR) << "Failed to run the graph: " << run_status.message(); + ABSL_LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; } else { - LOG(INFO) << "Success!"; + ABSL_LOG(INFO) << "Success!"; } return EXIT_SUCCESS; } diff --git a/mediapipe/examples/desktop/youtube8m/BUILD b/mediapipe/examples/desktop/youtube8m/BUILD index e0e44c4d9..783c7a9dd 100644 --- a/mediapipe/examples/desktop/youtube8m/BUILD +++ b/mediapipe/examples/desktop/youtube8m/BUILD @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Placeholder: load py_binary + licenses(["notice"]) cc_binary( @@ -20,6 +22,7 @@ cc_binary( deps = [ "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", diff --git a/mediapipe/examples/desktop/youtube8m/extract_yt8m_features.cc b/mediapipe/examples/desktop/youtube8m/extract_yt8m_features.cc index 9030e9255..dbabf84b1 100644 --- a/mediapipe/examples/desktop/youtube8m/extract_yt8m_features.cc +++ b/mediapipe/examples/desktop/youtube8m/extract_yt8m_features.cc @@ -19,6 +19,7 @@ #include "absl/flags/flag.h" #include "absl/flags/parse.h" +#include "absl/log/absl_log.h" #include "absl/strings/str_split.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/matrix.h" @@ -44,8 +45,8 @@ absl::Status RunMPPGraph() { MP_RETURN_IF_ERROR(mediapipe::file::GetContents( absl::GetFlag(FLAGS_calculator_graph_config_file), &calculator_graph_config_contents)); - LOG(INFO) << "Get calculator graph config contents: " - << calculator_graph_config_contents; + ABSL_LOG(INFO) << "Get calculator graph config contents: " + << calculator_graph_config_contents; mediapipe::CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie( calculator_graph_config_contents); @@ -102,12 +103,12 @@ absl::Status RunMPPGraph() { input_side_packets["vggish_pca_projection_matrix"] = mediapipe::MakePacket(vggish_pca_projection_matrix); - LOG(INFO) << "Initialize the calculator graph."; + ABSL_LOG(INFO) << "Initialize the calculator graph."; mediapipe::CalculatorGraph graph; MP_RETURN_IF_ERROR(graph.Initialize(config, input_side_packets)); - LOG(INFO) << "Start running the calculator graph."; + ABSL_LOG(INFO) << "Start running the calculator graph."; MP_RETURN_IF_ERROR(graph.Run()); - LOG(INFO) << "Gathering output side packets."; + ABSL_LOG(INFO) << "Gathering output side packets."; kv_pairs = absl::StrSplit(absl::GetFlag(FLAGS_output_side_packets), ','); for (const std::string& kv_pair : kv_pairs) { std::vector name_and_value = absl::StrSplit(kv_pair, '='); @@ -129,10 +130,10 @@ int main(int argc, char** argv) { absl::ParseCommandLine(argc, argv); absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { - LOG(ERROR) << "Failed to run the graph: " << run_status.message(); + ABSL_LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; } else { - LOG(INFO) << "Success!"; + ABSL_LOG(INFO) << "Success!"; } return EXIT_SUCCESS; } diff --git a/mediapipe/examples/ios/BUILD b/mediapipe/examples/ios/BUILD index fd611a615..1aed02282 100644 --- a/mediapipe/examples/ios/BUILD +++ b/mediapipe/examples/ios/BUILD @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Placeholder: load py_test + licenses(["notice"]) package(default_visibility = ["//visibility:public"]) diff --git a/mediapipe/examples/ios/facedetectioncpu/BUILD b/mediapipe/examples/ios/facedetectioncpu/BUILD index 9424fddea..300901909 100644 --- a/mediapipe/examples/ios/facedetectioncpu/BUILD +++ b/mediapipe/examples/ios/facedetectioncpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "facedetectioncpu", diff --git a/mediapipe/examples/ios/facedetectiongpu/BUILD b/mediapipe/examples/ios/facedetectiongpu/BUILD index 8ed689b4f..d3725aa33 100644 --- a/mediapipe/examples/ios/facedetectiongpu/BUILD +++ b/mediapipe/examples/ios/facedetectiongpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "facedetectiongpu", diff --git a/mediapipe/examples/ios/faceeffect/BUILD b/mediapipe/examples/ios/faceeffect/BUILD index 1152bed33..c9415068b 100644 --- a/mediapipe/examples/ios/faceeffect/BUILD +++ b/mediapipe/examples/ios/faceeffect/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "faceeffect", diff --git a/mediapipe/examples/ios/facemeshgpu/BUILD b/mediapipe/examples/ios/facemeshgpu/BUILD index 6caf8c09c..250a8bca1 100644 --- a/mediapipe/examples/ios/facemeshgpu/BUILD +++ b/mediapipe/examples/ios/facemeshgpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "facemeshgpu", diff --git a/mediapipe/examples/ios/handdetectiongpu/BUILD b/mediapipe/examples/ios/handdetectiongpu/BUILD index 9b9255374..6deb1be1d 100644 --- a/mediapipe/examples/ios/handdetectiongpu/BUILD +++ b/mediapipe/examples/ios/handdetectiongpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "handdetectiongpu", diff --git a/mediapipe/examples/ios/handtrackinggpu/BUILD b/mediapipe/examples/ios/handtrackinggpu/BUILD index c5b8e7b58..b8f1442fe 100644 --- a/mediapipe/examples/ios/handtrackinggpu/BUILD +++ b/mediapipe/examples/ios/handtrackinggpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "handtrackinggpu", diff --git a/mediapipe/examples/ios/helloworld/BUILD b/mediapipe/examples/ios/helloworld/BUILD index 6bfcfaaef..3bed74843 100644 --- a/mediapipe/examples/ios/helloworld/BUILD +++ b/mediapipe/examples/ios/helloworld/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "helloworld", diff --git a/mediapipe/examples/ios/holistictrackinggpu/BUILD b/mediapipe/examples/ios/holistictrackinggpu/BUILD index cd10877de..56c74148c 100644 --- a/mediapipe/examples/ios/holistictrackinggpu/BUILD +++ b/mediapipe/examples/ios/holistictrackinggpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "holistictrackinggpu", diff --git a/mediapipe/examples/ios/iristrackinggpu/BUILD b/mediapipe/examples/ios/iristrackinggpu/BUILD index 646d2e5a2..78d4bbd1e 100644 --- a/mediapipe/examples/ios/iristrackinggpu/BUILD +++ b/mediapipe/examples/ios/iristrackinggpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "iristrackinggpu", diff --git a/mediapipe/examples/ios/link_local_profiles.py b/mediapipe/examples/ios/link_local_profiles.py index bc4a06c97..d353c337e 100755 --- a/mediapipe/examples/ios/link_local_profiles.py +++ b/mediapipe/examples/ios/link_local_profiles.py @@ -147,12 +147,18 @@ def main(): f"Looking for profiles for app ids with prefix '{bundle_id_prefix}' in '{profile_dir}'" ) + profiles_found = False for name in os.listdir(profile_dir): if not name.endswith(".mobileprovision"): continue + profiles_found = True profile_path = os.path.join(profile_dir, name) process_profile(profile_path, our_app_id_re) + if not profiles_found: + print("Error: Unable to find any provisioning profiles " + + f"(*.mobileprovision files) in '{profile_dir}'") + if __name__ == "__main__": main() diff --git a/mediapipe/examples/ios/objectdetectioncpu/BUILD b/mediapipe/examples/ios/objectdetectioncpu/BUILD index 7638c7413..47bde166e 100644 --- a/mediapipe/examples/ios/objectdetectioncpu/BUILD +++ b/mediapipe/examples/ios/objectdetectioncpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "objectdetectioncpu", diff --git a/mediapipe/examples/ios/objectdetectiongpu/BUILD b/mediapipe/examples/ios/objectdetectiongpu/BUILD index 3b925c078..174db7582 100644 --- a/mediapipe/examples/ios/objectdetectiongpu/BUILD +++ b/mediapipe/examples/ios/objectdetectiongpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "objectdetectiongpu", diff --git a/mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD b/mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD index 2236c5257..cb8626cc3 100644 --- a/mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD +++ b/mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "objectdetectiontrackinggpu", diff --git a/mediapipe/examples/ios/posetrackinggpu/BUILD b/mediapipe/examples/ios/posetrackinggpu/BUILD index 4fbc2280c..855d32954 100644 --- a/mediapipe/examples/ios/posetrackinggpu/BUILD +++ b/mediapipe/examples/ios/posetrackinggpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "posetrackinggpu", diff --git a/mediapipe/examples/ios/selfiesegmentationgpu/BUILD b/mediapipe/examples/ios/selfiesegmentationgpu/BUILD index 1ba7997ed..2abf05617 100644 --- a/mediapipe/examples/ios/selfiesegmentationgpu/BUILD +++ b/mediapipe/examples/ios/selfiesegmentationgpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "selfiesegmentationgpu", diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index ae788ed58..b289fc582 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -33,7 +33,9 @@ bzl_library( srcs = [ "transitive_protos.bzl", ], - visibility = ["//mediapipe/framework:__subpackages__"], + visibility = [ + "//mediapipe/framework:__subpackages__", + ], ) bzl_library( @@ -42,6 +44,9 @@ bzl_library( "encode_binary_proto.bzl", ], visibility = ["//visibility:public"], + deps = [ + "@bazel_skylib//lib:paths", + ], ) alias( @@ -177,8 +182,8 @@ cc_library( ":timestamp", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:logging", - "//mediapipe/framework/port:status", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", ], ) @@ -199,6 +204,7 @@ cc_library( ":timestamp", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", ], ) @@ -215,6 +221,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/tool:tag_map", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", ], @@ -267,6 +274,8 @@ cc_library( ], deps = [ ":calculator_base", + ":calculator_context", + ":calculator_contract", ":calculator_graph", ":calculator_registry", ":counter_factory", @@ -316,7 +325,6 @@ cc_library( ":input_stream_manager", ":mediapipe_profiling", ":output_side_packet_impl", - ":output_stream", ":output_stream_manager", ":output_stream_poller", ":output_stream_shard", @@ -330,6 +338,7 @@ cc_library( ":scheduler_queue", ":status_handler", ":status_handler_cc_proto", + ":subgraph", ":thread_pool_executor", ":thread_pool_executor_cc_proto", ":timestamp", @@ -337,6 +346,7 @@ cc_library( "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:map_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:source_location", "//mediapipe/framework/port:status", @@ -350,11 +360,13 @@ cc_library( "//mediapipe/gpu:graph_support", "//mediapipe/util:cpu_util", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", @@ -423,6 +435,8 @@ cc_library( "//mediapipe/framework/tool:tag_map", "//mediapipe/framework/tool:validate_name", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -449,11 +463,12 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":calculator_framework", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:sink", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -479,6 +494,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/tool:options_map", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", ], ) @@ -496,11 +512,12 @@ cc_library( deps = [ ":collection_item_id", ":type_map", - "//mediapipe/framework/port:logging", "//mediapipe/framework/tool:tag_map", "//mediapipe/framework/tool:tag_map_helper", "//mediapipe/framework/tool:validate_name", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -534,6 +551,8 @@ cc_library( "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:map_util", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", @@ -604,10 +623,11 @@ cc_library( ":packet_set", ":packet_type", ":timestamp", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", ], @@ -622,6 +642,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", ], ) @@ -640,6 +661,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:fill_packet_set", + "@com_google_absl//absl/log:absl_check", ], ) @@ -678,6 +700,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:tag_map", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", ], ) @@ -698,6 +721,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/tool:status_util", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", ], @@ -717,6 +741,7 @@ cc_library( "//mediapipe/framework/port:source_location", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:status_util", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", ], ) @@ -756,6 +781,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:source_location", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", ], ) @@ -794,6 +820,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/tool:tag_map", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/synchronization", ], ) @@ -812,6 +839,7 @@ cc_library( ":timestamp", "//mediapipe/framework/port:source_location", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/synchronization", ], ) @@ -822,6 +850,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":graph_output_stream", + "@com_google_absl//absl/log:absl_check", ], ) @@ -838,6 +867,7 @@ cc_library( ":timestamp", "//mediapipe/framework/port:source_location", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", ], ) @@ -862,6 +892,8 @@ cc_library( "//mediapipe/framework/port:statusor", "//mediapipe/framework/tool:type_util", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -942,6 +974,8 @@ cc_library( "//mediapipe/framework/tool:type_util", "//mediapipe/framework/tool:validate_name", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1018,6 +1052,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/synchronization", ], ) @@ -1077,6 +1112,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@eigen_archive//:eigen3", ], @@ -1097,6 +1133,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", ], + alwayslink = True, # Defines TestServiceCalculator ) cc_library( @@ -1126,6 +1163,8 @@ cc_library( "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", ], ) @@ -1146,6 +1185,8 @@ cc_library( "//mediapipe/framework/tool:status_util", "//mediapipe/framework/tool:type_util", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/synchronization", ], alwayslink = 1, @@ -1175,7 +1216,6 @@ cc_library( ":calculator_contract", ":graph_service_manager", ":legacy_calculator_support", - ":packet", ":packet_generator", ":packet_generator_cc_proto", ":packet_set", @@ -1186,7 +1226,6 @@ cc_library( ":stream_handler_cc_proto", ":subgraph", ":thread_pool_executor_cc_proto", - ":timestamp", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", @@ -1200,10 +1239,12 @@ cc_library( "//mediapipe/framework/tool:subgraph_expansion", "//mediapipe/framework/tool:validate", "//mediapipe/framework/tool:validate_name", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", ], ) @@ -1285,10 +1326,11 @@ cc_test( ":calculator_node", "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework/port:gtest_main", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:source", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", ], ) @@ -1352,6 +1394,23 @@ cc_test( ], ) +cc_test( + name = "calculator_graph_summary_packet_test", + srcs = ["calculator_graph_summary_packet_test.cc"], + deps = [ + ":calculator_framework", + ":packet", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + "//mediapipe/framework/tool:sink", + "@com_google_absl//absl/status", + ], +) + cc_test( name = "calculator_runner_test", size = "medium", @@ -1365,8 +1424,8 @@ cc_test( ":packet_type", ":timestamp", "//mediapipe/framework/port:gtest_main", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", ], ) @@ -1403,6 +1462,7 @@ cc_test( "calculator_graph_test.cc", ], linkstatic = 1, + tags = ["not_run:arm"], visibility = ["//visibility:public"], deps = [ ":calculator_framework", @@ -1427,7 +1487,6 @@ cc_test( "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework/port:gtest_main", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -1442,7 +1501,10 @@ cc_test( "//mediapipe/framework/tool:status_util", "//mediapipe/gpu:gpu_service", "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", @@ -1497,11 +1559,11 @@ cc_test( "//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework/port:gtest_main", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:sink", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/time", ], ) @@ -1628,6 +1690,7 @@ cc_test( ":packet", ":packet_test_cc_proto", ":type_map", + "//mediapipe/framework/api2:builder", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:gtest_main", "@com_google_absl//absl/strings", @@ -1672,6 +1735,7 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:template_parser", + "@com_google_absl//absl/log:absl_check", ], ) diff --git a/mediapipe/framework/api2/BUILD b/mediapipe/framework/api2/BUILD index 8a3946899..5c5ec04ea 100644 --- a/mediapipe/framework/api2/BUILD +++ b/mediapipe/framework/api2/BUILD @@ -1,3 +1,5 @@ +# Placeholder: load py_test + package( default_visibility = ["//visibility:public"], features = ["-use_header_modules"], @@ -20,6 +22,7 @@ cc_library( "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], @@ -112,6 +115,7 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_log", ], ) @@ -123,6 +127,7 @@ cc_library( ":tuple", "//mediapipe/framework:packet", "//mediapipe/framework/port:logging", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/meta:type_traits", ], ) @@ -152,6 +157,7 @@ cc_library( "//mediapipe/framework:output_side_packet", "//mediapipe/framework/port:logging", "//mediapipe/framework/tool:type_util", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/framework/api2/README.md b/mediapipe/framework/api2/README.md index eb53dd67e..849a5f4c4 100644 --- a/mediapipe/framework/api2/README.md +++ b/mediapipe/framework/api2/README.md @@ -52,7 +52,7 @@ int select = cc->Inputs().Tag(kSelectTag).Get(); write ``` -int select = kSelectTag(cc).Get(); // alternative: *kSelectTag(cc) +int select = kSelect(cc).Get(); // alternative: *kSelect(cc) ``` Sets of multiple ports can be declared with `::Multiple`. Note, also, that a tag diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index adc2c9ffa..fde281121 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -11,6 +11,7 @@ #include #include "absl/container/btree_map.h" +#include "absl/log/absl_check.h" #include "absl/strings/string_view.h" #include "google/protobuf/message_lite.h" #include "mediapipe/framework/api2/port.h" @@ -32,7 +33,7 @@ template struct dependent_false : std::false_type {}; template -T& GetWithAutoGrow(std::vector>* vecp, int index) { +T& GetWithAutoGrow(std::vector>* vecp, size_t index) { auto& vec = *vecp; if (vec.size() <= index) { vec.resize(index + 1); @@ -109,7 +110,7 @@ class MultiPort : public Single { : Single(vec), vec_(*vec) {} Single operator[](int index) { - CHECK_GE(index, 0); + ABSL_CHECK_GE(index, 0); return Single{&GetWithAutoGrow(&vec_, index)}; } @@ -193,7 +194,7 @@ class SourceImpl { template {}, int>::type = 0> Src& ConnectTo(const Dst& dest) { - CHECK(dest.base_.source == nullptr); + ABSL_CHECK(dest.base_.source == nullptr); dest.base_.source = base_; base_->dests_.emplace_back(&dest.base_); return *this; @@ -223,6 +224,16 @@ class SourceImpl { return !(*this == other); } + Src& SetName(const char* name) { + base_->name_ = std::string(name); + return *this; + } + + Src& SetName(absl::string_view name) { + base_->name_ = std::string(name); + return *this; + } + Src& SetName(std::string name) { base_->name_ = std::move(name); return *this; @@ -711,14 +722,14 @@ class Graph { config.set_type(type_); } FixUnnamedConnections(); - CHECK_OK(UpdateBoundaryConfig(&config)); + ABSL_CHECK_OK(UpdateBoundaryConfig(&config)); for (const std::unique_ptr& node : nodes_) { auto* out_node = config.add_node(); - CHECK_OK(UpdateNodeConfig(*node, out_node)); + ABSL_CHECK_OK(UpdateNodeConfig(*node, out_node)); } for (const std::unique_ptr& node : packet_gens_) { auto* out_node = config.add_packet_generator(); - CHECK_OK(UpdateNodeConfig(*node, out_node)); + ABSL_CHECK_OK(UpdateNodeConfig(*node, out_node)); } return config; } @@ -772,7 +783,7 @@ class Graph { config->set_calculator(node.type_); node.in_streams_.Visit( [&](const TagIndexLocation& loc, const DestinationBase& endpoint) { - CHECK(endpoint.source != nullptr); + ABSL_CHECK(endpoint.source != nullptr); config->add_input_stream(TaggedName(loc, endpoint.source->name_)); }); node.out_streams_.Visit( @@ -781,7 +792,7 @@ class Graph { }); node.in_sides_.Visit([&](const TagIndexLocation& loc, const DestinationBase& endpoint) { - CHECK(endpoint.source != nullptr); + ABSL_CHECK(endpoint.source != nullptr); config->add_input_side_packet(TaggedName(loc, endpoint.source->name_)); }); node.out_sides_.Visit( @@ -802,7 +813,7 @@ class Graph { config->set_packet_generator(node.type_); node.in_sides_.Visit([&](const TagIndexLocation& loc, const DestinationBase& endpoint) { - CHECK(endpoint.source != nullptr); + ABSL_CHECK(endpoint.source != nullptr); config->add_input_side_packet(TaggedName(loc, endpoint.source->name_)); }); node.out_sides_.Visit( @@ -819,7 +830,7 @@ class Graph { absl::Status UpdateBoundaryConfig(CalculatorGraphConfig* config) { graph_boundary_.in_streams_.Visit( [&](const TagIndexLocation& loc, const DestinationBase& endpoint) { - CHECK(endpoint.source != nullptr); + ABSL_CHECK(endpoint.source != nullptr); config->add_output_stream(TaggedName(loc, endpoint.source->name_)); }); graph_boundary_.out_streams_.Visit( @@ -828,7 +839,7 @@ class Graph { }); graph_boundary_.in_sides_.Visit([&](const TagIndexLocation& loc, const DestinationBase& endpoint) { - CHECK(endpoint.source != nullptr); + ABSL_CHECK(endpoint.source != nullptr); config->add_output_side_packet(TaggedName(loc, endpoint.source->name_)); }); graph_boundary_.out_sides_.Visit( diff --git a/mediapipe/framework/api2/node.h b/mediapipe/framework/api2/node.h index 14c098246..58cebf1ea 100644 --- a/mediapipe/framework/api2/node.h +++ b/mediapipe/framework/api2/node.h @@ -64,58 +64,13 @@ class CalculatorBaseFactoryFor< namespace api2 { namespace internal { -// Defining a member of this type causes P to be ODR-used, which forces its -// instantiation if it's a static member of a template. -// Previously we depended on the pointer's value to determine whether the size -// of a character array is 0 or 1, forcing it to be instantiated so the -// compiler can determine the object's layout. But using it as a template -// argument is more compact. -template -struct ForceStaticInstantiation { -#ifdef _MSC_VER - // Just having it as the template argument does not count as a use for - // MSVC. - static constexpr bool Use() { return P != nullptr; } - char force_static[Use()]; -#endif // _MSC_VER -}; +MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE( + NodeRegistrator, mediapipe::CalculatorBaseRegistry, T::kCalculatorName, + absl::make_unique>) -// Helper template for forcing the definition of a static registration token. -template -struct NodeRegistrationStatic { - static NoDestructor registration; - - static mediapipe::RegistrationToken Make() { - return mediapipe::CalculatorBaseRegistry::Register( - T::kCalculatorName, - absl::make_unique>, - __FILE__, __LINE__); - } - - using RequireStatics = ForceStaticInstantiation<®istration>; -}; - -// Static members of template classes can be defined in the header. -template -NoDestructor - NodeRegistrationStatic::registration(NodeRegistrationStatic::Make()); - -template -struct SubgraphRegistrationImpl { - static NoDestructor registration; - - static mediapipe::RegistrationToken Make() { - return mediapipe::SubgraphRegistry::Register( - T::kCalculatorName, absl::make_unique, __FILE__, __LINE__); - } - - using RequireStatics = ForceStaticInstantiation<®istration>; -}; - -template -NoDestructor - SubgraphRegistrationImpl::registration( - SubgraphRegistrationImpl::Make()); +MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(SubgraphRegistrator, + mediapipe::SubgraphRegistry, + T::kCalculatorName, absl::make_unique) } // namespace internal @@ -128,14 +83,7 @@ template class RegisteredNode; template -class RegisteredNode : public Node { - private: - // The member below triggers instantiation of the registration static. - // Note that the constructor of calculator subclasses is only invoked through - // the registration token, and so we cannot simply use the static in the - // constructor. - typename internal::NodeRegistrationStatic::RequireStatics register_; -}; +class RegisteredNode : public Node, private internal::NodeRegistrator {}; // No-op version for backwards compatibility. template <> @@ -217,31 +165,27 @@ class NodeImpl : public RegisteredNode, public Intf { // TODO: verify that the subgraph config fully implements the // declared interface. template -class SubgraphImpl : public Subgraph, public Intf { - private: - typename internal::SubgraphRegistrationImpl::RequireStatics register_; -}; +class SubgraphImpl : public Subgraph, + public Intf, + private internal::SubgraphRegistrator {}; // This macro is used to register a calculator that does not use automatic // registration. Deprecated. -#define MEDIAPIPE_NODE_IMPLEMENTATION(Impl) \ - static mediapipe::NoDestructor \ - REGISTRY_STATIC_VAR(calculator_registration, \ - __LINE__)(mediapipe::CalculatorBaseRegistry::Register( \ - Impl::kCalculatorName, \ - absl::make_unique>, \ - __FILE__, __LINE__)) +#define MEDIAPIPE_NODE_IMPLEMENTATION(Impl) \ + MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED( \ + mediapipe::CalculatorBaseRegistry, calculator_registration, \ + Impl::kCalculatorName, \ + absl::make_unique>) // This macro is used to register a non-split-contract calculator. Deprecated. #define MEDIAPIPE_REGISTER_NODE(name) REGISTER_CALCULATOR(name) // This macro is used to define a subgraph that does not use automatic // registration. Deprecated. -#define MEDIAPIPE_SUBGRAPH_IMPLEMENTATION(Impl) \ - static mediapipe::NoDestructor \ - REGISTRY_STATIC_VAR(subgraph_registration, \ - __LINE__)(mediapipe::SubgraphRegistry::Register( \ - Impl::kCalculatorName, absl::make_unique, __FILE__, __LINE__)) +#define MEDIAPIPE_SUBGRAPH_IMPLEMENTATION(Impl) \ + MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED( \ + mediapipe::SubgraphRegistry, subgraph_registration, \ + Impl::kCalculatorName, absl::make_unique) } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/framework/api2/node_test.cc b/mediapipe/framework/api2/node_test.cc index a6c1ef7c6..ac1ca6015 100644 --- a/mediapipe/framework/api2/node_test.cc +++ b/mediapipe/framework/api2/node_test.cc @@ -3,6 +3,7 @@ #include #include +#include "absl/log/absl_log.h" #include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/test_contracts.h" @@ -19,8 +20,6 @@ namespace mediapipe { namespace api2 { namespace test { -using testing::ElementsAre; - // Returns the packet values for a vector of Packets. template std::vector PacketValues(const std::vector& packets) { @@ -572,7 +571,7 @@ struct LogSinkNode : public Node { MEDIAPIPE_NODE_CONTRACT(kIn); absl::Status Process(CalculatorContext* cc) override { - LOG(INFO) << "LogSinkNode received: " << kIn(cc).Get(); + ABSL_LOG(INFO) << "LogSinkNode received: " << kIn(cc).Get(); return {}; } }; diff --git a/mediapipe/framework/api2/packet.h b/mediapipe/framework/api2/packet.h index b1ebb0410..f231f4c80 100644 --- a/mediapipe/framework/api2/packet.h +++ b/mediapipe/framework/api2/packet.h @@ -13,6 +13,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/meta/type_traits.h" #include "mediapipe/framework/api2/tuple.h" #include "mediapipe/framework/packet.h" @@ -102,9 +103,9 @@ mediapipe::Packet ToOldPacket(PacketBase&& p); template inline const T& PacketBase::Get() const { - CHECK(payload_); + ABSL_CHECK(payload_); packet_internal::Holder* typed_payload = payload_->As(); - CHECK(typed_payload) << absl::StrCat( + ABSL_CHECK(typed_payload) << absl::StrCat( "The Packet stores \"", payload_->DebugTypeName(), "\", but \"", MediaPipeTypeStringOrDemangled(), "\" was requested."); return typed_payload->data(); @@ -134,17 +135,17 @@ namespace internal { template inline void CheckCompatibleType(const HolderBase& holder, internal::Wrap) { const packet_internal::Holder* typed_payload = holder.As(); - CHECK(typed_payload) << absl::StrCat( + ABSL_CHECK(typed_payload) << absl::StrCat( "The Packet stores \"", holder.DebugTypeName(), "\", but \"", MediaPipeTypeStringOrDemangled(), "\" was requested."); - // CHECK(payload_->has_type()); + // ABSL_CHECK(payload_->has_type()); } template inline void CheckCompatibleType(const HolderBase& holder, internal::Wrap>) { bool compatible = (holder.As() || ...); - CHECK(compatible) + ABSL_CHECK(compatible) << "The Packet stores \"" << holder.DebugTypeName() << "\", but one of " << absl::StrJoin( {absl::StrCat("\"", MediaPipeTypeStringOrDemangled(), "\"")...}, @@ -165,7 +166,7 @@ template struct IsCompatibleType> : std::integral_constant || ...)> {}; -}; // namespace internal +} // namespace internal template inline Packet PacketBase::As() const { @@ -211,9 +212,9 @@ class Packet : public Packet { Packet At(Timestamp timestamp) &&; const T& Get() const { - CHECK(payload_); + ABSL_CHECK(payload_); packet_internal::Holder* typed_payload = payload_->As(); - CHECK(typed_payload); + ABSL_CHECK(typed_payload); return typed_payload->data(); } const T& operator*() const { return Get(); } @@ -259,19 +260,19 @@ struct First { template struct AddStatus { - using type = StatusOr; + using type = absl::StatusOr; }; template -struct AddStatus> { - using type = StatusOr; +struct AddStatus> { + using type = absl::StatusOr; }; template <> -struct AddStatus { - using type = Status; +struct AddStatus { + using type = absl::Status; }; template <> struct AddStatus { - using type = Status; + using type = absl::Status; }; template @@ -282,7 +283,7 @@ struct CallAndAddStatusImpl { }; template struct CallAndAddStatusImpl { - Status operator()(const F& f, A&&... a) { + absl::Status operator()(const F& f, A&&... a) { f(std::forward(a)...); return {}; } @@ -330,9 +331,9 @@ class Packet> : public PacketBase { template > const U& Get() const { - CHECK(payload_); + ABSL_CHECK(payload_); packet_internal::Holder* typed_payload = payload_->As(); - CHECK(typed_payload); + ABSL_CHECK(typed_payload); return typed_payload->data(); } @@ -343,7 +344,7 @@ class Packet> : public PacketBase { template auto Visit(const F&... args) const { - CHECK(payload_); + ABSL_CHECK(payload_); auto f = internal::Overload{args...}; using FirstT = typename internal::First::type; using ResultType = absl::result_of_t; @@ -364,7 +365,7 @@ class Packet> : public PacketBase { template auto ConsumeAndVisit(const F&... args) { - CHECK(payload_); + ABSL_CHECK(payload_); auto f = internal::Overload{args...}; using FirstT = typename internal::First::type; using VisitorResultType = diff --git a/mediapipe/framework/api2/port.h b/mediapipe/framework/api2/port.h index f6abe75ed..075e88437 100644 --- a/mediapipe/framework/api2/port.h +++ b/mediapipe/framework/api2/port.h @@ -20,6 +20,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/api2/const_str.h" @@ -243,8 +244,8 @@ class MultiplePortAccess { // container? int Count() { return count_; } AccessT operator[](int pos) { - CHECK_GE(pos, 0); - CHECK_LT(pos, count_); + ABSL_CHECK_GE(pos, 0); + ABSL_CHECK_LT(pos, count_); return SinglePortAccess(cc_, &first_[pos]); } @@ -467,6 +468,11 @@ class SideFallbackT : public Base { // CalculatorContext (e.g. kOut(cc)), and provides a type-safe interface to // OutputStreamShard. Like that class, this class will not be usually named in // calculator code, but used as a temporary object (e.g. kOut(cc).Send(...)). +// +// If not connected (!IsConnected()) SetNextTimestampBound is safe to call and +// does nothing. +// All the sub-classes that define Send should implement it to be safe to to +// call if not connected and do nothing in such case. class OutputShardAccessBase { public: OutputShardAccessBase(const CalculatorContext& cc, OutputStreamShard* output) diff --git a/mediapipe/framework/api2/stream/BUILD b/mediapipe/framework/api2/stream/BUILD new file mode 100644 index 000000000..f59a65d95 --- /dev/null +++ b/mediapipe/framework/api2/stream/BUILD @@ -0,0 +1,139 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +cc_library( + name = "landmarks_to_detection", + srcs = ["landmarks_to_detection.cc"], + hdrs = ["landmarks_to_detection.h"], + deps = [ + "//mediapipe/calculators/util:landmarks_to_detection_calculator", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", + ], +) + +cc_test( + name = "landmarks_to_detection_test", + srcs = ["landmarks_to_detection_test.cc"], + deps = [ + ":landmarks_to_detection", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:status_matchers", + ], +) + +cc_library( + name = "landmarks_projection", + srcs = ["landmarks_projection.cc"], + hdrs = ["landmarks_projection.h"], + deps = [ + "//mediapipe/calculators/util:landmark_projection_calculator", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:landmark_cc_proto", + ], +) + +cc_test( + name = "landmarks_projection_test", + srcs = ["landmarks_projection_test.cc"], + deps = [ + ":landmarks_projection", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:status_matchers", + ], +) + +cc_library( + name = "loopback", + hdrs = ["loopback.h"], + deps = [ + "//mediapipe/calculators/core:previous_loopback_calculator", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + ], +) + +cc_test( + name = "loopback_test", + srcs = ["loopback_test.cc"], + deps = [ + ":loopback", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status_matchers", + ], +) + +cc_library( + name = "image_size", + hdrs = ["image_size.h"], + deps = [ + "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/gpu:gpu_buffer", + ], +) + +cc_test( + name = "image_size_test", + srcs = ["image_size_test.cc"], + deps = [ + ":image_size", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status_matchers", + "//mediapipe/gpu:gpu_buffer", + ], +) + +cc_library( + name = "rect_transformation", + srcs = ["rect_transformation.cc"], + hdrs = ["rect_transformation.h"], + deps = [ + "//mediapipe/calculators/util:rect_transformation_calculator", + "//mediapipe/calculators/util:rect_transformation_calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:rect_cc_proto", + "@com_google_absl//absl/types:optional", + ], +) + +cc_test( + name = "rect_transformation_test", + srcs = ["rect_transformation_test.cc"], + deps = [ + ":rect_transformation", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + ], +) diff --git a/mediapipe/framework/api2/stream/image_size.h b/mediapipe/framework/api2/stream/image_size.h new file mode 100644 index 000000000..b726f07a9 --- /dev/null +++ b/mediapipe/framework/api2/stream/image_size.h @@ -0,0 +1,34 @@ +#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_IMAGE_SIZE_H_ +#define MEDIAPIPE_FRAMEWORK_API2_STREAM_IMAGE_SIZE_H_ + +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/gpu/gpu_buffer.h" + +namespace mediapipe::api2::builder { + +// Updates graph to calculate image size and returns corresponding stream. +// +// @image image represented as ImageFrame/Image/GpuBuffer. +// @graph graph to update. +template +Stream> GetImageSize( + Stream image, mediapipe::api2::builder::Graph& graph) { + auto& img_props_node = graph.AddNode("ImagePropertiesCalculator"); + if constexpr (std::is_same_v || + std::is_same_v) { + image.ConnectTo(img_props_node.In("IMAGE")); + } else if constexpr (std::is_same_v) { + image.ConnectTo(img_props_node.In("IMAGE_GPU")); + } else { + static_assert(dependent_false::value, "Type not supported."); + } + return img_props_node.Out("SIZE").Cast>(); +} + +} // namespace mediapipe::api2::builder + +#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_IMAGE_SIZE_H_ diff --git a/mediapipe/framework/api2/stream/image_size_test.cc b/mediapipe/framework/api2/stream/image_size_test.cc new file mode 100644 index 000000000..3b080ba02 --- /dev/null +++ b/mediapipe/framework/api2/stream/image_size_test.cc @@ -0,0 +1,57 @@ +#include "mediapipe/framework/api2/stream/image_size.h" + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/gpu/gpu_buffer.h" + +namespace mediapipe::api2::builder { +namespace { + +TEST(GetImageSize, VerifyConfig) { + Graph graph; + + Stream image_frame = graph.In("IMAGE_FRAME").Cast(); + image_frame.SetName("image_frame"); + Stream gpu_buffer = graph.In("GPU_BUFFER").Cast(); + gpu_buffer.SetName("gpu_buffer"); + Stream image = graph.In("IMAGE").Cast(); + image.SetName("image"); + + GetImageSize(image_frame, graph).SetName("image_frame_size"); + GetImageSize(gpu_buffer, graph).SetName("gpu_buffer_size"); + GetImageSize(image, graph).SetName("image_size"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE:image_frame" + output_stream: "SIZE:image_frame_size" + } + node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE_GPU:gpu_buffer" + output_stream: "SIZE:gpu_buffer_size" + } + node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE:image" + output_stream: "SIZE:image_size" + } + input_stream: "GPU_BUFFER:gpu_buffer" + input_stream: "IMAGE:image" + input_stream: "IMAGE_FRAME:image_frame" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} +} // namespace +} // namespace mediapipe::api2::builder diff --git a/mediapipe/framework/api2/stream/landmarks_projection.cc b/mediapipe/framework/api2/stream/landmarks_projection.cc new file mode 100644 index 000000000..1735dc1d6 --- /dev/null +++ b/mediapipe/framework/api2/stream/landmarks_projection.cc @@ -0,0 +1,20 @@ +#include "mediapipe/framework/api2/stream/landmarks_projection.h" + +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe::api2::builder { + +Stream ProjectLandmarks( + Stream landmarks, + Stream> projection_matrix, Graph& graph) { + auto& projector = graph.AddNode("LandmarkProjectionCalculator"); + landmarks.ConnectTo(projector.In("NORM_LANDMARKS")); + projection_matrix.ConnectTo(projector.In("PROJECTION_MATRIX")); + return projector.Out("NORM_LANDMARKS") + .Cast(); +} + +} // namespace mediapipe::api2::builder diff --git a/mediapipe/framework/api2/stream/landmarks_projection.h b/mediapipe/framework/api2/stream/landmarks_projection.h new file mode 100644 index 000000000..3a9508a45 --- /dev/null +++ b/mediapipe/framework/api2/stream/landmarks_projection.h @@ -0,0 +1,23 @@ +#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_LANDMARKS_PROJECTION_H_ +#define MEDIAPIPE_FRAMEWORK_API2_STREAM_LANDMARKS_PROJECTION_H_ + +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe::api2::builder { + +// Updates @graph to project predicted @landmarks back to the original @image +// based on @projection_matrix +// +// @landmarks - landmarks (NormalizedLandmarkList) stream, output from the model +// @projection_matrix - matrix that stores the preprocessing information +// @graph - mediapipe graph to update. +Stream ProjectLandmarks( + Stream landmarks, + Stream> projection_matrix, Graph& graph); + +} // namespace mediapipe::api2::builder + +#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_LANDMARKS_PROJECTION_H_ diff --git a/mediapipe/framework/api2/stream/landmarks_projection_test.cc b/mediapipe/framework/api2/stream/landmarks_projection_test.cc new file mode 100644 index 000000000..2f743d808 --- /dev/null +++ b/mediapipe/framework/api2/stream/landmarks_projection_test.cc @@ -0,0 +1,45 @@ +#include "mediapipe/framework/api2/stream/landmarks_projection.h" + +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe::api2::builder { +namespace { + +TEST(LandmarksProjection, ProjectLandmarks) { + mediapipe::api2::builder::Graph graph; + + Stream landmarks = + graph.In("NORM_LANDMARKS").Cast(); + Stream> projection_matrix = + graph.In("PROJECTION_MATRIX").Cast>(); + Stream result = + ProjectLandmarks(landmarks, projection_matrix, graph); + result.SetName("landmarks_value"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "LandmarkProjectionCalculator" + input_stream: "NORM_LANDMARKS:__stream_0" + input_stream: "PROJECTION_MATRIX:__stream_1" + output_stream: "NORM_LANDMARKS:landmarks_value" + } + input_stream: "NORM_LANDMARKS:__stream_0" + input_stream: "PROJECTION_MATRIX:__stream_1" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +} // namespace +} // namespace mediapipe::api2::builder diff --git a/mediapipe/framework/api2/stream/landmarks_to_detection.cc b/mediapipe/framework/api2/stream/landmarks_to_detection.cc new file mode 100644 index 000000000..99e576ba4 --- /dev/null +++ b/mediapipe/framework/api2/stream/landmarks_to_detection.cc @@ -0,0 +1,17 @@ +#include "mediapipe/framework/api2/stream/landmarks_to_detection.h" + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe::api2::builder { + +Stream ConvertLandmarksToDetection( + Stream landmarks, Graph& graph) { + auto& landmarks_to_detection = + graph.AddNode("LandmarksToDetectionCalculator"); + landmarks.ConnectTo(landmarks_to_detection.In("NORM_LANDMARKS")); + return landmarks_to_detection.Out("DETECTION").Cast(); +} + +} // namespace mediapipe::api2::builder diff --git a/mediapipe/framework/api2/stream/landmarks_to_detection.h b/mediapipe/framework/api2/stream/landmarks_to_detection.h new file mode 100644 index 000000000..0f0004b16 --- /dev/null +++ b/mediapipe/framework/api2/stream/landmarks_to_detection.h @@ -0,0 +1,16 @@ +#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_LANDMARKS_TO_DETECTION_H_ +#define MEDIAPIPE_FRAMEWORK_API2_STREAM_LANDMARKS_TO_DETECTION_H_ + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe::api2::builder { + +// Updates @graph to convert @landmarks to a detection. +Stream ConvertLandmarksToDetection( + Stream landmarks, Graph& graph); + +} // namespace mediapipe::api2::builder + +#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_LANDMARKS_TO_DETECTION_H_ diff --git a/mediapipe/framework/api2/stream/landmarks_to_detection_test.cc b/mediapipe/framework/api2/stream/landmarks_to_detection_test.cc new file mode 100644 index 000000000..8bd545306 --- /dev/null +++ b/mediapipe/framework/api2/stream/landmarks_to_detection_test.cc @@ -0,0 +1,35 @@ +#include "mediapipe/framework/api2/stream/landmarks_to_detection.h" + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe::api2::builder { +namespace { + +TEST(LandmarksToDetection, VerifyConfig) { + mediapipe::api2::builder::Graph graph; + + Stream landmarks = + graph.In("LANDMARKS").Cast(); + Stream detection = ConvertLandmarksToDetection(landmarks, graph); + detection.SetName("detection"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "LandmarksToDetectionCalculator" + input_stream: "NORM_LANDMARKS:__stream_0" + output_stream: "DETECTION:detection" + } + input_stream: "LANDMARKS:__stream_0" + )pb"))); +} + +} // namespace +} // namespace mediapipe::api2::builder diff --git a/mediapipe/framework/api2/stream/loopback.h b/mediapipe/framework/api2/stream/loopback.h new file mode 100644 index 000000000..3ad2f0a2d --- /dev/null +++ b/mediapipe/framework/api2/stream/loopback.h @@ -0,0 +1,55 @@ +#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_LOOPBACK_H_ +#define MEDIAPIPE_FRAMEWORK_API2_STREAM_LOOPBACK_H_ + +#include +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" + +namespace mediapipe::api2::builder { + +// Returns a pair of two values: +// - A stream with loopback data. Such stream, for each new packet in @tick +// stream, provides a packet previously calculated within the graph. +// - A function to define/set loopback data producing stream. +// NOTE: +// * function must be called and only once, otherwise graph validation will +// fail. +// * calling function after graph is destroyed results in undefined behavior +// +// The function wraps `PreviousLoopbackCalculator` into a convenience function +// and allows graph input to be processed together with some previous output. +// +// ------- +// +// Example: +// +// ``` +// +// Graph graph; +// Stream<...> tick = ...; // E.g. main input can surve as a tick. +// auto [prev_data, set_loopback_fn] = GetLoopbackData(tick, graph); +// ... +// Stream data = ...; +// set_loopback_fn(data); +// +// ``` +template +std::pair, std::function)>> GetLoopbackData( + Stream tick, mediapipe::api2::builder::Graph& graph) { + auto& prev = graph.AddNode("PreviousLoopbackCalculator"); + tick.ConnectTo(prev.In("MAIN")); + return {prev.Out("PREV_LOOP").template Cast(), + [prev_ptr = &prev](Stream data) { + // TODO: input stream info must be specified, but + // builder api doesn't support it at the moment. As a workaround, + // input stream info is added by GraphBuilder as a graph building + // post processing step. + data.ConnectTo(prev_ptr->In("LOOP")); + }}; +} + +} // namespace mediapipe::api2::builder + +#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_LOOPBACK_H_ diff --git a/mediapipe/framework/api2/stream/loopback_test.cc b/mediapipe/framework/api2/stream/loopback_test.cc new file mode 100644 index 000000000..50c3041e2 --- /dev/null +++ b/mediapipe/framework/api2/stream/loopback_test.cc @@ -0,0 +1,56 @@ +#include "mediapipe/framework/api2/stream/loopback.h" + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe::api2::builder { +namespace { + +class TestDataProducer : public NodeIntf { + public: + static constexpr Input kLoopbackData{"LOOPBACK_DATA"}; + static constexpr Output kProducedData{"PRODUCED_DATA"}; + MEDIAPIPE_NODE_INTERFACE(TestDataProducer, kLoopbackData, kProducedData); +}; + +TEST(LoopbackTest, GetLoopbackData) { + Graph graph; + + Stream tick = graph.In("TICK").Cast(); + + auto [data, set_loopback_data_fn] = GetLoopbackData(tick, graph); + + auto& producer = graph.AddNode(); + data.ConnectTo(producer[TestDataProducer::kLoopbackData]); + Stream data_to_loopback(producer[TestDataProducer::kProducedData]); + + set_loopback_data_fn(data_to_loopback); + + // PreviousLoopbackCalculator configuration is incorrect here and should be + // updated when corresponding b/175887687 is fixed. + // Use mediapipe::aimatter::GraphBuilder to fix back edges in the graph. + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "PreviousLoopbackCalculator" + input_stream: "LOOP:__stream_2" + input_stream: "MAIN:__stream_0" + output_stream: "PREV_LOOP:__stream_1" + } + node { + calculator: "TestDataProducer" + input_stream: "LOOPBACK_DATA:__stream_1" + output_stream: "PRODUCED_DATA:__stream_2" + } + input_stream: "TICK:__stream_0" + )pb"))); +} + +} // namespace +} // namespace mediapipe::api2::builder diff --git a/mediapipe/framework/api2/stream/rect_transformation.cc b/mediapipe/framework/api2/stream/rect_transformation.cc new file mode 100644 index 000000000..3e63375fc --- /dev/null +++ b/mediapipe/framework/api2/stream/rect_transformation.cc @@ -0,0 +1,108 @@ +#include "mediapipe/framework/api2/stream/rect_transformation.h" + +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/rect.pb.h" + +namespace mediapipe::api2::builder { + +namespace { + +using ::mediapipe::NormalizedRect; +using ::mediapipe::api2::builder::GenericNode; +using ::mediapipe::api2::builder::Graph; + +template +Stream InternalScaleAndShift( + Stream transformee, Stream> image_size, + float scale_x_factor, float scale_y_factor, std::optional shift_x, + std::optional shift_y, bool square_long, Graph& graph) { + auto& node = graph.AddNode("RectTransformationCalculator"); + auto& node_opts = + node.GetOptions(); + node_opts.set_scale_x(scale_x_factor); + node_opts.set_scale_y(scale_y_factor); + if (shift_x) { + node_opts.set_shift_x(shift_x.value()); + } + if (shift_y) { + node_opts.set_shift_y(shift_y.value()); + } + if (square_long) { + node_opts.set_square_long(square_long); + } + image_size.ConnectTo(node.In("IMAGE_SIZE")); + if constexpr (std::is_same_v>) { + transformee.ConnectTo(node.In("NORM_RECTS")); + } else if constexpr (std::is_same_v) { + transformee.ConnectTo(node.In("NORM_RECT")); + } else { + static_assert(dependent_false::value, "Unsupported type."); + } + return node.Out("").template Cast(); +} + +} // namespace + +Stream ScaleAndMakeSquare( + Stream rect, Stream> image_size, + float scale_x_factor, float scale_y_factor, Graph& graph) { + return InternalScaleAndShift(rect, image_size, scale_x_factor, scale_y_factor, + /*shift_x=*/std::nullopt, + /*shift_y=*/std::nullopt, + /*square_long=*/true, graph); +} + +Stream Scale(Stream rect, + Stream> image_size, + float scale_x_factor, float scale_y_factor, + Graph& graph) { + return InternalScaleAndShift(rect, image_size, scale_x_factor, scale_y_factor, + /*shift_x=*/std::nullopt, + /*shift_y=*/std::nullopt, + /*square_long=*/false, graph); +} + +Stream> ScaleAndShiftAndMakeSquareLong( + Stream> rects, + Stream> image_size, float scale_x_factor, + float scale_y_factor, float shift_x, float shift_y, Graph& graph) { + return InternalScaleAndShift(rects, image_size, scale_x_factor, + scale_y_factor, shift_x, shift_y, + /*square_long=*/true, graph); +} + +Stream> ScaleAndShift( + Stream> rects, + Stream> image_size, float scale_x_factor, + float scale_y_factor, float shift_x, float shift_y, Graph& graph) { + return InternalScaleAndShift(rects, image_size, scale_x_factor, + scale_y_factor, shift_x, shift_y, + /*square_long=*/false, graph); +} + +Stream ScaleAndShiftAndMakeSquareLong( + Stream rect, Stream> image_size, + float scale_x_factor, float scale_y_factor, float shift_x, float shift_y, + Graph& graph) { + return InternalScaleAndShift(rect, image_size, scale_x_factor, scale_y_factor, + shift_x, shift_y, + /*square_long=*/true, graph); +} + +Stream ScaleAndShift(Stream rect, + Stream> image_size, + float scale_x_factor, float scale_y_factor, + float shift_x, float shift_y, + Graph& graph) { + return InternalScaleAndShift(rect, image_size, scale_x_factor, scale_y_factor, + shift_x, shift_y, /*square_long=*/false, graph); +} + +} // namespace mediapipe::api2::builder diff --git a/mediapipe/framework/api2/stream/rect_transformation.h b/mediapipe/framework/api2/stream/rect_transformation.h new file mode 100644 index 000000000..9f6a98980 --- /dev/null +++ b/mediapipe/framework/api2/stream/rect_transformation.h @@ -0,0 +1,67 @@ +#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_RECT_TRANSFORMATION_H_ +#define MEDIAPIPE_FRAMEWORK_API2_STREAM_RECT_TRANSFORMATION_H_ + +#include +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/rect.pb.h" + +namespace mediapipe::api2::builder { + +// Updates @graph to scale @rect according to passed parameters. +Stream Scale(Stream rect, + Stream> image_size, + float scale_x_factor, + float scale_y_factor, + mediapipe::api2::builder::Graph& graph); + +// Updates @graph to scale @rect according to passed parameters and make it a +// square that has the same center and rotation, and with the side of the square +// equal to the long side of the rect. +// +// TODO: consider removing after migrating to `Scale`. +Stream ScaleAndMakeSquare( + Stream rect, + Stream> image_size, float scale_x_factor, + float scale_y_factor, mediapipe::api2::builder::Graph& graph); + +// Updates @graph to scale and shift vector of @rects according to parameters. +Stream> ScaleAndShift( + Stream> rects, + Stream> image_size, float scale_x_factor, + float scale_y_factor, float shift_x, float shift_y, + mediapipe::api2::builder::Graph& graph); + +// Updates @graph to scale and shift vector of @rects according to passed +// parameters and make each a square that has the same center and rotation, and +// with the side of the square equal to the long side of a particular rect. +// +// TODO: consider removing after migrating to `ScaleAndShift`. +Stream> ScaleAndShiftAndMakeSquareLong( + Stream> rects, + Stream> image_size, float scale_x_factor, + float scale_y_factor, float shift_x, float shift_y, + mediapipe::api2::builder::Graph& graph); + +// Updates @graph to scale, shift @rect according to passed parameters. +Stream ScaleAndShift( + Stream rect, + Stream> image_size, float scale_x_factor, + float scale_y_factor, float shift_x, float shift_y, + mediapipe::api2::builder::Graph& graph); + +// Updates @graph to scale and shift @rect according to passed parameters and +// make it a square that has the same center and rotation, and with the side of +// the square equal to the long side of the rect. +// +// TODO: consider removing after migrating to `ScaleAndShift`. +Stream ScaleAndShiftAndMakeSquareLong( + Stream rect, + Stream> image_size, float scale_x_factor, + float scale_y_factor, float shift_x, float shift_y, + mediapipe::api2::builder::Graph& graph); + +} // namespace mediapipe::api2::builder + +#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_RECT_TRANSFORMATION_H_ diff --git a/mediapipe/framework/api2/stream/rect_transformation_test.cc b/mediapipe/framework/api2/stream/rect_transformation_test.cc new file mode 100644 index 000000000..79fa66175 --- /dev/null +++ b/mediapipe/framework/api2/stream/rect_transformation_test.cc @@ -0,0 +1,217 @@ +#include "mediapipe/framework/api2/stream/rect_transformation.h" + +#include +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" + +namespace mediapipe::api2::builder { + +namespace { + +using ::mediapipe::NormalizedRect; + +TEST(RectTransformation, ScaleAndMakeSquare) { + mediapipe::api2::builder::Graph graph; + + Stream rect = graph.In("RECT").Cast(); + Stream> size = + graph.In("SIZE").Cast>(); + Stream transformed_rect = ScaleAndMakeSquare( + rect, size, /*scale_x_factor=*/2, /*scale_y_factor=*/7, graph); + transformed_rect.SetName("transformed_rect"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "RectTransformationCalculator" + input_stream: "IMAGE_SIZE:__stream_1" + input_stream: "NORM_RECT:__stream_0" + output_stream: "transformed_rect" + options { + [mediapipe.RectTransformationCalculatorOptions.ext] { + scale_x: 2 + scale_y: 7 + square_long: true + } + } + } + input_stream: "RECT:__stream_0" + input_stream: "SIZE:__stream_1" + )pb"))); +} + +TEST(RectTransformation, Scale) { + mediapipe::api2::builder::Graph graph; + + Stream rect = graph.In("RECT").Cast(); + Stream> size = + graph.In("SIZE").Cast>(); + Stream transformed_rect = + Scale(rect, size, /*scale_x_factor=*/2, /*scale_y_factor=*/7, graph); + transformed_rect.SetName("transformed_rect"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "RectTransformationCalculator" + input_stream: "IMAGE_SIZE:__stream_1" + input_stream: "NORM_RECT:__stream_0" + output_stream: "transformed_rect" + options { + [mediapipe.RectTransformationCalculatorOptions.ext] { + scale_x: 2 + scale_y: 7 + } + } + } + input_stream: "RECT:__stream_0" + input_stream: "SIZE:__stream_1" + )pb"))); +} + +TEST(RectTransformation, ScaleAndShift) { + mediapipe::api2::builder::Graph graph; + + Stream rect = graph.In("RECT").Cast(); + Stream> size = + graph.In("SIZE").Cast>(); + Stream transformed_rect = + ScaleAndShift(rect, size, /*scale_x_factor=*/2, /*scale_y_factor=*/7, + /*shift_x=*/10, /*shift_y=*/0.5f, graph); + transformed_rect.SetName("transformed_rect"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "RectTransformationCalculator" + input_stream: "IMAGE_SIZE:__stream_1" + input_stream: "NORM_RECT:__stream_0" + output_stream: "transformed_rect" + options { + [mediapipe.RectTransformationCalculatorOptions.ext] { + scale_x: 2 + scale_y: 7 + shift_x: 10 + shift_y: 0.5 + } + } + } + input_stream: "RECT:__stream_0" + input_stream: "SIZE:__stream_1" + )pb"))); +} + +TEST(RectTransformation, ScaleAndShiftAndMakeSquareLong) { + mediapipe::api2::builder::Graph graph; + + Stream rect = graph.In("RECT").Cast(); + Stream> size = + graph.In("SIZE").Cast>(); + Stream transformed_rect = ScaleAndShiftAndMakeSquareLong( + rect, size, /*scale_x_factor=*/2, /*scale_y_factor=*/7, + /*shift_x=*/10, /*shift_y=*/0.5f, graph); + transformed_rect.SetName("transformed_rect"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "RectTransformationCalculator" + input_stream: "IMAGE_SIZE:__stream_1" + input_stream: "NORM_RECT:__stream_0" + output_stream: "transformed_rect" + options { + [mediapipe.RectTransformationCalculatorOptions.ext] { + scale_x: 2 + scale_y: 7 + shift_x: 10 + shift_y: 0.5 + square_long: true + } + } + } + input_stream: "RECT:__stream_0" + input_stream: "SIZE:__stream_1" + )pb"))); +} + +TEST(RectTransformation, ScaleAndShiftMultipleRects) { + mediapipe::api2::builder::Graph graph; + + Stream> rects = + graph.In("RECTS").Cast>(); + Stream> size = + graph.In("SIZE").Cast>(); + Stream> transformed_rects = + ScaleAndShift(rects, size, /*scale_x_factor=*/2, /*scale_y_factor=*/7, + /*shift_x=*/10, /*shift_y=*/0.5f, graph); + transformed_rects.SetName("transformed_rects"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "RectTransformationCalculator" + input_stream: "IMAGE_SIZE:__stream_1" + input_stream: "NORM_RECTS:__stream_0" + output_stream: "transformed_rects" + options { + [mediapipe.RectTransformationCalculatorOptions.ext] { + scale_x: 2 + scale_y: 7 + shift_x: 10 + shift_y: 0.5 + } + } + } + input_stream: "RECTS:__stream_0" + input_stream: "SIZE:__stream_1" + )pb"))); +} + +TEST(RectTransformation, ScaleAndShiftAndMakeSquareLongMultipleRects) { + mediapipe::api2::builder::Graph graph; + + Stream> rects = + graph.In("RECTS").Cast>(); + Stream> size = + graph.In("SIZE").Cast>(); + Stream> transformed_rects = + ScaleAndShiftAndMakeSquareLong(rects, size, /*scale_x_factor=*/2, + /*scale_y_factor=*/7, + /*shift_x=*/10, /*shift_y=*/0.5f, graph); + transformed_rects.SetName("transformed_rects"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "RectTransformationCalculator" + input_stream: "IMAGE_SIZE:__stream_1" + input_stream: "NORM_RECTS:__stream_0" + output_stream: "transformed_rects" + options { + [mediapipe.RectTransformationCalculatorOptions.ext] { + scale_x: 2 + scale_y: 7 + shift_x: 10 + shift_y: 0.5 + square_long: true + } + } + } + input_stream: "RECTS:__stream_0" + input_stream: "SIZE:__stream_1" + )pb"))); +} + +} // namespace +} // namespace mediapipe::api2::builder diff --git a/mediapipe/framework/calculator_base.h b/mediapipe/framework/calculator_base.h index 19f37f9de..1f4c82160 100644 --- a/mediapipe/framework/calculator_base.h +++ b/mediapipe/framework/calculator_base.h @@ -17,14 +17,16 @@ #ifndef MEDIAPIPE_FRAMEWORK_CALCULATOR_BASE_H_ #define MEDIAPIPE_FRAMEWORK_CALCULATOR_BASE_H_ +#include +#include #include #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/deps/registration.h" #include "mediapipe/framework/port.h" -#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/timestamp.h" namespace mediapipe { @@ -150,8 +152,9 @@ class CalculatorBase { // Packets may be output during a call to Close(). However, output packets // are silently discarded if Close() is called after a graph run has ended. // - // NOTE: If Close() needs to perform an action only when processing is - // complete, Close() must check if cc->GraphStatus() is OK. + // NOTE: Do not call cc->GraphStatus() in Close() if you need to check if the + // processing is complete. Please, see CalculatorContext::GraphStatus + // documentation for the suggested solution. virtual absl::Status Close(CalculatorContext* cc) { return absl::OkStatus(); } // Returns a value according to which the framework selects diff --git a/mediapipe/framework/calculator_base_test.cc b/mediapipe/framework/calculator_base_test.cc index c26006e0f..42c03696c 100644 --- a/mediapipe/framework/calculator_base_test.cc +++ b/mediapipe/framework/calculator_base_test.cc @@ -183,8 +183,7 @@ TEST(CalculatorTest, CreateByNameWhitelisted) { CalculatorBaseRegistry::Register( "::mediapipe::test_ns::whitelisted_ns::DeadCalculator", absl::make_unique>, - __FILE__, __LINE__); + mediapipe::test_ns::whitelisted_ns::DeadCalculator>>); // A whitelisted calculator can be found in its own namespace. MP_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( // diff --git a/mediapipe/framework/calculator_context.cc b/mediapipe/framework/calculator_context.cc index 4452f45e3..25f29222c 100644 --- a/mediapipe/framework/calculator_context.cc +++ b/mediapipe/framework/calculator_context.cc @@ -14,35 +14,37 @@ #include "mediapipe/framework/calculator_context.h" +#include "absl/log/absl_check.h" + namespace mediapipe { const std::string& CalculatorContext::CalculatorType() const { - CHECK(calculator_state_); + ABSL_CHECK(calculator_state_); return calculator_state_->CalculatorType(); } const CalculatorOptions& CalculatorContext::Options() const { - CHECK(calculator_state_); + ABSL_CHECK(calculator_state_); return calculator_state_->Options(); } const std::string& CalculatorContext::NodeName() const { - CHECK(calculator_state_); + ABSL_CHECK(calculator_state_); return calculator_state_->NodeName(); } int CalculatorContext::NodeId() const { - CHECK(calculator_state_); + ABSL_CHECK(calculator_state_); return calculator_state_->NodeId(); } Counter* CalculatorContext::GetCounter(const std::string& name) { - CHECK(calculator_state_); + ABSL_CHECK(calculator_state_); return calculator_state_->GetCounter(name); } CounterFactory* CalculatorContext::GetCounterFactory() { - CHECK(calculator_state_); + ABSL_CHECK(calculator_state_); return calculator_state_->GetCounterFactory(); } diff --git a/mediapipe/framework/calculator_context.h b/mediapipe/framework/calculator_context.h index 284226d92..315d26511 100644 --- a/mediapipe/framework/calculator_context.h +++ b/mediapipe/framework/calculator_context.h @@ -20,6 +20,7 @@ #include #include +#include "absl/log/absl_check.h" #include "mediapipe/framework/calculator_state.h" #include "mediapipe/framework/counter.h" #include "mediapipe/framework/graph_service.h" @@ -109,9 +110,20 @@ class CalculatorContext { // use OutputStream::SetOffset() directly. void SetOffset(TimestampDiff offset); - // Returns the status of the graph run. + // DEPRECATED: This was intended to get graph run status during + // `CalculatorBase::Close` call. However, `Close` can run simultaneously with + // other calculators `CalculatorBase::Process`, hence the actual graph + // status may change any time and returned graph status here does not + // necessarily reflect the actual graph status. // - // NOTE: This method should only be called during CalculatorBase::Close(). + // As an alternative, instead of checking graph status in `Close` and doing + // work for "done" state, you can enable timestamp bound processing for your + // calculator (`CalculatorContract::SetProcessTimestampBounds`) to trigger + // `Process` on timestamp bound updates and handle "done" state there. + // Check examples in: + // mediapipe/framework/calculator_graph_summary_packet_test.cc. + // + ABSL_DEPRECATED("Does not reflect the actual graph status.") absl::Status GraphStatus() const { return graph_status_; } ProfilingContext* GetProfilingContext() const { @@ -136,7 +148,7 @@ class CalculatorContext { } void PopInputTimestamp() { - CHECK(!input_timestamps_.empty()); + ABSL_CHECK(!input_timestamps_.empty()); input_timestamps_.pop(); } diff --git a/mediapipe/framework/calculator_context_manager.cc b/mediapipe/framework/calculator_context_manager.cc index acd70dd94..7da3d2778 100644 --- a/mediapipe/framework/calculator_context_manager.cc +++ b/mediapipe/framework/calculator_context_manager.cc @@ -16,6 +16,7 @@ #include +#include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/port/logging.h" @@ -27,7 +28,7 @@ void CalculatorContextManager::Initialize( std::shared_ptr input_tag_map, std::shared_ptr output_tag_map, bool calculator_run_in_parallel) { - CHECK(calculator_state); + ABSL_CHECK(calculator_state); calculator_state_ = calculator_state; input_tag_map_ = std::move(input_tag_map); output_tag_map_ = std::move(output_tag_map); @@ -51,15 +52,15 @@ void CalculatorContextManager::CleanupAfterRun() { CalculatorContext* CalculatorContextManager::GetDefaultCalculatorContext() const { - CHECK(default_context_.get()); + ABSL_CHECK(default_context_.get()); return default_context_.get(); } CalculatorContext* CalculatorContextManager::GetFrontCalculatorContext( Timestamp* context_input_timestamp) { - CHECK(calculator_run_in_parallel_); + ABSL_CHECK(calculator_run_in_parallel_); absl::MutexLock lock(&contexts_mutex_); - CHECK(!active_contexts_.empty()); + ABSL_CHECK(!active_contexts_.empty()); *context_input_timestamp = active_contexts_.begin()->first; return active_contexts_.begin()->second.get(); } @@ -70,7 +71,7 @@ CalculatorContext* CalculatorContextManager::PrepareCalculatorContext( return GetDefaultCalculatorContext(); } absl::MutexLock lock(&contexts_mutex_); - CHECK(!mediapipe::ContainsKey(active_contexts_, input_timestamp)) + ABSL_CHECK(!mediapipe::ContainsKey(active_contexts_, input_timestamp)) << "Multiple invocations with the same timestamps are not allowed with " "parallel execution, input_timestamp = " << input_timestamp; diff --git a/mediapipe/framework/calculator_context_manager.h b/mediapipe/framework/calculator_context_manager.h index 6b988b03d..ae697e12f 100644 --- a/mediapipe/framework/calculator_context_manager.h +++ b/mediapipe/framework/calculator_context_manager.h @@ -21,6 +21,7 @@ #include #include "absl/base/thread_annotations.h" +#include "absl/log/absl_check.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_state.h" @@ -97,18 +98,18 @@ class CalculatorContextManager { void PushInputTimestampToContext(CalculatorContext* calculator_context, Timestamp input_timestamp) { - CHECK(calculator_context); + ABSL_CHECK(calculator_context); calculator_context->PushInputTimestamp(input_timestamp); } void PopInputTimestampFromContext(CalculatorContext* calculator_context) { - CHECK(calculator_context); + ABSL_CHECK(calculator_context); calculator_context->PopInputTimestamp(); } void SetGraphStatusInContext(CalculatorContext* calculator_context, const absl::Status& status) { - CHECK(calculator_context); + ABSL_CHECK(calculator_context); calculator_context->SetGraphStatus(status); } diff --git a/mediapipe/framework/calculator_contract.h b/mediapipe/framework/calculator_contract.h index 7726065a7..1bc5c5aed 100644 --- a/mediapipe/framework/calculator_contract.h +++ b/mediapipe/framework/calculator_contract.h @@ -169,12 +169,12 @@ class CalculatorContract { // For services which allow default initialization: // - `CalculatorGraph` will try to create corresponding service object by // default even if request is made optional - // (`GraphServiceRequest::Optional()`) + // (`GraphServiceRequest::Optional()`). // // For services which disallow default initialization: // - `CalculatorGraph` requires client to set corresponding service object and - // otherwise fails, unles request is mad optional - // (`GraphServiceRequest::Optional()`) + // otherwise fails, unless request is made optional + // (`GraphServiceRequest::Optional()`). GraphServiceRequest& UseService(const GraphServiceBase& service) { auto it = service_requests_.emplace(service.key, service).first; return it->second; diff --git a/mediapipe/framework/calculator_framework.h b/mediapipe/framework/calculator_framework.h index afb73fb30..8f193fde8 100644 --- a/mediapipe/framework/calculator_framework.h +++ b/mediapipe/framework/calculator_framework.h @@ -52,6 +52,8 @@ #define MEDIAPIPE_FRAMEWORK_CALCULATOR_FRAMEWORK_H_ #include "mediapipe/framework/calculator_base.h" +#include "mediapipe/framework/calculator_context.h" +#include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_graph.h" #include "mediapipe/framework/calculator_registry.h" #include "mediapipe/framework/counter_factory.h" diff --git a/mediapipe/framework/calculator_graph.cc b/mediapipe/framework/calculator_graph.cc index b9fc4c965..03c5d2296 100644 --- a/mediapipe/framework/calculator_graph.cc +++ b/mediapipe/framework/calculator_graph.cc @@ -17,14 +17,17 @@ #include #include +#include +#include #include #include -#include +#include #include #include -#include "absl/container/fixed_array.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -37,9 +40,15 @@ #include "mediapipe/framework/calculator_base.h" #include "mediapipe/framework/counter_factory.h" #include "mediapipe/framework/delegating_executor.h" +#include "mediapipe/framework/executor.h" +#include "mediapipe/framework/graph_output_stream.h" #include "mediapipe/framework/graph_service_manager.h" #include "mediapipe/framework/input_stream_manager.h" #include "mediapipe/framework/mediapipe_profiling.h" +#include "mediapipe/framework/output_side_packet_impl.h" +#include "mediapipe/framework/output_stream_manager.h" +#include "mediapipe/framework/output_stream_poller.h" +#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_generator.h" #include "mediapipe/framework/packet_generator.pb.h" #include "mediapipe/framework/packet_set.h" @@ -48,14 +57,18 @@ #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/core_proto_inc.h" #include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/map_util.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/source_location.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_builder.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/scheduler.h" #include "mediapipe/framework/status_handler.h" #include "mediapipe/framework/status_handler.pb.h" #include "mediapipe/framework/thread_pool_executor.h" #include "mediapipe/framework/thread_pool_executor.pb.h" +#include "mediapipe/framework/timestamp.h" #include "mediapipe/framework/tool/fill_packet_set.h" #include "mediapipe/framework/tool/status_util.h" #include "mediapipe/framework/tool/tag_map.h" @@ -75,6 +88,11 @@ namespace { constexpr int kMaxNumAccumulatedErrors = 1000; constexpr char kApplicationThreadExecutorType[] = "ApplicationThreadExecutor"; +// Do not log status payloads, but do include stack traces. +constexpr absl::StatusToStringMode kStatusLogFlags = + absl::StatusToStringMode::kWithEverything & + (~absl::StatusToStringMode::kWithPayload); + } // namespace void CalculatorGraph::ScheduleAllOpenableNodes() { @@ -127,10 +145,10 @@ CalculatorGraph::CalculatorGraph(CalculatorGraphConfig config) // they only need to be fully visible here, where their destructor is // instantiated. CalculatorGraph::~CalculatorGraph() { - // Stop periodic profiler output to ublock Executor destructors. + // Stop periodic profiler output to unblock Executor destructors. absl::Status status = profiler()->Stop(); if (!status.ok()) { - LOG(ERROR) << "During graph destruction: " << status; + ABSL_LOG(ERROR) << "During graph destruction: " << status; } } @@ -155,7 +173,7 @@ absl::Status CalculatorGraph::InitializePacketGeneratorGraph( Executor* default_executor = nullptr; if (!use_application_thread_) { default_executor = executors_[""].get(); - CHECK(default_executor); + ABSL_CHECK(default_executor); } // If default_executor is nullptr, then packet_generator_graph_ will create // its own DelegatingExecutor to use the application thread. @@ -174,6 +192,7 @@ absl::Status CalculatorGraph::InitializeStreams() { const EdgeInfo& edge_info = validated_graph_->InputStreamInfos()[index]; MP_RETURN_IF_ERROR(input_stream_managers_[index].Initialize( edge_info.name, edge_info.packet_type, edge_info.back_edge)); + input_stream_to_index_[&input_stream_managers_[index]] = index; } // Create and initialize the output streams. @@ -382,6 +401,7 @@ absl::Status CalculatorGraph::InitializeDefaultExecutor( "", std::make_shared( std::bind(&internal::Scheduler::AddApplicationThreadTask, &scheduler_, std::placeholders::_1)))); + VLOG(1) << "Using default executor and application thread."; return absl::OkStatus(); } @@ -401,6 +421,8 @@ absl::Status CalculatorGraph::InitializeDefaultExecutor( } MP_RETURN_IF_ERROR( CreateDefaultThreadPool(default_executor_options, num_threads)); + VLOG(1) << absl::StrCat("Using default executor with num_threads: ", + num_threads); return absl::OkStatus(); } @@ -579,7 +601,7 @@ absl::Status CalculatorGraph::MaybeSetUpGpuServiceFromLegacySidePacket( if (legacy_sp.IsEmpty()) return absl::OkStatus(); auto gpu_resources = service_manager_.GetServiceObject(kGpuService); if (gpu_resources) { - LOG(WARNING) + ABSL_LOG(WARNING) << "::mediapipe::GpuSharedData provided as a side packet while the " << "graph already had one; ignoring side packet"; return absl::OkStatus(); @@ -707,7 +729,7 @@ absl::Status CalculatorGraph::PrepareForRun( absl::Status error_status; if (has_error_) { GetCombinedErrors(&error_status); - LOG(ERROR) << error_status; + ABSL_LOG(ERROR) << error_status.ToString(kStatusLogFlags); return error_status; } @@ -786,7 +808,7 @@ absl::Status CalculatorGraph::PrepareForRun( } if (GetCombinedErrors(&error_status)) { - LOG(ERROR) << error_status; + ABSL_LOG(ERROR) << error_status.ToString(kStatusLogFlags); CleanupAfterRun(&error_status); return error_status; } @@ -839,11 +861,18 @@ absl::Status CalculatorGraph::PrepareForRun( } absl::Status CalculatorGraph::WaitUntilIdle() { + if (has_sources_) { + ABSL_LOG_FIRST_N(WARNING, 1) + << "WaitUntilIdle called on a graph with source nodes, which " + "is not fully supported at the moment. Source nodes: " + << ListSourceNodes(); + } + MP_RETURN_IF_ERROR(scheduler_.WaitUntilIdle()); VLOG(2) << "Scheduler idle."; absl::Status status = absl::OkStatus(); if (GetCombinedErrors(&status)) { - LOG(ERROR) << status; + ABSL_LOG(ERROR) << status.ToString(kStatusLogFlags); } return status; } @@ -897,7 +926,7 @@ absl::Status CalculatorGraph::AddPacketToInputStreamInternal( "graph input stream.", stream_name); int node_id = mediapipe::FindOrDie(graph_input_stream_node_ids_, stream_name); - CHECK_GE(node_id, validated_graph_->CalculatorInfos().size()); + ABSL_CHECK_GE(node_id, validated_graph_->CalculatorInfos().size()); { absl::MutexLock lock(&full_input_streams_mutex_); if (full_input_streams_.empty()) { @@ -1036,17 +1065,17 @@ void CalculatorGraph::RecordError(const absl::Status& error) { } if (errors_.size() > kMaxNumAccumulatedErrors) { for (const absl::Status& error : errors_) { - LOG(ERROR) << error; + ABSL_LOG(ERROR) << error; } - LOG(FATAL) << "Forcefully aborting to prevent the framework running out " - "of memory."; + ABSL_LOG(FATAL) + << "Forcefully aborting to prevent the framework running out " + "of memory."; } } } bool CalculatorGraph::GetCombinedErrors(absl::Status* error_status) { - return GetCombinedErrors("CalculatorGraph::Run() failed in Run: ", - error_status); + return GetCombinedErrors("CalculatorGraph::Run() failed: ", error_status); } bool CalculatorGraph::GetCombinedErrors(const std::string& error_prefix, @@ -1085,7 +1114,8 @@ void CalculatorGraph::CallStatusHandlers(GraphRunState graph_run_state, absl::StatusOr> static_access_statusor = internal::StaticAccessToStatusHandlerRegistry:: CreateByNameInNamespace(validated_graph_->Package(), handler_type); - CHECK(static_access_statusor.ok()) << handler_type << " is not registered."; + ABSL_CHECK(static_access_statusor.ok()) + << handler_type << " is not registered."; auto static_access = std::move(static_access_statusor).value(); absl::Status handler_result; if (graph_run_state == GraphRunState::PRE_RUN) { @@ -1126,7 +1156,7 @@ void CalculatorGraph::UpdateThrottledNodes(InputStreamManager* stream, upstream_nodes = &validated_graph_->CalculatorInfos()[node_index].AncestorSources(); } - CHECK(upstream_nodes); + ABSL_CHECK(upstream_nodes); std::vector nodes_to_schedule; { @@ -1148,10 +1178,10 @@ void CalculatorGraph::UpdateThrottledNodes(InputStreamManager* stream, .set_stream_id(&stream->Name())); bool was_throttled = !full_input_streams_[node_id].empty(); if (stream_is_full) { - DCHECK_EQ(full_input_streams_[node_id].count(stream), 0); + ABSL_DCHECK_EQ(full_input_streams_[node_id].count(stream), 0); full_input_streams_[node_id].insert(stream); } else { - DCHECK_EQ(full_input_streams_[node_id].count(stream), 1); + ABSL_DCHECK_EQ(full_input_streams_[node_id].count(stream), 1); full_input_streams_[node_id].erase(stream); } @@ -1208,7 +1238,7 @@ bool CalculatorGraph::UnthrottleSources() { // NOTE: We can be sure that this function will grow input streams enough // to unthrottle at least one source node. The current stream queue sizes // will remain unchanged until at least one source node becomes unthrottled. - // This is a sufficient because succesfully growing at least one full input + // This is a sufficient because successfully growing at least one full input // stream during each call to UnthrottleSources will eventually resolve // each deadlock. absl::flat_hash_set full_streams; @@ -1228,7 +1258,8 @@ bool CalculatorGraph::UnthrottleSources() { for (InputStreamManager* stream : full_streams) { if (Config().report_deadlock()) { RecordError(absl::UnavailableError(absl::StrCat( - "Detected a deadlock due to input throttling for: \"", stream->Name(), + "Detected a deadlock due to input throttling for input stream: \"", + stream->Name(), "\" of a node \"", GetParentNodeDebugName(stream), "\". All calculators are idle while packet sources remain active " "and throttled. Consider adjusting \"max_queue_size\" or " "\"report_deadlock\"."))); @@ -1236,10 +1267,11 @@ bool CalculatorGraph::UnthrottleSources() { } int new_size = stream->QueueSize() + 1; stream->SetMaxQueueSize(new_size); - LOG_EVERY_N(WARNING, 100) - << "Resolved a deadlock by increasing max_queue_size of input stream: " - << stream->Name() << " to: " << new_size - << ". Consider increasing max_queue_size for better performance."; + ABSL_LOG_EVERY_N(WARNING, 100) << absl::StrCat( + "Resolved a deadlock by increasing max_queue_size of input stream: \"", + stream->Name(), "\" of a node \"", GetParentNodeDebugName(stream), + "\" to ", new_size, + ". Consider increasing max_queue_size for better performance."); } return !full_streams.empty(); } @@ -1333,7 +1365,7 @@ void CalculatorGraph::CleanupAfterRun(absl::Status* status) { // Obtain the combined status again, so that it includes the new errors // added by CallStatusHandlers. GetCombinedErrors(status); - CHECK(!status->ok()); + ABSL_CHECK(!status->ok()); } else { MEDIAPIPE_CHECK_OK(*status); } @@ -1368,6 +1400,37 @@ const OutputStreamManager* CalculatorGraph::FindOutputStreamManager( .get()[validated_graph_->OutputStreamIndex(name)]; } +std::string CalculatorGraph::ListSourceNodes() const { + std::vector sources; + for (auto& node : nodes_) { + if (node->IsSource()) { + sources.push_back(node->DebugName()); + } + } + return absl::StrJoin(sources, ", "); +} + +std::string CalculatorGraph::GetParentNodeDebugName( + InputStreamManager* stream) const { + auto iter = input_stream_to_index_.find(stream); + if (iter == input_stream_to_index_.end()) { + return absl::StrCat("Unknown (node with input stream: ", stream->Name(), + ")"); + } + + const int input_stream_index = iter->second; + const EdgeInfo& edge_info = + validated_graph_->InputStreamInfos()[input_stream_index]; + const int node_index = edge_info.parent_node.index; + const CalculatorGraphConfig& config = validated_graph_->Config(); + if (node_index < 0 || node_index >= config.node_size()) { + return absl::StrCat("Unknown (node index: ", node_index, + ", with input stream: ", stream->Name(), ")"); + } + + return DebugName(config.node(node_index)); +} + namespace { void PrintTimingToInfo(const std::string& label, int64_t timer_value) { const int64_t total_seconds = timer_value / 1000000ll; @@ -1376,12 +1439,13 @@ void PrintTimingToInfo(const std::string& label, int64_t timer_value) { const int64_t minutes = (total_seconds / 60ll) % 60ll; const int64_t seconds = total_seconds % 60ll; const int64_t milliseconds = (timer_value / 1000ll) % 1000ll; - LOG(INFO) << label << " took " - << absl::StrFormat( - "%02lld days, %02lld:%02lld:%02lld.%03lld (total seconds: " - "%lld.%06lld)", - days, hours, minutes, seconds, milliseconds, total_seconds, - timer_value % int64_t{1000000}); + ABSL_LOG(INFO) + << label << " took " + << absl::StrFormat( + "%02lld days, %02lld:%02lld:%02lld.%03lld (total seconds: " + "%lld.%06lld)", + days, hours, minutes, seconds, milliseconds, total_seconds, + timer_value % int64_t{1000000}); } bool MetricElementComparator(const std::pair& e1, diff --git a/mediapipe/framework/calculator_graph.h b/mediapipe/framework/calculator_graph.h index 354694e39..4284beb7c 100644 --- a/mediapipe/framework/calculator_graph.h +++ b/mediapipe/framework/calculator_graph.h @@ -26,10 +26,12 @@ #include #include -#include "absl/base/macros.h" -#include "absl/container/fixed_array.h" +#include "absl/base/attributes.h" +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_base.h" @@ -41,18 +43,17 @@ #include "mediapipe/framework/graph_service_manager.h" #include "mediapipe/framework/mediapipe_profiling.h" #include "mediapipe/framework/output_side_packet_impl.h" -#include "mediapipe/framework/output_stream.h" #include "mediapipe/framework/output_stream_manager.h" #include "mediapipe/framework/output_stream_poller.h" #include "mediapipe/framework/output_stream_shard.h" #include "mediapipe/framework/packet.h" -#include "mediapipe/framework/packet_generator.pb.h" #include "mediapipe/framework/packet_generator_graph.h" -#include "mediapipe/framework/port.h" -#include "mediapipe/framework/port/integral_types.h" -#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/scheduler.h" +#include "mediapipe/framework/scheduler_shared.h" +#include "mediapipe/framework/subgraph.h" #include "mediapipe/framework/thread_pool_executor.pb.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/framework/validated_graph_config.h" namespace mediapipe { @@ -229,8 +230,11 @@ class CalculatorGraph { // Wait until the running graph is in the idle mode, which is when nothing can // be scheduled and nothing is running in the worker threads. This function // can be called only after StartRun(). + // // NOTE: The graph must not have any source nodes because source nodes prevent // the running graph from becoming idle until the source nodes are done. + // Currently, `WaitUntilIdle` cannot be used reliably on graphs with any + // source nodes. absl::Status WaitUntilIdle(); // Wait until a packet is emitted on one of the observed output streams. @@ -594,6 +598,12 @@ class CalculatorGraph { // status before taking any action. void UpdateThrottledNodes(InputStreamManager* stream, bool* stream_was_full); + // Returns a comma-separated list of source nodes. + std::string ListSourceNodes() const; + + // Returns a parent node name for the given input stream. + std::string GetParentNodeDebugName(InputStreamManager* stream) const; + #if !MEDIAPIPE_DISABLE_GPU // Owns the legacy GpuSharedData if we need to create one for backwards // compatibility. @@ -649,6 +659,9 @@ class CalculatorGraph { std::vector> full_input_streams_ ABSL_GUARDED_BY(full_input_streams_mutex_); + // Input stream to index within `input_stream_managers_` mapping. + absl::flat_hash_map input_stream_to_index_; + // Maps stream names to graph input stream objects. absl::flat_hash_map> graph_input_streams_; diff --git a/mediapipe/framework/calculator_graph_side_packet_test.cc b/mediapipe/framework/calculator_graph_side_packet_test.cc index a9567c805..6f42f585e 100644 --- a/mediapipe/framework/calculator_graph_side_packet_test.cc +++ b/mediapipe/framework/calculator_graph_side_packet_test.cc @@ -17,6 +17,7 @@ #include #include +#include "absl/log/absl_log.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "mediapipe/framework/calculator.pb.h" @@ -24,7 +25,6 @@ #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" @@ -128,7 +128,7 @@ class IntegerOutputSidePacketCalculator : public CalculatorBase { } absl::Status Process(CalculatorContext* cc) final { - LOG(FATAL) << "Not reached."; + ABSL_LOG(FATAL) << "Not reached."; return absl::OkStatus(); } }; @@ -153,7 +153,7 @@ class SidePacketAdderCalculator : public CalculatorBase { } absl::Status Process(CalculatorContext* cc) final { - LOG(FATAL) << "Not reached."; + ABSL_LOG(FATAL) << "Not reached."; return absl::OkStatus(); } }; @@ -778,7 +778,7 @@ class OutputSidePacketCachedCalculator : public CalculatorBase { } absl::Status Process(CalculatorContext* cc) final { - LOG(FATAL) << "Not reached."; + ABSL_LOG(FATAL) << "Not reached."; return absl::OkStatus(); } }; diff --git a/mediapipe/framework/calculator_graph_summary_packet_test.cc b/mediapipe/framework/calculator_graph_summary_packet_test.cc new file mode 100644 index 000000000..e6a04e060 --- /dev/null +++ b/mediapipe/framework/calculator_graph_summary_packet_test.cc @@ -0,0 +1,430 @@ +#include "absl/status/status.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Node; +using ::mediapipe::api2::Output; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::Value; + +namespace { + +MATCHER_P2(IntPacket, value, timestamp, "") { + *result_listener << "where object is (value: " << arg.template Get() + << ", timestamp: " << arg.Timestamp() << ")"; + return Value(arg.template Get(), Eq(value)) && + Value(arg.Timestamp(), Eq(timestamp)); +} + +// Calculates and produces sum of all passed inputs when no more packets can be +// expected on the input stream. +class SummaryPacketCalculator : public Node { + public: + static constexpr Input kIn{"IN"}; + static constexpr Output kOut{"SUMMARY"}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + + static absl::Status UpdateContract(CalculatorContract* cc) { + // Makes sure there are no automatic timestamp bound updates when Process + // is called. + cc->SetTimestampOffset(TimestampDiff::Unset()); + // Currently, only ImmediateInputStreamHandler supports "done" timestamp + // bound update. (ImmediateInputStreamhandler handles multiple input + // streams differently, so, in that case, calculator adjustments may be + // required.) + // TODO: update all input stream handlers to support "done" + // timestamp bound update. + cc->SetInputStreamHandler("ImmediateInputStreamHandler"); + // Enables processing timestamp bound updates. For this use case we are + // specifically interested in "done" timestamp bound update. (E.g. when + // all input packet sources are closed.) + cc->SetProcessTimestampBounds(true); + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) final { + if (!kIn(cc).IsEmpty()) { + value_ += kIn(cc).Get(); + value_set_ = true; + } + + if (kOut(cc).IsClosed()) { + // This can happen: + // 1. If, during previous invocation, kIn(cc).IsDone() == true (e.g. + // source calculator finished generating packets sent to kIn) and + // HasNextAllowedInStream() == true (which is an often case). + // 2. For Timestamp::PreStream, ImmediateInputStreamHandler will still + // invoke Process() with Timestamp::Max to indicate "Done" timestamp + // bound update. + return absl::OkStatus(); + } + + // TODO: input stream holding a packet with timestamp that has + // no next timestamp allowed in stream should always result in + // InputStream::IsDone() == true. + if (kIn(cc).IsDone() || !cc->InputTimestamp().HasNextAllowedInStream()) { + // `Process` may or may not be invoked for "done" timestamp bound when + // upstream calculator fails in `Close`. Hence, extra care is needed to + // identify whether the calculator needs to send output. + // TODO: remove when "done" timestamp bound flakiness fixed. + if (value_set_) { + // kOut(cc).Send(value_) can be used here as well, however in the case + // of source calculator sending inputs into kIn the resulting timestamp + // is not well defined (e.g. it can be the last packet timestamp or + // Timestamp::Max()) + // TODO: last packet from source should always result in + // InputStream::IsDone() == true. + kOut(cc).Send(value_, Timestamp::Max()); + } + kOut(cc).Close(); + } + return absl::OkStatus(); + } + + private: + int value_ = 0; + bool value_set_ = false; +}; +MEDIAPIPE_REGISTER_NODE(SummaryPacketCalculator); + +TEST(SummaryPacketCalculatorUseCaseTest, + ProducesSummaryPacketOnClosingAllPacketSources) { + auto graph_config = ParseTextProtoOrDie(R"pb( + input_stream: 'input' + node { + calculator: "SummaryPacketCalculator" + input_stream: 'IN:input' + output_stream: 'SUMMARY:output' + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("output", &graph_config, &output_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, IsEmpty()); + + auto send_packet = [&graph](int value, Timestamp timestamp) { + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input", MakePacket(value).At(timestamp))); + }; + + send_packet(10, Timestamp(10)); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, IsEmpty()); + + send_packet(20, Timestamp(11)); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, IsEmpty()); + + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max()))); +} + +TEST(SummaryPacketCalculatorUseCaseTest, ProducesSummaryPacketOnMaxTimestamp) { + auto graph_config = ParseTextProtoOrDie(R"pb( + input_stream: 'input' + node { + calculator: "SummaryPacketCalculator" + input_stream: 'IN:input' + output_stream: 'SUMMARY:output' + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("output", &graph_config, &output_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, IsEmpty()); + + auto send_packet = [&graph](int value, Timestamp timestamp) { + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input", MakePacket(value).At(timestamp))); + }; + + send_packet(10, Timestamp(10)); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, IsEmpty()); + + send_packet(20, Timestamp::Max()); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max()))); + + output_packets.clear(); + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_THAT(output_packets, IsEmpty()); +} + +TEST(SummaryPacketCalculatorUseCaseTest, + ProducesSummaryPacketOnPreStreamTimestamp) { + auto graph_config = ParseTextProtoOrDie(R"pb( + input_stream: 'input' + node { + calculator: "SummaryPacketCalculator" + input_stream: 'IN:input' + output_stream: 'SUMMARY:output' + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("output", &graph_config, &output_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, IsEmpty()); + + auto send_packet = [&graph](int value, Timestamp timestamp) { + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input", MakePacket(value).At(timestamp))); + }; + + send_packet(10, Timestamp::PreStream()); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max()))); + + output_packets.clear(); + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_THAT(output_packets, IsEmpty()); +} + +TEST(SummaryPacketCalculatorUseCaseTest, + ProducesSummaryPacketOnPostStreamTimestamp) { + std::vector output_packets; + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie(R"pb( + input_stream: 'input' + node { + calculator: "SummaryPacketCalculator" + input_stream: 'IN:input' + output_stream: 'SUMMARY:output' + } + )pb"); + tool::AddVectorSink("output", &graph_config, &output_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, IsEmpty()); + + auto send_packet = [&graph](int value, Timestamp timestamp) { + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input", MakePacket(value).At(timestamp))); + }; + + send_packet(10, Timestamp::PostStream()); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max()))); + + output_packets.clear(); + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_THAT(output_packets, IsEmpty()); +} + +class IntGeneratorCalculator : public Node { + public: + static constexpr Output kOut{"INT"}; + + MEDIAPIPE_NODE_CONTRACT(kOut); + + absl::Status Process(CalculatorContext* cc) final { + kOut(cc).Send(20, Timestamp(0)); + kOut(cc).Send(10, Timestamp(1000)); + return tool::StatusStop(); + } +}; +MEDIAPIPE_REGISTER_NODE(IntGeneratorCalculator); + +TEST(SummaryPacketCalculatorUseCaseTest, + ProducesSummaryPacketOnSourceCalculatorCompletion) { + std::vector output_packets; + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie(R"pb( + node { + calculator: "IntGeneratorCalculator" + output_stream: "INT:int_value" + } + node { + calculator: "SummaryPacketCalculator" + input_stream: "IN:int_value" + output_stream: "SUMMARY:output" + } + )pb"); + tool::AddVectorSink("output", &graph_config, &output_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + MP_EXPECT_OK(graph.WaitUntilDone()); + EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max()))); +} + +class EmitOnCloseCalculator : public Node { + public: + static constexpr Input kIn{"IN"}; + static constexpr Output kOut{"INT"}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + + absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); } + + absl::Status Close(CalculatorContext* cc) final { + kOut(cc).Send(20, Timestamp(0)); + kOut(cc).Send(10, Timestamp(1000)); + return absl::OkStatus(); + } +}; +MEDIAPIPE_REGISTER_NODE(EmitOnCloseCalculator); + +TEST(SummaryPacketCalculatorUseCaseTest, + ProducesSummaryPacketOnAnotherCalculatorClosure) { + auto graph_config = ParseTextProtoOrDie(R"pb( + input_stream: "input" + node { + calculator: "EmitOnCloseCalculator" + input_stream: "IN:input" + output_stream: "INT:int_value" + } + node { + calculator: "SummaryPacketCalculator" + input_stream: "IN:int_value" + output_stream: "SUMMARY:output" + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("output", &graph_config, &output_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, IsEmpty()); + + MP_ASSERT_OK(graph.CloseInputStream("input")); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max()))); + + output_packets.clear(); + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_THAT(output_packets, IsEmpty()); +} + +class FailureInCloseCalculator : public Node { + public: + static constexpr Input kIn{"IN"}; + static constexpr Output kOut{"INT"}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + + absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); } + + absl::Status Close(CalculatorContext* cc) final { + return absl::InternalError("error"); + } +}; +MEDIAPIPE_REGISTER_NODE(FailureInCloseCalculator); + +TEST(SummaryPacketCalculatorUseCaseTest, + DoesNotProduceSummaryPacketWhenUpstreamCalculatorFailsInClose) { + auto graph_config = ParseTextProtoOrDie(R"pb( + input_stream: "input" + node { + calculator: "FailureInCloseCalculator" + input_stream: "IN:input" + output_stream: "INT:int_value" + } + node { + calculator: "SummaryPacketCalculator" + input_stream: "IN:int_value" + output_stream: "SUMMARY:output" + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("output", &graph_config, &output_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, IsEmpty()); + + MP_ASSERT_OK(graph.CloseInputStream("input")); + EXPECT_THAT(graph.WaitUntilIdle(), + StatusIs(absl::StatusCode::kInternal, HasSubstr("error"))); + EXPECT_THAT(output_packets, IsEmpty()); +} + +class FailureInProcessCalculator : public Node { + public: + static constexpr Input kIn{"IN"}; + static constexpr Output kOut{"INT"}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + + absl::Status Process(CalculatorContext* cc) final { + return absl::InternalError("error"); + } +}; +MEDIAPIPE_REGISTER_NODE(FailureInProcessCalculator); + +TEST(SummaryPacketCalculatorUseCaseTest, + DoesNotProduceSummaryPacketWhenUpstreamCalculatorFailsInProcess) { + auto graph_config = ParseTextProtoOrDie(R"pb( + input_stream: "input" + node { + calculator: "FailureInProcessCalculator" + input_stream: "IN:input" + output_stream: "INT:int_value" + } + node { + calculator: "SummaryPacketCalculator" + input_stream: "IN:int_value" + output_stream: "SUMMARY:output" + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("output", &graph_config, &output_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, IsEmpty()); + + auto send_packet = [&graph](int value, Timestamp timestamp) { + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input", MakePacket(value).At(timestamp))); + }; + + send_packet(10, Timestamp::PostStream()); + EXPECT_THAT(graph.WaitUntilIdle(), + StatusIs(absl::StatusCode::kInternal, HasSubstr("error"))); + EXPECT_THAT(output_packets, IsEmpty()); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/calculator_graph_test.cc b/mediapipe/framework/calculator_graph_test.cc index 2e7d99ef6..91bf72e31 100644 --- a/mediapipe/framework/calculator_graph_test.cc +++ b/mediapipe/framework/calculator_graph_test.cc @@ -17,16 +17,22 @@ #include #include +#include #include #include +#include #include #include +#include #include #include #include #include "absl/container/fixed_array.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -47,7 +53,6 @@ #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" @@ -725,13 +730,13 @@ class SlowCountingSinkCalculator : public CalculatorBase { absl::Status Process(CalculatorContext* cc) override { absl::SleepFor(absl::Milliseconds(10)); int value = cc->Inputs().Index(0).Get(); - CHECK_EQ(value, counter_); + ABSL_CHECK_EQ(value, counter_); ++counter_; return absl::OkStatus(); } absl::Status Close(CalculatorContext* cc) override { - CHECK_EQ(10, counter_); + ABSL_CHECK_EQ(10, counter_); return absl::OkStatus(); } @@ -1014,7 +1019,7 @@ class CheckInputTimestampSourceCalculator : public CalculatorBase { absl::Status Close(CalculatorContext* cc) final { // Must use CHECK instead of RET_CHECK in Close(), because the framework // may call the Close() method of a source node with .IgnoreError(). - CHECK_EQ(cc->InputTimestamp(), Timestamp::Done()); + ABSL_CHECK_EQ(cc->InputTimestamp(), Timestamp::Done()); return absl::OkStatus(); } @@ -1092,7 +1097,7 @@ class CheckInputTimestamp2SourceCalculator : public CalculatorBase { absl::Status Close(CalculatorContext* cc) final { // Must use CHECK instead of RET_CHECK in Close(), because the framework // may call the Close() method of a source node with .IgnoreError(). - CHECK_EQ(cc->InputTimestamp(), Timestamp::Done()); + ABSL_CHECK_EQ(cc->InputTimestamp(), Timestamp::Done()); return absl::OkStatus(); } @@ -1242,8 +1247,8 @@ REGISTER_STATUS_HANDLER(IncrementingStatusHandler); class CurrentThreadExecutor : public Executor { public: ~CurrentThreadExecutor() override { - CHECK(!executing_); - CHECK(tasks_.empty()); + ABSL_CHECK(!executing_); + ABSL_CHECK(tasks_.empty()); } void Schedule(std::function task) override { @@ -1254,7 +1259,7 @@ class CurrentThreadExecutor : public Executor { // running) to avoid an indefinitely-deep call stack. tasks_.emplace_back(std::move(task)); } else { - CHECK(tasks_.empty()); + ABSL_CHECK(tasks_.empty()); executing_ = true; task(); while (!tasks_.empty()) { @@ -1406,7 +1411,7 @@ void RunComprehensiveTest(CalculatorGraph* graph, // Call graph->Run() several times, to make sure that the appropriate // cleanup happens between iterations. for (int iteration = 0; iteration < 2; ++iteration) { - LOG(INFO) << "Loop iteration " << iteration; + ABSL_LOG(INFO) << "Loop iteration " << iteration; dumped_final_sum_packet = Packet(); dumped_final_stddev_packet = Packet(); dumped_final_packet = Packet(); @@ -1448,7 +1453,7 @@ void RunComprehensiveTest(CalculatorGraph* graph, ->GetCounter("copy_range5-PassThrough") ->Get()); } - LOG(INFO) << "After Loop Runs."; + ABSL_LOG(INFO) << "After Loop Runs."; // Verify that the graph can still run (but not successfully) when // one of the nodes is caused to fail. extra_side_packets.clear(); @@ -1459,9 +1464,9 @@ void RunComprehensiveTest(CalculatorGraph* graph, dumped_final_sum_packet = Packet(); dumped_final_stddev_packet = Packet(); dumped_final_packet = Packet(); - LOG(INFO) << "Expect an error to be logged here."; + ABSL_LOG(INFO) << "Expect an error to be logged here."; ASSERT_FALSE(graph->Run(extra_side_packets).ok()); - LOG(INFO) << "Error should have been logged."; + ABSL_LOG(INFO) << "Error should have been logged."; } TEST(CalculatorGraph, BadInitialization) { @@ -2549,6 +2554,129 @@ TEST(CalculatorGraph, OutputPacketInOpen2) { EXPECT_EQ(Timestamp(i), packet_dump[i].Timestamp()); } +TEST(CalculatorGraph, DeadlockIsReportedAndSufficientInfoProvided) { + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(R"pb( + report_deadlock: true + max_queue_size: 1 + input_stream: 'input1' + input_stream: 'input2' + node { + calculator: 'PassThroughCalculator' + input_stream: 'input1' + input_stream: 'input2' + output_stream: 'output1' + output_stream: 'output2' + } + )pb"); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.StartRun({})); + + Packet packet = MakePacket(1); + MP_EXPECT_OK(graph.AddPacketToInputStream("input1", packet.At(Timestamp(0)))); + absl::Status status = + graph.AddPacketToInputStream("input1", packet.At(Timestamp(1))); + + EXPECT_EQ(status.code(), absl::StatusCode::kUnavailable); + EXPECT_THAT(status.message(), + testing::AllOf(testing::HasSubstr("deadlock"), + testing::HasSubstr("input1"), + testing::HasSubstr("PassThroughCalculator"))); + graph.Cancel(); +} + +TEST(CalculatorGraph, + DeadlockIsReportedAndSufficientInfoProvidedMultipleCalculators) { + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(R"pb( + report_deadlock: true + max_queue_size: 1 + input_stream: 'input1' + input_stream: 'input2' + node { + calculator: 'PassThroughCalculator' + input_stream: 'input1' + input_stream: 'input2' + output_stream: 'output1' + output_stream: 'output2' + } + node { + calculator: 'MergeCalculator' + input_stream: 'output1' + input_stream: 'output2' + output_stream: 'output3' + } + )pb"); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.StartRun({})); + + Packet packet = MakePacket(1); + MP_EXPECT_OK(graph.AddPacketToInputStream("input1", packet.At(Timestamp(0)))); + absl::Status status = + graph.AddPacketToInputStream("input1", packet.At(Timestamp(1))); + + EXPECT_EQ(status.code(), absl::StatusCode::kUnavailable); + EXPECT_THAT(status.message(), + testing::AllOf(testing::HasSubstr("deadlock"), + testing::HasSubstr("input1"), + testing::HasSubstr("PassThroughCalculator"))); + graph.Cancel(); +} + +TEST(CalculatorGraph, TwoDeadlocksAreReportedAndSufficientInfoProvided) { + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(R"pb( + report_deadlock: true + max_queue_size: 1 + input_stream: 'input1' + input_stream: 'input2' + node { + calculator: 'PassThroughCalculator' + input_stream: 'input1' + input_stream: 'input2' + output_stream: 'output1' + output_stream: 'output2' + } + node { + calculator: 'PassThroughCalculator' + input_stream: 'output1' + input_stream: 'output2' + output_stream: 'output3' + output_stream: 'output4' + } + node { + calculator: 'MergeCalculator' + input_stream: 'input1' + input_stream: 'output1' + input_stream: 'output2' + input_stream: 'output3' + input_stream: 'output4' + output_stream: 'output5' + } + )pb"); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.StartRun({})); + + Packet packet = MakePacket(1); + MP_EXPECT_OK(graph.AddPacketToInputStream("input1", packet.At(Timestamp(0)))); + absl::Status status = + graph.AddPacketToInputStream("input1", packet.At(Timestamp(1))); + + EXPECT_EQ(status.code(), absl::StatusCode::kUnavailable); + EXPECT_THAT(status.message(), + testing::AllOf(testing::HasSubstr("deadlock"), + testing::HasSubstr("input1"), + testing::HasSubstr("PassThroughCalculator"), + testing::HasSubstr("MergeCalculator"))); + graph.Cancel(); +} + // Tests that no packets are available on input streams in Open(), even if the // upstream calculator outputs a packet in Open(). TEST(CalculatorGraph, EmptyInputInOpen) { @@ -2619,7 +2747,7 @@ TEST(CalculatorGraph, UnthrottleRespectsLayers) { std::map input_side_packets; input_side_packets["global_counter"] = Adopt(new auto(&global_counter)); // TODO: Set this value to true. When the calculator outputs a - // packet in Open, it will trigget b/33568859, and the test will fail. Use + // packet in Open, it will trigger b/33568859, and the test will fail. Use // this test to verify that b/33568859 is fixed. constexpr bool kOutputInOpen = true; input_side_packets["output_in_open"] = MakePacket(kOutputInOpen); @@ -3339,7 +3467,7 @@ TEST(CalculatorGraph, SetInputStreamMaxQueueSizeWorksSlowCalculator) { // Verify the scheduler unthrottles the graph input stream to avoid a deadlock, // and won't enter a busy loop. TEST(CalculatorGraph, AddPacketNoBusyLoop) { - // The DecimatorCalculator ouputs 1 out of every 101 input packets and drops + // The DecimatorCalculator outputs 1 out of every 101 input packets and drops // the rest, without setting the next timestamp bound on its output. As a // result, the MergeCalculator is not runnable in between and packets on its // "in" input stream will be queued and exceed the max queue size. @@ -3467,7 +3595,7 @@ REGISTER_CALCULATOR(::mediapipe::nested_ns::ProcessCallbackCalculator); TEST(CalculatorGraph, CalculatorInNamepsace) { CalculatorGraphConfig config; - CHECK(proto_ns::TextFormat::ParseFromString(R"( + ABSL_CHECK(proto_ns::TextFormat::ParseFromString(R"( input_stream: 'in_a' node { calculator: 'mediapipe.nested_ns.ProcessCallbackCalculator' @@ -3476,7 +3604,7 @@ TEST(CalculatorGraph, CalculatorInNamepsace) { input_side_packet: 'callback_1' } )", - &config)); + &config)); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); nested_ns::ProcessFunction callback_1; diff --git a/mediapipe/framework/calculator_node.cc b/mediapipe/framework/calculator_node.cc index f6a1c7dbf..c0aff3b13 100644 --- a/mediapipe/framework/calculator_node.cc +++ b/mediapipe/framework/calculator_node.cc @@ -19,6 +19,8 @@ #include #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -59,7 +61,7 @@ const PacketType* GetPacketType(const PacketTypeSet& packet_type_set, } else { id = packet_type_set.GetId(tag, 0); } - CHECK(id.IsValid()) << "Internal mediapipe error."; + ABSL_CHECK(id.IsValid()) << "Internal mediapipe error."; return &packet_type_set.Get(id); } @@ -341,7 +343,7 @@ absl::Status CalculatorNode::ConnectShardsToStreams( void CalculatorNode::SetExecutor(const std::string& executor) { absl::MutexLock status_lock(&status_mutex_); - CHECK_LT(status_, kStateOpened); + ABSL_CHECK_LT(status_, kStateOpened); executor_ = executor; } @@ -366,7 +368,7 @@ bool CalculatorNode::Closed() const { } void CalculatorNode::SetMaxInputStreamQueueSize(int max_queue_size) { - CHECK(input_stream_handler_); + ABSL_CHECK(input_stream_handler_); input_stream_handler_->SetMaxQueueSize(max_queue_size); } @@ -506,7 +508,7 @@ absl::Status CalculatorNode::OpenNode() { Timestamp(0)); } - LOG_IF(FATAL, result == tool::StatusStop()) << absl::Substitute( + ABSL_LOG_IF(FATAL, result == tool::StatusStop()) << absl::Substitute( "Open() on node \"$0\" returned tool::StatusStop() which should only be " "used to signal that a source node is done producing data.", DebugName()); @@ -519,7 +521,7 @@ absl::Status CalculatorNode::OpenNode() { offset_enabled = offset_enabled || stream->Spec()->offset_enabled; } if (offset_enabled && input_stream_handler_->SyncSetCount() > 1) { - LOG(WARNING) << absl::Substitute( + ABSL_LOG(WARNING) << absl::Substitute( "Calculator node \"$0\" is configured with multiple input sync-sets " "and an output timestamp-offset, which will often conflict due to " "the order of packet arrival. With multiple input sync-sets, use " @@ -539,7 +541,7 @@ absl::Status CalculatorNode::OpenNode() { void CalculatorNode::ActivateNode() { absl::MutexLock status_lock(&status_mutex_); - CHECK_EQ(status_, kStateOpened) << DebugName(); + ABSL_CHECK_EQ(status_, kStateOpened) << DebugName(); status_ = kStateActive; } @@ -601,7 +603,7 @@ absl::Status CalculatorNode::CloseNode(const absl::Status& graph_status, } needs_to_close_ = false; - LOG_IF(FATAL, result == tool::StatusStop()) << absl::Substitute( + ABSL_LOG_IF(FATAL, result == tool::StatusStop()) << absl::Substitute( "Close() on node \"$0\" returned tool::StatusStop() which should only be " "used to signal that a source node is done producing data.", DebugName()); @@ -694,8 +696,8 @@ void CalculatorNode::InputStreamHeadersReady() { bool ready_for_open = false; { absl::MutexLock lock(&status_mutex_); - CHECK_EQ(status_, kStatePrepared) << DebugName(); - CHECK(!input_stream_headers_ready_called_); + ABSL_CHECK_EQ(status_, kStatePrepared) << DebugName(); + ABSL_CHECK(!input_stream_headers_ready_called_); input_stream_headers_ready_called_ = true; input_stream_headers_ready_ = true; ready_for_open = input_side_packets_ready_; @@ -709,8 +711,8 @@ void CalculatorNode::InputSidePacketsReady() { bool ready_for_open = false; { absl::MutexLock lock(&status_mutex_); - CHECK_EQ(status_, kStatePrepared) << DebugName(); - CHECK(!input_side_packets_ready_called_); + ABSL_CHECK_EQ(status_, kStatePrepared) << DebugName(); + ABSL_CHECK(!input_side_packets_ready_called_); input_side_packets_ready_called_ = true; input_side_packets_ready_ = true; ready_for_open = input_stream_headers_ready_; @@ -760,7 +762,7 @@ void CalculatorNode::EndScheduling() { return; } --current_in_flight_; - CHECK_GE(current_in_flight_, 0); + ABSL_CHECK_GE(current_in_flight_, 0); if (scheduling_state_ == kScheduling) { // Changes the state to scheduling pending if another thread is doing the @@ -790,7 +792,7 @@ std::string CalculatorNode::DebugInputStreamNames() const { } std::string CalculatorNode::DebugName() const { - DCHECK(calculator_state_); + ABSL_DCHECK(calculator_state_); return calculator_state_->NodeName(); } @@ -893,9 +895,9 @@ absl::Status CalculatorNode::ProcessNode( // open input streams for Process(). So this node needs to be closed // too. // If the streams are closed, there shouldn't be more input. - CHECK_EQ(calculator_context_manager_.NumberOfContextTimestamps( - *calculator_context), - 1); + ABSL_CHECK_EQ(calculator_context_manager_.NumberOfContextTimestamps( + *calculator_context), + 1); return CloseNode(absl::OkStatus(), /*graph_run_ended=*/false); } else { RET_CHECK_FAIL() @@ -910,7 +912,7 @@ absl::Status CalculatorNode::ProcessNode( void CalculatorNode::SetQueueSizeCallbacks( InputStreamManager::QueueSizeCallback becomes_full_callback, InputStreamManager::QueueSizeCallback becomes_not_full_callback) { - CHECK(input_stream_handler_); + ABSL_CHECK(input_stream_handler_); input_stream_handler_->SetQueueSizeCallbacks( std::move(becomes_full_callback), std::move(becomes_not_full_callback)); } diff --git a/mediapipe/framework/calculator_node_test.cc b/mediapipe/framework/calculator_node_test.cc index 1c62a7141..deac61f13 100644 --- a/mediapipe/framework/calculator_node_test.cc +++ b/mediapipe/framework/calculator_node_test.cc @@ -18,11 +18,12 @@ #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_macros.h" @@ -95,7 +96,8 @@ int CountCalculator::num_destroyed_ = 0; void SourceNodeOpenedNoOp() {} void CheckFail(const absl::Status& status) { - LOG(FATAL) << "The test triggered the error callback with status: " << status; + ABSL_LOG(FATAL) << "The test triggered the error callback with status: " + << status; } class CalculatorNodeTest : public ::testing::Test { @@ -103,7 +105,7 @@ class CalculatorNodeTest : public ::testing::Test { void ReadyForOpen(int* count) { ++(*count); } void Notification(CalculatorContext* cc, int* count) { - CHECK(cc); + ABSL_CHECK(cc); cc_ = cc; ++(*count); } diff --git a/mediapipe/framework/calculator_options.proto b/mediapipe/framework/calculator_options.proto index 747e9c4af..3bc9f6615 100644 --- a/mediapipe/framework/calculator_options.proto +++ b/mediapipe/framework/calculator_options.proto @@ -23,15 +23,13 @@ package mediapipe; option java_package = "com.google.mediapipe.proto"; option java_outer_classname = "CalculatorOptionsProto"; -// Options for Calculators. Each Calculator implementation should -// have its own options proto, which should look like this: +// Options for Calculators, DEPRECATED. New calculators are encouraged to use +// proto3 syntax options: // // message MyCalculatorOptions { -// extend CalculatorOptions { -// optional MyCalculatorOptions ext = ; -// } -// optional string field_needed_by_my_calculator = 1; -// optional int32 another_field = 2; +// // proto3 does not expect "optional" +// string field_needed_by_my_calculator = 1; +// int32 another_field = 2; // // etc // } message CalculatorOptions { diff --git a/mediapipe/framework/calculator_runner.cc b/mediapipe/framework/calculator_runner.cc index 1bd3211ed..800f041cc 100644 --- a/mediapipe/framework/calculator_runner.cc +++ b/mediapipe/framework/calculator_runner.cc @@ -16,10 +16,11 @@ #include "mediapipe/framework/calculator_runner.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator_framework.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" @@ -139,7 +140,7 @@ CalculatorRunner::CalculatorRunner(const std::string& calculator_type, #if !defined(MEDIAPIPE_PROTO_LITE) CalculatorRunner::CalculatorRunner(const std::string& node_config_string) { CalculatorGraphConfig::Node node_config; - CHECK( + ABSL_CHECK( proto_ns::TextFormat::ParseFromString(node_config_string, &node_config)); MEDIAPIPE_CHECK_OK(InitializeFromNodeConfig(node_config)); } @@ -149,8 +150,8 @@ CalculatorRunner::CalculatorRunner(const std::string& calculator_type, int num_inputs, int num_outputs, int num_side_packets) { node_config_.set_calculator(calculator_type); - CHECK(proto_ns::TextFormat::ParseFromString(options_string, - node_config_.mutable_options())); + ABSL_CHECK(proto_ns::TextFormat::ParseFromString( + options_string, node_config_.mutable_options())); SetNumInputs(num_inputs); SetNumOutputs(num_outputs); SetNumInputSidePackets(num_side_packets); @@ -188,7 +189,7 @@ void CalculatorRunner::SetNumInputSidePackets(int n) { } void CalculatorRunner::InitializeInputs(const tool::TagAndNameInfo& info) { - CHECK(graph_ == nullptr); + ABSL_CHECK(graph_ == nullptr); MEDIAPIPE_CHECK_OK( tool::SetFromTagAndNameInfo(info, node_config_.mutable_input_stream())); inputs_.reset(new StreamContentsSet(info)); @@ -196,7 +197,7 @@ void CalculatorRunner::InitializeInputs(const tool::TagAndNameInfo& info) { } void CalculatorRunner::InitializeOutputs(const tool::TagAndNameInfo& info) { - CHECK(graph_ == nullptr); + ABSL_CHECK(graph_ == nullptr); MEDIAPIPE_CHECK_OK( tool::SetFromTagAndNameInfo(info, node_config_.mutable_output_stream())); outputs_.reset(new StreamContentsSet(info)); @@ -205,7 +206,7 @@ void CalculatorRunner::InitializeOutputs(const tool::TagAndNameInfo& info) { void CalculatorRunner::InitializeInputSidePackets( const tool::TagAndNameInfo& info) { - CHECK(graph_ == nullptr); + ABSL_CHECK(graph_ == nullptr); MEDIAPIPE_CHECK_OK(tool::SetFromTagAndNameInfo( info, node_config_.mutable_input_side_packet())); input_side_packets_.reset(new PacketSet(info)); @@ -262,16 +263,18 @@ absl::Status CalculatorRunner::BuildGraph() { if (log_calculator_proto_) { #if defined(MEDIAPIPE_PROTO_LITE) - LOG(INFO) << "Please initialize CalculatorRunner using the recommended " - "constructor:\n CalculatorRunner runner(node_config);"; + ABSL_LOG(INFO) + << "Please initialize CalculatorRunner using the recommended " + "constructor:\n CalculatorRunner runner(node_config);"; #else std::string config_string; proto_ns::TextFormat::Printer printer; printer.SetInitialIndentLevel(4); printer.PrintToString(node_config_, &config_string); - LOG(INFO) << "Please initialize CalculatorRunner using the recommended " - "constructor:\n CalculatorRunner runner(R\"(\n" - << config_string << "\n )\");"; + ABSL_LOG(INFO) + << "Please initialize CalculatorRunner using the recommended " + "constructor:\n CalculatorRunner runner(R\"(\n" + << config_string << "\n )\");"; #endif } diff --git a/mediapipe/framework/calculator_runner_test.cc b/mediapipe/framework/calculator_runner_test.cc index a7890badd..7fd118cc6 100644 --- a/mediapipe/framework/calculator_runner_test.cc +++ b/mediapipe/framework/calculator_runner_test.cc @@ -16,6 +16,7 @@ #include "mediapipe/framework/calculator_runner.h" +#include "absl/log/absl_log.h" #include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator_base.h" #include "mediapipe/framework/calculator_registry.h" @@ -24,7 +25,6 @@ #include "mediapipe/framework/packet_type.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/timestamp.h" @@ -136,7 +136,7 @@ TEST(CalculatorRunner, RunsCalculator) { // Run CalculatorRunner::Run() several times, with different inputs. This // tests that a CalculatorRunner instance can be reused. for (int iter = 0; iter < 3; ++iter) { - LOG(INFO) << "iter: " << iter; + ABSL_LOG(INFO) << "iter: " << iter; const int length = iter; // Generate the inputs at timestamps 0 ... length-1, at timestamp t having // values t and t*2 for the two streams, respectively. diff --git a/mediapipe/framework/calculator_state.cc b/mediapipe/framework/calculator_state.cc index 3b0264e97..9ff478688 100644 --- a/mediapipe/framework/calculator_state.cc +++ b/mediapipe/framework/calculator_state.cc @@ -18,6 +18,7 @@ #include +#include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "mediapipe/framework/port/logging.h" @@ -46,23 +47,23 @@ void CalculatorState::ResetBetweenRuns() { } void CalculatorState::SetInputSidePackets(const PacketSet* input_side_packets) { - CHECK(input_side_packets); + ABSL_CHECK(input_side_packets); input_side_packets_ = input_side_packets; } void CalculatorState::SetOutputSidePackets( OutputSidePacketSet* output_side_packets) { - CHECK(output_side_packets); + ABSL_CHECK(output_side_packets); output_side_packets_ = output_side_packets; } Counter* CalculatorState::GetCounter(const std::string& name) { - CHECK(counter_factory_); + ABSL_CHECK(counter_factory_); return counter_factory_->GetCounter(absl::StrCat(NodeName(), "-", name)); } CounterFactory* CalculatorState::GetCounterFactory() { - CHECK(counter_factory_); + ABSL_CHECK(counter_factory_); return counter_factory_; } diff --git a/mediapipe/framework/collection.h b/mediapipe/framework/collection.h index c7b6fb0de..d955c9cbe 100644 --- a/mediapipe/framework/collection.h +++ b/mediapipe/framework/collection.h @@ -24,11 +24,12 @@ #include #include "absl/base/macros.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/collection_item_id.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/tool/tag_map.h" #include "mediapipe/framework/tool/tag_map_helper.h" #include "mediapipe/framework/tool/validate_name.h" @@ -52,7 +53,7 @@ struct CollectionErrorHandlerFatal { // get away with only one version of this function (which is const // but returns a non-const reference). T& GetFallback(const absl::string_view tag, int index) const { - LOG(FATAL) << "Failed to get tag \"" << tag << "\" index " << index; + ABSL_LOG(FATAL) << "Failed to get tag \"" << tag << "\" index " << index; std::abort(); } }; @@ -365,7 +366,7 @@ class Collection { std::unique_ptr data_; // A class which allows errors to be reported flexibly. The default - // instantiation performs a LOG(FATAL) and does not have any member + // instantiation performs a ABSL_LOG(FATAL) and does not have any member // variables (zero size). ErrorHandler error_handler_; }; @@ -413,16 +414,16 @@ bool Collection::UsesTags() const { template typename Collection::value_type& Collection::Get(CollectionItemId id) { - CHECK_LE(BeginId(), id); - CHECK_LT(id, EndId()); + ABSL_CHECK_LE(BeginId(), id); + ABSL_CHECK_LT(id, EndId()); return begin()[id.value()]; } template const typename Collection::value_type& Collection::Get(CollectionItemId id) const { - CHECK_LE(BeginId(), id); - CHECK_LT(id, EndId()); + ABSL_CHECK_LE(BeginId(), id); + ABSL_CHECK_LT(id, EndId()); return begin()[id.value()]; } @@ -433,8 +434,8 @@ Collection::GetPtr(CollectionItemId id) { "mediapipe::internal::Collection::GetPtr() is only " "available for collections that were defined with template " "argument storage == CollectionStorage::kStorePointer."); - CHECK_LE(BeginId(), id); - CHECK_LT(id, EndId()); + ABSL_CHECK_LE(BeginId(), id); + ABSL_CHECK_LT(id, EndId()); return data_[id.value()]; } @@ -445,8 +446,8 @@ Collection::GetPtr(CollectionItemId id) const { "mediapipe::internal::Collection::GetPtr() is only " "available for collections that were defined with template " "argument storage == CollectionStorage::kStorePointer."); - CHECK_LE(BeginId(), id); - CHECK_LT(id, EndId()); + ABSL_CHECK_LE(BeginId(), id); + ABSL_CHECK_LT(id, EndId()); return data_[id.value()]; } diff --git a/mediapipe/framework/counter_factory.cc b/mediapipe/framework/counter_factory.cc index 895b44ea6..b4da1043e 100644 --- a/mediapipe/framework/counter_factory.cc +++ b/mediapipe/framework/counter_factory.cc @@ -16,6 +16,7 @@ #include +#include "absl/log/absl_log.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" @@ -59,9 +60,9 @@ void CounterSet::PublishCounters() ABSL_LOCKS_EXCLUDED(mu_) {} void CounterSet::PrintCounters() ABSL_LOCKS_EXCLUDED(mu_) { absl::ReaderMutexLock lock(&mu_); - LOG_IF(INFO, !counters_.empty()) << "MediaPipe Counters:"; + ABSL_LOG_IF(INFO, !counters_.empty()) << "MediaPipe Counters:"; for (const auto& counter : counters_) { - LOG(INFO) << counter.first << ": " << counter.second->Get(); + ABSL_LOG(INFO) << counter.first << ": " << counter.second->Get(); } } diff --git a/mediapipe/framework/deps/BUILD b/mediapipe/framework/deps/BUILD index 7fe37bae6..6b6709526 100644 --- a/mediapipe/framework/deps/BUILD +++ b/mediapipe/framework/deps/BUILD @@ -77,8 +77,9 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework/port:logging", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", ], @@ -130,8 +131,9 @@ cc_library( deps = [ "//mediapipe/framework/port", "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:logging", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", ], ) @@ -149,7 +151,10 @@ cc_library( # Use this library through "mediapipe/framework/port:map_util". visibility = ["//mediapipe/framework/port:__pkg__"], - deps = ["//mediapipe/framework/port:logging"], + deps = [ + "//mediapipe/framework/port:logging", + "@com_google_absl//absl/log:absl_check", + ], ) cc_library( @@ -160,7 +165,7 @@ cc_library( ], deps = [ "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:logging", + "@com_google_absl//absl/log:absl_check", ], ) @@ -228,12 +233,13 @@ cc_library( ], deps = [ ":registration_token", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -276,8 +282,8 @@ cc_library( visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ ":source_location", - "//mediapipe/framework/port:logging", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -344,6 +350,8 @@ cc_library( deps = [ ":thread_options", "//mediapipe/framework/port:logging", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", ], @@ -358,6 +366,8 @@ cc_library( visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ "//mediapipe/framework/port:logging", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", ], ) @@ -369,7 +379,7 @@ cc_library( visibility = ["//mediapipe/framework/port:__pkg__"], deps = [ "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:logging", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/utility", ], ) @@ -415,10 +425,10 @@ cc_test( ":clock", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:threadpool", "//mediapipe/framework/tool:simulation_clock", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", diff --git a/mediapipe/framework/deps/cleanup.h b/mediapipe/framework/deps/cleanup.h index 125cc7400..0541e314f 100644 --- a/mediapipe/framework/deps/cleanup.h +++ b/mediapipe/framework/deps/cleanup.h @@ -26,7 +26,7 @@ // DataObject d; // while (ReadDataObject(fp, &d)) { // if (d.IsBad()) { -// LOG(ERROR) << "Bad Data"; +// ABSL_LOG(ERROR) << "Bad Data"; // return; // } // PushGoodData(d); diff --git a/mediapipe/framework/deps/clock.cc b/mediapipe/framework/deps/clock.cc index f68143862..418d82814 100644 --- a/mediapipe/framework/deps/clock.cc +++ b/mediapipe/framework/deps/clock.cc @@ -14,8 +14,8 @@ #include "mediapipe/framework/deps/clock.h" +#include "absl/log/absl_log.h" #include "absl/time/clock.h" -#include "mediapipe/framework/port/logging.h" namespace mediapipe { @@ -28,7 +28,7 @@ namespace { class RealTimeClock : public Clock { public: virtual ~RealTimeClock() { - LOG(FATAL) << "RealTimeClock should never be destroyed"; + ABSL_LOG(FATAL) << "RealTimeClock should never be destroyed"; } absl::Time TimeNow() override { return absl::Now(); } diff --git a/mediapipe/framework/deps/file_helpers.cc b/mediapipe/framework/deps/file_helpers.cc index 23b50310f..84faefc0a 100644 --- a/mediapipe/framework/deps/file_helpers.cc +++ b/mediapipe/framework/deps/file_helpers.cc @@ -17,6 +17,9 @@ #ifdef _WIN32 #include #include + +#include +#include #else #include #endif // _WIN32 @@ -86,11 +89,31 @@ class DirectoryListing { struct dirent* next_entry_ = nullptr; }; #else +#if defined(UNICODE) +using PathString = std::wstring; + +PathString Utf8ToNative(const std::string& string) { + std::wstring_convert, wchar_t> converter; + return converter.from_bytes(string.data(), string.data() + string.size()); +} +std::string NativeToUtf8(const PathString& string) { + std::wstring_convert, wchar_t> converter; + return converter.to_bytes(string.data(), string.data() + string.size()); +} +#define FILE_PATH_LITERAL_INTERNAL(x) L##x +#define FILE_PATH_LITERAL(x) FILE_PATH_LITERAL_INTERNAL(x) +#else +using PathString = std::string; +PathString Utf8ToNative(const std::string& string) { return string; } +std::string NativeToUtf8(const PathString& string) { return string; } +#define FILE_PATH_LITERAL(x) x +#endif + class DirectoryListing { public: - explicit DirectoryListing(const std::string& directory) { - directory_ = directory; - std::string search_string = directory + "\\*.*"; + explicit DirectoryListing(const std::string& directory) + : directory_(Utf8ToNative(directory)) { + PathString search_string = directory_ + Utf8ToNative("\\*.*"); find_handle_ = FindFirstFile(search_string.c_str(), &find_data_); } @@ -107,10 +130,10 @@ class DirectoryListing { // after the one that is returned, if it exists. std::string NextEntry() { if (HasNextEntry()) { - std::string result = - std::string(directory_ + "\\" + find_data_.cFileName); + PathString result = + directory_ + Utf8ToNative("\\") + PathString(find_data_.cFileName); ReadNextEntry(); - return result; + return NativeToUtf8(result); } else { return std::string(); } @@ -119,8 +142,9 @@ class DirectoryListing { private: void ReadNextEntry() { int find_result = FindNextFile(find_handle_, &find_data_); - while (find_result != 0 && (std::string(find_data_.cFileName) == "." || - std::string(find_data_.cFileName) == "..")) { + while (find_result != 0 && + (PathString(find_data_.cFileName) == FILE_PATH_LITERAL(".") || + PathString(find_data_.cFileName) == FILE_PATH_LITERAL(".."))) { find_result = FindNextFile(find_handle_, &find_data_); } @@ -130,7 +154,7 @@ class DirectoryListing { } } - std::string directory_; + const PathString directory_; HANDLE find_handle_ = INVALID_HANDLE_VALUE; WIN32_FIND_DATA find_data_; }; @@ -162,7 +186,7 @@ absl::Status GetContents(absl::string_view file_name, std::string* output, absl::Status SetContents(absl::string_view file_name, absl::string_view content) { - FILE* fp = fopen(file_name.data(), "w"); + FILE* fp = fopen(file_name.data(), "wb"); if (fp == NULL) { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Can't open file: " << file_name; diff --git a/mediapipe/framework/deps/map_util.h b/mediapipe/framework/deps/map_util.h index 05d47b7e7..940ff03f8 100644 --- a/mediapipe/framework/deps/map_util.h +++ b/mediapipe/framework/deps/map_util.h @@ -27,6 +27,7 @@ #include #include +#include "absl/log/absl_check.h" #include "mediapipe/framework/port/logging.h" namespace mediapipe { @@ -53,7 +54,7 @@ template const typename M::value_type::second_type& FindOrDie( const M& m, const typename M::value_type::first_type& key) { auto it = m.find(key); - CHECK(it != m.end()) << "Map key not found: " << key; + ABSL_CHECK(it != m.end()) << "Map key not found: " << key; return it->second; } @@ -63,7 +64,7 @@ typename M::value_type::second_type& FindOrDie( M& m, // NOLINT const typename M::value_type::first_type& key) { auto it = m.find(key); - CHECK(it != m.end()) << "Map key not found: " << key; + ABSL_CHECK(it != m.end()) << "Map key not found: " << key; return it->second; } @@ -138,7 +139,7 @@ bool InsertIfNotPresent(M* m, const typename M::value_type::first_type& key, // inserted. template bool ReverseMap(const M& m, ReverseM* reverse) { - CHECK(reverse != nullptr); + ABSL_CHECK(reverse != nullptr); for (const auto& kv : m) { if (!InsertIfNotPresent(reverse, kv.second, kv.first)) { return false; diff --git a/mediapipe/framework/deps/mathutil.h b/mediapipe/framework/deps/mathutil.h index 315b78c42..a3d8b6e80 100644 --- a/mediapipe/framework/deps/mathutil.h +++ b/mediapipe/framework/deps/mathutil.h @@ -23,8 +23,8 @@ #include #include +#include "absl/log/absl_check.h" #include "mediapipe/framework/port/integral_types.h" -#include "mediapipe/framework/port/logging.h" namespace mediapipe { @@ -354,7 +354,7 @@ class MathUtil { template // T models LessThanComparable. static const T& Clamp(const T& low, const T& high, const T& value) { // Prevents errors in ordering the arguments. - DCHECK(!(high < low)); + ABSL_DCHECK(!(high < low)); if (high < value) return high; if (value < low) return low; return value; @@ -364,7 +364,7 @@ class MathUtil { // absolute margin of error. template static bool WithinMargin(const T x, const T y, const T margin) { - DCHECK_GE(margin, 0); + ABSL_DCHECK_GE(margin, 0); return (std::abs(x) <= std::abs(y) + margin) && (std::abs(x) >= std::abs(y) - margin); } diff --git a/mediapipe/framework/deps/monotonic_clock.cc b/mediapipe/framework/deps/monotonic_clock.cc index 503ef5cfd..17542b6f6 100644 --- a/mediapipe/framework/deps/monotonic_clock.cc +++ b/mediapipe/framework/deps/monotonic_clock.cc @@ -16,9 +16,10 @@ #include "absl/base/macros.h" #include "absl/base/thread_annotations.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" -#include "mediapipe/framework/port/logging.h" namespace mediapipe { @@ -60,7 +61,7 @@ class MonotonicClockImpl : public MonotonicClock { // Absolve this object of responsibility for state_. void ReleaseState() { - CHECK(state_owned_); + ABSL_CHECK(state_owned_); state_owned_ = false; } @@ -80,7 +81,7 @@ class MonotonicClockImpl : public MonotonicClock { absl::MutexLock m(&state_->lock); // Check consistency of internal data with state_. - CHECK_LE(last_raw_time_, state_->max_time) + ABSL_CHECK_LE(last_raw_time_, state_->max_time) << "non-monotonic behavior: last_raw_time_=" << last_raw_time_ << ", max_time=" << state_->max_time; @@ -107,7 +108,7 @@ class MonotonicClockImpl : public MonotonicClock { // First, update correction metrics. ++correction_count_; absl::Duration delta = state_->max_time - raw_time; - CHECK_LT(absl::ZeroDuration(), delta); + ABSL_CHECK_LT(absl::ZeroDuration(), delta); if (delta > max_correction_) { max_correction_ = delta; } @@ -205,7 +206,7 @@ MonotonicClock* MonotonicClock::CreateSynchronizedMonotonicClock() { // Test access methods. void MonotonicClockAccess::SynchronizedMonotonicClockReset() { - LOG(INFO) << "Resetting SynchronizedMonotonicClock"; + ABSL_LOG(INFO) << "Resetting SynchronizedMonotonicClock"; State* sync_state = GlobalSyncState(); absl::MutexLock m(&sync_state->lock); sync_state->max_time = absl::UnixEpoch(); diff --git a/mediapipe/framework/deps/monotonic_clock_test.cc b/mediapipe/framework/deps/monotonic_clock_test.cc index 0a049392f..9b57ffe51 100644 --- a/mediapipe/framework/deps/monotonic_clock_test.cc +++ b/mediapipe/framework/deps/monotonic_clock_test.cc @@ -21,13 +21,13 @@ #include #include "absl/base/thread_annotations.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/synchronization/mutex.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/integral_types.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/threadpool.h" #include "mediapipe/framework/tool/simulation_clock.h" @@ -254,8 +254,8 @@ TEST_F(MonotonicClockTest, RealTime) { // Just out of curiousity -- did real clock go backwards? int clock_num_corrections; mono_clock->GetCorrectionMetrics(&clock_num_corrections, NULL); - LOG(INFO) << clock_num_corrections << " corrections in " << num_calls - << " calls to mono_clock->Now()"; + ABSL_LOG(INFO) << clock_num_corrections << " corrections in " << num_calls + << " calls to mono_clock->Now()"; delete mono_clock; } @@ -523,13 +523,13 @@ TEST_F(MonotonicClockTest, RealFrenzy) { // Just out of curiousity -- did real clock go backwards? int clock_num_corrections; m1->GetCorrectionMetrics(&clock_num_corrections, NULL); - LOG_IF(INFO, clock_num_corrections > 0) + ABSL_LOG_IF(INFO, clock_num_corrections > 0) << clock_num_corrections << " corrections"; m2->GetCorrectionMetrics(&clock_num_corrections, NULL); - LOG_IF(INFO, clock_num_corrections > 0) + ABSL_LOG_IF(INFO, clock_num_corrections > 0) << clock_num_corrections << " corrections"; m3->GetCorrectionMetrics(&clock_num_corrections, NULL); - LOG_IF(INFO, clock_num_corrections > 0) + ABSL_LOG_IF(INFO, clock_num_corrections > 0) << clock_num_corrections << " corrections"; delete m1; delete m2; diff --git a/mediapipe/framework/deps/re2.h b/mediapipe/framework/deps/re2.h index 61f7985ee..89dc8fcdb 100644 --- a/mediapipe/framework/deps/re2.h +++ b/mediapipe/framework/deps/re2.h @@ -19,7 +19,7 @@ namespace mediapipe { -// Implementats a subset of RE2 using std::regex_match. +// Implements a subset of RE2 using std::regex_match. class RE2 { public: RE2(const std::string& pattern) : std_regex_(pattern) {} diff --git a/mediapipe/framework/deps/registration.h b/mediapipe/framework/deps/registration.h index 7965539b6..f974d6896 100644 --- a/mediapipe/framework/deps/registration.h +++ b/mediapipe/framework/deps/registration.h @@ -16,7 +16,6 @@ #define MEDIAPIPE_DEPS_REGISTRATION_H_ #include -#include #include #include #include @@ -29,6 +28,8 @@ #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/meta/type_traits.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" @@ -36,7 +37,6 @@ #include "absl/synchronization/mutex.h" #include "mediapipe/framework/deps/registration_token.h" #include "mediapipe/framework/port/canonical_errors.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/statusor.h" namespace mediapipe { @@ -145,6 +145,23 @@ template struct WrapStatusOr> { using type = absl::StatusOr; }; + +// Defining a member of this type causes P to be ODR-used, which forces its +// instantiation if it's a static member of a template. +// Previously we depended on the pointer's value to determine whether the size +// of a character array is 0 or 1, forcing it to be instantiated so the +// compiler can determine the object's layout. But using it as a template +// argument is more compact. +template +struct ForceStaticInstantiation { +#ifdef _MSC_VER + // Just having it as the template argument does not count as a use for + // MSVC. + static constexpr bool Use() { return P != nullptr; } + char force_static[Use()]; +#endif // _MSC_VER +}; + } // namespace registration_internal class NamespaceAllowlist { @@ -162,8 +179,7 @@ class FunctionRegistry { FunctionRegistry(const FunctionRegistry&) = delete; FunctionRegistry& operator=(const FunctionRegistry&) = delete; - RegistrationToken Register(absl::string_view name, Function func, - std::string filename, uint64_t line) + RegistrationToken Register(absl::string_view name, Function func) ABSL_LOCKS_EXCLUDED(lock_) { std::string normalized_name = GetNormalizedName(name); absl::WriterMutexLock lock(&lock_); @@ -173,21 +189,10 @@ class FunctionRegistry { } if (functions_.insert(std::make_pair(normalized_name, std::move(func))) .second) { -#ifndef NDEBUG - locations_.emplace(normalized_name, - std::make_pair(std::move(filename), line)); -#endif return RegistrationToken( [this, normalized_name]() { Unregister(normalized_name); }); } -#ifndef NDEBUG - LOG(FATAL) << "Function with name " << name << " already registered." - << " First registration at " - << locations_.at(normalized_name).first << ":" - << locations_.at(normalized_name).second; -#else - LOG(FATAL) << "Function with name " << name << " already registered."; -#endif + ABSL_LOG(FATAL) << "Function with name " << name << " already registered."; return RegistrationToken([]() {}); } @@ -266,7 +271,7 @@ class FunctionRegistry { if (names[0].empty()) { names.erase(names.begin()); } else { - CHECK_EQ(1u, names.size()) + ABSL_CHECK_EQ(1u, names.size()) << "A registered class name must be either fully qualified " << "with a leading :: or unqualified, got: " << name << "."; } @@ -316,11 +321,6 @@ class FunctionRegistry { private: mutable absl::Mutex lock_; absl::flat_hash_map functions_ ABSL_GUARDED_BY(lock_); -#ifndef NDEBUG - // Stores filename and line number for useful debug log. - absl::flat_hash_map> locations_ - ABSL_GUARDED_BY(lock_); -#endif // For names included in NamespaceAllowlist, strips the namespace. std::string GetAdjustedName(absl::string_view name) { @@ -351,10 +351,8 @@ class GlobalFactoryRegistry { public: static RegistrationToken Register(absl::string_view name, - typename Functions::Function func, - std::string filename, uint64_t line) { - return functions()->Register(name, std::move(func), std::move(filename), - line); + typename Functions::Function func) { + return functions()->Register(name, std::move(func)); } // Invokes the specified factory function and returns the result. @@ -411,15 +409,181 @@ class GlobalFactoryRegistry { #define REGISTRY_STATIC_VAR(var_name, line) \ REGISTRY_STATIC_VAR_INNER(var_name, line) -#define MEDIAPIPE_REGISTER_FACTORY_FUNCTION(RegistryType, name, ...) \ - static auto* REGISTRY_STATIC_VAR(registration_##name, __LINE__) = \ - new mediapipe::RegistrationToken( \ - RegistryType::Register(#name, __VA_ARGS__, __FILE__, __LINE__)) +// Disables all static registration in MediaPipe accomplished using: +// - REGISTER_FACTORY_FUNCTION_QUALIFIED +// - MEDIAPIPE_REGISTER_FACTORY_FUNCTION +// - MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE +// +// Which includes: +// - calculators +// - input stream handlers +// - output stream handlers +// - generators +// - anything else registered using above macros +#if !defined(MEDIAPIPE_DISABLE_STATIC_REGISTRATION) +#define MEDIAPIPE_DISABLE_STATIC_REGISTRATION 0 +#endif // !defined(MEDIAPIPE_DISABLE_STATIC_REGISTRATION) +// Enables "Dry Run" for MediaPipe static registration: MediaPipe logs the +// registration code, instead of actual registration. +// +// The intended use: if you plan to disable static registration using +// MEDIAPIPE_DISABLE_STATIC_REGISTRATION, you may find it useful to build your +// MediaPipe dependency first with only: +// MEDIAPIPE_ENABLE_STATIC_REGISTRATION_DRY_RUN +// and load it to see what manual registration will be required when you build +// with: +// MEDIAPIPE_DISABLE_STATIC_REGISTRATION +#if !defined(MEDIAPIPE_ENABLE_STATIC_REGISTRATION_DRY_RUN) +#define MEDIAPIPE_ENABLE_STATIC_REGISTRATION_DRY_RUN 0 +#endif // !defined(MEDIAPIPE_ENABLE_STATIC_REGISTRATION_DRY_RUN) + +#if MEDIAPIPE_DISABLE_STATIC_REGISTRATION && \ + MEDIAPIPE_ENABLE_STATIC_REGISTRATION_DRY_RUN +static_assert(false, + "Cannot do static registration Dry Run as static registration is " + "disabled."); +#endif // MEDIAPIPE_DISABLE_STATIC_REGISTRATION && + // MEDIAPIPE_ENABLE_STATIC_REGISTRATION_DRY_RUN + +#if MEDIAPIPE_DISABLE_STATIC_REGISTRATION +// When static registration is disabled, make sure corresponding macros don't do +// any registration. + +#define MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED(RegistryType, var_name, \ + name, ...) +#define MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(RegistratorName, RegistryType, \ + name, ...) \ + template \ + class RegistratorName {}; + +#elif MEDIAPIPE_ENABLE_STATIC_REGISTRATION_DRY_RUN +// When static registration is enabled and running in Dry-Run mode, make sure +// corresponding macros print registration details instead of doing actual +// registration. + +#define INTERNAL_MEDIAPIPE_REGISTER_FACTORY_STRINGIFY_HELPER(x) #x +#define INTERNAL_MEDIAPIPE_REGISTER_FACTORY_STRINGIFY(x) \ + INTERNAL_MEDIAPIPE_REGISTER_FACTORY_STRINGIFY_HELPER(x) + +#define MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED(RegistryType, var_name, \ + name, ...) \ + static mediapipe::RegistrationToken* REGISTRY_STATIC_VAR(var_name, \ + __LINE__) = []() { \ + ABSL_RAW_LOG(WARNING, "Registration Dry Run: %s", \ + INTERNAL_MEDIAPIPE_REGISTER_FACTORY_STRINGIFY( \ + RegistryType::Register(name, __VA_ARGS__))); \ + return nullptr; \ + }(); + +#define MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(RegistratorName, RegistryType, \ + names, ...) \ + template \ + struct Internal##RegistratorName { \ + static NoDestructor registration; \ + \ + static mediapipe::RegistrationToken Make() { \ + ABSL_RAW_LOG(WARNING, "Registration Dry Run: %s", \ + INTERNAL_MEDIAPIPE_REGISTER_FACTORY_STRINGIFY( \ + RegistryType::Register(names, __VA_ARGS__))); \ + ABSL_RAW_LOG(WARNING, "Where typeid(T).name() is: %s", \ + typeid(T).name()); \ + return {}; \ + } \ + \ + using RequireStatics = \ + registration_internal::ForceStaticInstantiation<®istration>; \ + }; \ + /* Static members of template classes can be defined in the header. */ \ + template \ + NoDestructor \ + Internal##RegistratorName::registration( \ + Internal##RegistratorName::Make()); \ + \ + template \ + class RegistratorName { \ + private: \ + /* The member below triggers instantiation of the registration static. */ \ + typename Internal##RegistratorName::RequireStatics register_; \ + }; + +#else +// When static registration is enabled and NOT running in Dry-Run mode, make +// sure corresponding macros do proper static registration. + +#define MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED(RegistryType, var_name, \ + name, ...) \ + static mediapipe::RegistrationToken* REGISTRY_STATIC_VAR(var_name, \ + __LINE__) = \ + new mediapipe::RegistrationToken( \ + RegistryType::Register(name, __VA_ARGS__)); + +// Defines a utility registrator class which can be used to automatically +// register factory functions. +// +// Example: +// === Defining a registry ================================================ +// +// class Component {}; +// +// using ComponentRegistry = GlobalFactoryRegistry>; +// +// === Defining a registrator ============================================= +// +// MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(ComponentRegistrator, +// ComponentRegistry, T::kName, +// absl::make_unique); +// +// === Defining and registering a new component. ========================== +// +// class MyComponent : public Component, +// private ComponentRegistrator { +// public: +// static constexpr char kName[] = "MyComponent"; +// ... +// }; +// +// NOTE: +// - MyComponent is automatically registered in ComponentRegistry by +// "MyComponent" name. +// - Every component is require to provide its name (T::kName here.) +#define MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(RegistratorName, RegistryType, \ + name, ...) \ + template \ + struct Internal##RegistratorName { \ + static NoDestructor registration; \ + \ + static mediapipe::RegistrationToken Make() { \ + return RegistryType::Register(name, __VA_ARGS__); \ + } \ + \ + using RequireStatics = \ + registration_internal::ForceStaticInstantiation<®istration>; \ + }; \ + /* Static members of template classes can be defined in the header. */ \ + template \ + NoDestructor \ + Internal##RegistratorName::registration( \ + Internal##RegistratorName::Make()); \ + \ + template \ + class RegistratorName { \ + private: \ + /* The member below triggers instantiation of the registration static. */ \ + typename Internal##RegistratorName::RequireStatics register_; \ + }; + +#endif // MEDIAPIPE_DISABLE_STATIC_REGISTRATION + +#define MEDIAPIPE_REGISTER_FACTORY_FUNCTION(RegistryType, name, ...) \ + MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED( \ + RegistryType, registration_##name, #name, __VA_ARGS__) + +// TODO: migrate usages to use +// MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED. #define REGISTER_FACTORY_FUNCTION_QUALIFIED(RegistryType, var_name, name, ...) \ - static auto* REGISTRY_STATIC_VAR(var_name, __LINE__) = \ - new mediapipe::RegistrationToken( \ - RegistryType::Register(#name, __VA_ARGS__, __FILE__, __LINE__)) + MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED(RegistryType, var_name, #name, \ + __VA_ARGS__) } // namespace mediapipe diff --git a/mediapipe/framework/deps/safe_int.h b/mediapipe/framework/deps/safe_int.h index 4c120bc1b..37d8663cc 100644 --- a/mediapipe/framework/deps/safe_int.h +++ b/mediapipe/framework/deps/safe_int.h @@ -34,7 +34,7 @@ // define any custom policy they desire. // // PolicyTypes: -// LogFatalOnError: LOG(FATAL) when a error occurs. +// LogFatalOnError: ABSL_LOG(FATAL) when a error occurs. #ifndef MEDIAPIPE_DEPS_SAFE_INT_H_ #define MEDIAPIPE_DEPS_SAFE_INT_H_ @@ -44,8 +44,9 @@ #include #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "mediapipe/framework/deps/strong_int.h" -#include "mediapipe/framework/port/logging.h" namespace mediapipe { namespace intops { @@ -67,17 +68,17 @@ class SafeIntStrongIntValidator { // Check that the underlying integral type provides a range that is // compatible with two's complement. if (std::numeric_limits::is_signed) { - CHECK_EQ(-1, - std::numeric_limits::min() + std::numeric_limits::max()) + ABSL_CHECK_EQ( + -1, std::numeric_limits::min() + std::numeric_limits::max()) << "unexpected integral bounds"; } // Check that division truncates towards 0 (implementation defined in // C++'03, but standard in C++'11). - CHECK_EQ(12, 127 / 10) << "division does not truncate towards 0"; - CHECK_EQ(-12, -127 / 10) << "division does not truncate towards 0"; - CHECK_EQ(-12, 127 / -10) << "division does not truncate towards 0"; - CHECK_EQ(12, -127 / -10) << "division does not truncate towards 0"; + ABSL_CHECK_EQ(12, 127 / 10) << "division does not truncate towards 0"; + ABSL_CHECK_EQ(-12, -127 / 10) << "division does not truncate towards 0"; + ABSL_CHECK_EQ(-12, 127 / -10) << "division does not truncate towards 0"; + ABSL_CHECK_EQ(12, -127 / -10) << "division does not truncate towards 0"; } public: @@ -88,10 +89,13 @@ class SafeIntStrongIntValidator { // If the argument is floating point, we can do a simple check to make // sure the value is in range. It is undefined behavior to convert to int - // from a float that is out of range. + // from a float that is out of range. Since large integers will loose some + // precision when being converted to floating point, the integer max and min + // are explicitly converted back to floating point for this comparison, in + // order to satisfy compiler warnings. if (std::is_floating_point::value) { - if (arg < std::numeric_limits::min() || - arg > std::numeric_limits::max()) { + if (arg < static_cast(std::numeric_limits::min()) || + arg > static_cast(std::numeric_limits::max())) { ErrorType::Error("SafeInt: init from out of bounds float", arg, "="); } } else { @@ -281,15 +285,15 @@ class SafeIntStrongIntValidator { } }; -// A SafeIntStrongIntValidator policy class to LOG(FATAL) on errors. +// A SafeIntStrongIntValidator policy class to ABSL_LOG(FATAL) on errors. struct LogFatalOnError { template - static void Error(const char *error, Tlhs lhs, Trhs rhs, const char *op) { - LOG(FATAL) << error << ": (" << lhs << " " << op << " " << rhs << ")"; + static void Error(const char* error, Tlhs lhs, Trhs rhs, const char* op) { + ABSL_LOG(FATAL) << error << ": (" << lhs << " " << op << " " << rhs << ")"; } template - static void Error(const char *error, Tval val, const char *op) { - LOG(FATAL) << error << ": (" << op << val << ")"; + static void Error(const char* error, Tval val, const char* op) { + ABSL_LOG(FATAL) << error << ": (" << op << val << ")"; } }; diff --git a/mediapipe/framework/deps/status.h b/mediapipe/framework/deps/status.h index 492e4d434..8ee38f32d 100644 --- a/mediapipe/framework/deps/status.h +++ b/mediapipe/framework/deps/status.h @@ -21,9 +21,9 @@ #include #include "absl/base/attributes.h" +#include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" -#include "mediapipe/framework/port/logging.h" namespace mediapipe { @@ -44,7 +44,7 @@ inline std::string* MediaPipeCheckOpHelper(absl::Status v, const char* msg) { #define MEDIAPIPE_DO_CHECK_OK(val, level) \ while (auto _result = mediapipe::MediaPipeCheckOpHelper(val, #val)) \ - LOG(level) << *(_result) + ABSL_LOG(level) << *(_result) #define MEDIAPIPE_CHECK_OK(val) MEDIAPIPE_DO_CHECK_OK(val, FATAL) #define MEDIAPIPE_QCHECK_OK(val) MEDIAPIPE_DO_CHECK_OK(val, QFATAL) @@ -53,7 +53,7 @@ inline std::string* MediaPipeCheckOpHelper(absl::Status v, const char* msg) { #define MEDIAPIPE_DCHECK_OK(val) MEDIAPIPE_CHECK_OK(val) #else #define MEDIAPIPE_DCHECK_OK(val) \ - while (false && (absl::OkStatus() == (val))) LOG(FATAL) + while (false && (absl::OkStatus() == (val))) ABSL_LOG(FATAL) #endif #define CHECK_OK MEDIAPIPE_CHECK_OK diff --git a/mediapipe/framework/deps/status_builder.cc b/mediapipe/framework/deps/status_builder.cc index 0202b8689..041b8608b 100644 --- a/mediapipe/framework/deps/status_builder.cc +++ b/mediapipe/framework/deps/status_builder.cc @@ -68,11 +68,11 @@ StatusBuilder&& StatusBuilder::SetNoLogging() && { return std::move(SetNoLogging()); } -StatusBuilder::operator Status() const& { +StatusBuilder::operator absl::Status() const& { return StatusBuilder(*this).JoinMessageToStatus(); } -StatusBuilder::operator Status() && { return JoinMessageToStatus(); } +StatusBuilder::operator absl::Status() && { return JoinMessageToStatus(); } absl::Status StatusBuilder::JoinMessageToStatus() { if (!impl_) { diff --git a/mediapipe/framework/deps/status_builder.h b/mediapipe/framework/deps/status_builder.h index ae11699d2..935ab7776 100644 --- a/mediapipe/framework/deps/status_builder.h +++ b/mediapipe/framework/deps/status_builder.h @@ -83,8 +83,8 @@ class ABSL_MUST_USE_RESULT StatusBuilder { return std::move(*this << msg); } - operator Status() const&; - operator Status() &&; + operator absl::Status() const&; + operator absl::Status() &&; absl::Status JoinMessageToStatus(); diff --git a/mediapipe/framework/deps/strong_int.h b/mediapipe/framework/deps/strong_int.h index 6f102238f..b4bfef770 100644 --- a/mediapipe/framework/deps/strong_int.h +++ b/mediapipe/framework/deps/strong_int.h @@ -103,6 +103,7 @@ #include #include "absl/base/macros.h" +#include "absl/log/absl_log.h" #include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/port.h" @@ -134,7 +135,7 @@ struct NullStrongIntValidator { // // template // static void ValidateInit(U arg) { - // if (arg < 0) LOG(FATAL) << "arg < 0"; + // if (arg < 0) ABSL_LOG(FATAL) << "arg < 0"; // } // // template @@ -403,11 +404,11 @@ std::ostream &operator<<(std::ostream &os, lhs op## = rhs; \ return lhs; \ } -STRONG_INT_VS_STRONG_INT_BINARY_OP(+); -STRONG_INT_VS_STRONG_INT_BINARY_OP(-); -STRONG_INT_VS_STRONG_INT_BINARY_OP(&); -STRONG_INT_VS_STRONG_INT_BINARY_OP(|); -STRONG_INT_VS_STRONG_INT_BINARY_OP(^); +STRONG_INT_VS_STRONG_INT_BINARY_OP(+) +STRONG_INT_VS_STRONG_INT_BINARY_OP(-) +STRONG_INT_VS_STRONG_INT_BINARY_OP(&) +STRONG_INT_VS_STRONG_INT_BINARY_OP(|) +STRONG_INT_VS_STRONG_INT_BINARY_OP(^) #undef STRONG_INT_VS_STRONG_INT_BINARY_OP // Define operators that take one StrongInt and one native integer argument. @@ -431,12 +432,12 @@ STRONG_INT_VS_STRONG_INT_BINARY_OP(^); rhs op## = lhs; \ return rhs; \ } -STRONG_INT_VS_NUMERIC_BINARY_OP(*); -NUMERIC_VS_STRONG_INT_BINARY_OP(*); -STRONG_INT_VS_NUMERIC_BINARY_OP(/); -STRONG_INT_VS_NUMERIC_BINARY_OP(%); -STRONG_INT_VS_NUMERIC_BINARY_OP(<<); // NOLINT(whitespace/operators) -STRONG_INT_VS_NUMERIC_BINARY_OP(>>); // NOLINT(whitespace/operators) +STRONG_INT_VS_NUMERIC_BINARY_OP(*) +NUMERIC_VS_STRONG_INT_BINARY_OP(*) +STRONG_INT_VS_NUMERIC_BINARY_OP(/) +STRONG_INT_VS_NUMERIC_BINARY_OP(%) +STRONG_INT_VS_NUMERIC_BINARY_OP(<<) // NOLINT(whitespace/operators) +STRONG_INT_VS_NUMERIC_BINARY_OP(>>) // NOLINT(whitespace/operators) #undef STRONG_INT_VS_NUMERIC_BINARY_OP #undef NUMERIC_VS_STRONG_INT_BINARY_OP @@ -447,12 +448,12 @@ STRONG_INT_VS_NUMERIC_BINARY_OP(>>); // NOLINT(whitespace/operators) StrongInt rhs) { \ return lhs.value() op rhs.value(); \ } -STRONG_INT_COMPARISON_OP(==); // NOLINT(whitespace/operators) -STRONG_INT_COMPARISON_OP(!=); // NOLINT(whitespace/operators) -STRONG_INT_COMPARISON_OP(<); // NOLINT(whitespace/operators) -STRONG_INT_COMPARISON_OP(<=); // NOLINT(whitespace/operators) -STRONG_INT_COMPARISON_OP(>); // NOLINT(whitespace/operators) -STRONG_INT_COMPARISON_OP(>=); // NOLINT(whitespace/operators) +STRONG_INT_COMPARISON_OP(==) // NOLINT(whitespace/operators) +STRONG_INT_COMPARISON_OP(!=) // NOLINT(whitespace/operators) +STRONG_INT_COMPARISON_OP(<) // NOLINT(whitespace/operators) +STRONG_INT_COMPARISON_OP(<=) // NOLINT(whitespace/operators) +STRONG_INT_COMPARISON_OP(>) // NOLINT(whitespace/operators) +STRONG_INT_COMPARISON_OP(>=) // NOLINT(whitespace/operators) #undef STRONG_INT_COMPARISON_OP } // namespace intops diff --git a/mediapipe/framework/deps/threadpool_pthread_impl.cc b/mediapipe/framework/deps/threadpool_pthread_impl.cc index d9c32d35e..5033b7522 100644 --- a/mediapipe/framework/deps/threadpool_pthread_impl.cc +++ b/mediapipe/framework/deps/threadpool_pthread_impl.cc @@ -18,6 +18,8 @@ #include #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "mediapipe/framework/deps/threadpool.h" @@ -48,7 +50,7 @@ ThreadPool::WorkerThread::WorkerThread(ThreadPool* pool, const std::string& name_prefix) : pool_(pool), name_prefix_(name_prefix) { int res = pthread_create(&thread_, nullptr, ThreadBody, this); - CHECK_EQ(res, 0) << "pthread_create failed"; + ABSL_CHECK_EQ(res, 0) << "pthread_create failed"; } ThreadPool::WorkerThread::~WorkerThread() {} @@ -67,9 +69,9 @@ void* ThreadPool::WorkerThread::ThreadBody(void* arg) { if (nice(nice_priority_level) != -1 || errno == 0) { VLOG(1) << "Changed the nice priority level by " << nice_priority_level; } else { - LOG(ERROR) << "Error : " << strerror(errno) << std::endl - << "Could not change the nice priority level by " - << nice_priority_level; + ABSL_LOG(ERROR) << "Error : " << strerror(errno) << std::endl + << "Could not change the nice priority level by " + << nice_priority_level; } } if (!selected_cpus.empty()) { @@ -84,27 +86,27 @@ void* ThreadPool::WorkerThread::ThreadBody(void* arg) { VLOG(1) << "Pinned the thread pool executor to processor " << absl::StrJoin(selected_cpus, ", processor ") << "."; } else { - LOG(ERROR) << "Error : " << strerror(errno) << std::endl - << "Failed to set processor affinity. Ignore processor " - "affinity setting for now."; + ABSL_LOG(ERROR) << "Error : " << strerror(errno) << std::endl + << "Failed to set processor affinity. Ignore processor " + "affinity setting for now."; } } int error = pthread_setname_np(pthread_self(), name.c_str()); if (error != 0) { - LOG(ERROR) << "Error : " << strerror(error) << std::endl - << "Failed to set name for thread: " << name; + ABSL_LOG(ERROR) << "Error : " << strerror(error) << std::endl + << "Failed to set name for thread: " << name; } #else const std::string name = internal::CreateThreadName(thread->name_prefix_, 0); if (nice_priority_level != 0 || !selected_cpus.empty()) { - LOG(ERROR) << "Thread priority and processor affinity feature aren't " - "supported on the current platform."; + ABSL_LOG(ERROR) << "Thread priority and processor affinity feature aren't " + "supported on the current platform."; } #if __APPLE__ int error = pthread_setname_np(name.c_str()); if (error != 0) { - LOG(ERROR) << "Error : " << strerror(error) << std::endl - << "Failed to set name for thread: " << name; + ABSL_LOG(ERROR) << "Error : " << strerror(error) << std::endl + << "Failed to set name for thread: " << name; } #endif // __APPLE__ #endif // __linux__ diff --git a/mediapipe/framework/deps/threadpool_std_thread_impl.cc b/mediapipe/framework/deps/threadpool_std_thread_impl.cc index 4a902495d..a5f86eeb6 100644 --- a/mediapipe/framework/deps/threadpool_std_thread_impl.cc +++ b/mediapipe/framework/deps/threadpool_std_thread_impl.cc @@ -17,18 +17,10 @@ #include // NOLINT(build/c++11) -#include "mediapipe/framework/deps/threadpool.h" - -#ifdef _WIN32 -#include -#else -#include -#include -#endif - +#include "absl/log/absl_log.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" -#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/deps/threadpool.h" namespace mediapipe { @@ -67,8 +59,9 @@ void* ThreadPool::WorkerThread::ThreadBody(void* arg) { thread->pool_->thread_options().nice_priority_level(); const std::set selected_cpus = thread->pool_->thread_options().cpu_set(); if (nice_priority_level != 0 || !selected_cpus.empty()) { - LOG(ERROR) << "Thread priority and processor affinity feature aren't " - "supported by the std::thread threadpool implementation."; + ABSL_LOG(ERROR) + << "Thread priority and processor affinity feature aren't " + "supported by the std::thread threadpool implementation."; } thread->pool_->RunWorker(); return nullptr; diff --git a/mediapipe/framework/deps/topologicalsorter.cc b/mediapipe/framework/deps/topologicalsorter.cc index 67fc6adc4..ba906ea65 100644 --- a/mediapipe/framework/deps/topologicalsorter.cc +++ b/mediapipe/framework/deps/topologicalsorter.cc @@ -16,18 +16,19 @@ #include +#include "absl/log/absl_check.h" #include "mediapipe/framework/port/logging.h" namespace mediapipe { TopologicalSorter::TopologicalSorter(int num_nodes) : num_nodes_(num_nodes) { - CHECK_GE(num_nodes_, 0); + ABSL_CHECK_GE(num_nodes_, 0); adjacency_lists_.resize(num_nodes_); } void TopologicalSorter::AddEdge(int from, int to) { - CHECK(!traversal_started_ && from < num_nodes_ && to < num_nodes_ && - from >= 0 && to >= 0); + ABSL_CHECK(!traversal_started_ && from < num_nodes_ && to < num_nodes_ && + from >= 0 && to >= 0); adjacency_lists_[from].push_back(to); } diff --git a/mediapipe/framework/deps/topologicalsorter.h b/mediapipe/framework/deps/topologicalsorter.h index d5027477c..2270f2945 100644 --- a/mediapipe/framework/deps/topologicalsorter.h +++ b/mediapipe/framework/deps/topologicalsorter.h @@ -40,7 +40,7 @@ namespace mediapipe { // if (cyclic) { // PrintCycleNodes(cycle_nodes); // } else { -// LOG(INFO) << idx; +// ABSL_LOG(INFO) << idx; // } // } class TopologicalSorter { diff --git a/mediapipe/framework/deps/vector.h b/mediapipe/framework/deps/vector.h index 2d4de82f3..5d1400ef5 100644 --- a/mediapipe/framework/deps/vector.h +++ b/mediapipe/framework/deps/vector.h @@ -24,9 +24,9 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/utility/utility.h" #include "mediapipe/framework/port/integral_types.h" -#include "mediapipe/framework/port/logging.h" template class Vector2; @@ -78,13 +78,13 @@ class BasicVector { void Clear() { AsD() = D(); } T& operator[](int b) { - DCHECK_GE(b, 0); - DCHECK_LT(b, SIZE); + ABSL_DCHECK_GE(b, 0); + ABSL_DCHECK_LT(b, SIZE); return static_cast(*this).Data()[b]; } T operator[](int b) const { - DCHECK_GE(b, 0); - DCHECK_LT(b, SIZE); + ABSL_DCHECK_GE(b, 0); + ABSL_DCHECK_LT(b, SIZE); return static_cast(*this).Data()[b]; } diff --git a/mediapipe/framework/encode_binary_proto.bzl b/mediapipe/framework/encode_binary_proto.bzl index e849d971f..bf7f0583d 100644 --- a/mediapipe/framework/encode_binary_proto.bzl +++ b/mediapipe/framework/encode_binary_proto.bzl @@ -37,29 +37,33 @@ Args: output: The desired name of the output file. Optional. """ +load("@bazel_skylib//lib:paths.bzl", "paths") + PROTOC = "@com_google_protobuf//:protoc" -def _canonicalize_proto_path_oss(all_protos, genfile_path): - """For the protos from external repository, canonicalize the proto path and the file name. +def _canonicalize_proto_path_oss(f): + if not f.root.path: + return struct( + proto_path = ".", + file_name = f.short_path, + ) - Returns: - Proto path list and proto source file list. - """ - proto_paths = [] - proto_file_names = [] - for s in all_protos.to_list(): - if s.path.startswith(genfile_path): - repo_name, _, file_name = s.path[len(genfile_path + "/external/"):].partition("/") + # `f.path` looks like "/external//(_virtual_imports//)?" + repo_name, _, file_name = f.path[len(paths.join(f.root.path, "external") + "/"):].partition("/") + if file_name.startswith("_virtual_imports/"): + # This is a virtual import; move "_virtual_imports/" from `repo_name` to `file_name`. + repo_name = paths.join(repo_name, *file_name.split("/", 2)[:2]) + file_name = file_name.split("/", 2)[-1] + return struct( + proto_path = paths.join(f.root.path, "external", repo_name), + file_name = file_name, + ) - # handle virtual imports - if file_name.startswith("_virtual_imports"): - repo_name = repo_name + "/" + "/".join(file_name.split("/", 2)[:2]) - file_name = file_name.split("/", 2)[-1] - proto_paths.append(genfile_path + "/external/" + repo_name) - proto_file_names.append(file_name) - else: - proto_file_names.append(s.path) - return ([" --proto_path=" + path for path in proto_paths], proto_file_names) +def _map_root_path(f): + return _canonicalize_proto_path_oss(f).proto_path + +def _map_short_path(f): + return _canonicalize_proto_path_oss(f).file_name def _get_proto_provider(dep): """Get the provider for protocol buffers from a dependnecy. @@ -90,25 +94,37 @@ def _encode_binary_proto_impl(ctx): sibling = textpb, ) - path_list, file_list = _canonicalize_proto_path_oss(all_protos, ctx.genfiles_dir.path) + args = ctx.actions.args() + args.add(textpb) + args.add(binarypb) + args.add(ctx.executable._proto_compiler) + args.add(ctx.attr.message_type, format = "--encode=%s") + args.add("--proto_path=.") + args.add_all( + all_protos, + map_each = _map_root_path, + format_each = "--proto_path=%s", + uniquify = True, + ) + args.add_all( + all_protos, + map_each = _map_short_path, + uniquify = True, + ) # Note: the combination of absolute_paths and proto_path, as well as the exact # order of gendir before ., is needed for the proto compiler to resolve # import statements that reference proto files produced by a genrule. ctx.actions.run_shell( - tools = all_protos.to_list() + [textpb, ctx.executable._proto_compiler], - outputs = [binarypb], - command = " ".join( - [ - ctx.executable._proto_compiler.path, - "--encode=" + ctx.attr.message_type, - "--proto_path=" + ctx.genfiles_dir.path, - "--proto_path=" + ctx.bin_dir.path, - "--proto_path=.", - ] + path_list + file_list + - ["<", textpb.path, ">", binarypb.path], + tools = depset( + direct = [textpb, ctx.executable._proto_compiler], + transitive = [all_protos], ), + outputs = [binarypb], + command = "${@:3} < $1 > $2", + arguments = [args], mnemonic = "EncodeProto", + toolchain = None, ) output_depset = depset([binarypb]) diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index 496094c5f..9a570d524 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -104,7 +104,7 @@ cc_library( srcs = ["deleting_file.cc"], hdrs = ["deleting_file.h"], deps = [ - "//mediapipe/framework/port:logging", + "@com_google_absl//absl/log:absl_log", ], ) @@ -119,6 +119,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", "@eigen_archive//:eigen3", ], ) @@ -155,11 +156,12 @@ cc_library( "//mediapipe/framework/port:aligned_malloc_and_free", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:source_location", "//mediapipe/framework/tool:type_util", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ] + select({ @@ -206,7 +208,6 @@ cc_library( "//mediapipe/framework/formats/annotation:locus_cc_proto", "//mediapipe/framework/formats/annotation:rasterization_cc_proto", "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:point", "//mediapipe/framework/port:rectangle", "//mediapipe/framework/port:ret_check", @@ -214,6 +215,8 @@ cc_library( "//mediapipe/framework/port:statusor", "//mediapipe/framework/tool:status_util", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", @@ -234,6 +237,7 @@ cc_library( ":location", "//mediapipe/framework/formats/annotation:rasterization_cc_proto", "//mediapipe/framework/port:opencv_imgproc", + "@com_google_absl//absl/log:absl_log", ], alwayslink = 1, ) @@ -339,6 +343,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer_format", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/synchronization", ] + select({ "//conditions:default": [ @@ -363,6 +368,7 @@ cc_library( ":image_frame_pool", "//mediapipe/framework:port", "//mediapipe/framework/port:logging", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", ] + select({ @@ -400,6 +406,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:statusor", + "@com_google_absl//absl/log:absl_check", ], ) @@ -427,6 +434,17 @@ cc_test( ], ) +# Used by vendor processes that don't have access to libandroid.so, but want to use AHardwareBuffer. +config_setting( + name = "android_link_native_window", + define_values = { + "MEDIAPIPE_ANDROID_LINK_NATIVE_WINDOW": "1", + "MEDIAPIPE_NO_JNI": "1", + }, + values = {"crosstool_top": "//external:android/crosstool"}, + visibility = ["//visibility:private"], +) + cc_library( name = "tensor", srcs = @@ -449,7 +467,13 @@ cc_library( "//conditions:default": [], }), defines = select({ + # Excludes AHardwareBuffer features from vendor processes "//mediapipe/framework:android_no_jni": ["MEDIAPIPE_NO_JNI"], + # unless they're linked against nativewindow. + ":android_link_native_window": [ + "MEDIAPIPE_ANDROID_LINK_NATIVE_WINDOW", + "MEDIAPIPE_NO_JNI", + ], "//conditions:default": [], }), linkopts = select({ @@ -462,11 +486,15 @@ cc_library( "//mediapipe:android": [ "-landroid", ], + ":android_link_native_window": [ + "-lnativewindow", # Provides to vendor processes on Android API >= 26. + ], }), deps = [ "//mediapipe/framework:port", - "//mediapipe/framework/port:logging", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", ] + select({ @@ -481,6 +509,7 @@ cc_library( cc_test( name = "tensor_test", srcs = ["tensor_test.cc"], + tags = ["not_run:arm"], deps = [ ":tensor", "//mediapipe/framework/port:gtest_main", @@ -499,7 +528,7 @@ cc_library( hdrs = ["frame_buffer.h"], deps = [ "//mediapipe/framework/port:integral_types", - "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], diff --git a/mediapipe/framework/formats/body_rig.proto b/mediapipe/framework/formats/body_rig.proto index 5420ccc10..88964d995 100644 --- a/mediapipe/framework/formats/body_rig.proto +++ b/mediapipe/framework/formats/body_rig.proto @@ -19,7 +19,7 @@ package mediapipe; // Joint of a 3D human model (e.g. elbow, knee, wrist). Contains 3D rotation of // the joint and its visibility. message Joint { - // Joint rotation in 6D contineous representation ordered as + // Joint rotation in 6D continuous representation ordered as // [a1, b1, a2, b2, a3, b3]. // // Such representation is more sutable for NN model training and can be diff --git a/mediapipe/framework/formats/deleting_file.cc b/mediapipe/framework/formats/deleting_file.cc index 977a78940..b759a5f64 100644 --- a/mediapipe/framework/formats/deleting_file.cc +++ b/mediapipe/framework/formats/deleting_file.cc @@ -17,7 +17,7 @@ #include -#include "mediapipe/framework/port/logging.h" +#include "absl/log/absl_log.h" namespace mediapipe { @@ -27,7 +27,7 @@ DeletingFile::DeletingFile(const std::string& path, bool delete_on_destruction) DeletingFile::~DeletingFile() { if (delete_on_destruction_) { if (remove(path_.c_str()) != 0) { - LOG(ERROR) << "Unable to delete file: " << path_; + ABSL_LOG(ERROR) << "Unable to delete file: " << path_; } } } diff --git a/mediapipe/framework/formats/frame_buffer.cc b/mediapipe/framework/formats/frame_buffer.cc index 930a3651a..743de8121 100644 --- a/mediapipe/framework/formats/frame_buffer.cc +++ b/mediapipe/framework/formats/frame_buffer.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/framework/formats/frame_buffer.h b/mediapipe/framework/formats/frame_buffer.h index 32ba41a2d..71e154572 100644 --- a/mediapipe/framework/formats/frame_buffer.h +++ b/mediapipe/framework/formats/frame_buffer.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ limitations under the License. #include -#include "absl/log/check.h" +#include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "mediapipe/framework/port/integral_types.h" @@ -147,15 +147,15 @@ class FrameBuffer { // Returns plane indexed by the input `index`. const Plane& plane(int index) const { - CHECK_GE(index, 0); - CHECK_LT(static_cast(index), planes_.size()); + ABSL_CHECK_GE(index, 0); + ABSL_CHECK_LT(static_cast(index), planes_.size()); return planes_[index]; } // Returns mutable plane indexed by the input `index`. Plane mutable_plane(int index) { - CHECK_GE(index, 0); - CHECK_LT(static_cast(index), planes_.size()); + ABSL_CHECK_GE(index, 0); + ABSL_CHECK_LT(static_cast(index), planes_.size()); return planes_[index]; } diff --git a/mediapipe/framework/formats/image.cc b/mediapipe/framework/formats/image.cc index 1ef7e3cb9..b37d95aad 100644 --- a/mediapipe/framework/formats/image.cc +++ b/mediapipe/framework/formats/image.cc @@ -14,6 +14,7 @@ #include "mediapipe/framework/formats/image.h" +#include "absl/log/absl_check.h" #include "mediapipe/framework/type_map.h" #if !MEDIAPIPE_DISABLE_GPU diff --git a/mediapipe/framework/formats/image.h b/mediapipe/framework/formats/image.h index ffb6362f3..936a3554e 100644 --- a/mediapipe/framework/formats/image.h +++ b/mediapipe/framework/formats/image.h @@ -113,11 +113,11 @@ class Image { #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #endif // !MEDIAPIPE_DISABLE_GPU - // Get a GPU view. Automatically uploads from CPU if needed. - const mediapipe::GpuBuffer GetGpuBuffer() const { -#if !MEDIAPIPE_DISABLE_GPU - if (use_gpu_ == false) ConvertToGpu(); -#endif // !MEDIAPIPE_DISABLE_GPU + // Provides access to the underlying GpuBuffer storage. + // Automatically uploads from CPU to GPU if needed and requested through the + // `upload_to_gpu` argument. + const mediapipe::GpuBuffer GetGpuBuffer(bool upload_to_gpu = true) const { + if (!use_gpu_ && upload_to_gpu) ConvertToGpu(); return gpu_buffer_; } diff --git a/mediapipe/framework/formats/image_frame.cc b/mediapipe/framework/formats/image_frame.cc index 2de819a35..472da76a9 100644 --- a/mediapipe/framework/formats/image_frame.cc +++ b/mediapipe/framework/formats/image_frame.cc @@ -23,10 +23,11 @@ #include #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/strings/str_cat.h" #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/port/aligned_malloc_and_free.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/proto_ns.h" namespace mediapipe { @@ -98,8 +99,8 @@ void ImageFrame::Reset(ImageFormat::Format format, int width, int height, format_ = format; width_ = width; height_ = height; - CHECK_NE(ImageFormat::UNKNOWN, format_); - CHECK(IsValidAlignmentNumber(alignment_boundary)); + ABSL_CHECK_NE(ImageFormat::UNKNOWN, format_); + ABSL_CHECK(IsValidAlignmentNumber(alignment_boundary)); width_step_ = width * NumberOfChannels() * ByteDepth(); if (alignment_boundary == 1) { pixel_data_ = {new uint8_t[height * width_step_], @@ -124,8 +125,8 @@ void ImageFrame::AdoptPixelData(ImageFormat::Format format, int width, height_ = height; width_step_ = width_step; - CHECK_NE(ImageFormat::UNKNOWN, format_); - CHECK_GE(width_step_, width * NumberOfChannels() * ByteDepth()); + ABSL_CHECK_NE(ImageFormat::UNKNOWN, format_); + ABSL_CHECK_GE(width_step_, width * NumberOfChannels() * ByteDepth()); pixel_data_ = {pixel_data, deleter}; } @@ -136,8 +137,8 @@ std::unique_ptr ImageFrame::Release() { void ImageFrame::InternalCopyFrom(int width, int height, int width_step, int channel_size, const uint8_t* pixel_data) { - CHECK_EQ(width_, width); - CHECK_EQ(height_, height); + ABSL_CHECK_EQ(width_, width); + ABSL_CHECK_EQ(height_, height); // row_bytes = channel_size * num_channels * width const int row_bytes = channel_size * NumberOfChannels() * width; if (width_step == 0) { @@ -187,8 +188,8 @@ void ImageFrame::SetAlignmentPaddingAreas() { if (!pixel_data_) { return; } - CHECK_GE(width_, 1); - CHECK_GE(height_, 1); + ABSL_CHECK_GE(width_, 1); + ABSL_CHECK_GE(height_, 1); const int pixel_size = ByteDepth() * NumberOfChannels(); const int padding_size = width_step_ - width_ * pixel_size; @@ -222,7 +223,7 @@ bool ImageFrame::IsContiguous() const { } bool ImageFrame::IsAligned(uint32_t alignment_boundary) const { - CHECK(IsValidAlignmentNumber(alignment_boundary)); + ABSL_CHECK(IsValidAlignmentNumber(alignment_boundary)); if (!pixel_data_) { return false; } @@ -287,7 +288,7 @@ int ImageFrame::NumberOfChannelsForFormat(ImageFormat::Format format) { case ImageFormat::SBGRA: return 4; default: - LOG(FATAL) << InvalidFormatString(format); + ABSL_LOG(FATAL) << InvalidFormatString(format); } } @@ -318,7 +319,7 @@ int ImageFrame::ChannelSizeForFormat(ImageFormat::Format format) { case ImageFormat::SBGRA: return sizeof(uint8_t); default: - LOG(FATAL) << InvalidFormatString(format); + ABSL_LOG(FATAL) << InvalidFormatString(format); } } @@ -349,7 +350,7 @@ int ImageFrame::ByteDepthForFormat(ImageFormat::Format format) { case ImageFormat::SBGRA: return 1; default: - LOG(FATAL) << InvalidFormatString(format); + ABSL_LOG(FATAL) << InvalidFormatString(format); } } @@ -359,7 +360,7 @@ void ImageFrame::CopyFrom(const ImageFrame& image_frame, Reset(image_frame.Format(), image_frame.Width(), image_frame.Height(), alignment_boundary); - CHECK_EQ(format_, image_frame.Format()); + ABSL_CHECK_EQ(format_, image_frame.Format()); InternalCopyFrom(image_frame.Width(), image_frame.Height(), image_frame.WidthStep(), image_frame.ChannelSize(), image_frame.PixelData()); @@ -382,10 +383,10 @@ void ImageFrame::CopyPixelData(ImageFormat::Format format, int width, } void ImageFrame::CopyToBuffer(uint8_t* buffer, int buffer_size) const { - CHECK(buffer); - CHECK_EQ(1, ByteDepth()); + ABSL_CHECK(buffer); + ABSL_CHECK_EQ(1, ByteDepth()); const int data_size = width_ * height_ * NumberOfChannels(); - CHECK_LE(data_size, buffer_size); + ABSL_CHECK_LE(data_size, buffer_size); if (IsContiguous()) { // The data is stored contiguously, we can just copy. const uint8_t* src = reinterpret_cast(pixel_data_.get()); @@ -397,10 +398,10 @@ void ImageFrame::CopyToBuffer(uint8_t* buffer, int buffer_size) const { } void ImageFrame::CopyToBuffer(uint16_t* buffer, int buffer_size) const { - CHECK(buffer); - CHECK_EQ(2, ByteDepth()); + ABSL_CHECK(buffer); + ABSL_CHECK_EQ(2, ByteDepth()); const int data_size = width_ * height_ * NumberOfChannels(); - CHECK_LE(data_size, buffer_size); + ABSL_CHECK_LE(data_size, buffer_size); if (IsContiguous()) { // The data is stored contiguously, we can just copy. const uint16_t* src = reinterpret_cast(pixel_data_.get()); @@ -412,10 +413,10 @@ void ImageFrame::CopyToBuffer(uint16_t* buffer, int buffer_size) const { } void ImageFrame::CopyToBuffer(float* buffer, int buffer_size) const { - CHECK(buffer); - CHECK_EQ(4, ByteDepth()); + ABSL_CHECK(buffer); + ABSL_CHECK_EQ(4, ByteDepth()); const int data_size = width_ * height_ * NumberOfChannels(); - CHECK_LE(data_size, buffer_size); + ABSL_CHECK_LE(data_size, buffer_size); if (IsContiguous()) { // The data is stored contiguously, we can just copy. const float* src = reinterpret_cast(pixel_data_.get()); diff --git a/mediapipe/framework/formats/image_multi_pool.cc b/mediapipe/framework/formats/image_multi_pool.cc index 655064d36..a38e30a67 100644 --- a/mediapipe/framework/formats/image_multi_pool.cc +++ b/mediapipe/framework/formats/image_multi_pool.cc @@ -16,6 +16,7 @@ #include +#include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/port/logging.h" @@ -43,7 +44,7 @@ ImageMultiPool::SimplePoolGpu ImageMultiPool::MakeSimplePoolGpu( IBufferSpec spec) { OSType cv_format = mediapipe::CVPixelFormatForGpuBufferFormat( GpuBufferFormatForImageFormat(spec.format)); - CHECK_NE(cv_format, -1) << "unsupported pixel format"; + ABSL_CHECK_NE(cv_format, -1) << "unsupported pixel format"; return MakeCFHolderAdopting(mediapipe::CreateCVPixelBufferPool( spec.width, spec.height, cv_format, kKeepCount, 0.1 /* max age in seconds */)); @@ -61,11 +62,11 @@ Image ImageMultiPool::GetBufferFromSimplePool( // pool to give us contiguous data. OSType cv_format = mediapipe::CVPixelFormatForGpuBufferFormat( mediapipe::GpuBufferFormatForImageFormat(spec.format)); - CHECK_NE(cv_format, -1) << "unsupported pixel format"; + ABSL_CHECK_NE(cv_format, -1) << "unsupported pixel format"; CVPixelBufferRef buffer; CVReturn err = mediapipe::CreateCVPixelBufferWithoutPool( spec.width, spec.height, cv_format, &buffer); - CHECK(!err) << "Error creating pixel buffer: " << err; + ABSL_CHECK(!err) << "Error creating pixel buffer: " << err; return Image(MakeCFHolderAdopting(buffer)); #else CVPixelBufferRef buffer; @@ -87,7 +88,7 @@ Image ImageMultiPool::GetBufferFromSimplePool( } }, &buffer); - CHECK(!err) << "Error creating pixel buffer: " << err; + ABSL_CHECK(!err) << "Error creating pixel buffer: " << err; return Image(MakeCFHolderAdopting(buffer)); #endif // TARGET_IPHONE_SIMULATOR } @@ -188,7 +189,7 @@ Image ImageMultiPool::GetBuffer(int width, int height, bool use_gpu, ImageMultiPool::~ImageMultiPool() { #if !MEDIAPIPE_DISABLE_GPU #ifdef __APPLE__ - CHECK_EQ(texture_caches_.size(), 0) + ABSL_CHECK_EQ(texture_caches_.size(), 0) << "Failed to unregister texture caches before deleting pool"; #endif // defined(__APPLE__) #endif // !MEDIAPIPE_DISABLE_GPU @@ -199,8 +200,8 @@ ImageMultiPool::~ImageMultiPool() { void ImageMultiPool::RegisterTextureCache(mediapipe::CVTextureCacheType cache) { absl::MutexLock lock(&mutex_gpu_); - CHECK(std::find(texture_caches_.begin(), texture_caches_.end(), cache) == - texture_caches_.end()) + ABSL_CHECK(std::find(texture_caches_.begin(), texture_caches_.end(), cache) == + texture_caches_.end()) << "Attempting to register a texture cache twice"; texture_caches_.emplace_back(cache); } @@ -210,7 +211,7 @@ void ImageMultiPool::UnregisterTextureCache( absl::MutexLock lock(&mutex_gpu_); auto it = std::find(texture_caches_.begin(), texture_caches_.end(), cache); - CHECK(it != texture_caches_.end()) + ABSL_CHECK(it != texture_caches_.end()) << "Attempting to unregister an unknown texture cache"; texture_caches_.erase(it); } diff --git a/mediapipe/framework/formats/image_opencv.cc b/mediapipe/framework/formats/image_opencv.cc index 498c7831f..387afb5e8 100644 --- a/mediapipe/framework/formats/image_opencv.cc +++ b/mediapipe/framework/formats/image_opencv.cc @@ -14,6 +14,7 @@ #include "mediapipe/framework/formats/image_opencv.h" +#include "absl/log/absl_check.h" #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/port/logging.h" @@ -100,7 +101,7 @@ std::shared_ptr MatView(const mediapipe::Image* image) { auto owner = std::make_shared(const_cast(image)); uint8_t* data_ptr = owner->lock.Pixels(); - CHECK(data_ptr != nullptr); + ABSL_CHECK(data_ptr != nullptr); // Use Image to initialize in-place. Image still owns memory. if (steps[0] == sizes[1] * image->channels() * ImageFrame::ByteDepthForFormat(image->image_format())) { diff --git a/mediapipe/framework/formats/location.cc b/mediapipe/framework/formats/location.cc index 205edf191..b9dd97e74 100644 --- a/mediapipe/framework/formats/location.cc +++ b/mediapipe/framework/formats/location.cc @@ -18,13 +18,14 @@ #include #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/strings/substitute.h" #include "mediapipe/framework/formats/annotation/locus.pb.h" #include "mediapipe/framework/formats/annotation/rasterization.pb.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/integral_types.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/point2.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" @@ -39,7 +40,7 @@ namespace { // the location_data, the tightest bounding box, that contains all pixels // encoded in the rasterizations. Rectangle_i MaskToRectangle(const LocationData& location_data) { - CHECK(location_data.mask().has_rasterization()); + ABSL_CHECK(location_data.mask().has_rasterization()); const auto& rasterization = location_data.mask().rasterization(); if (rasterization.interval_size() == 0) { return Rectangle_i(0, 0, 0, 0); @@ -63,7 +64,7 @@ Location::Location() {} Location::Location(const LocationData& location_data) : location_data_(location_data) { - CHECK(IsValidLocationData(location_data_)); + ABSL_CHECK(IsValidLocationData(location_data_)); } Location Location::CreateGlobalLocation() { @@ -152,15 +153,15 @@ bool Location::IsValidLocationData(const LocationData& location_data) { template <> Rectangle_i Location::GetBBox() const { - CHECK_EQ(LocationData::BOUNDING_BOX, location_data_.format()); + ABSL_CHECK_EQ(LocationData::BOUNDING_BOX, location_data_.format()); const auto& box = location_data_.bounding_box(); return Rectangle_i(box.xmin(), box.ymin(), box.width(), box.height()); } Location& Location::Scale(const float scale) { - CHECK(!location_data_.has_mask()) + ABSL_CHECK(!location_data_.has_mask()) << "Location mask scaling is not implemented."; - CHECK_GT(scale, 0.0f); + ABSL_CHECK_GT(scale, 0.0f); switch (location_data_.format()) { case LocationData::GLOBAL: { // Do nothing. @@ -187,7 +188,8 @@ Location& Location::Scale(const float scale) { break; } case LocationData::MASK: { - LOG(FATAL) << "Scaling for location data of type MASK is not supported."; + ABSL_LOG(FATAL) + << "Scaling for location data of type MASK is not supported."; break; } } @@ -232,7 +234,8 @@ Location& Location::Square(int image_width, int image_height) { break; } case LocationData::MASK: { - LOG(FATAL) << "Squaring for location data of type MASK is not supported."; + ABSL_LOG(FATAL) + << "Squaring for location data of type MASK is not supported."; break; } } @@ -247,7 +250,7 @@ namespace { // This function is inteded to shift boundaries of intervals such that they // best fit within an image. float BestShift(float min_value, float max_value, float range) { - CHECK_LE(min_value, max_value); + ABSL_CHECK_LE(min_value, max_value); const float value_range = max_value - min_value; if (value_range > range) { return 0.5f * (range - min_value - max_value); @@ -294,8 +297,8 @@ Location& Location::ShiftToFitBestIntoImage(int image_width, int image_height) { const float y_shift = BestShift(mask_bounding_box.xmin(), mask_bounding_box.xmax(), image_height); auto* mask = location_data_.mutable_mask(); - CHECK_EQ(image_width, mask->width()); - CHECK_EQ(image_height, mask->height()); + ABSL_CHECK_EQ(image_width, mask->width()); + ABSL_CHECK_EQ(image_height, mask->height()); for (auto& interval : *mask->mutable_rasterization()->mutable_interval()) { interval.set_y(interval.y() + y_shift); @@ -327,7 +330,7 @@ Location& Location::Crop(const Rectangle_i& crop_box) { break; } case LocationData::RELATIVE_BOUNDING_BOX: - LOG(FATAL) + ABSL_LOG(FATAL) << "Can't crop a relative bounding box using absolute coordinates. " "Use the 'Rectangle_f version of Crop() instead"; case LocationData::MASK: { @@ -361,7 +364,7 @@ Location& Location::Crop(const Rectangle_f& crop_box) { // Do nothing. break; case LocationData::BOUNDING_BOX: - LOG(FATAL) + ABSL_LOG(FATAL) << "Can't crop an absolute bounding box using relative coordinates. " "Use the 'Rectangle_i version of Crop() instead"; case LocationData::RELATIVE_BOUNDING_BOX: { @@ -377,8 +380,9 @@ Location& Location::Crop(const Rectangle_f& crop_box) { break; } case LocationData::MASK: - LOG(FATAL) << "Can't crop a mask using relative coordinates. Use the " - "'Rectangle_i' version of Crop() instead"; + ABSL_LOG(FATAL) + << "Can't crop a mask using relative coordinates. Use the " + "'Rectangle_i' version of Crop() instead"; } return *this; } @@ -418,7 +422,7 @@ Rectangle_i Location::ConvertToBBox(int image_width, } Rectangle_f Location::GetRelativeBBox() const { - CHECK_EQ(LocationData::RELATIVE_BOUNDING_BOX, location_data_.format()); + ABSL_CHECK_EQ(LocationData::RELATIVE_BOUNDING_BOX, location_data_.format()); const auto& box = location_data_.relative_bounding_box(); return Rectangle_f(box.xmin(), box.ymin(), box.width(), box.height()); } @@ -457,7 +461,7 @@ Rectangle_f Location::ConvertToRelativeBBox(int image_width, template <> ::mediapipe::BoundingBox Location::GetBBox<::mediapipe::BoundingBox>() const { - CHECK_EQ(LocationData::BOUNDING_BOX, location_data_.format()); + ABSL_CHECK_EQ(LocationData::BOUNDING_BOX, location_data_.format()); const auto& box = location_data_.bounding_box(); ::mediapipe::BoundingBox bounding_box; bounding_box.set_left_x(box.xmin()); @@ -480,7 +484,7 @@ template <> } std::vector Location::GetRelativeKeypoints() const { - CHECK_EQ(LocationData::RELATIVE_BOUNDING_BOX, location_data_.format()); + ABSL_CHECK_EQ(LocationData::RELATIVE_BOUNDING_BOX, location_data_.format()); std::vector keypoints; for (const auto& keypoint : location_data_.relative_keypoints()) { keypoints.emplace_back(Point2_f(keypoint.x(), keypoint.y())); diff --git a/mediapipe/framework/formats/location_opencv.cc b/mediapipe/framework/formats/location_opencv.cc index 6e15b299a..4b69cc6dc 100644 --- a/mediapipe/framework/formats/location_opencv.cc +++ b/mediapipe/framework/formats/location_opencv.cc @@ -14,11 +14,12 @@ #include "mediapipe/framework/formats/location_opencv.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/strings/substitute.h" #include "mediapipe/framework/formats/annotation/rasterization.pb.h" #include "mediapipe/framework/formats/location.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/statusor.h" @@ -26,7 +27,7 @@ namespace mediapipe { namespace { Rectangle_i MaskToRectangle(const LocationData& location_data) { - CHECK(location_data.mask().has_rasterization()); + ABSL_CHECK(location_data.mask().has_rasterization()); const auto& rasterization = location_data.mask().rasterization(); if (rasterization.interval_size() == 0) { return Rectangle_i(0, 0, 0, 0); @@ -85,7 +86,7 @@ Location CreateBBoxLocation(const cv::Rect& rect) { std::unique_ptr GetCvMask(const Location& location) { const auto location_data = location.ConvertToProto(); - CHECK_EQ(LocationData::MASK, location_data.format()); + ABSL_CHECK_EQ(LocationData::MASK, location_data.format()); const auto& mask = location_data.mask(); std::unique_ptr mat( new cv::Mat(mask.height(), mask.width(), CV_8UC1, cv::Scalar(0))); @@ -108,7 +109,7 @@ std::unique_ptr ConvertToCvMask(const Location& location, image_width, image_height, location.ConvertToBBox(image_width, image_height)); if (!status_or_mat.ok()) { - LOG(ERROR) << status_or_mat.status().message(); + ABSL_LOG(ERROR) << status_or_mat.status().message(); return nullptr; } return std::move(status_or_mat).value(); @@ -120,15 +121,15 @@ std::unique_ptr ConvertToCvMask(const Location& location, // This should never happen; a new LocationData::Format enum was introduced // without updating this function's switch(...) to support it. #if !defined(MEDIAPIPE_MOBILE) && !defined(MEDIAPIPE_LITE) - LOG(ERROR) << "Location's LocationData has format not supported by " - "Location::ConvertToMask: " - << location_data.DebugString(); + ABSL_LOG(ERROR) << "Location's LocationData has format not supported by " + "Location::ConvertToMask: " + << location_data.DebugString(); #endif return nullptr; } void EnlargeLocation(Location& location, const float factor) { - CHECK_GT(factor, 0.0f); + ABSL_CHECK_GT(factor, 0.0f); if (factor == 1.0f) return; auto location_data = location.ConvertToProto(); switch (location_data.format()) { @@ -183,7 +184,7 @@ void EnlargeLocation(Location& location, const float factor) { template Location CreateCvMaskLocation(const cv::Mat_& mask) { - CHECK_EQ(1, mask.channels()) + ABSL_CHECK_EQ(1, mask.channels()) << "The specified cv::Mat mask should be single-channel."; LocationData location_data; diff --git a/mediapipe/framework/formats/matrix.cc b/mediapipe/framework/formats/matrix.cc index 42f2df5f8..34ffc6e74 100644 --- a/mediapipe/framework/formats/matrix.cc +++ b/mediapipe/framework/formats/matrix.cc @@ -15,6 +15,7 @@ #include +#include "absl/log/absl_check.h" #include "mediapipe/framework/port/core_proto_inc.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/proto_ns.h" @@ -33,8 +34,8 @@ void MatrixDataProtoFromMatrix(const Matrix& matrix, MatrixData* matrix_data) { } void MatrixFromMatrixDataProto(const MatrixData& matrix_data, Matrix* matrix) { - CHECK_EQ(matrix_data.rows() * matrix_data.cols(), - matrix_data.packed_data_size()); + ABSL_CHECK_EQ(matrix_data.rows() * matrix_data.cols(), + matrix_data.packed_data_size()); if (matrix_data.layout() == MatrixData::ROW_MAJOR) { matrix->resize(matrix_data.cols(), matrix_data.rows()); } else { @@ -56,9 +57,9 @@ std::string MatrixAsTextProto(const Matrix& matrix) { } void MatrixFromTextProto(const std::string& text_proto, Matrix* matrix) { - CHECK(matrix); + ABSL_CHECK(matrix); MatrixData matrix_data; - CHECK(proto_ns::TextFormat::ParseFromString(text_proto, &matrix_data)); + ABSL_CHECK(proto_ns::TextFormat::ParseFromString(text_proto, &matrix_data)); MatrixFromMatrixDataProto(matrix_data, matrix); } #endif // !defined(MEDIAPIPE_MOBILE) && !defined(MEDIAPIPE_LITE) diff --git a/mediapipe/framework/formats/motion/BUILD b/mediapipe/framework/formats/motion/BUILD index 919b82406..8f40202cf 100644 --- a/mediapipe/framework/formats/motion/BUILD +++ b/mediapipe/framework/formats/motion/BUILD @@ -39,11 +39,12 @@ cc_library( "//mediapipe/framework/formats:location_opencv", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:point", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/core:framework", ], @@ -61,8 +62,9 @@ cc_test( "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:logging", "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@org_tensorflow//tensorflow/core:framework", ], ) diff --git a/mediapipe/framework/formats/motion/optical_flow_field.cc b/mediapipe/framework/formats/motion/optical_flow_field.cc index a96504192..fd9b8e300 100644 --- a/mediapipe/framework/formats/motion/optical_flow_field.cc +++ b/mediapipe/framework/formats/motion/optical_flow_field.cc @@ -18,6 +18,8 @@ #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/deps/mathutil.h" @@ -25,7 +27,6 @@ #include "mediapipe/framework/formats/location_opencv.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/integral_types.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/point2.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/type_map.h" @@ -40,8 +41,8 @@ const float kFloFileHeaderOnRead = 202021.25; void CartesianToPolarCoordinates(const cv::Mat& cartesian, cv::Mat* magnitudes, cv::Mat* angles) { - CHECK(magnitudes != nullptr); - CHECK(angles != nullptr); + ABSL_CHECK(magnitudes != nullptr); + ABSL_CHECK(angles != nullptr); cv::Mat cartesian_components[2]; cv::split(cartesian, cartesian_components); cv::cartToPolar(cartesian_components[0], cartesian_components[1], *magnitudes, @@ -105,7 +106,7 @@ cv::Mat OpticalFlowField::GetVisualizationInternal( std::max(std::numeric_limits::epsilon(), MaxAbsoluteValueIgnoringHuge(magnitudes, kHugeToIgnore)); } - CHECK_LT(0, max_magnitude); + ABSL_CHECK_LT(0, max_magnitude); cv::Mat hsv = MakeVisualizationHsv(angles, magnitudes, max_magnitude); cv::Mat viz; cv::cvtColor(hsv, viz, 71 /*cv::COLOR_HSV2RGB_FULL*/); @@ -119,7 +120,7 @@ cv::Mat OpticalFlowField::GetVisualization() const { cv::Mat OpticalFlowField::GetVisualizationSaturatedAt( float max_magnitude) const { - CHECK_LT(0, max_magnitude) + ABSL_CHECK_LT(0, max_magnitude) << "Specified saturation magnitude must be positive."; return GetVisualizationInternal(max_magnitude, true); } @@ -147,9 +148,9 @@ void OpticalFlowField::Resize(int new_width, int new_height) { } void OpticalFlowField::CopyFromTensor(const tensorflow::Tensor& tensor) { - CHECK_EQ(tensorflow::DT_FLOAT, tensor.dtype()); - CHECK_EQ(3, tensor.dims()) << "Tensor must be height x width x 2."; - CHECK_EQ(2, tensor.dim_size(2)) << "Tensor must be height x width x 2."; + ABSL_CHECK_EQ(tensorflow::DT_FLOAT, tensor.dtype()); + ABSL_CHECK_EQ(3, tensor.dims()) << "Tensor must be height x width x 2."; + ABSL_CHECK_EQ(2, tensor.dim_size(2)) << "Tensor must be height x width x 2."; const int height = tensor.dim_size(0); const int width = tensor.dim_size(1); Allocate(width, height); @@ -163,8 +164,8 @@ void OpticalFlowField::CopyFromTensor(const tensorflow::Tensor& tensor) { } void OpticalFlowField::SetFromProto(const OpticalFlowFieldData& proto) { - CHECK_EQ(proto.width() * proto.height(), proto.dx_size()); - CHECK_EQ(proto.width() * proto.height(), proto.dy_size()); + ABSL_CHECK_EQ(proto.width() * proto.height(), proto.dx_size()); + ABSL_CHECK_EQ(proto.width() * proto.height(), proto.dy_size()); flow_data_.create(proto.height(), proto.width()); int i = 0; for (int r = 0; r < flow_data_.rows; ++r) { @@ -191,8 +192,8 @@ void OpticalFlowField::ConvertToProto(OpticalFlowFieldData* proto) const { bool OpticalFlowField::FollowFlow(float x, float y, float* new_x, float* new_y) const { - CHECK(new_x); - CHECK(new_y); + ABSL_CHECK(new_x); + ABSL_CHECK(new_y); if (x < 0 || x > flow_data_.cols - 1 || // horizontal bounds y < 0 || y > flow_data_.rows - 1) { // vertical bounds return false; @@ -205,10 +206,10 @@ bool OpticalFlowField::FollowFlow(float x, float y, float* new_x, cv::Point2f OpticalFlowField::InterpolatedFlowAt(float x, float y) const { // Sanity bounds checks. - CHECK_GE(x, 0); - CHECK_GE(y, 0); - CHECK_LE(x, flow_data_.cols - 1); - CHECK_LE(y, flow_data_.rows - 1); + ABSL_CHECK_GE(x, 0); + ABSL_CHECK_GE(y, 0); + ABSL_CHECK_LE(x, flow_data_.cols - 1); + ABSL_CHECK_LE(y, flow_data_.rows - 1); const int x0 = static_cast(std::floor(x)); const int y0 = static_cast(std::floor(y)); @@ -253,7 +254,7 @@ bool OpticalFlowField::AllWithinMargin(const OpticalFlowField& other, const cv::Point2f& other_motion = other.flow_data().at(r, c); if (!MathUtil::WithinMargin(this_motion.x, other_motion.x, margin) || !MathUtil::WithinMargin(this_motion.y, other_motion.y, margin)) { - LOG(INFO) << "First failure at" << r << " " << c; + ABSL_LOG(INFO) << "First failure at" << r << " " << c; return false; } } @@ -265,9 +266,9 @@ void OpticalFlowField::EstimateMotionConsistencyOcclusions( const OpticalFlowField& forward, const OpticalFlowField& backward, double spatial_distance_threshold, Location* occluded_mask, Location* disoccluded_mask) { - CHECK_EQ(forward.width(), backward.width()) + ABSL_CHECK_EQ(forward.width(), backward.width()) << "Flow fields have different widths."; - CHECK_EQ(forward.height(), backward.height()) + ABSL_CHECK_EQ(forward.height(), backward.height()) << "Flow fields have different heights."; if (occluded_mask != nullptr) { *occluded_mask = FindMotionInconsistentPixels(forward, backward, diff --git a/mediapipe/framework/formats/motion/optical_flow_field_test.cc b/mediapipe/framework/formats/motion/optical_flow_field_test.cc index fdce418fa..2647c2613 100644 --- a/mediapipe/framework/formats/motion/optical_flow_field_test.cc +++ b/mediapipe/framework/formats/motion/optical_flow_field_test.cc @@ -19,12 +19,13 @@ #include #include "absl/flags/flag.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/location_opencv.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/integral_types.h" -#include "mediapipe/framework/port/logging.h" #include "tensorflow/core/framework/tensor.h" namespace mediapipe { diff --git a/mediapipe/framework/formats/tensor.cc b/mediapipe/framework/formats/tensor.cc index 9d75bcbaf..2f2bfaae4 100644 --- a/mediapipe/framework/formats/tensor.cc +++ b/mediapipe/framework/formats/tensor.cc @@ -17,9 +17,10 @@ #include #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/port.h" -#include "mediapipe/framework/port/logging.h" #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #include "mediapipe/gpu/gl_base.h" #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 @@ -81,7 +82,7 @@ void* AllocateVirtualMemory(size_t size) { vm_address_t data; auto error = vm_allocate(mach_task_self(), &data, AlignToPageSize(size), VM_FLAGS_ANYWHERE); - LOG_IF(FATAL, error != KERN_SUCCESS) + ABSL_LOG_IF(FATAL, error != KERN_SUCCESS) << "Can't allocate virtual memory for Tensor."; return reinterpret_cast(data); } @@ -113,12 +114,12 @@ void MtlBufferView::AllocateMtlBuffer(const Tensor& tensor, MtlBufferView MtlBufferView::GetReadView(const Tensor& tensor, id command_buffer) { - LOG_IF(FATAL, tensor.valid_ == Tensor::kValidNone) + ABSL_LOG_IF(FATAL, tensor.valid_ == Tensor::kValidNone) << "Tensor must be written prior to read from."; - LOG_IF(FATAL, - !(tensor.valid_ & (Tensor::kValidCpu | Tensor::kValidMetalBuffer))) - << "Tensor conversion between different GPU resources is not supported " - "yet."; + ABSL_LOG_IF( + FATAL, !(tensor.valid_ & (Tensor::kValidCpu | Tensor::kValidMetalBuffer))) + << "Tensor conversion between different GPU backing formats is not " + "supported yet."; auto lock(absl::make_unique(&tensor.view_mutex_)); tensor.valid_ |= Tensor::kValidMetalBuffer; AllocateMtlBuffer(tensor, [command_buffer device]); @@ -152,7 +153,7 @@ bool Tensor::NeedsHalfFloatRenderTarget() const { if (!has_color_buffer_float) { static bool has_color_buffer_half_float = gl_context_->HasGlExtension("EXT_color_buffer_half_float"); - LOG_IF(FATAL, !has_color_buffer_half_float) + ABSL_LOG_IF(FATAL, !has_color_buffer_half_float) << "EXT_color_buffer_half_float or WEBGL_color_buffer_float " << "required on web to use MP tensor"; return true; @@ -161,11 +162,11 @@ bool Tensor::NeedsHalfFloatRenderTarget() const { } Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dReadView() const { - LOG_IF(FATAL, valid_ == kValidNone) + ABSL_LOG_IF(FATAL, valid_ == kValidNone) << "Tensor must be written prior to read from."; - LOG_IF(FATAL, !(valid_ & (kValidCpu | kValidOpenGlTexture2d))) - << "Tensor conversion between different GPU resources is not supported " - "yet."; + ABSL_LOG_IF(FATAL, !(valid_ & (kValidCpu | kValidOpenGlTexture2d))) + << "Tensor conversion between different GPU backing formats is not " + "supported yet."; auto lock = absl::make_unique(&view_mutex_); AllocateOpenGlTexture2d(); if (!(valid_ & kValidOpenGlTexture2d)) { @@ -266,7 +267,7 @@ Tensor::OpenGlTexture2dView::GetLayoutDimensions(const Tensor::Shape& shape, float power = std::log2(std::sqrt(static_cast(num_pixels))); w = 1 << static_cast(power); int h = (num_pixels + w - 1) / w; - LOG_IF(FATAL, w > max_size || h > max_size) + ABSL_LOG_IF(FATAL, w > max_size || h > max_size) << "The tensor can't fit into OpenGL Texture2D View."; *width = w; *height = h; @@ -276,7 +277,7 @@ Tensor::OpenGlTexture2dView::GetLayoutDimensions(const Tensor::Shape& shape, void Tensor::AllocateOpenGlTexture2d() const { if (opengl_texture2d_ == GL_INVALID_INDEX) { gl_context_ = mediapipe::GlContext::GetCurrent(); - LOG_IF(FATAL, !gl_context_) << "GlContext is not bound to the thread."; + ABSL_LOG_IF(FATAL, !gl_context_) << "GlContext is not bound to the thread."; glGenTextures(1, &opengl_texture2d_); glBindTexture(GL_TEXTURE_2D, opengl_texture2d_); // Texture2D represents a buffer with computable data so should be fetched @@ -302,7 +303,7 @@ void Tensor::AllocateOpenGlTexture2d() const { // once for OES_texture_float extension, to save time. static bool has_oes_extension = gl_context_->HasGlExtension("OES_texture_float"); - LOG_IF(FATAL, !has_oes_extension) + ABSL_LOG_IF(FATAL, !has_oes_extension) << "OES_texture_float extension required in order to use MP tensor " << "with GLES 2.0"; // Allocate the image data; note that it's no longer RGBA32F, so will be @@ -328,14 +329,15 @@ void Tensor::AllocateOpenGlTexture2d() const { #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const { - LOG_IF(FATAL, valid_ == kValidNone) + ABSL_LOG_IF(FATAL, valid_ == kValidNone) << "Tensor must be written prior to read from."; - LOG_IF(FATAL, !(valid_ & (kValidCpu | + ABSL_LOG_IF(FATAL, !(valid_ & (kValidCpu | #ifdef MEDIAPIPE_TENSOR_USE_AHWB - kValidAHardwareBuffer | + kValidAHardwareBuffer | #endif // MEDIAPIPE_TENSOR_USE_AHWB - kValidOpenGlBuffer))) - << "Tensor conversion between different GPU resources is not supported."; + kValidOpenGlBuffer))) + << "Tensor conversion between different GPU backing formats is not " + "supported yet."; auto lock(absl::make_unique(&view_mutex_)); AllocateOpenGlBuffer(); if (!(valid_ & kValidOpenGlBuffer)) { @@ -346,7 +348,7 @@ Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const { void* ptr = glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(), GL_MAP_INVALIDATE_BUFFER_BIT | GL_MAP_WRITE_BIT); - CHECK(ptr) << "glMapBufferRange failed: " << glGetError(); + ABSL_CHECK(ptr) << "glMapBufferRange failed: " << glGetError(); std::memcpy(ptr, cpu_buffer_, bytes()); glUnmapBuffer(GL_SHADER_STORAGE_BUFFER); } @@ -373,7 +375,7 @@ Tensor::OpenGlBufferView Tensor::GetOpenGlBufferWriteView( void Tensor::AllocateOpenGlBuffer() const { if (opengl_buffer_ == GL_INVALID_INDEX) { gl_context_ = mediapipe::GlContext::GetCurrent(); - LOG_IF(FATAL, !gl_context_) << "GlContext is not bound to the thread."; + ABSL_LOG_IF(FATAL, !gl_context_) << "GlContext is not bound to the thread."; glGenBuffers(1, &opengl_buffer_); glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); if (!use_ahwb_ || !AllocateAhwbMapToSsbo()) { @@ -527,7 +529,7 @@ void Tensor::Invalidate() { Tensor::CpuReadView Tensor::GetCpuReadView() const { auto lock = absl::make_unique(&view_mutex_); - LOG_IF(FATAL, valid_ == kValidNone) + ABSL_LOG_IF(FATAL, valid_ == kValidNone) << "Tensor must be written prior to read from."; #ifdef MEDIAPIPE_TENSOR_USE_AHWB if (__builtin_available(android 26, *)) { @@ -536,7 +538,7 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { valid_ |= kValidCpu; return {ptr, std::move(lock), [ahwb = ahwb_] { auto error = AHardwareBuffer_unlock(ahwb, nullptr); - CHECK(error == 0) << "AHardwareBuffer_unlock " << error; + ABSL_CHECK(error == 0) << "AHardwareBuffer_unlock " << error; }}; } } @@ -547,7 +549,7 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { // GPU-to-CPU synchronization and read-back. #if MEDIAPIPE_METAL_ENABLED if (valid_ & kValidMetalBuffer) { - LOG_IF(FATAL, !mtl_resources_->command_buffer) + ABSL_LOG_IF(FATAL, !mtl_resources_->command_buffer) << "Metal -> CPU synchronization " "requires MTLCommandBuffer to be set."; if (mtl_resources_->command_buffer) { @@ -620,7 +622,7 @@ Tensor::CpuWriteView Tensor::GetCpuWriteView( if (ptr) { return {ptr, std::move(lock), [ahwb = ahwb_, fence_fd = &fence_fd_] { auto error = AHardwareBuffer_unlock(ahwb, fence_fd); - CHECK(error == 0) << "AHardwareBuffer_unlock " << error; + ABSL_CHECK(error == 0) << "AHardwareBuffer_unlock " << error; }}; } } diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 1d670d805..701707ded 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -25,16 +25,21 @@ #include #include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/formats/tensor/internal.h" #include "mediapipe/framework/port.h" -#ifndef MEDIAPIPE_NO_JNI +// Supported use cases for tensor_ahwb: +// 1. Native code running in Android apps. +// 2. Android vendor processes linked against nativewindow. +#if !defined(MEDIAPIPE_NO_JNI) || defined(MEDIAPIPE_ANDROID_LINK_NATIVE_WINDOW) #if __ANDROID_API__ >= 26 || defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__) #define MEDIAPIPE_TENSOR_USE_AHWB 1 #endif // __ANDROID_API__ >= 26 || // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__) -#endif // MEDIAPIPE_NO_JNI +#endif // !defined(MEDIAPIPE_NO_JNI) || + // defined(MEDIAPIPE_ANDROID_LINK_NATIVE_WINDOW) #ifdef MEDIAPIPE_TENSOR_USE_AHWB #include @@ -117,11 +122,18 @@ class Tensor { Shape() = default; Shape(std::initializer_list dimensions) : dims(dimensions) {} Shape(const std::vector& dimensions) : dims(dimensions) {} + Shape(std::initializer_list dimensions, bool is_dynamic) + : dims(dimensions), is_dynamic(is_dynamic) {} + Shape(const std::vector& dimensions, bool is_dynamic) + : dims(dimensions), is_dynamic(is_dynamic) {} int num_elements() const { return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies()); } std::vector dims; + // The Tensor has dynamic rather than static shape so the TFLite interpreter + // needs to be reallocated. Only relevant for CPU. + bool is_dynamic = false; }; // Quantization parameters corresponding to the zero_point and scale value // made available by TfLite quantized (uint8/int8) tensors. @@ -193,12 +205,12 @@ class Tensor { } int file_descriptor() const { return file_descriptor_; } void SetReadingFinishedFunc(FinishingFunc&& func) { - CHECK(ahwb_written_) + ABSL_CHECK(ahwb_written_) << "AHWB write view can't accept 'reading finished callback'"; *ahwb_written_ = std::move(func); } void SetWritingFinishedFD(int fd, FinishingFunc func = nullptr) { - CHECK(fence_fd_) + ABSL_CHECK(fence_fd_) << "AHWB read view can't accept 'writing finished file descriptor'"; *fence_fd_ = fd; *ahwb_written_ = std::move(func); diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index 525f05f31..339148e94 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -7,9 +7,10 @@ #include #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/port.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/gpu/gl_base.h" #endif // MEDIAPIPE_TENSOR_USE_AHWB @@ -208,12 +209,13 @@ class DelayedReleaser { Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const { auto lock(absl::make_unique(&view_mutex_)); - CHECK(valid_ != kValidNone) << "Tensor must be written prior to read from."; - CHECK(!(valid_ & kValidOpenGlTexture2d)) + ABSL_CHECK(valid_ != kValidNone) + << "Tensor must be written prior to read from."; + ABSL_CHECK(!(valid_ & kValidOpenGlTexture2d)) << "Tensor conversion between OpenGL texture and AHardwareBuffer is not " "supported."; bool transfer = !ahwb_; - CHECK(AllocateAHardwareBuffer()) + ABSL_CHECK(AllocateAHardwareBuffer()) << "AHardwareBuffer is not supported on the target system."; valid_ |= kValidAHardwareBuffer; if (transfer) { @@ -253,7 +255,7 @@ void Tensor::CreateEglSyncAndFd() const { Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView( int size_alignment) const { auto lock(absl::make_unique(&view_mutex_)); - CHECK(AllocateAHardwareBuffer(size_alignment)) + ABSL_CHECK(AllocateAHardwareBuffer(size_alignment)) << "AHardwareBuffer is not supported on the target system."; valid_ = kValidAHardwareBuffer; return {ahwb_, @@ -319,7 +321,7 @@ void Tensor::MoveCpuOrSsboToAhwb() const { if (__builtin_available(android 26, *)) { auto error = AHardwareBuffer_lock( ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, -1, nullptr, &dest); - CHECK(error == 0) << "AHardwareBuffer_lock " << error; + ABSL_CHECK(error == 0) << "AHardwareBuffer_lock " << error; } if (valid_ & kValidCpu) { std::memcpy(dest, cpu_buffer_, bytes()); @@ -342,11 +344,12 @@ void Tensor::MoveCpuOrSsboToAhwb() const { // of the Ahwb at the next request to the OpenGlBufferView. valid_ &= ~kValidOpenGlBuffer; } else { - LOG(FATAL) << "Can't convert tensor with mask " << valid_ << " into AHWB."; + ABSL_LOG(FATAL) << "Can't convert tensor with mask " << valid_ + << " into AHWB."; } if (__builtin_available(android 26, *)) { auto error = AHardwareBuffer_unlock(ahwb_, nullptr); - CHECK(error == 0) << "AHardwareBuffer_unlock " << error; + ABSL_CHECK(error == 0) << "AHardwareBuffer_unlock " << error; } } @@ -421,9 +424,10 @@ void* Tensor::MapAhwbToCpuRead() const { // TODO: Use tflite::gpu::GlBufferSync and GlActiveSync. gl_context_->Run([]() { glFinish(); }); } else if (valid_ & kValidAHardwareBuffer) { - CHECK(ahwb_written_) << "Ahwb-to-Cpu synchronization requires the " - "completion function to be set"; - CHECK(ahwb_written_(true)) + ABSL_CHECK(ahwb_written_) + << "Ahwb-to-Cpu synchronization requires the " + "completion function to be set"; + ABSL_CHECK(ahwb_written_(true)) << "An error oqcured while waiting for the buffer to be written"; } } @@ -431,7 +435,7 @@ void* Tensor::MapAhwbToCpuRead() const { auto error = AHardwareBuffer_lock(ahwb_, AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN, ssbo_written_, nullptr, &ptr); - CHECK(error == 0) << "AHardwareBuffer_lock " << error; + ABSL_CHECK(error == 0) << "AHardwareBuffer_lock " << error; close(ssbo_written_); ssbo_written_ = -1; return ptr; @@ -449,7 +453,7 @@ void* Tensor::MapAhwbToCpuWrite() const { void* ptr; auto error = AHardwareBuffer_lock( ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN, -1, nullptr, &ptr); - CHECK(error == 0) << "AHardwareBuffer_lock " << error; + ABSL_CHECK(error == 0) << "AHardwareBuffer_lock " << error; return ptr; } } diff --git a/mediapipe/framework/formats/tensor_test.cc b/mediapipe/framework/formats/tensor_test.cc index 4ad4e18eb..468af4ab9 100644 --- a/mediapipe/framework/formats/tensor_test.cc +++ b/mediapipe/framework/formats/tensor_test.cc @@ -2,6 +2,7 @@ #include #include +#include #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -34,6 +35,17 @@ TEST(General, TestDataTypes) { EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool)); } +TEST(General, TestDynamic) { + Tensor t1(Tensor::ElementType::kFloat32, Tensor::Shape({1, 2, 3, 4}, true)); + EXPECT_EQ(t1.shape().num_elements(), 1 * 2 * 3 * 4); + EXPECT_TRUE(t1.shape().is_dynamic); + + std::vector t2_dims = {4, 3, 2, 3}; + Tensor t2(Tensor::ElementType::kFloat16, Tensor::Shape(t2_dims, true)); + EXPECT_EQ(t2.shape().num_elements(), 4 * 3 * 2 * 3); + EXPECT_TRUE(t2.shape().is_dynamic); +} + TEST(Cpu, TestMemoryAllocation) { Tensor t1(Tensor::ElementType::kFloat32, Tensor::Shape{4, 3, 2, 3}); auto v1 = t1.GetCpuWriteView(); diff --git a/mediapipe/framework/graph_output_stream.cc b/mediapipe/framework/graph_output_stream.cc index de024dfe5..e456c6535 100644 --- a/mediapipe/framework/graph_output_stream.cc +++ b/mediapipe/framework/graph_output_stream.cc @@ -14,6 +14,7 @@ #include "mediapipe/framework/graph_output_stream.h" +#include "absl/log/absl_check.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/port/status.h" @@ -153,7 +154,7 @@ void OutputStreamPollerImpl::Reset() { } void OutputStreamPollerImpl::SetMaxQueueSize(int queue_size) { - CHECK(queue_size >= -1) + ABSL_CHECK(queue_size >= -1) << "Max queue size must be either -1 or non-negative."; input_stream_handler_->SetMaxQueueSize(queue_size); } @@ -175,7 +176,7 @@ void OutputStreamPollerImpl::NotifyError() { } bool OutputStreamPollerImpl::Next(Packet* packet) { - CHECK(packet); + ABSL_CHECK(packet); bool empty_queue = true; bool timestamp_bound_changed = false; Timestamp min_timestamp = Timestamp::Unset(); @@ -212,7 +213,7 @@ bool OutputStreamPollerImpl::Next(Packet* packet) { bool stream_is_done = false; *packet = input_stream_->PopPacketAtTimestamp( min_timestamp, &num_packets_dropped, &stream_is_done); - CHECK_EQ(num_packets_dropped, 0) + ABSL_CHECK_EQ(num_packets_dropped, 0) << absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".", num_packets_dropped, input_stream_->Name()); } else if (timestamp_bound_changed) { diff --git a/mediapipe/framework/graph_output_stream.h b/mediapipe/framework/graph_output_stream.h index b541aec12..7308be111 100644 --- a/mediapipe/framework/graph_output_stream.h +++ b/mediapipe/framework/graph_output_stream.h @@ -22,6 +22,7 @@ #include "absl/base/attributes.h" #include "absl/base/thread_annotations.h" +#include "absl/log/absl_log.h" #include "absl/strings/substitute.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/input_stream_handler.h" @@ -30,7 +31,6 @@ #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_set.h" #include "mediapipe/framework/packet_type.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/timestamp.h" @@ -76,7 +76,7 @@ class GraphOutputStream { // TODO: Simplify this. We are forced to use an ISH just to // receive a packet, even though we do not need to do any of the things an ISH // normally does. The fact that we have to disable required overrides with - // LOG(FATAL) shows that this is the wrong interface. + // ABSL_LOG(FATAL) shows that this is the wrong interface. class GraphOutputStreamHandler : public InputStreamHandler { public: GraphOutputStreamHandler(std::shared_ptr tag_map, @@ -88,15 +88,15 @@ class GraphOutputStream { protected: NodeReadiness GetNodeReadiness(Timestamp* min_stream_timestamp) override { - LOG(FATAL) << "GraphOutputStreamHandler::GetNodeReadiness should " - "never be invoked."; + ABSL_LOG(FATAL) << "GraphOutputStreamHandler::GetNodeReadiness should " + "never be invoked."; return NodeReadiness::kNotReady; } void FillInputSet(Timestamp input_timestamp, InputStreamShardSet* input_set) override { - LOG(FATAL) << "GraphOutputStreamHandler::FillInputSet should " - "never be invoked."; + ABSL_LOG(FATAL) << "GraphOutputStreamHandler::FillInputSet should " + "never be invoked."; } }; diff --git a/mediapipe/framework/graph_service.h b/mediapipe/framework/graph_service.h index 51caf31f2..95f55bbd1 100644 --- a/mediapipe/framework/graph_service.h +++ b/mediapipe/framework/graph_service.h @@ -19,6 +19,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/port/status.h" @@ -44,7 +45,6 @@ class GraphServiceBase { constexpr GraphServiceBase(const char* key) : key(key) {} - virtual ~GraphServiceBase() = default; inline virtual absl::StatusOr CreateDefaultObject() const { return DefaultInitializationUnsupported(); } @@ -52,14 +52,32 @@ class GraphServiceBase { const char* key; protected: + // `GraphService` objects, deriving `GraphServiceBase` are designed to be + // global constants and not ever deleted through `GraphServiceBase`. Hence, + // protected and non-virtual destructor which helps to make `GraphService` + // trivially destructible and properly defined as global constants. + // + // A class with any virtual functions should have a destructor that is either + // public and virtual or else protected and non-virtual. + // https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#Rc-dtor-virtual + ~GraphServiceBase() = default; + absl::Status DefaultInitializationUnsupported() const { return absl::UnimplementedError(absl::StrCat( "Graph service '", key, "' does not support default initialization")); } }; +// A global constant to refer a service: +// - Requesting `CalculatorContract::UseService` from calculator +// - Accessing `Calculator/SubgraphContext::Service`from calculator/subgraph +// - Setting before graph initialization `CalculatorGraph::SetServiceObject` +// +// NOTE: In headers, define your graph service reference safely as following: +// `inline constexpr GraphService kYourService("YourService");` +// template -class GraphService : public GraphServiceBase { +class GraphService final : public GraphServiceBase { public: using type = T; using packet_type = std::shared_ptr; @@ -68,7 +86,7 @@ class GraphService : public GraphServiceBase { kDisallowDefaultInitialization) : GraphServiceBase(my_key), default_init_(default_init) {} - absl::StatusOr CreateDefaultObject() const override { + absl::StatusOr CreateDefaultObject() const final { if (default_init_ != kAllowDefaultInitialization) { return DefaultInitializationUnsupported(); } @@ -108,7 +126,7 @@ class ServiceBinding { public: bool IsAvailable() { return service_ != nullptr; } T& GetObject() { - CHECK(service_) << "Service is unavailable."; + ABSL_CHECK(service_) << "Service is unavailable."; return *service_; } diff --git a/mediapipe/framework/graph_service_manager_test.cc b/mediapipe/framework/graph_service_manager_test.cc index 1895a6f70..23d4af0df 100644 --- a/mediapipe/framework/graph_service_manager_test.cc +++ b/mediapipe/framework/graph_service_manager_test.cc @@ -7,7 +7,7 @@ namespace mediapipe { namespace { -const GraphService kIntService("mediapipe::IntService"); +constexpr GraphService kIntService("mediapipe::IntService"); } // namespace TEST(GraphServiceManager, SetGetServiceObject) { diff --git a/mediapipe/framework/graph_service_test.cc b/mediapipe/framework/graph_service_test.cc index 69992f212..0556aac63 100644 --- a/mediapipe/framework/graph_service_test.cc +++ b/mediapipe/framework/graph_service_test.cc @@ -14,6 +14,8 @@ #include "mediapipe/framework/graph_service.h" +#include + #include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/canonical_errors.h" @@ -159,7 +161,7 @@ TEST_F(GraphServiceTest, CreateDefault) { struct TestServiceData {}; -const GraphService kTestServiceAllowDefaultInitialization( +constexpr GraphService kTestServiceAllowDefaultInitialization( "kTestServiceAllowDefaultInitialization", GraphServiceBase::kAllowDefaultInitialization); @@ -272,9 +274,13 @@ TEST(AllowDefaultInitializationGraphServiceTest, HasSubstr("Service is unavailable."))); } -const GraphService kTestServiceDisallowDefaultInitialization( - "kTestServiceDisallowDefaultInitialization", - GraphServiceBase::kDisallowDefaultInitialization); +constexpr GraphService + kTestServiceDisallowDefaultInitialization( + "kTestServiceDisallowDefaultInitialization", + GraphServiceBase::kDisallowDefaultInitialization); + +static_assert(std::is_trivially_destructible_v>, + "GraphService is not trivially destructible"); class FailOnUnavailableOptionalDisallowDefaultInitServiceCalculator : public CalculatorBase { diff --git a/mediapipe/framework/graph_validation_test.cc b/mediapipe/framework/graph_validation_test.cc index c98983838..3982adbe5 100644 --- a/mediapipe/framework/graph_validation_test.cc +++ b/mediapipe/framework/graph_validation_test.cc @@ -19,6 +19,7 @@ #include +#include "absl/log/absl_check.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -121,7 +122,7 @@ TEST(GraphValidationTest, InitializeGraphFromLinker) { TEST(GraphValidationTest, InitializeTemplateFromProtos) { mediapipe::tool::TemplateParser::Parser parser; CalculatorGraphTemplate config_1; - CHECK(parser.ParseFromString(R"( + ABSL_CHECK(parser.ParseFromString(R"( type: "PassThroughGraph" input_stream: % "INPUT:" + in_name % output_stream: "OUTPUT:stream_2" @@ -132,7 +133,7 @@ TEST(GraphValidationTest, InitializeTemplateFromProtos) { output_stream: "stream_2" # Same as input. } )", - &config_1)); + &config_1)); auto config_2 = ParseTextProtoOrDie(R"pb( input_stream: "INPUT:stream_1" output_stream: "OUTPUT:stream_2" diff --git a/mediapipe/framework/input_side_packet_handler.cc b/mediapipe/framework/input_side_packet_handler.cc index 9b01cc31a..b2eccf0db 100644 --- a/mediapipe/framework/input_side_packet_handler.cc +++ b/mediapipe/framework/input_side_packet_handler.cc @@ -14,6 +14,7 @@ #include "mediapipe/framework/input_side_packet_handler.h" +#include "absl/log/absl_check.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status_builder.h" @@ -82,7 +83,7 @@ absl::Status InputSidePacketHandler::SetInternal(CollectionItemId id, void InputSidePacketHandler::TriggerErrorCallback( const absl::Status& status) const { - CHECK(error_callback_); + ABSL_CHECK(error_callback_); error_callback_(status); } diff --git a/mediapipe/framework/input_stream_handler.cc b/mediapipe/framework/input_stream_handler.cc index a7bd9ef43..e222c2e6c 100644 --- a/mediapipe/framework/input_stream_handler.cc +++ b/mediapipe/framework/input_stream_handler.cc @@ -14,6 +14,7 @@ #include "mediapipe/framework/input_stream_handler.h" +#include "absl/log/absl_check.h" #include "absl/strings/str_join.h" #include "absl/strings/substitute.h" #include "mediapipe/framework/collection_item_id.h" @@ -102,7 +103,7 @@ void InputStreamHandler::SetHeader(CollectionItemId id, const Packet& header) { return; } if (!input_stream_managers_.Get(id)->BackEdge()) { - CHECK_GT(unset_header_count_, 0); + ABSL_CHECK_GT(unset_header_count_, 0); if (unset_header_count_.fetch_sub(1, std::memory_order_acq_rel) == 1) { headers_ready_callback_(); } @@ -111,7 +112,7 @@ void InputStreamHandler::SetHeader(CollectionItemId id, const Packet& header) { void InputStreamHandler::UpdateInputShardHeaders( InputStreamShardSet* input_shards) { - CHECK(input_shards); + ABSL_CHECK(input_shards); for (CollectionItemId id = input_stream_managers_.BeginId(); id < input_stream_managers_.EndId(); ++id) { input_shards->Get(id).SetHeader(input_stream_managers_.Get(id)->Header()); @@ -198,7 +199,7 @@ bool InputStreamHandler::ScheduleInvocations(int max_allowance, TraceEvent(TraceEvent::READY_FOR_PROCESS) .set_node_id(calculator_context->NodeId())); } else { - CHECK(node_readiness == NodeReadiness::kReadyForClose); + ABSL_CHECK(node_readiness == NodeReadiness::kReadyForClose); // If any parallel invocations are in progress or a calculator context has // been prepared for Close(), we shouldn't prepare another calculator // context for Close(). @@ -302,7 +303,7 @@ void InputStreamHandler::SetNextTimestampBound(CollectionItemId id, void InputStreamHandler::ClearCurrentInputs( CalculatorContext* calculator_context) { - CHECK(calculator_context); + ABSL_CHECK(calculator_context); calculator_context_manager_->PopInputTimestampFromContext(calculator_context); for (auto& input : calculator_context->Inputs()) { // Invokes InputStreamShard's private method to clear packet. @@ -317,18 +318,20 @@ void InputStreamHandler::Close() { } void InputStreamHandler::SetBatchSize(int batch_size) { - CHECK(!calculator_run_in_parallel_ || batch_size == 1) + ABSL_CHECK(!calculator_run_in_parallel_ || batch_size == 1) << "Batching cannot be combined with parallel execution."; - CHECK(!late_preparation_ || batch_size == 1) + ABSL_CHECK(!late_preparation_ || batch_size == 1) << "Batching cannot be combined with late preparation."; - CHECK_GE(batch_size, 1) << "Batch size has to be greater than or equal to 1."; + ABSL_CHECK_GE(batch_size, 1) + << "Batch size has to be greater than or equal to 1."; // Source nodes shouldn't specify batch_size even if it's set to 1. - CHECK_GE(NumInputStreams(), 0) << "Source nodes cannot batch input packets."; + ABSL_CHECK_GE(NumInputStreams(), 0) + << "Source nodes cannot batch input packets."; batch_size_ = batch_size; } void InputStreamHandler::SetLatePreparation(bool late_preparation) { - CHECK(batch_size_ == 1 || !late_preparation_) + ABSL_CHECK(batch_size_ == 1 || !late_preparation_) << "Batching cannot be combined with late preparation."; late_preparation_ = late_preparation; } @@ -404,15 +407,15 @@ Timestamp SyncSet::MinPacketTimestamp() const { void SyncSet::FillInputSet(Timestamp input_timestamp, InputStreamShardSet* input_set) { - CHECK(input_timestamp.IsAllowedInStream()); - CHECK(input_set); + ABSL_CHECK(input_timestamp.IsAllowedInStream()); + ABSL_CHECK(input_set); for (CollectionItemId id : stream_ids_) { const auto& stream = input_stream_handler_->input_stream_managers_.Get(id); int num_packets_dropped = 0; bool stream_is_done = false; Packet current_packet = stream->PopPacketAtTimestamp( input_timestamp, &num_packets_dropped, &stream_is_done); - CHECK_EQ(num_packets_dropped, 0) + ABSL_CHECK_EQ(num_packets_dropped, 0) << absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".", num_packets_dropped, stream->Name()); input_stream_handler_->AddPacketToShard( diff --git a/mediapipe/framework/input_stream_manager.cc b/mediapipe/framework/input_stream_manager.cc index 1af2e2cc8..fe63b62e3 100644 --- a/mediapipe/framework/input_stream_manager.cc +++ b/mediapipe/framework/input_stream_manager.cc @@ -17,6 +17,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/packet.h" @@ -244,7 +245,7 @@ Timestamp InputStreamManager::MinTimestampOrBoundHelper() const Packet InputStreamManager::PopPacketAtTimestamp(Timestamp timestamp, int* num_packets_dropped, bool* stream_is_done) { - CHECK(enable_timestamps_); + ABSL_CHECK(enable_timestamps_); *num_packets_dropped = -1; *stream_is_done = false; bool queue_became_non_full = false; @@ -252,7 +253,7 @@ Packet InputStreamManager::PopPacketAtTimestamp(Timestamp timestamp, { absl::MutexLock stream_lock(&stream_mutex_); // Make sure timestamp didn't decrease from last time. - CHECK_LE(last_select_timestamp_, timestamp); + ABSL_CHECK_LE(last_select_timestamp_, timestamp); last_select_timestamp_ = timestamp; // Make sure AddPacket and SetNextTimestampBound are not called with @@ -299,7 +300,7 @@ Packet InputStreamManager::PopPacketAtTimestamp(Timestamp timestamp, } Packet InputStreamManager::PopQueueHead(bool* stream_is_done) { - CHECK(!enable_timestamps_); + ABSL_CHECK(!enable_timestamps_); *stream_is_done = false; bool queue_became_non_full = false; Packet packet; diff --git a/mediapipe/framework/input_stream_shard.cc b/mediapipe/framework/input_stream_shard.cc index 8e3348dd6..c7d1df8a3 100644 --- a/mediapipe/framework/input_stream_shard.cc +++ b/mediapipe/framework/input_stream_shard.cc @@ -14,12 +14,14 @@ #include "mediapipe/framework/input_stream_shard.h" +#include "absl/log/absl_check.h" + namespace mediapipe { void InputStreamShard::AddPacket(Packet&& value, bool is_done) { // A packet can be added if the shard is still active or the packet being // added is empty. An empty packet corresponds to absence of a packet. - CHECK(!is_done_ || value.IsEmpty()); + ABSL_CHECK(!is_done_ || value.IsEmpty()); packet_queue_.emplace(std::move(value)); is_done_ = is_done; } diff --git a/mediapipe/framework/legacy_calculator_support.h b/mediapipe/framework/legacy_calculator_support.h index 9378d14f0..6ec0d953b 100644 --- a/mediapipe/framework/legacy_calculator_support.h +++ b/mediapipe/framework/legacy_calculator_support.h @@ -66,7 +66,7 @@ class LegacyCalculatorSupport { }; }; -#if !defined(_MSC_VER) +#if !defined(_MSC_VER) || defined(__clang__) // We only declare this variable for two specializations of the template because // it is only meant to be used for these two types. // Note that, since these variables are members of specific template diff --git a/mediapipe/framework/mediapipe_cc_test.bzl b/mediapipe/framework/mediapipe_cc_test.bzl index fe0d44e0c..5e1daca7b 100644 --- a/mediapipe/framework/mediapipe_cc_test.bzl +++ b/mediapipe/framework/mediapipe_cc_test.bzl @@ -15,11 +15,12 @@ def mediapipe_cc_test( platforms = ["linux", "android", "ios", "wasm"], exclude_platforms = None, # ios_unit_test arguments - ios_minimum_os_version = "11.0", + ios_minimum_os_version = "12.0", # android_cc_test arguments open_gl_driver = None, emulator_mini_boot = True, requires_full_emulation = True, + android_devices = {}, # wasm_web_test arguments browsers = None, **kwargs): diff --git a/mediapipe/framework/output_side_packet_impl.cc b/mediapipe/framework/output_side_packet_impl.cc index 94bc518f8..dcb541408 100644 --- a/mediapipe/framework/output_side_packet_impl.cc +++ b/mediapipe/framework/output_side_packet_impl.cc @@ -14,6 +14,7 @@ #include "mediapipe/framework/output_side_packet_impl.h" +#include "absl/log/absl_check.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/source_location.h" #include "mediapipe/framework/port/status_builder.h" @@ -42,7 +43,7 @@ void OutputSidePacketImpl::Set(const Packet& packet) { void OutputSidePacketImpl::AddMirror( InputSidePacketHandler* input_side_packet_handler, CollectionItemId id) { - CHECK(input_side_packet_handler); + ABSL_CHECK(input_side_packet_handler); mirrors_.emplace_back(input_side_packet_handler, id); } @@ -81,7 +82,7 @@ absl::Status OutputSidePacketImpl::SetInternal(const Packet& packet) { void OutputSidePacketImpl::TriggerErrorCallback( const absl::Status& status) const { - CHECK(error_callback_); + ABSL_CHECK(error_callback_); error_callback_(status); } diff --git a/mediapipe/framework/output_stream_handler.cc b/mediapipe/framework/output_stream_handler.cc index ba8f46718..377de6c88 100644 --- a/mediapipe/framework/output_stream_handler.cc +++ b/mediapipe/framework/output_stream_handler.cc @@ -14,6 +14,7 @@ #include "mediapipe/framework/output_stream_handler.h" +#include "absl/log/absl_check.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/output_stream_shard.h" @@ -31,7 +32,7 @@ absl::Status OutputStreamHandler::InitializeOutputStreamManagers( absl::Status OutputStreamHandler::SetupOutputShards( OutputStreamShardSet* output_shards) { - CHECK(output_shards); + ABSL_CHECK(output_shards); for (CollectionItemId id = output_stream_managers_.BeginId(); id < output_stream_managers_.EndId(); ++id) { OutputStreamManager* manager = output_stream_managers_.Get(id); @@ -52,7 +53,7 @@ void OutputStreamHandler::PrepareForRun( } void OutputStreamHandler::Open(OutputStreamShardSet* output_shards) { - CHECK(output_shards); + ABSL_CHECK(output_shards); PropagateOutputPackets(Timestamp::Unstarted(), output_shards); for (auto& manager : output_stream_managers_) { manager->PropagateHeader(); @@ -62,7 +63,7 @@ void OutputStreamHandler::Open(OutputStreamShardSet* output_shards) { void OutputStreamHandler::PrepareOutputs(Timestamp input_timestamp, OutputStreamShardSet* output_shards) { - CHECK(output_shards); + ABSL_CHECK(output_shards); for (CollectionItemId id = output_stream_managers_.BeginId(); id < output_stream_managers_.EndId(); ++id) { output_stream_managers_.Get(id)->ResetShard(&output_shards->Get(id)); @@ -79,7 +80,7 @@ void OutputStreamHandler::UpdateTaskTimestampBound(Timestamp timestamp) { if (task_timestamp_bound_ == timestamp) { return; } - CHECK_GT(timestamp, task_timestamp_bound_); + ABSL_CHECK_GT(timestamp, task_timestamp_bound_); task_timestamp_bound_ = timestamp; if (propagation_state_ == kPropagatingBound) { propagation_state_ = kPropagationPending; @@ -149,7 +150,7 @@ void OutputStreamHandler::Close(OutputStreamShardSet* output_shards) { void OutputStreamHandler::PropagateOutputPackets( Timestamp input_timestamp, OutputStreamShardSet* output_shards) { - CHECK(output_shards); + ABSL_CHECK(output_shards); for (CollectionItemId id = output_stream_managers_.BeginId(); id < output_stream_managers_.EndId(); ++id) { OutputStreamManager* manager = output_stream_managers_.Get(id); diff --git a/mediapipe/framework/output_stream_handler.h b/mediapipe/framework/output_stream_handler.h index 0b8dbed2c..cb6b2d6e1 100644 --- a/mediapipe/framework/output_stream_handler.h +++ b/mediapipe/framework/output_stream_handler.h @@ -25,6 +25,7 @@ // TODO: Move protos in another CL after the C++ code migration. #include "absl/base/thread_annotations.h" +#include "absl/log/absl_check.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/calculator_context_manager.h" #include "mediapipe/framework/collection.h" @@ -63,7 +64,7 @@ class OutputStreamHandler { calculator_context_manager_(calculator_context_manager), options_(options), calculator_run_in_parallel_(calculator_run_in_parallel) { - CHECK(calculator_context_manager_); + ABSL_CHECK(calculator_context_manager_); } virtual ~OutputStreamHandler() = default; diff --git a/mediapipe/framework/output_stream_manager.cc b/mediapipe/framework/output_stream_manager.cc index b092313e2..0cb592943 100644 --- a/mediapipe/framework/output_stream_manager.cc +++ b/mediapipe/framework/output_stream_manager.cc @@ -14,6 +14,7 @@ #include "mediapipe/framework/output_stream_manager.h" +#include "absl/log/absl_check.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/input_stream_handler.h" #include "mediapipe/framework/port/status_builder.h" @@ -80,7 +81,7 @@ void OutputStreamManager::PropagateHeader() { void OutputStreamManager::AddMirror(InputStreamHandler* input_stream_handler, CollectionItemId id) { - CHECK(input_stream_handler); + ABSL_CHECK(input_stream_handler); mirrors_.emplace_back(input_stream_handler, id); } @@ -163,7 +164,7 @@ Timestamp OutputStreamManager::ComputeOutputTimestampBound( // TODO Consider moving the propagation logic to OutputStreamHandler. void OutputStreamManager::PropagateUpdatesToMirrors( Timestamp next_timestamp_bound, OutputStreamShard* output_stream_shard) { - CHECK(output_stream_shard); + ABSL_CHECK(output_stream_shard); { if (next_timestamp_bound != Timestamp::Unset()) { absl::MutexLock lock(&stream_mutex_); diff --git a/mediapipe/framework/output_stream_poller.h b/mediapipe/framework/output_stream_poller.h index 26c0e72b2..98ebda313 100644 --- a/mediapipe/framework/output_stream_poller.h +++ b/mediapipe/framework/output_stream_poller.h @@ -17,6 +17,7 @@ #include +#include "absl/log/absl_check.h" #include "mediapipe/framework/graph_output_stream.h" namespace mediapipe { @@ -34,7 +35,7 @@ class OutputStreamPoller { // Resets OutputStramPollerImpl and cleans the internal packet queue. void Reset() { auto poller = internal_poller_impl_.lock(); - CHECK(poller) << "OutputStreamPollerImpl is already destroyed."; + ABSL_CHECK(poller) << "OutputStreamPollerImpl is already destroyed."; poller->Reset(); } @@ -50,14 +51,14 @@ class OutputStreamPoller { void SetMaxQueueSize(int queue_size) { auto poller = internal_poller_impl_.lock(); - CHECK(poller) << "OutputStreamPollerImpl is already destroyed."; + ABSL_CHECK(poller) << "OutputStreamPollerImpl is already destroyed."; return poller->SetMaxQueueSize(queue_size); } // Returns the number of packets in the queue. int QueueSize() { auto poller = internal_poller_impl_.lock(); - CHECK(poller) << "OutputStreamPollerImpl is already destroyed."; + ABSL_CHECK(poller) << "OutputStreamPollerImpl is already destroyed."; return poller->QueueSize(); } diff --git a/mediapipe/framework/output_stream_shard.cc b/mediapipe/framework/output_stream_shard.cc index 1b096efbb..3b24321fb 100644 --- a/mediapipe/framework/output_stream_shard.cc +++ b/mediapipe/framework/output_stream_shard.cc @@ -14,6 +14,7 @@ #include "mediapipe/framework/output_stream_shard.h" +#include "absl/log/absl_check.h" #include "mediapipe/framework/port/source_location.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_builder.h" @@ -23,7 +24,7 @@ namespace mediapipe { OutputStreamShard::OutputStreamShard() : closed_(false) {} void OutputStreamShard::SetSpec(OutputStreamSpec* output_stream_spec) { - CHECK(output_stream_spec); + ABSL_CHECK(output_stream_spec); output_stream_spec_ = output_stream_spec; } @@ -94,7 +95,7 @@ const Packet& OutputStreamShard::Header() const { // binary. This function can be defined in the .cc file because only two // versions are ever instantiated, and all call sites are within this .cc file. template -Status OutputStreamShard::AddPacketInternal(T&& packet) { +absl::Status OutputStreamShard::AddPacketInternal(T&& packet) { if (IsClosed()) { return mediapipe::FailedPreconditionErrorBuilder(MEDIAPIPE_LOC) << "Packet sent to closed stream \"" << Name() << "\"."; @@ -113,7 +114,7 @@ Status OutputStreamShard::AddPacketInternal(T&& packet) { << timestamp.DebugString(); } - Status result = output_stream_spec_->packet_type->Validate(packet); + absl::Status result = output_stream_spec_->packet_type->Validate(packet); if (!result.ok()) { return StatusBuilder(result, MEDIAPIPE_LOC).SetPrepend() << absl::StrCat( "Packet type mismatch on calculator outputting to stream \"", @@ -132,14 +133,14 @@ Status OutputStreamShard::AddPacketInternal(T&& packet) { } void OutputStreamShard::AddPacket(const Packet& packet) { - Status status = AddPacketInternal(packet); + absl::Status status = AddPacketInternal(packet); if (!status.ok()) { output_stream_spec_->TriggerErrorCallback(status); } } void OutputStreamShard::AddPacket(Packet&& packet) { - Status status = AddPacketInternal(std::move(packet)); + absl::Status status = AddPacketInternal(std::move(packet)); if (!status.ok()) { output_stream_spec_->TriggerErrorCallback(status); } diff --git a/mediapipe/framework/output_stream_shard.h b/mediapipe/framework/output_stream_shard.h index 718174c45..81a897591 100644 --- a/mediapipe/framework/output_stream_shard.h +++ b/mediapipe/framework/output_stream_shard.h @@ -18,6 +18,7 @@ #include #include +#include "absl/log/absl_check.h" #include "mediapipe/framework/output_stream.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_type.h" @@ -34,7 +35,7 @@ struct OutputStreamSpec { // Triggers the error callback with absl::Status info when an error // occurs. void TriggerErrorCallback(const absl::Status& status) const { - CHECK(error_callback); + ABSL_CHECK(error_callback); error_callback(status); } diff --git a/mediapipe/framework/packet.cc b/mediapipe/framework/packet.cc index 05d3c6c52..edcdaf19f 100644 --- a/mediapipe/framework/packet.cc +++ b/mediapipe/framework/packet.cc @@ -14,6 +14,7 @@ #include "mediapipe/framework/packet.h" +#include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "mediapipe/framework/port.h" #include "mediapipe/framework/port/canonical_errors.h" @@ -135,10 +136,11 @@ absl::Status Packet::ValidateAsProtoMessageLite() const { } const proto_ns::MessageLite& Packet::GetProtoMessageLite() const { - CHECK(holder_ != nullptr) << "The packet is empty."; + ABSL_CHECK(holder_ != nullptr) << "The packet is empty."; const proto_ns::MessageLite* proto = holder_->GetProtoMessageLite(); - CHECK(proto != nullptr) << "The Packet stores '" << holder_->DebugTypeName() - << "', it cannot be converted to MessageLite type."; + ABSL_CHECK(proto != nullptr) + << "The Packet stores '" << holder_->DebugTypeName() + << "', it cannot be converted to MessageLite type."; return *proto; } diff --git a/mediapipe/framework/packet.h b/mediapipe/framework/packet.h index af2ec5a98..770dd9d4c 100644 --- a/mediapipe/framework/packet.h +++ b/mediapipe/framework/packet.h @@ -18,11 +18,14 @@ #define MEDIAPIPE_FRAMEWORK_PACKET_H_ #include +#include #include #include #include #include "absl/base/macros.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" @@ -368,11 +371,14 @@ class HolderBase { } // Returns a printable string identifying the type stored in the holder. virtual const std::string DebugTypeName() const = 0; + // Returns debug data id. + virtual int64_t DebugDataId() const = 0; // Returns the registered type name if it's available, otherwise the // empty string. virtual const std::string RegisteredTypeName() const = 0; // Get the type id of the underlying data type. virtual TypeId GetTypeId() const = 0; + // Downcasts this to Holder. Returns nullptr if deserialization // failed or if the requested type is not what is stored. template @@ -451,61 +457,37 @@ struct is_concrete_proto_t !std::is_same{} && !std::is_same{}> {}; -// Registers a message type. T must be a non-cv-qualified concrete proto type. template -struct MessageRegistrationImpl { - static NoDestructor registration; - // This could have been a lambda inside registration's initializer below, but - // MSVC has a bug with lambdas, so we put it here as a workaround. - static std::unique_ptr> CreateMessageHolder() { - return absl::make_unique>(new T); - } -}; +std::unique_ptr CreateMessageHolder() { + return absl::make_unique>(new T); +} -// Static members of template classes can be defined in the header. -template -NoDestructor - MessageRegistrationImpl::registration(MessageHolderRegistry::Register( - T{}.GetTypeName(), MessageRegistrationImpl::CreateMessageHolder, - __FILE__, __LINE__)); +// Registers a message type. T must be a non-cv-qualified concrete proto type. +MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(MessageRegistrator, MessageHolderRegistry, + T{}.GetTypeName(), CreateMessageHolder) // For non-Message payloads, this does nothing. template -struct HolderSupport { - static void EnsureStaticInit() {} -}; +struct HolderPayloadRegistrator {}; // This template ensures that, for each concrete MessageLite subclass that is // stored in a Packet, we register a function that allows us to create a // Holder with the correct payload type from the proto's type name. +// +// We must use std::remove_cv to ensure we don't try to register Foo twice if +// there are Holder and Holder. TODO: lift this +// up to Holder? template -struct HolderSupport{}>::type> { - // We must use std::remove_cv to ensure we don't try to register Foo twice if - // there are Holder and Holder. TODO: lift this - // up to Holder? - using R = MessageRegistrationImpl::type>; - // For the registration static member to be instantiated, it needs to be - // referenced in a context that requires the definition to exist (see ISO/IEC - // C++ 2003 standard, 14.7.1). Calling this ensures that's the case. - // We need two different call-sites to cover proto types for which packets - // are only ever created (i.e. the protos are only produced by calculators) - // and proto types for which packets are only ever consumed (i.e. the protos - // are only consumed by calculators). - static void EnsureStaticInit() { CHECK(R::registration.get() != nullptr); } -}; +struct HolderPayloadRegistrator< + T, typename std::enable_if{}>::type> + : private MessageRegistrator::type> {}; template -class Holder : public HolderBase { +class Holder : public HolderBase, private HolderPayloadRegistrator { public: - explicit Holder(const T* ptr) : ptr_(ptr) { - HolderSupport::EnsureStaticInit(); - } + explicit Holder(const T* ptr) : ptr_(ptr) {} ~Holder() override { delete_helper(); } - const T& data() const { - HolderSupport::EnsureStaticInit(); - return *ptr_; - } + const T& data() const { return *ptr_; } TypeId GetTypeId() const final { return kTypeId; } // Releases the underlying data pointer and transfers the ownership to a // unique pointer. @@ -535,6 +517,7 @@ class Holder : public HolderBase { const std::string DebugTypeName() const final { return MediaPipeTypeStringOrDemangled(); } + int64_t DebugDataId() const final { return reinterpret_cast(ptr_); } const std::string RegisteredTypeName() const final { const std::string* type_string = MediaPipeTypeString(); if (type_string) { @@ -743,7 +726,7 @@ inline Packet& Packet::operator=(Packet&& packet) { inline bool Packet::IsEmpty() const { return holder_ == nullptr; } inline TypeId Packet::GetTypeId() const { - CHECK(holder_); + ABSL_CHECK(holder_); return holder_->GetTypeId(); } @@ -753,7 +736,7 @@ inline const T& Packet::Get() const { if (holder == nullptr) { // Produce a good error message. absl::Status status = ValidateAsType(); - LOG(FATAL) << "Packet::Get() failed: " << status.message(); + ABSL_LOG(FATAL) << "Packet::Get() failed: " << status.message(); } return holder->data(); } @@ -762,13 +745,13 @@ inline Timestamp Packet::Timestamp() const { return timestamp_; } template Packet Adopt(const T* ptr) { - CHECK(ptr != nullptr); + ABSL_CHECK(ptr != nullptr); return packet_internal::Create(new packet_internal::Holder(ptr)); } template Packet PointToForeign(const T* ptr) { - CHECK(ptr != nullptr); + ABSL_CHECK(ptr != nullptr); return packet_internal::Create(new packet_internal::ForeignHolder(ptr)); } diff --git a/mediapipe/framework/packet_registration_test.cc b/mediapipe/framework/packet_registration_test.cc index 30c7c7893..7b2ea1f79 100644 --- a/mediapipe/framework/packet_registration_test.cc +++ b/mediapipe/framework/packet_registration_test.cc @@ -12,7 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include + #include "absl/strings/str_cat.h" +#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_test.pb.h" @@ -24,6 +28,9 @@ namespace mediapipe { namespace { +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Stream; + namespace test_ns { constexpr char kOutTag[] = "OUT"; @@ -48,7 +55,7 @@ REGISTER_CALCULATOR(TestSinkCalculator); } // namespace test_ns -TEST(PacketTest, InputTypeRegistration) { +TEST(PacketRegistrationTest, InputTypeRegistration) { using testing::Contains; ASSERT_EQ(mediapipe::InputOnlyProto{}.GetTypeName(), "mediapipe.InputOnlyProto"); @@ -56,5 +63,33 @@ TEST(PacketTest, InputTypeRegistration) { Contains("mediapipe.InputOnlyProto")); } +TEST(PacketRegistrationTest, AdoptingRegisteredProtoWorks) { + CalculatorGraphConfig config; + { + Graph graph; + Stream input = + graph.In(0).SetName("in").Cast(); + + auto& sink_node = graph.AddNode("TestSinkCalculator"); + input.ConnectTo(sink_node.In(test_ns::kInTag)); + Stream output = sink_node.Out(test_ns::kOutTag).Cast(); + + output.ConnectTo(graph.Out(0)).SetName("out"); + + config = graph.GetConfig(); + } + + CalculatorGraph calculator_graph; + MP_ASSERT_OK(calculator_graph.Initialize(std::move(config))); + MP_ASSERT_OK(calculator_graph.StartRun({})); + + int value = 10; + auto proto = std::make_unique(); + proto->set_x(value); + MP_ASSERT_OK(calculator_graph.AddPacketToInputStream( + "in", Adopt(proto.release()).At(Timestamp(0)))); + MP_ASSERT_OK(calculator_graph.WaitUntilIdle()); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/packet_type.h b/mediapipe/framework/packet_type.h index 9b4bbd36c..10496f052 100644 --- a/mediapipe/framework/packet_type.h +++ b/mediapipe/framework/packet_type.h @@ -23,6 +23,8 @@ #include #include "absl/base/macros.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" @@ -162,15 +164,15 @@ class PacketTypeSetErrorHandler { if (!missing_) { missing_ = absl::make_unique(); } - CHECK(!missing_->initialized_errors); + ABSL_CHECK(!missing_->initialized_errors); std::string key = absl::StrCat(tag, ":", index); return missing_->entries[key]; } // In the const setting produce a FATAL error. const PacketType& GetFallback(const absl::string_view tag, int index) const { - LOG(FATAL) << "Failed to get tag \"" << tag << "\" index " << index - << ". Unable to defer error due to const specifier."; + ABSL_LOG(FATAL) << "Failed to get tag \"" << tag << "\" index " << index + << ". Unable to defer error due to const specifier."; std::abort(); } @@ -181,9 +183,9 @@ class PacketTypeSetErrorHandler { // Get the error messages that have been deferred. // This function can only be called if HasError() is true. const std::vector& ErrorMessages() const { - CHECK(missing_) << "ErrorMessages() can only be called if errors have " - "occurred. Call HasError() before calling this " - "function."; + ABSL_CHECK(missing_) << "ErrorMessages() can only be called if errors have " + "occurred. Call HasError() before calling this " + "function."; if (!missing_->initialized_errors) { for (const auto& entry : missing_->entries) { // Optional entries that were missing are not considered errors. diff --git a/mediapipe/framework/port/BUILD b/mediapipe/framework/port/BUILD index cae439bc0..f8c95d68b 100644 --- a/mediapipe/framework/port/BUILD +++ b/mediapipe/framework/port/BUILD @@ -261,8 +261,8 @@ cc_library( ) cc_library( - name = "opencv_highgui", - hdrs = ["opencv_highgui_inc.h"], + name = "opencv_photo", + hdrs = ["opencv_photo_inc.h"], deps = [ ":opencv_core", "//third_party:opencv", @@ -297,6 +297,15 @@ cc_library( ], ) +cc_library( + name = "opencv_highgui", + hdrs = ["opencv_highgui_inc.h"], + deps = [ + ":opencv_core", + "//third_party:opencv", + ], +) + cc_library( name = "opencv_videoio", hdrs = ["opencv_videoio_inc.h"], @@ -317,6 +326,7 @@ cc_library( ":core_proto", ":logging", "//mediapipe/framework:port", + "@com_google_absl//absl/log:absl_check", ], ) diff --git a/mediapipe/framework/port/drishti_proto_alias_rules.bzl b/mediapipe/framework/port/drishti_proto_alias_rules.bzl new file mode 100644 index 000000000..7df141cbe --- /dev/null +++ b/mediapipe/framework/port/drishti_proto_alias_rules.bzl @@ -0,0 +1,31 @@ +"""Rules implementation for mediapipe_proto_alias.bzl, do not load directly.""" + +def _copy_header_impl(ctx): + source = ctx.attr.source.replace("//", "").replace(":", "/") + files = [] + for dep in ctx.attr.deps: + for header in dep[CcInfo].compilation_context.direct_headers: + if (header.short_path == source): + files.append(header) + if len(files) != 1: + fail("Expected exactly 1 source, got ", str(files)) + dest_file = ctx.actions.declare_file(ctx.attr.filename) + + # Use expand_template() with no substitutions as a simple copier. + ctx.actions.expand_template( + template = files[0], + output = dest_file, + substitutions = {}, + ) + return [DefaultInfo(files = depset([dest_file]))] + +copy_header = rule( + implementation = _copy_header_impl, + attrs = { + "filename": attr.string(), + "source": attr.string(), + "deps": attr.label_list(providers = [CcInfo]), + }, + output_to_genfiles = True, + outputs = {"out": "%{filename}"}, +) diff --git a/mediapipe/framework/port/opencv_highgui_inc.h b/mediapipe/framework/port/opencv_highgui_inc.h index c3ca4b7f0..c79804e1f 100644 --- a/mediapipe/framework/port/opencv_highgui_inc.h +++ b/mediapipe/framework/port/opencv_highgui_inc.h @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef MEDIAPIPE_PORT_OPENCV_HIGHGUI_INC_H_ -#define MEDIAPIPE_PORT_OPENCV_HIGHGUI_INC_H_ +#ifndef MEDIAPIPE_FRAMEWORK_PORT_OPENCV_HIGHGUI_INC_H_ +#define MEDIAPIPE_FRAMEWORK_PORT_OPENCV_HIGHGUI_INC_H_ #include @@ -25,4 +25,4 @@ #include #endif -#endif // MEDIAPIPE_PORT_OPENCV_HIGHGUI_INC_H_ +#endif // MEDIAPIPE_FRAMEWORK_PORT_OPENCV_HIGHGUI_INC_H_ diff --git a/mediapipe/framework/port/opencv_imgcodecs_inc.h b/mediapipe/framework/port/opencv_imgcodecs_inc.h index 60bcd49e9..4c867ed56 100644 --- a/mediapipe/framework/port/opencv_imgcodecs_inc.h +++ b/mediapipe/framework/port/opencv_imgcodecs_inc.h @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2022 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/mediapipe/framework/port/opencv_photo_inc.h b/mediapipe/framework/port/opencv_photo_inc.h new file mode 100644 index 000000000..1416fda70 --- /dev/null +++ b/mediapipe/framework/port/opencv_photo_inc.h @@ -0,0 +1,20 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_PORT_OPENCV_PHOTO_INC_H_ +#define MEDIAPIPE_PORT_OPENCV_PHOTO_INC_H_ + +#include "third_party/OpenCV/photo.hpp" + +#endif // MEDIAPIPE_PORT_OPENCV_PHOTO_INC_H_ diff --git a/mediapipe/framework/port/opencv_video_inc.h b/mediapipe/framework/port/opencv_video_inc.h index dc84bf59b..5f06d9233 100644 --- a/mediapipe/framework/port/opencv_video_inc.h +++ b/mediapipe/framework/port/opencv_video_inc.h @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2022 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/mediapipe/framework/port/parse_text_proto.h b/mediapipe/framework/port/parse_text_proto.h index c352d4f01..722ded6ea 100644 --- a/mediapipe/framework/port/parse_text_proto.h +++ b/mediapipe/framework/port/parse_text_proto.h @@ -15,6 +15,7 @@ #ifndef MEDIAPIPE_PORT_PARSE_TEXT_PROTO_H_ #define MEDIAPIPE_PORT_PARSE_TEXT_PROTO_H_ +#include "absl/log/absl_check.h" #include "mediapipe/framework/port/core_proto_inc.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/proto_ns.h" @@ -29,7 +30,7 @@ bool ParseTextProto(const std::string& input, T* proto) { template T ParseTextProtoOrDie(const std::string& input) { T result; - CHECK(ParseTextProto(input, &result)); + ABSL_CHECK(ParseTextProto(input, &result)); return result; } diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index 53aeb1eaf..99699f2cd 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -116,13 +116,14 @@ cc_library( "//mediapipe/framework/port:advanced_proto_lite", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:re2", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:name_util", "//mediapipe/framework/tool:tag_map", "//mediapipe/framework/tool:validate_name", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -218,11 +219,11 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:mediapipe_options_cc_proto", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", "//mediapipe/framework/tool:tag_map", "//mediapipe/framework/tool:tag_map_helper", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", ], ) @@ -257,6 +258,7 @@ cc_test( "//mediapipe/framework/tool:simulation_clock_executor", "//mediapipe/framework/tool:status_util", "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/time", ], ) @@ -268,9 +270,9 @@ cc_test( ":sharded_map", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:threadpool", "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", ], @@ -374,6 +376,7 @@ cc_test( "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/profiler/reporter:reporter_lib", "//mediapipe/framework/tool:test_util", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], diff --git a/mediapipe/framework/profiler/gl_context_profiler.cc b/mediapipe/framework/profiler/gl_context_profiler.cc index 59c9f01ff..667d153da 100644 --- a/mediapipe/framework/profiler/gl_context_profiler.cc +++ b/mediapipe/framework/profiler/gl_context_profiler.cc @@ -14,6 +14,8 @@ #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/time/clock.h" #include "absl/time/time.h" diff --git a/mediapipe/framework/profiler/graph_profiler.cc b/mediapipe/framework/profiler/graph_profiler.cc index 6aead5250..949955111 100644 --- a/mediapipe/framework/profiler/graph_profiler.cc +++ b/mediapipe/framework/profiler/graph_profiler.cc @@ -17,13 +17,14 @@ #include #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/strings/substitute.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "mediapipe/framework/port/advanced_proto_lite_inc.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/file_helpers.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/proto_ns.h" #include "mediapipe/framework/port/re2.h" #include "mediapipe/framework/port/ret_check.h" @@ -158,7 +159,7 @@ void GraphProfiler::Initialize( const ValidatedGraphConfig& validated_graph_config) { absl::WriterMutexLock lock(&profiler_mutex_); validated_graph_ = &validated_graph_config; - CHECK(!is_initialized_) + ABSL_CHECK(!is_initialized_) << "Cannot initialize the profiler for the same graph multiple times."; profiler_config_ = validated_graph_config.Config().profiler_config(); int64 interval_size_usec = profiler_config_.histogram_interval_size_usec(); @@ -190,7 +191,7 @@ void GraphProfiler::Initialize( } auto iter = calculator_profiles_.insert({node_name, profile}); - CHECK(iter.second) << absl::Substitute( + ABSL_CHECK(iter.second) << absl::Substitute( "Calculator \"$0\" has already been added.", node_name); } profile_builder_ = std::make_unique(this); @@ -201,7 +202,7 @@ void GraphProfiler::Initialize( void GraphProfiler::SetClock(const std::shared_ptr& clock) { absl::WriterMutexLock lock(&profiler_mutex_); - CHECK(clock) << "GraphProfiler::SetClock() is called with a nullptr."; + ABSL_CHECK(clock) << "GraphProfiler::SetClock() is called with a nullptr."; clock_ = clock; } @@ -251,10 +252,10 @@ absl::Status GraphProfiler::Start(mediapipe::Executor* executor) { file::SetContents(absl::StrCat(trace_log_path, "trace_writing_check"), "can write trace logs to this location"); if (status.ok()) { - LOG(INFO) << "trace_log_path: " << trace_log_path; + ABSL_LOG(INFO) << "trace_log_path: " << trace_log_path; } else { - LOG(ERROR) << "cannot write to trace_log_path: " << trace_log_path << ": " - << status; + ABSL_LOG(ERROR) << "cannot write to trace_log_path: " << trace_log_path + << ": " << status; } is_running_ = true; @@ -315,7 +316,7 @@ void GraphProfiler::AddPacketInfo(const TraceEvent& packet_info) { return; } if (!packet_timestamp.IsRangeValue()) { - LOG(WARNING) << absl::Substitute( + ABSL_LOG(WARNING) << absl::Substitute( "Skipped adding packet info because the timestamp $0 for stream " "\"$1\" is not valid.", packet_timestamp.Value(), stream_name); @@ -386,7 +387,7 @@ std::set GraphProfiler::GetBackEdgeIds( tool::ParseTagIndex(input_stream_info.tag_index(), &tag, &index)) << absl::Substitute("Cannot parse TAG or index for the backedge \"$0\"", input_stream_info.tag_index()); - CHECK(0 <= index && index < input_tag_map.NumEntries(tag)) + ABSL_CHECK(0 <= index && index < input_tag_map.NumEntries(tag)) << absl::Substitute( "The input_stream_info for tag \"$0\" (index " "$1) does not match any input_stream.", @@ -445,7 +446,7 @@ void GraphProfiler::SetOpenRuntime(const CalculatorContext& calculator_context, const std::string& node_name = calculator_context.NodeName(); int64 time_usec = end_time_usec - start_time_usec; auto profile_iter = calculator_profiles_.find(node_name); - CHECK(profile_iter != calculator_profiles_.end()) << absl::Substitute( + ABSL_CHECK(profile_iter != calculator_profiles_.end()) << absl::Substitute( "Calculator \"$0\" has not been added during initialization.", calculator_context.NodeName()); CalculatorProfile* calculator_profile = &profile_iter->second; @@ -467,7 +468,7 @@ void GraphProfiler::SetCloseRuntime(const CalculatorContext& calculator_context, const std::string& node_name = calculator_context.NodeName(); int64 time_usec = end_time_usec - start_time_usec; auto profile_iter = calculator_profiles_.find(node_name); - CHECK(profile_iter != calculator_profiles_.end()) << absl::Substitute( + ABSL_CHECK(profile_iter != calculator_profiles_.end()) << absl::Substitute( "Calculator \"$0\" has not been added during initialization.", calculator_context.NodeName()); CalculatorProfile* calculator_profile = &profile_iter->second; @@ -482,7 +483,7 @@ void GraphProfiler::SetCloseRuntime(const CalculatorContext& calculator_context, void GraphProfiler::AddTimeSample(int64 start_time_usec, int64 end_time_usec, TimeHistogram* histogram) { if (end_time_usec < start_time_usec) { - LOG(ERROR) << absl::Substitute( + ABSL_LOG(ERROR) << absl::Substitute( "end_time_usec ($0) is < start_time_usec ($1)", end_time_usec, start_time_usec); return; @@ -519,8 +520,8 @@ int64 GraphProfiler::AddInputStreamTimeSamples( // This is a condition rather than a failure CHECK because // under certain conditions the consumer calculator's Process() // can start before the producer calculator's Process() is finished. - LOG_FIRST_N(WARNING, 10) << "Expected packet info is missing for: " - << PacketIdToString(packet_id); + ABSL_LOG_FIRST_N(WARNING, 10) << "Expected packet info is missing for: " + << PacketIdToString(packet_id); continue; } AddTimeSample( @@ -545,7 +546,7 @@ void GraphProfiler::AddProcessSample( const std::string& node_name = calculator_context.NodeName(); auto profile_iter = calculator_profiles_.find(node_name); - CHECK(profile_iter != calculator_profiles_.end()) << absl::Substitute( + ABSL_CHECK(profile_iter != calculator_profiles_.end()) << absl::Substitute( "Calculator \"$0\" has not been added during initialization.", calculator_context.NodeName()); CalculatorProfile* calculator_profile = &profile_iter->second; diff --git a/mediapipe/framework/profiler/graph_profiler_test.cc b/mediapipe/framework/profiler/graph_profiler_test.cc index e9badaa25..8a9bc141e 100644 --- a/mediapipe/framework/profiler/graph_profiler_test.cc +++ b/mediapipe/framework/profiler/graph_profiler_test.cc @@ -14,6 +14,7 @@ #include "mediapipe/framework/profiler/graph_profiler.h" +#include "absl/log/absl_log.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" @@ -59,7 +60,8 @@ CalculatorProfile GetProfileWithName( return p; } } - LOG(FATAL) << "Cannot find calulator profile with name " << calculator_name; + ABSL_LOG(FATAL) << "Cannot find calulator profile with name " + << calculator_name; return CalculatorProfile::default_instance(); } @@ -1227,7 +1229,7 @@ TEST(GraphProfilerTest, ParallelReads) { EXPECT_EQ(1003, profiles[0].process_runtime().count(0)); EXPECT_EQ(1000, profiles[1].process_runtime().count(0)); } else { - LOG(FATAL) << "Unexpected profile name " << profiles[0].name(); + ABSL_LOG(FATAL) << "Unexpected profile name " << profiles[0].name(); } EXPECT_EQ(1001, out_1_packets.size()); } diff --git a/mediapipe/framework/profiler/graph_tracer_test.cc b/mediapipe/framework/profiler/graph_tracer_test.cc index c1cc819c1..4fe9826c0 100644 --- a/mediapipe/framework/profiler/graph_tracer_test.cc +++ b/mediapipe/framework/profiler/graph_tracer_test.cc @@ -22,6 +22,7 @@ #include #include "absl/flags/flag.h" +#include "absl/log/absl_check.h" #include "absl/time/time.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -332,7 +333,7 @@ TEST_F(GraphTracerTest, GraphTrace) { class GraphTracerE2ETest : public ::testing::Test { protected: void SetUpPassThroughGraph() { - CHECK(proto_ns::TextFormat::ParseFromString(R"( + ABSL_CHECK(proto_ns::TextFormat::ParseFromString(R"( input_stream: "input_0" node { calculator: "LambdaCalculator" @@ -346,11 +347,11 @@ class GraphTracerE2ETest : public ::testing::Test { trace_enabled: true } )", - &graph_config_)); + &graph_config_)); } void SetUpDemuxInFlightGraph() { - CHECK(proto_ns::TextFormat::ParseFromString(R"( + ABSL_CHECK(proto_ns::TextFormat::ParseFromString(R"( node { calculator: "LambdaCalculator" input_side_packet: 'callback_2' @@ -404,7 +405,7 @@ class GraphTracerE2ETest : public ::testing::Test { trace_enabled: true } )", - &graph_config_)); + &graph_config_)); } absl::Time ParseTime(const std::string& date_time_str) { @@ -1372,7 +1373,7 @@ TEST_F(GraphTracerE2ETest, GpuTaskTrace) { // Show that trace_enabled activates the GlContextProfiler. TEST_F(GraphTracerE2ETest, GpuTracing) { - CHECK(proto_ns::TextFormat::ParseFromString(R"( + ABSL_CHECK(proto_ns::TextFormat::ParseFromString(R"( input_stream: "input_buffer" input_stream: "render_data" output_stream: "annotated_buffer" @@ -1386,7 +1387,7 @@ TEST_F(GraphTracerE2ETest, GpuTracing) { trace_enabled: true } )", - &graph_config_)); + &graph_config_)); // Create the CalculatorGraph with only trace_enabled set. MP_ASSERT_OK(graph_.Initialize(graph_config_, {})); @@ -1423,5 +1424,13 @@ TEST_F(GraphTracerE2ETest, DestructGraph) { } } +TEST(TraceBuilderTest, EventDataIsExtracted) { + int value = 10; + Packet p = PointToForeign(&value); + TraceEvent event; + event.set_packet_data_id(&p); + EXPECT_EQ(event.event_data, reinterpret_cast(&value)); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/profiler/reporter_test.cc b/mediapipe/framework/profiler/reporter_test.cc index e5bc541a7..6ca6c6424 100644 --- a/mediapipe/framework/profiler/reporter_test.cc +++ b/mediapipe/framework/profiler/reporter_test.cc @@ -21,6 +21,7 @@ #include #include +#include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator.pb.h" @@ -43,15 +44,15 @@ using ::testing::IsSupersetOf; void LoadGraphProfile(const std::string& path, GraphProfile* proto) { int fd = open(path.c_str(), O_RDONLY); if (fd == -1) { - LOG(ERROR) << "could not open test graph: " << path - << ", error: " << strerror(errno); + ABSL_LOG(ERROR) << "could not open test graph: " << path + << ", error: " << strerror(errno); return; } proto_ns::io::FileInputStream input(fd); bool success = proto->ParseFromZeroCopyStream(&input); close(fd); if (!success) { - LOG(ERROR) << "could not parse test graph: " << path; + ABSL_LOG(ERROR) << "could not parse test graph: " << path; } } diff --git a/mediapipe/framework/profiler/sharded_map_test.cc b/mediapipe/framework/profiler/sharded_map_test.cc index e551b25c8..5a47b390b 100644 --- a/mediapipe/framework/profiler/sharded_map_test.cc +++ b/mediapipe/framework/profiler/sharded_map_test.cc @@ -17,13 +17,13 @@ #include #include "absl/container/node_hash_map.h" +#include "absl/log/absl_log.h" #include "absl/synchronization/mutex.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/integral_types.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/threadpool.h" namespace { @@ -134,9 +134,9 @@ TEST(ShardedMapTest, TestParallelAccess) { ShardedMap sharded_map(4999); TestParallelAccess(sharded_map, 13); }); - LOG(INFO) << "Ellapsed time: simple_map: " << simple_time; - LOG(INFO) << "Ellapsed time: safe_map: " << safe_time; - LOG(INFO) << "Ellapsed time: sharded_map: " << sharded_time; + ABSL_LOG(INFO) << "Ellapsed time: simple_map: " << simple_time; + ABSL_LOG(INFO) << "Ellapsed time: safe_map: " << safe_time; + ABSL_LOG(INFO) << "Ellapsed time: sharded_map: " << sharded_time; } } // namespace diff --git a/mediapipe/framework/profiler/test_context_builder.h b/mediapipe/framework/profiler/test_context_builder.h index abf9ee749..4018a0349 100644 --- a/mediapipe/framework/profiler/test_context_builder.h +++ b/mediapipe/framework/profiler/test_context_builder.h @@ -21,11 +21,11 @@ #include #include +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_options.pb.h" #include "mediapipe/framework/mediapipe_options.pb.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/tool/tag_map.h" @@ -92,7 +92,7 @@ class TestContextBuilder { spec.name = output_map_->Names()[id.value()]; spec.packet_type = packet_type; spec.error_callback = [](const absl::Status& status) { - LOG(ERROR) << status; + ABSL_LOG(ERROR) << status; }; output_specs_[spec.name] = spec; } diff --git a/mediapipe/framework/profiler/testing/BUILD b/mediapipe/framework/profiler/testing/BUILD index 0b0d256e5..55b3613f9 100644 --- a/mediapipe/framework/profiler/testing/BUILD +++ b/mediapipe/framework/profiler/testing/BUILD @@ -15,9 +15,7 @@ licenses(["notice"]) -package( - default_visibility = ["//mediapipe/framework:__subpackages__"], -) +package(default_visibility = ["//mediapipe/framework:__subpackages__"]) cc_library( name = "simple_calculator", @@ -25,6 +23,7 @@ cc_library( deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_log", ], alwayslink = 1, ) diff --git a/mediapipe/framework/profiler/testing/simple_calculator.cc b/mediapipe/framework/profiler/testing/simple_calculator.cc index 18ba67b9b..fa1123ee0 100644 --- a/mediapipe/framework/profiler/testing/simple_calculator.cc +++ b/mediapipe/framework/profiler/testing/simple_calculator.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/log/absl_log.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/status.h" @@ -28,7 +29,7 @@ class SimpleCalculator : public CalculatorBase { } absl::Status Process(CalculatorContext* cc) final { - LOG(WARNING) << "Simple Calculator Process called, count_: " << count_; + ABSL_LOG(WARNING) << "Simple Calculator Process called, count_: " << count_; int max_count = 1; if (cc->InputSidePackets().HasTag("MAX_COUNT")) { max_count = cc->InputSidePackets().Tag("MAX_COUNT").Get(); diff --git a/mediapipe/framework/profiler/trace_buffer.h b/mediapipe/framework/profiler/trace_buffer.h index b5e2d9994..8dc09aef7 100644 --- a/mediapipe/framework/profiler/trace_buffer.h +++ b/mediapipe/framework/profiler/trace_buffer.h @@ -15,6 +15,9 @@ #ifndef MEDIAPIPE_FRAMEWORK_PROFILER_TRACE_BUFFER_H_ #define MEDIAPIPE_FRAMEWORK_PROFILER_TRACE_BUFFER_H_ +#include +#include + #include "absl/time/time.h" #include "mediapipe/framework/calculator_profile.pb.h" #include "mediapipe/framework/packet.h" @@ -23,17 +26,6 @@ namespace mediapipe { -namespace packet_internal { -// Returns a hash of the packet data address from a packet data holder. -inline const int64 GetPacketDataId(const HolderBase* holder) { - if (holder == nullptr) { - return 0; - } - const void* address = &(static_cast*>(holder)->data()); - return reinterpret_cast(address); -} -} // namespace packet_internal - // Packet trace log event. struct TraceEvent { using EventType = GraphTrace::EventType; @@ -75,8 +67,12 @@ struct TraceEvent { return *this; } inline TraceEvent& set_packet_data_id(const Packet* packet) { - this->event_data = - packet_internal::GetPacketDataId(packet_internal::GetHolder(*packet)); + const auto* holder = packet_internal::GetHolder(*packet); + int64_t data_id = 0; + if (holder != nullptr) { + data_id = holder->DebugDataId(); + } + this->event_data = data_id; return *this; } inline TraceEvent& set_thread_id(int thread_id) { diff --git a/mediapipe/framework/scheduler.cc b/mediapipe/framework/scheduler.cc index 854c10fd5..36effe016 100644 --- a/mediapipe/framework/scheduler.cc +++ b/mediapipe/framework/scheduler.cc @@ -19,6 +19,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/calculator_graph.h" @@ -77,7 +78,7 @@ void Scheduler::Reset() { void Scheduler::CloseAllSourceNodes() { shared_.stopping = true; } void Scheduler::SetExecutor(Executor* executor) { - CHECK_EQ(state_, STATE_NOT_STARTED) + ABSL_CHECK_EQ(state_, STATE_NOT_STARTED) << "SetExecutor must not be called after the scheduler has started"; default_queue_.SetExecutor(executor); } @@ -147,7 +148,7 @@ void Scheduler::HandleIdle() { // Note: TryToScheduleNextSourceLayer unlocks and locks state_mutex_ // internally. bool did_activate = TryToScheduleNextSourceLayer(); - CHECK(did_activate || active_sources_.empty()); + ABSL_CHECK(did_activate || active_sources_.empty()); continue; } @@ -183,7 +184,7 @@ void Scheduler::HandleIdle() { void Scheduler::Quit() { // All calls to Calculator::Process() have returned (even if we had an // error). - CHECK(state_ == STATE_RUNNING || state_ == STATE_CANCELLING); + ABSL_CHECK(state_ == STATE_RUNNING || state_ == STATE_CANCELLING); SetQueuesRunning(false); shared_.timer.EndRun(); @@ -198,7 +199,7 @@ void Scheduler::Start() { shared_.timer.StartRun(); { absl::MutexLock lock(&state_mutex_); - CHECK_EQ(state_, STATE_NOT_STARTED); + ABSL_CHECK_EQ(state_, STATE_NOT_STARTED); state_ = STATE_RUNNING; SetQueuesRunning(true); @@ -270,13 +271,6 @@ absl::Status Scheduler::WaitForObservedOutput() { return observed ? absl::OkStatus() : absl::OutOfRangeError("Graph is done."); } -// Idleness requires: -// 1. either the graph has no source nodes or all source nodes are closed, and -// 2. no packets are added to graph input streams. -// For simplicity, we only allow WaitUntilIdle() to be called on a graph with -// no source nodes. (This is enforced by CalculatorGraph::WaitUntilIdle().) -// The application must ensure no other threads are adding packets to graph -// input streams while a WaitUntilIdle() call is in progress. absl::Status Scheduler::WaitUntilIdle() { RET_CHECK_NE(state_, STATE_NOT_STARTED); ApplicationThreadAwait(std::bind(&Scheduler::IsIdle, this)); @@ -333,15 +327,15 @@ void Scheduler::ClosedAllGraphInputStreams() { // container. void Scheduler::ScheduleNodeIfNotThrottled( CalculatorNode* node, CalculatorContext* calculator_context) { - DCHECK(node); - DCHECK(calculator_context); + ABSL_DCHECK(node); + ABSL_DCHECK(calculator_context); if (!graph_->IsNodeThrottled(node->Id())) { node->GetSchedulerQueue()->AddNode(node, calculator_context); } } void Scheduler::ScheduleNodeForOpen(CalculatorNode* node) { - DCHECK(node); + ABSL_DCHECK(node); VLOG(1) << "Scheduling OpenNode of calculator " << node->DebugName(); node->GetSchedulerQueue()->AddNodeForOpen(node); } @@ -351,7 +345,7 @@ void Scheduler::ScheduleUnthrottledReadyNodes( for (CalculatorNode* node : nodes_to_schedule) { // Source nodes always reuse the default calculator context because they // can't be executed in parallel. - CHECK(node->IsSource()); + ABSL_CHECK(node->IsSource()); CalculatorContext* default_context = node->GetDefaultCalculatorContext(); node->GetSchedulerQueue()->AddNode(node, default_context); } @@ -374,8 +368,8 @@ void Scheduler::CleanupActiveSources() { bool Scheduler::TryToScheduleNextSourceLayer() { VLOG(3) << "TryToScheduleNextSourceLayer"; - CHECK(active_sources_.empty()); - CHECK(!sources_queue_.empty()); + ABSL_CHECK(active_sources_.empty()); + ABSL_CHECK(!sources_queue_.empty()); if (!unopened_sources_.empty() && (*unopened_sources_.begin())->source_layer() < @@ -427,8 +421,9 @@ bool Scheduler::TryToScheduleNextSourceLayer() { } void Scheduler::AddUnopenedSourceNode(CalculatorNode* node) { - CHECK_EQ(state_, STATE_NOT_STARTED) << "AddUnopenedSourceNode can only be " - "called before starting the scheduler"; + ABSL_CHECK_EQ(state_, STATE_NOT_STARTED) + << "AddUnopenedSourceNode can only be " + "called before starting the scheduler"; unopened_sources_.insert(node); } @@ -445,7 +440,7 @@ void Scheduler::AssignNodeToSchedulerQueue(CalculatorNode* node) { SchedulerQueue* queue; if (!node->Executor().empty()) { auto iter = non_default_queues_.find(node->Executor()); - CHECK(iter != non_default_queues_.end()); + ABSL_CHECK(iter != non_default_queues_.end()); queue = iter->second.get(); } else { queue = &default_queue_; @@ -528,7 +523,7 @@ void Scheduler::CleanupAfterRun() { while (!sources_queue_.empty()) { sources_queue_.pop(); } - CHECK(app_thread_tasks_.empty()); + ABSL_CHECK(app_thread_tasks_.empty()); } for (auto queue : scheduler_queues_) { queue->CleanupAfterRun(); @@ -539,7 +534,7 @@ void Scheduler::CleanupAfterRun() { } internal::SchedulerTimes Scheduler::GetSchedulerTimes() { - CHECK_EQ(state_, STATE_TERMINATED); + ABSL_CHECK_EQ(state_, STATE_TERMINATED); return shared_.timer.GetSchedulerTimes(); } diff --git a/mediapipe/framework/scheduler.h b/mediapipe/framework/scheduler.h index b59467b9f..22d552c71 100644 --- a/mediapipe/framework/scheduler.h +++ b/mediapipe/framework/scheduler.h @@ -76,6 +76,16 @@ class Scheduler { // be scheduled and nothing is running in the worker threads. This function // can be called only after Start(). // Runs application thread tasks while waiting. + // + // Idleness requires: + // 1. either the graph has no source nodes or all source nodes are closed, and + // 2. no packets are added to graph input streams. + // + // For simplicity, we only fully support WaitUntilIdle() to be called on a + // graph with no source nodes. + // + // The application must ensure no other threads are adding packets to graph + // input streams while a WaitUntilIdle() call is in progress. absl::Status WaitUntilIdle() ABSL_LOCKS_EXCLUDED(state_mutex_); // Wait until any graph input stream has been unthrottled. @@ -310,7 +320,7 @@ class Scheduler { absl::Mutex state_mutex_; // Current state of the scheduler. - std::atomic state_ = ATOMIC_VAR_INIT(STATE_NOT_STARTED); + std::atomic state_ = STATE_NOT_STARTED; // True if all graph input streams are closed. bool graph_input_streams_closed_ ABSL_GUARDED_BY(state_mutex_) = false; diff --git a/mediapipe/framework/scheduler_queue.cc b/mediapipe/framework/scheduler_queue.cc index 33214cf64..557d7e40e 100644 --- a/mediapipe/framework/scheduler_queue.cc +++ b/mediapipe/framework/scheduler_queue.cc @@ -18,6 +18,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/calculator_node.h" #include "mediapipe/framework/executor.h" @@ -36,8 +37,8 @@ namespace internal { SchedulerQueue::Item::Item(CalculatorNode* node, CalculatorContext* cc) : node_(node), cc_(cc) { - CHECK(node); - CHECK(cc); + ABSL_CHECK(node); + ABSL_CHECK(cc); is_source_ = node->IsSource(); id_ = node->Id(); if (is_source_) { @@ -48,7 +49,7 @@ SchedulerQueue::Item::Item(CalculatorNode* node, CalculatorContext* cc) SchedulerQueue::Item::Item(CalculatorNode* node) : node_(node), cc_(nullptr), is_open_node_(true) { - CHECK(node); + ABSL_CHECK(node); is_source_ = node->IsSource(); id_ = node->Id(); if (is_source_) { @@ -104,7 +105,7 @@ bool SchedulerQueue::IsIdle() { void SchedulerQueue::SetRunning(bool running) { absl::MutexLock lock(&mutex_); running_count_ += running ? 1 : -1; - DCHECK_LE(running_count_, 1); + ABSL_DCHECK_LE(running_count_, 1); } void SchedulerQueue::AddNode(CalculatorNode* node, CalculatorContext* cc) { @@ -117,7 +118,7 @@ void SchedulerQueue::AddNode(CalculatorNode* node, CalculatorContext* cc) { // Only happens when the framework tries to schedule an unthrottled source // node while it's running. For non-source nodes, if a calculator context is // prepared, it is committed to be scheduled. - CHECK(node->IsSource()) << node->DebugName(); + ABSL_CHECK(node->IsSource()) << node->DebugName(); return; } AddItemToQueue(Item(node, cc)); @@ -192,15 +193,16 @@ void SchedulerQueue::RunNextTask() { { absl::MutexLock lock(&mutex_); - CHECK(!queue_.empty()) << "Called RunNextTask when the queue is empty. " - "This should not happen."; + ABSL_CHECK(!queue_.empty()) + << "Called RunNextTask when the queue is empty. " + "This should not happen."; node = queue_.top().Node(); calculator_context = queue_.top().Context(); is_open_node = queue_.top().IsOpenNode(); queue_.pop(); - CHECK(!node->Closed()) + ABSL_CHECK(!node->Closed()) << "Scheduled a node that was closed. This should not happen."; } @@ -211,7 +213,7 @@ void SchedulerQueue::RunNextTask() { // do it here to ensure all executors are covered. AUTORELEASEPOOL { if (is_open_node) { - DCHECK(!calculator_context); + ABSL_DCHECK(!calculator_context); OpenCalculatorNode(node); } else { RunCalculatorNode(node, calculator_context); @@ -221,7 +223,7 @@ void SchedulerQueue::RunNextTask() { bool is_idle; { absl::MutexLock lock(&mutex_); - DCHECK_GT(num_pending_tasks_, 0); + ABSL_DCHECK_GT(num_pending_tasks_, 0); --num_pending_tasks_; is_idle = IsIdle(); } @@ -266,8 +268,8 @@ void SchedulerQueue::RunCalculatorNode(CalculatorNode* node, // that all sources will be closed and no further sources should be // scheduled. The graph will be terminated as soon as its scheduler // queue becomes empty. - CHECK(!node->IsSource()); // ProcessNode takes care of StatusStop() - // from sources. + ABSL_CHECK(!node->IsSource()); // ProcessNode takes care of + // StatusStop() from sources. shared_->stopping = true; } else { // If we have an error in this calculator. @@ -299,8 +301,8 @@ void SchedulerQueue::CleanupAfterRun() { { absl::MutexLock lock(&mutex_); was_idle = IsIdle(); - CHECK_EQ(num_pending_tasks_, 0); - CHECK_EQ(num_tasks_to_add_, queue_.size()); + ABSL_CHECK_EQ(num_pending_tasks_, 0); + ABSL_CHECK_EQ(num_tasks_to_add_, queue_.size()); num_tasks_to_add_ = 0; while (!queue_.empty()) { queue_.pop(); diff --git a/mediapipe/framework/stream_handler/BUILD b/mediapipe/framework/stream_handler/BUILD index 8b54ade8b..c3eb334fa 100644 --- a/mediapipe/framework/stream_handler/BUILD +++ b/mediapipe/framework/stream_handler/BUILD @@ -53,8 +53,16 @@ mediapipe_proto_library( cc_library( name = "barrier_input_stream_handler", srcs = ["barrier_input_stream_handler.cc"], + hdrs = ["barrier_input_stream_handler.h"], deps = [ + "//mediapipe/framework:calculator_context_manager", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", + "//mediapipe/framework:mediapipe_options_cc_proto", + "//mediapipe/framework/tool:tag_map", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", ], alwayslink = 1, ) @@ -74,8 +82,15 @@ cc_library( cc_library( name = "early_close_input_stream_handler", srcs = ["early_close_input_stream_handler.cc"], + hdrs = ["early_close_input_stream_handler.h"], deps = [ + "//mediapipe/framework:calculator_context_manager", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", + "//mediapipe/framework:mediapipe_options_cc_proto", + "//mediapipe/framework/tool:tag_map", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -84,10 +99,21 @@ cc_library( cc_library( name = "fixed_size_input_stream_handler", srcs = ["fixed_size_input_stream_handler.cc"], + hdrs = ["fixed_size_input_stream_handler.h"], deps = [ ":default_input_stream_handler", ":fixed_size_input_stream_handler_cc_proto", + "//mediapipe/framework:calculator_context_manager", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", + "//mediapipe/framework:mediapipe_options_cc_proto", + "//mediapipe/framework:packet", + "//mediapipe/framework/tool:tag_map", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/synchronization", ], alwayslink = 1, ) @@ -95,8 +121,18 @@ cc_library( cc_library( name = "immediate_input_stream_handler", srcs = ["immediate_input_stream_handler.cc"], + hdrs = ["immediate_input_stream_handler.h"], deps = [ + "//mediapipe/framework:calculator_context_manager", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", + "//mediapipe/framework:mediapipe_options_cc_proto", + "//mediapipe/framework/tool:tag_map", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", ], alwayslink = 1, ) @@ -115,6 +151,7 @@ cc_library( "//mediapipe/framework:packet_set", "//mediapipe/framework:timestamp", "//mediapipe/framework/tool:tag_map", + "@com_google_absl//absl/log:absl_check", ], alwayslink = 1, ) @@ -122,9 +159,13 @@ cc_library( cc_library( name = "mux_input_stream_handler", srcs = ["mux_input_stream_handler.cc"], + hdrs = ["mux_input_stream_handler.h"], deps = [ + "//mediapipe/framework:calculator_context_manager", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", - "//mediapipe/framework/port:logging", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", ], @@ -134,16 +175,23 @@ cc_library( cc_library( name = "sync_set_input_stream_handler", srcs = ["sync_set_input_stream_handler.cc"], + hdrs = ["sync_set_input_stream_handler.h"], deps = [ ":sync_set_input_stream_handler_cc_proto", - "//mediapipe/framework:collection", + "//mediapipe/framework:calculator_context_manager", + "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:packet_set", "//mediapipe/framework:timestamp", + "//mediapipe/framework/port:map_util", + "//mediapipe/framework/port:status", "//mediapipe/framework/tool:tag_map", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", ], alwayslink = 1, @@ -152,12 +200,19 @@ cc_library( cc_library( name = "timestamp_align_input_stream_handler", srcs = ["timestamp_align_input_stream_handler.cc"], + hdrs = ["timestamp_align_input_stream_handler.h"], deps = [ ":timestamp_align_input_stream_handler_cc_proto", + "//mediapipe/framework:calculator_context_manager", + "//mediapipe/framework:calculator_framework", "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", + "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:timestamp", "//mediapipe/framework/tool:validate_name", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", ], @@ -176,6 +231,7 @@ cc_test( "//mediapipe/framework/tool:tag_map", "//mediapipe/framework/tool:tag_map_helper", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", ], ) @@ -194,6 +250,7 @@ cc_test( "//mediapipe/framework/tool:tag_map", "//mediapipe/framework/tool:tag_map_helper", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", ], ) diff --git a/mediapipe/framework/stream_handler/barrier_input_stream_handler.cc b/mediapipe/framework/stream_handler/barrier_input_stream_handler.cc index ece873b1e..4150fafac 100644 --- a/mediapipe/framework/stream_handler/barrier_input_stream_handler.cc +++ b/mediapipe/framework/stream_handler/barrier_input_stream_handler.cc @@ -11,84 +11,70 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/stream_handler/barrier_input_stream_handler.h" -#include -#include -#include +#include +#include +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/input_stream_handler.h" namespace mediapipe { -// Implementation of an input stream handler that considers a node as ready for -// Process() if all input streams have a packet available. This implies it must -// consider a node as ready for Close() if any input stream is done. -class BarrierInputStreamHandler : public InputStreamHandler { - public: - BarrierInputStreamHandler() = delete; - BarrierInputStreamHandler( - std::shared_ptr tag_map, - CalculatorContextManager* calculator_context_manager, - const MediaPipeOptions& options, bool calculator_run_in_parallel) - : InputStreamHandler(std::move(tag_map), calculator_context_manager, - options, calculator_run_in_parallel) {} - - void PrepareForRun( - std::function headers_ready_callback, - std::function notification_callback, - std::function schedule_callback, - std::function error_callback) override { - InputStreamHandler::PrepareForRun( - std::move(headers_ready_callback), std::move(notification_callback), - std::move(schedule_callback), std::move(error_callback)); - for (auto& stream : input_stream_managers_) { - stream->DisableTimestamps(); - } +void BarrierInputStreamHandler::PrepareForRun( + std::function headers_ready_callback, + std::function notification_callback, + std::function schedule_callback, + std::function error_callback) { + InputStreamHandler::PrepareForRun( + std::move(headers_ready_callback), std::move(notification_callback), + std::move(schedule_callback), std::move(error_callback)); + for (auto& stream : input_stream_managers_) { + stream->DisableTimestamps(); } +} - protected: - // In BarrierInputStreamHandler, a node is "ready" if: - // - any stream is done (need to call Close() in this case), or - // - all streams have a packet available. - NodeReadiness GetNodeReadiness(Timestamp* min_stream_timestamp) override { - DCHECK(min_stream_timestamp); - *min_stream_timestamp = Timestamp::Done(); - bool all_available = true; - for (const auto& stream : input_stream_managers_) { - bool empty; - Timestamp stream_timestamp = stream->MinTimestampOrBound(&empty); - if (empty) { - if (stream_timestamp == Timestamp::Done()) { - *min_stream_timestamp = Timestamp::Done(); - return NodeReadiness::kReadyForClose; - } - all_available = false; +NodeReadiness BarrierInputStreamHandler::GetNodeReadiness( + Timestamp* min_stream_timestamp) { + ABSL_DCHECK(min_stream_timestamp); + *min_stream_timestamp = Timestamp::Done(); + bool all_available = true; + for (const auto& stream : input_stream_managers_) { + bool empty; + Timestamp stream_timestamp = stream->MinTimestampOrBound(&empty); + if (empty) { + if (stream_timestamp == Timestamp::Done()) { + *min_stream_timestamp = Timestamp::Done(); + return NodeReadiness::kReadyForClose; } - *min_stream_timestamp = std::min(*min_stream_timestamp, stream_timestamp); + all_available = false; } - - CHECK_NE(*min_stream_timestamp, Timestamp::Done()); - if (all_available) { - return NodeReadiness::kReadyForProcess; - } - return NodeReadiness::kNotReady; + *min_stream_timestamp = std::min(*min_stream_timestamp, stream_timestamp); } - // Only invoked when associated GetNodeReadiness() returned kReadyForProcess. - void FillInputSet(Timestamp input_timestamp, - InputStreamShardSet* input_set) override { - CHECK(input_timestamp.IsAllowedInStream()); - CHECK(input_set); - for (CollectionItemId id = input_stream_managers_.BeginId(); - id < input_stream_managers_.EndId(); ++id) { - auto& stream = input_stream_managers_.Get(id); - bool stream_is_done = false; - Packet current_packet = stream->PopQueueHead(&stream_is_done); - AddPacketToShard(&input_set->Get(id), std::move(current_packet), - stream_is_done); - } + ABSL_CHECK_NE(*min_stream_timestamp, Timestamp::Done()); + if (all_available) { + return NodeReadiness::kReadyForProcess; } -}; + return NodeReadiness::kNotReady; +} + +void BarrierInputStreamHandler::FillInputSet(Timestamp input_timestamp, + InputStreamShardSet* input_set) { + ABSL_CHECK(input_timestamp.IsAllowedInStream()); + ABSL_CHECK(input_set); + for (CollectionItemId id = input_stream_managers_.BeginId(); + id < input_stream_managers_.EndId(); ++id) { + auto& stream = input_stream_managers_.Get(id); + bool stream_is_done = false; + Packet current_packet = stream->PopQueueHead(&stream_is_done); + AddPacketToShard(&input_set->Get(id), std::move(current_packet), + stream_is_done); + } +} REGISTER_INPUT_STREAM_HANDLER(BarrierInputStreamHandler); diff --git a/mediapipe/framework/stream_handler/barrier_input_stream_handler.h b/mediapipe/framework/stream_handler/barrier_input_stream_handler.h new file mode 100644 index 000000000..55a21d332 --- /dev/null +++ b/mediapipe/framework/stream_handler/barrier_input_stream_handler.h @@ -0,0 +1,64 @@ + +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_BARRIER_INPUT_STREAM_HANDLER_H_ +#define MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_BARRIER_INPUT_STREAM_HANDLER_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "mediapipe/framework/calculator_context_manager.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/input_stream_handler.h" +#include "mediapipe/framework/mediapipe_options.pb.h" +#include "mediapipe/framework/tool/tag_map.h" + +namespace mediapipe { + +// Implementation of an input stream handler that considers a node as ready for +// Process() if all input streams have a packet available. This implies it must +// consider a node as ready for Close() if any input stream is done. +class BarrierInputStreamHandler : public InputStreamHandler { + public: + BarrierInputStreamHandler() = delete; + BarrierInputStreamHandler( + std::shared_ptr tag_map, + CalculatorContextManager* calculator_context_manager, + const mediapipe::MediaPipeOptions& options, + bool calculator_run_in_parallel) + : InputStreamHandler(std::move(tag_map), calculator_context_manager, + options, calculator_run_in_parallel) {} + + void PrepareForRun(std::function headers_ready_callback, + std::function notification_callback, + std::function schedule_callback, + std::function error_callback) override; + + protected: + // In BarrierInputStreamHandler, a node is "ready" if: + // - any stream is done (need to call Close() in this case), or + // - all streams have a packet available. + NodeReadiness GetNodeReadiness(Timestamp* min_stream_timestamp) override; + + // Only invoked when associated GetNodeReadiness() returned kReadyForProcess. + void FillInputSet(Timestamp input_timestamp, + InputStreamShardSet* input_set) override; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_BARRIER_INPUT_STREAM_HANDLER_H_ diff --git a/mediapipe/framework/stream_handler/barrier_input_stream_handler_test.cc b/mediapipe/framework/stream_handler/barrier_input_stream_handler_test.cc index 9f341ba54..deb04fc39 100644 --- a/mediapipe/framework/stream_handler/barrier_input_stream_handler_test.cc +++ b/mediapipe/framework/stream_handler/barrier_input_stream_handler_test.cc @@ -18,6 +18,7 @@ #include #include "absl/base/macros.h" +#include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_context_manager.h" @@ -105,7 +106,7 @@ class BarrierInputStreamHandlerTest : public ::testing::Test { void NotifyNoOp() {} void Schedule(CalculatorContext* calculator_context) { - CHECK(calculator_context); + ABSL_CHECK(calculator_context); calculator_context_ = calculator_context; } diff --git a/mediapipe/framework/stream_handler/early_close_input_stream_handler.cc b/mediapipe/framework/stream_handler/early_close_input_stream_handler.cc index 983b986c3..3a7dd8678 100644 --- a/mediapipe/framework/stream_handler/early_close_input_stream_handler.cc +++ b/mediapipe/framework/stream_handler/early_close_input_stream_handler.cc @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -11,81 +11,70 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/stream_handler/early_close_input_stream_handler.h" #include -#include -#include +#include "absl/log/absl_check.h" #include "absl/strings/substitute.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/input_stream_handler.h" namespace mediapipe { -// Implementation of an input stream handler that considers a node as ready for -// Close() if any input stream is done. -class EarlyCloseInputStreamHandler : public InputStreamHandler { - public: - EarlyCloseInputStreamHandler() = delete; - EarlyCloseInputStreamHandler(std::shared_ptr tag_map, - CalculatorContextManager* cc_manager, - const MediaPipeOptions& options, - bool calculator_run_in_parallel) - : InputStreamHandler(std::move(tag_map), cc_manager, options, - calculator_run_in_parallel) {} - - protected: - // In EarlyCloseInputStreamHandler, a node is "ready" if: - // - any stream is done (need to call Close() in this case), or - // - the minimum bound (over all empty streams) is greater than the smallest - // timestamp of any stream, which means we have received all the packets - // that will be available at the next timestamp. - NodeReadiness GetNodeReadiness(Timestamp* min_stream_timestamp) override { - DCHECK(min_stream_timestamp); - *min_stream_timestamp = Timestamp::Done(); - Timestamp min_bound = Timestamp::Done(); - for (const auto& stream : input_stream_managers_) { - bool empty; - Timestamp stream_timestamp = stream->MinTimestampOrBound(&empty); - if (empty) { - if (stream_timestamp == Timestamp::Done()) { - *min_stream_timestamp = Timestamp::Done(); - return NodeReadiness::kReadyForClose; - } - min_bound = std::min(min_bound, stream_timestamp); +// In EarlyCloseInputStreamHandler, a node is "ready" if: +// - any stream is done (need to call Close() in this case), or +// - the minimum bound (over all empty streams) is greater than the smallest +// timestamp of any stream, which means we have received all the packets +// that will be available at the next timestamp. +NodeReadiness EarlyCloseInputStreamHandler::GetNodeReadiness( + Timestamp* min_stream_timestamp) { + ABSL_DCHECK(min_stream_timestamp); + *min_stream_timestamp = Timestamp::Done(); + Timestamp min_bound = Timestamp::Done(); + for (const auto& stream : input_stream_managers_) { + bool empty; + Timestamp stream_timestamp = stream->MinTimestampOrBound(&empty); + if (empty) { + if (stream_timestamp == Timestamp::Done()) { + *min_stream_timestamp = Timestamp::Done(); + return NodeReadiness::kReadyForClose; } - *min_stream_timestamp = std::min(*min_stream_timestamp, stream_timestamp); + min_bound = std::min(min_bound, stream_timestamp); } - - CHECK_NE(*min_stream_timestamp, Timestamp::Done()); - - if (min_bound > *min_stream_timestamp) { - return NodeReadiness::kReadyForProcess; - } - - CHECK_EQ(min_bound, *min_stream_timestamp); - return NodeReadiness::kNotReady; + *min_stream_timestamp = std::min(*min_stream_timestamp, stream_timestamp); } - // Only invoked when associated GetNodeReadiness() returned kReadyForProcess. - void FillInputSet(Timestamp input_timestamp, - InputStreamShardSet* input_set) override { - CHECK(input_timestamp.IsAllowedInStream()); - CHECK(input_set); - for (CollectionItemId id = input_stream_managers_.BeginId(); - id < input_stream_managers_.EndId(); ++id) { - auto& stream = input_stream_managers_.Get(id); - int num_packets_dropped = 0; - bool stream_is_done = false; - Packet current_packet = stream->PopPacketAtTimestamp( - input_timestamp, &num_packets_dropped, &stream_is_done); - CHECK_EQ(num_packets_dropped, 0) - << absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".", - num_packets_dropped, stream->Name()); - AddPacketToShard(&input_set->Get(id), std::move(current_packet), - stream_is_done); - } + ABSL_CHECK_NE(*min_stream_timestamp, Timestamp::Done()); + + if (min_bound > *min_stream_timestamp) { + return NodeReadiness::kReadyForProcess; } -}; + + ABSL_CHECK_EQ(min_bound, *min_stream_timestamp); + return NodeReadiness::kNotReady; +} + +// Only invoked when associated GetNodeReadiness() returned kReadyForProcess. +void EarlyCloseInputStreamHandler::FillInputSet( + Timestamp input_timestamp, InputStreamShardSet* input_set) { + ABSL_CHECK(input_timestamp.IsAllowedInStream()); + ABSL_CHECK(input_set); + for (CollectionItemId id = input_stream_managers_.BeginId(); + id < input_stream_managers_.EndId(); ++id) { + auto& stream = input_stream_managers_.Get(id); + int num_packets_dropped = 0; + bool stream_is_done = false; + Packet current_packet = stream->PopPacketAtTimestamp( + input_timestamp, &num_packets_dropped, &stream_is_done); + ABSL_CHECK_EQ(num_packets_dropped, 0) + << absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".", + num_packets_dropped, stream->Name()); + AddPacketToShard(&input_set->Get(id), std::move(current_packet), + stream_is_done); + } +} REGISTER_INPUT_STREAM_HANDLER(EarlyCloseInputStreamHandler); diff --git a/mediapipe/framework/stream_handler/early_close_input_stream_handler.h b/mediapipe/framework/stream_handler/early_close_input_stream_handler.h new file mode 100644 index 000000000..081954ef2 --- /dev/null +++ b/mediapipe/framework/stream_handler/early_close_input_stream_handler.h @@ -0,0 +1,56 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_EARLY_CLOSE_INPUT_STREAM_HANDLER_H_ +#define MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_EARLY_CLOSE_INPUT_STREAM_HANDLER_H_ + +#include +#include + +#include "mediapipe/framework/calculator_context_manager.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/input_stream_handler.h" +#include "mediapipe/framework/mediapipe_options.pb.h" +#include "mediapipe/framework/tool/tag_map.h" + +namespace mediapipe { + +// Implementation of an input stream handler that considers a node as ready for +// Close() if any input stream is done. +class EarlyCloseInputStreamHandler : public InputStreamHandler { + public: + EarlyCloseInputStreamHandler() = delete; + EarlyCloseInputStreamHandler(std::shared_ptr tag_map, + CalculatorContextManager* cc_manager, + const mediapipe::MediaPipeOptions& options, + bool calculator_run_in_parallel) + : InputStreamHandler(std::move(tag_map), cc_manager, options, + calculator_run_in_parallel) {} + + protected: + // In EarlyCloseInputStreamHandler, a node is "ready" if: + // - any stream is done (need to call Close() in this case), or + // - the minimum bound (over all empty streams) is greater than the smallest + // timestamp of any stream, which means we have received all the packets + // that will be available at the next timestamp. + NodeReadiness GetNodeReadiness(Timestamp* min_stream_timestamp) override; + + // Only invoked when associated GetNodeReadiness() returned kReadyForProcess. + void FillInputSet(Timestamp input_timestamp, + InputStreamShardSet* input_set) override; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_EARLY_CLOSE_INPUT_STREAM_HANDLER_H_ diff --git a/mediapipe/framework/stream_handler/fixed_size_input_stream_handler.cc b/mediapipe/framework/stream_handler/fixed_size_input_stream_handler.cc index fd51a7383..cb4e0fafa 100644 --- a/mediapipe/framework/stream_handler/fixed_size_input_stream_handler.cc +++ b/mediapipe/framework/stream_handler/fixed_size_input_stream_handler.cc @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -11,219 +11,185 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/stream_handler/fixed_size_input_stream_handler.h" +#include +#include #include +#include #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator_context_manager.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection_item_id.h" +#include "mediapipe/framework/input_stream_handler.h" +#include "mediapipe/framework/mediapipe_options.pb.h" +#include "mediapipe/framework/packet.h" #include "mediapipe/framework/stream_handler/default_input_stream_handler.h" -// TODO: Move protos in another CL after the C++ code migration. #include "mediapipe/framework/stream_handler/fixed_size_input_stream_handler.pb.h" +#include "mediapipe/framework/tool/tag_map.h" namespace mediapipe { -// Input stream handler that limits each input queue to a maximum of -// target_queue_size packets, discarding older packets as needed. When a -// timestamp is dropped from a stream, it is dropped from all others as well. -// -// For example, a calculator node with one input stream and the following input -// stream handler specs: -// -// node { -// calculator: "CalculatorRunningAtOneFps" -// input_stream: "packets_streaming_in_at_ten_fps" -// input_stream_handler { -// input_stream_handler: "FixedSizeInputStreamHandler" -// } -// } -// -// will always try to keep the newest packet in the input stream. -// -// A few details: FixedSizeInputStreamHandler takes action when any stream grows -// to trigger_queue_size or larger. It then keeps at most target_queue_size -// packets in every InputStreamImpl. Every stream is truncated at the same -// timestamp, so that each included timestamp delivers the same packets as -// DefaultInputStreamHandler includes. -// -class FixedSizeInputStreamHandler : public DefaultInputStreamHandler { - public: - FixedSizeInputStreamHandler() = delete; - FixedSizeInputStreamHandler(std::shared_ptr tag_map, - CalculatorContextManager* cc_manager, - const MediaPipeOptions& options, - bool calculator_run_in_parallel) - : DefaultInputStreamHandler(std::move(tag_map), cc_manager, options, - calculator_run_in_parallel) { - const auto& ext = - options.GetExtension(FixedSizeInputStreamHandlerOptions::ext); - trigger_queue_size_ = ext.trigger_queue_size(); - target_queue_size_ = ext.target_queue_size(); - fixed_min_size_ = ext.fixed_min_size(); - pending_ = false; - kept_timestamp_ = Timestamp::Unset(); - // TODO: Either re-enable SetLatePreparation(true) with - // CalculatorContext::InputTimestamp set correctly, or remove the - // implementation of SetLatePreparation. - } +FixedSizeInputStreamHandler::FixedSizeInputStreamHandler( + std::shared_ptr tag_map, CalculatorContextManager* cc_manager, + const mediapipe::MediaPipeOptions& options, bool calculator_run_in_parallel) + : DefaultInputStreamHandler(std::move(tag_map), cc_manager, options, + calculator_run_in_parallel) { + const auto& ext = + options.GetExtension(mediapipe::FixedSizeInputStreamHandlerOptions::ext); + trigger_queue_size_ = ext.trigger_queue_size(); + target_queue_size_ = ext.target_queue_size(); + fixed_min_size_ = ext.fixed_min_size(); + pending_ = false; + kept_timestamp_ = Timestamp::Unset(); + // TODO: Either re-enable SetLatePreparation(true) with + // CalculatorContext::InputTimestamp set correctly, or remove the + // implementation of SetLatePreparation. +} - private: - // Drops packets if all input streams exceed trigger_queue_size. - void EraseAllSurplus() ABSL_EXCLUSIVE_LOCKS_REQUIRED(erase_mutex_) { - Timestamp min_timestamp_all_streams = Timestamp::Max(); - for (const auto& stream : input_stream_managers_) { - // Check whether every InputStreamImpl grew beyond trigger_queue_size. - if (stream->QueueSize() < trigger_queue_size_) { - return; - } - Timestamp min_timestamp = - stream->GetMinTimestampAmongNLatest(target_queue_size_); - - // Record the min timestamp among the newest target_queue_size_ packets - // across all InputStreamImpls. - min_timestamp_all_streams = - std::min(min_timestamp_all_streams, min_timestamp); +void FixedSizeInputStreamHandler::EraseAllSurplus() { + Timestamp min_timestamp_all_streams = Timestamp::Max(); + for (const auto& stream : input_stream_managers_) { + // Check whether every InputStreamImpl grew beyond trigger_queue_size. + if (stream->QueueSize() < trigger_queue_size_) { + return; } - for (auto& stream : input_stream_managers_) { - stream->ErasePacketsEarlierThan(min_timestamp_all_streams); + Timestamp min_timestamp = + stream->GetMinTimestampAmongNLatest(target_queue_size_); + + // Record the min timestamp among the newest target_queue_size_ packets + // across all InputStreamImpls. + min_timestamp_all_streams = + std::min(min_timestamp_all_streams, min_timestamp); + } + for (auto& stream : input_stream_managers_) { + stream->ErasePacketsEarlierThan(min_timestamp_all_streams); + } +} + +Timestamp FixedSizeInputStreamHandler::PreviousAllowedInStream( + Timestamp bound) { + return bound.IsRangeValue() ? bound - 1 : bound; +} + +Timestamp FixedSizeInputStreamHandler::MinStreamBound() { + Timestamp min_bound = Timestamp::Done(); + for (const auto& stream : input_stream_managers_) { + Timestamp stream_bound = stream->GetMinTimestampAmongNLatest(1); + if (stream_bound > Timestamp::Unset()) { + stream_bound = stream_bound.NextAllowedInStream(); + } else { + stream_bound = stream->MinTimestampOrBound(nullptr); + } + min_bound = std::min(min_bound, stream_bound); + } + return min_bound; +} + +Timestamp FixedSizeInputStreamHandler::MinTimestampToProcess() { + Timestamp min_bound = Timestamp::Done(); + for (const auto& stream : input_stream_managers_) { + bool empty; + Timestamp stream_timestamp = stream->MinTimestampOrBound(&empty); + // If we're using the stream's *bound*, we only want to process up to the + // packet *before* the bound, because a packet may still arrive at that + // time. + if (empty) { + stream_timestamp = PreviousAllowedInStream(stream_timestamp); + } + min_bound = std::min(min_bound, stream_timestamp); + } + return min_bound; +} + +void FixedSizeInputStreamHandler::EraseAnySurplus(bool keep_one) { + // Record the most recent first kept timestamp on any stream. + for (const auto& stream : input_stream_managers_) { + int32_t queue_size = (stream->QueueSize() >= trigger_queue_size_) + ? target_queue_size_ + : trigger_queue_size_ - 1; + if (stream->QueueSize() > queue_size) { + kept_timestamp_ = std::max( + kept_timestamp_, stream->GetMinTimestampAmongNLatest(queue_size + 1) + .NextAllowedInStream()); } } - - // Returns the latest timestamp allowed before a bound. - Timestamp PreviousAllowedInStream(Timestamp bound) { - return bound.IsRangeValue() ? bound - 1 : bound; + if (keep_one) { + // In order to preserve one viable timestamp, do not truncate past + // the timestamp bound of the least current stream. + kept_timestamp_ = + std::min(kept_timestamp_, PreviousAllowedInStream(MinStreamBound())); } - - // Returns the lowest timestamp at which a packet may arrive at any stream. - Timestamp MinStreamBound() { - Timestamp min_bound = Timestamp::Done(); - for (const auto& stream : input_stream_managers_) { - Timestamp stream_bound = stream->GetMinTimestampAmongNLatest(1); - if (stream_bound > Timestamp::Unset()) { - stream_bound = stream_bound.NextAllowedInStream(); - } else { - stream_bound = stream->MinTimestampOrBound(nullptr); - } - min_bound = std::min(min_bound, stream_bound); - } - return min_bound; + for (auto& stream : input_stream_managers_) { + stream->ErasePacketsEarlierThan(kept_timestamp_); } +} - // Returns the lowest timestamp of a packet ready to process. - Timestamp MinTimestampToProcess() { - Timestamp min_bound = Timestamp::Done(); - for (const auto& stream : input_stream_managers_) { - bool empty; - Timestamp stream_timestamp = stream->MinTimestampOrBound(&empty); - // If we're using the stream's *bound*, we only want to process up to the - // packet *before* the bound, because a packet may still arrive at that - // time. - if (empty) { - stream_timestamp = PreviousAllowedInStream(stream_timestamp); - } - min_bound = std::min(min_bound, stream_timestamp); - } - return min_bound; +void FixedSizeInputStreamHandler::EraseSurplusPackets(bool keep_one) { + return (fixed_min_size_) ? EraseAllSurplus() : EraseAnySurplus(keep_one); +} + +NodeReadiness FixedSizeInputStreamHandler::GetNodeReadiness( + Timestamp* min_stream_timestamp) { + ABSL_DCHECK(min_stream_timestamp); + absl::MutexLock lock(&erase_mutex_); + // kReadyForProcess is returned only once until FillInputSet completes. + // In late_preparation mode, GetNodeReadiness must return kReadyForProcess + // exactly once for each input-set produced. Here, GetNodeReadiness + // releases just one input-set at a time and then disables input queue + // truncation until that promised input-set is consumed. + if (pending_) { + return NodeReadiness::kNotReady; } + EraseSurplusPackets(false); + NodeReadiness result = + DefaultInputStreamHandler::GetNodeReadiness(min_stream_timestamp); - // Keeps only the most recent target_queue_size packets in each stream - // exceeding trigger_queue_size. Also, discards all packets older than the - // first kept timestamp on any stream. - void EraseAnySurplus(bool keep_one) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(erase_mutex_) { - // Record the most recent first kept timestamp on any stream. - for (const auto& stream : input_stream_managers_) { - int32_t queue_size = (stream->QueueSize() >= trigger_queue_size_) - ? target_queue_size_ - : trigger_queue_size_ - 1; - if (stream->QueueSize() > queue_size) { - kept_timestamp_ = std::max( - kept_timestamp_, stream->GetMinTimestampAmongNLatest(queue_size + 1) - .NextAllowedInStream()); - } - } - if (keep_one) { - // In order to preserve one viable timestamp, do not truncate past - // the timestamp bound of the least current stream. - kept_timestamp_ = - std::min(kept_timestamp_, PreviousAllowedInStream(MinStreamBound())); - } - for (auto& stream : input_stream_managers_) { - stream->ErasePacketsEarlierThan(kept_timestamp_); - } - } - - void EraseSurplusPackets(bool keep_one) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(erase_mutex_) { - return (fixed_min_size_) ? EraseAllSurplus() : EraseAnySurplus(keep_one); - } - - NodeReadiness GetNodeReadiness(Timestamp* min_stream_timestamp) override { - DCHECK(min_stream_timestamp); - absl::MutexLock lock(&erase_mutex_); - // kReadyForProcess is returned only once until FillInputSet completes. - // In late_preparation mode, GetNodeReadiness must return kReadyForProcess - // exactly once for each input-set produced. Here, GetNodeReadiness - // releases just one input-set at a time and then disables input queue - // truncation until that promised input-set is consumed. - if (pending_) { - return NodeReadiness::kNotReady; - } + // If a packet has arrived below kept_timestamp_, recalculate. + while (*min_stream_timestamp < kept_timestamp_ && + result == NodeReadiness::kReadyForProcess) { EraseSurplusPackets(false); - NodeReadiness result = - DefaultInputStreamHandler::GetNodeReadiness(min_stream_timestamp); - - // If a packet has arrived below kept_timestamp_, recalculate. - while (*min_stream_timestamp < kept_timestamp_ && - result == NodeReadiness::kReadyForProcess) { - EraseSurplusPackets(false); - result = - DefaultInputStreamHandler::GetNodeReadiness(min_stream_timestamp); - } - pending_ = (result == NodeReadiness::kReadyForProcess); - return result; + result = DefaultInputStreamHandler::GetNodeReadiness(min_stream_timestamp); } + pending_ = (result == NodeReadiness::kReadyForProcess); + return result; +} - void AddPackets(CollectionItemId id, - const std::list& packets) override { - InputStreamHandler::AddPackets(id, packets); - absl::MutexLock lock(&erase_mutex_); - if (!pending_) { - EraseSurplusPackets(false); - } +void FixedSizeInputStreamHandler::AddPackets(CollectionItemId id, + const std::list& packets) { + InputStreamHandler::AddPackets(id, packets); + absl::MutexLock lock(&erase_mutex_); + if (!pending_) { + EraseSurplusPackets(false); } +} - void MovePackets(CollectionItemId id, std::list* packets) override { - InputStreamHandler::MovePackets(id, packets); - absl::MutexLock lock(&erase_mutex_); - if (!pending_) { - EraseSurplusPackets(false); - } +void FixedSizeInputStreamHandler::MovePackets(CollectionItemId id, + std::list* packets) { + InputStreamHandler::MovePackets(id, packets); + absl::MutexLock lock(&erase_mutex_); + if (!pending_) { + EraseSurplusPackets(false); } +} - void FillInputSet(Timestamp input_timestamp, - InputStreamShardSet* input_set) override { - CHECK(input_set); - absl::MutexLock lock(&erase_mutex_); - if (!pending_) { - LOG(ERROR) << "FillInputSet called without GetNodeReadiness."; - } - // input_timestamp is recalculated here to process the most recent packets. - EraseSurplusPackets(true); - input_timestamp = MinTimestampToProcess(); - DefaultInputStreamHandler::FillInputSet(input_timestamp, input_set); - pending_ = false; +void FixedSizeInputStreamHandler::FillInputSet(Timestamp input_timestamp, + InputStreamShardSet* input_set) { + ABSL_CHECK(input_set); + absl::MutexLock lock(&erase_mutex_); + if (!pending_) { + ABSL_LOG(ERROR) << "FillInputSet called without GetNodeReadiness."; } - - private: - int32_t trigger_queue_size_; - int32_t target_queue_size_; - bool fixed_min_size_; - // Indicates that GetNodeReadiness has returned kReadyForProcess once, and - // the corresponding call to FillInputSet has not yet completed. - bool pending_ ABSL_GUARDED_BY(erase_mutex_); - // The timestamp used to truncate all input streams. - Timestamp kept_timestamp_ ABSL_GUARDED_BY(erase_mutex_); - absl::Mutex erase_mutex_; -}; + // input_timestamp is recalculated here to process the most recent packets. + EraseSurplusPackets(true); + input_timestamp = MinTimestampToProcess(); + DefaultInputStreamHandler::FillInputSet(input_timestamp, input_set); + pending_ = false; +} REGISTER_INPUT_STREAM_HANDLER(FixedSizeInputStreamHandler); diff --git a/mediapipe/framework/stream_handler/fixed_size_input_stream_handler.h b/mediapipe/framework/stream_handler/fixed_size_input_stream_handler.h new file mode 100644 index 000000000..a00bdda55 --- /dev/null +++ b/mediapipe/framework/stream_handler/fixed_size_input_stream_handler.h @@ -0,0 +1,108 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_FIXED_SIZE_INPUT_STREAM_HANDLER_H_ +#define MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_FIXED_SIZE_INPUT_STREAM_HANDLER_H_ + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator_context_manager.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection_item_id.h" +#include "mediapipe/framework/input_stream_handler.h" +#include "mediapipe/framework/stream_handler/default_input_stream_handler.h" + +namespace mediapipe { + +// Input stream handler that limits each input queue to a maximum of +// target_queue_size packets, discarding older packets as needed. When a +// timestamp is dropped from a stream, it is dropped from all others as well. +// +// For example, a calculator node with one input stream and the following input +// stream handler specs: +// +// node { +// calculator: "CalculatorRunningAtOneFps" +// input_stream: "packets_streaming_in_at_ten_fps" +// input_stream_handler { +// input_stream_handler: "FixedSizeInputStreamHandler" +// } +// } +// +// will always try to keep the newest packet in the input stream. +// +// A few details: FixedSizeInputStreamHandler takes action when any stream grows +// to trigger_queue_size or larger. It then keeps at most target_queue_size +// packets in every InputStreamImpl. Every stream is truncated at the same +// timestamp, so that each included timestamp delivers the same packets as +// DefaultInputStreamHandler includes. +class FixedSizeInputStreamHandler : public DefaultInputStreamHandler { + public: + FixedSizeInputStreamHandler() = delete; + FixedSizeInputStreamHandler(std::shared_ptr tag_map, + CalculatorContextManager* cc_manager, + const MediaPipeOptions& options, + bool calculator_run_in_parallel); + + private: + // Drops packets if all input streams exceed trigger_queue_size. + void EraseAllSurplus() ABSL_EXCLUSIVE_LOCKS_REQUIRED(erase_mutex_); + + // Returns the latest timestamp allowed before a bound. + Timestamp PreviousAllowedInStream(Timestamp bound); + + // Returns the lowest timestamp at which a packet may arrive at any stream. + Timestamp MinStreamBound(); + + // Returns the lowest timestamp of a packet ready to process. + Timestamp MinTimestampToProcess(); + + // Keeps only the most recent target_queue_size packets in each stream + // exceeding trigger_queue_size. Also, discards all packets older than the + // first kept timestamp on any stream. + void EraseAnySurplus(bool keep_one) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(erase_mutex_); + + void EraseSurplusPackets(bool keep_one) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(erase_mutex_); + + NodeReadiness GetNodeReadiness(Timestamp* min_stream_timestamp) override; + + void AddPackets(CollectionItemId id, + const std::list& packets) override; + + void MovePackets(CollectionItemId id, std::list* packets) override; + + void FillInputSet(Timestamp input_timestamp, + InputStreamShardSet* input_set) override; + + private: + int32_t trigger_queue_size_; + int32_t target_queue_size_; + bool fixed_min_size_; + // Indicates that GetNodeReadiness has returned kReadyForProcess once, and + // the corresponding call to FillInputSet has not yet completed. + bool pending_ ABSL_GUARDED_BY(erase_mutex_); + // The timestamp used to truncate all input streams. + Timestamp kept_timestamp_ ABSL_GUARDED_BY(erase_mutex_); + absl::Mutex erase_mutex_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_FIXED_SIZE_INPUT_STREAM_HANDLER_H_ diff --git a/mediapipe/framework/stream_handler/fixed_size_input_stream_handler_test.cc b/mediapipe/framework/stream_handler/fixed_size_input_stream_handler_test.cc index 4f1367a9a..186d59dfe 100644 --- a/mediapipe/framework/stream_handler/fixed_size_input_stream_handler_test.cc +++ b/mediapipe/framework/stream_handler/fixed_size_input_stream_handler_test.cc @@ -30,15 +30,15 @@ namespace mediapipe { namespace { -const int64 kMaxPacketId = 100; -const int64 kSlowCalculatorRate = 10; +const int64_t kMaxPacketId = 100; +const int64_t kSlowCalculatorRate = 10; // Rate limiter for TestSlowCalculator. ABSL_CONST_INIT absl::Mutex g_source_mutex(absl::kConstInit); -int64 g_source_counter ABSL_GUARDED_BY(g_source_mutex); +int64_t g_source_counter ABSL_GUARDED_BY(g_source_mutex); // Rate limiter for TestSourceCalculator. -int64 g_slow_counter ABSL_GUARDED_BY(g_source_mutex); +int64_t g_slow_counter ABSL_GUARDED_BY(g_source_mutex); // Flag that indicates that the source is done. bool g_source_done ABSL_GUARDED_BY(g_source_mutex); @@ -47,7 +47,7 @@ class TestSourceCalculator : public CalculatorBase { public: TestSourceCalculator() : current_packet_id_(0) {} static absl::Status GetContract(CalculatorContract* cc) { - cc->Outputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) override { @@ -62,7 +62,7 @@ class TestSourceCalculator : public CalculatorBase { g_source_done = true; return tool::StatusStop(); } - cc->Outputs().Index(0).Add(new int64(0), Timestamp(current_packet_id_)); + cc->Outputs().Index(0).Add(new int64_t(0), Timestamp(current_packet_id_)); ++current_packet_id_; { absl::MutexLock lock(&g_source_mutex); @@ -78,7 +78,7 @@ class TestSourceCalculator : public CalculatorBase { return g_source_counter <= kSlowCalculatorRate * g_slow_counter || g_source_counter <= 1; } - int64 current_packet_id_; + int64_t current_packet_id_; }; REGISTER_CALCULATOR(TestSourceCalculator); @@ -87,8 +87,8 @@ class TestSlowCalculator : public CalculatorBase { public: TestSlowCalculator() = default; static absl::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Index(0).Set(); - cc->Outputs().Index(0).Set(); + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) override { @@ -97,7 +97,7 @@ class TestSlowCalculator : public CalculatorBase { return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { - cc->Outputs().Index(0).Add(new int64(0), + cc->Outputs().Index(0).Add(new int64_t(0), cc->Inputs().Index(0).Value().Timestamp()); { absl::MutexLock lock(&g_source_mutex); @@ -118,8 +118,9 @@ class TestSlowCalculator : public CalculatorBase { REGISTER_CALCULATOR(TestSlowCalculator); // Return the values of the timestamps of a vector of Packets. -static std::vector TimestampValues(const std::vector& packets) { - std::vector result; +static std::vector TimestampValues( + const std::vector& packets) { + std::vector result; for (const Packet& p : packets) { result.push_back(p.Timestamp().Value()); } @@ -174,7 +175,7 @@ TEST_P(FixedSizeInputStreamHandlerTest, DropsPackets) { // consumed. In this way, the TestSlowCalculator consumes and outputs only // every tenth packet. EXPECT_EQ(output_packets.size(), 11); - std::vector expected_ts = {0, 9, 19, 29, 39, 49, 59, 69, 79, 89, 99}; + std::vector expected_ts = {0, 9, 19, 29, 39, 49, 59, 69, 79, 89, 99}; EXPECT_THAT(TimestampValues(output_packets), testing::ContainerEq(expected_ts)); } @@ -344,18 +345,18 @@ TEST_P(FixedSizeInputStreamHandlerTest, LateArrivalDrop) { if (GetParam()) { EXPECT_THAT(TimestampValues(output_packets[0]), - testing::ContainerEq(std::vector{1, 2, 3, 4, 5, 6})); + testing::ContainerEq(std::vector{1, 2, 3, 4, 5, 6})); EXPECT_THAT(TimestampValues(output_packets[1]), - testing::ContainerEq(std::vector{3, 4, 5, 6, 7})); + testing::ContainerEq(std::vector{3, 4, 5, 6, 7})); EXPECT_THAT(TimestampValues(output_packets[2]), - testing::ContainerEq(std::vector{4, 5, 6, 7})); + testing::ContainerEq(std::vector{4, 5, 6, 7})); } else { EXPECT_THAT(TimestampValues(output_packets[0]), - testing::ContainerEq(std::vector{5, 6})); + testing::ContainerEq(std::vector{5, 6})); EXPECT_THAT(TimestampValues(output_packets[1]), - testing::ContainerEq(std::vector{5, 6, 7})); + testing::ContainerEq(std::vector{5, 6, 7})); EXPECT_THAT(TimestampValues(output_packets[2]), - testing::ContainerEq(std::vector{5, 6, 7})); + testing::ContainerEq(std::vector{5, 6, 7})); } } diff --git a/mediapipe/framework/stream_handler/immediate_input_stream_handler.cc b/mediapipe/framework/stream_handler/immediate_input_stream_handler.cc index c34fc96b3..b2fc1aa8d 100644 --- a/mediapipe/framework/stream_handler/immediate_input_stream_handler.cc +++ b/mediapipe/framework/stream_handler/immediate_input_stream_handler.cc @@ -11,65 +11,33 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/stream_handler/immediate_input_stream_handler.h" +#include +#include #include #include +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator_context_manager.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/input_stream_handler.h" +#include "mediapipe/framework/mediapipe_options.pb.h" +#include "mediapipe/framework/tool/tag_map.h" namespace mediapipe { using SyncSet = InputStreamHandler::SyncSet; -// An input stream handler that delivers input packets to the Calculator -// immediately, with no dependency between input streams. It also invokes -// Calculator::Process when any input stream becomes done. -// -// NOTE: If packets arrive successively on different input streams with -// identical or decreasing timestamps, this input stream handler will -// invoke its Calculator with a sequence of InputTimestamps that is -// non-increasing. Its Calculator is responsible for accumulating packets -// with the required timetamps before processing and delivering output. -// -class ImmediateInputStreamHandler : public InputStreamHandler { - public: - ImmediateInputStreamHandler() = delete; - ImmediateInputStreamHandler( - std::shared_ptr tag_map, - CalculatorContextManager* calculator_context_manager, - const MediaPipeOptions& options, bool calculator_run_in_parallel); - - protected: - // Reinitializes this InputStreamHandler before each CalculatorGraph run. - void PrepareForRun(std::function headers_ready_callback, - std::function notification_callback, - std::function schedule_callback, - std::function error_callback) override; - - // Returns kReadyForProcess whenever a Packet is available at any of - // the input streams, or any input stream becomes done. - NodeReadiness GetNodeReadiness(Timestamp* min_stream_timestamp) override; - - // Selects a packet on each stream with an available packet with the - // specified timestamp, leaving other input streams unaffected. - void FillInputSet(Timestamp input_timestamp, - InputStreamShardSet* input_set) override; - - // Returns the number of sync-sets maintained by this input-handler. - int SyncSetCount() override; - - absl::Mutex mutex_; - // The packet-set builder for each input stream. - std::vector sync_sets_ ABSL_GUARDED_BY(mutex_); - // The input timestamp for each kReadyForProcess input stream. - std::vector ready_timestamps_ ABSL_GUARDED_BY(mutex_); -}; REGISTER_INPUT_STREAM_HANDLER(ImmediateInputStreamHandler); ImmediateInputStreamHandler::ImmediateInputStreamHandler( std::shared_ptr tag_map, CalculatorContextManager* calculator_context_manager, - const MediaPipeOptions& options, bool calculator_run_in_parallel) + const mediapipe::MediaPipeOptions& options, bool calculator_run_in_parallel) : InputStreamHandler(tag_map, calculator_context_manager, options, calculator_run_in_parallel) { for (auto id = tag_map->BeginId(); id < tag_map->EndId(); ++id) { @@ -115,7 +83,7 @@ NodeReadiness ImmediateInputStreamHandler::GetNodeReadiness( ready_timestamps_[i] = stream_ts; input_timestamp = std::min(input_timestamp, stream_ts); } else if (readiness == NodeReadiness::kReadyForClose) { - CHECK_EQ(stream_ts, Timestamp::Done()); + ABSL_CHECK_EQ(stream_ts, Timestamp::Done()); if (ProcessTimestampBounds()) { // With kReadyForClose, the timestamp-bound Done is returned. // TODO: Make all InputStreamHandlers process Done() like this. diff --git a/mediapipe/framework/stream_handler/immediate_input_stream_handler.h b/mediapipe/framework/stream_handler/immediate_input_stream_handler.h new file mode 100644 index 000000000..dd15ad997 --- /dev/null +++ b/mediapipe/framework/stream_handler/immediate_input_stream_handler.h @@ -0,0 +1,77 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_IMMEDIATE_INPUT_STREAM_HANDLER_H_ +#define MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_IMMEDIATE_INPUT_STREAM_HANDLER_H_ + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator_context_manager.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/input_stream_handler.h" +#include "mediapipe/framework/tool/tag_map.h" + +namespace mediapipe { + +// An input stream handler that delivers input packets to the Calculator +// immediately, with no dependency between input streams. It also invokes +// Calculator::Process when any input stream becomes done. +// +// NOTE: If packets arrive successively on different input streams with +// identical or decreasing timestamps, this input stream handler will +// invoke its Calculator with a sequence of InputTimestamps that is +// non-increasing. Its Calculator is responsible for accumulating packets +// with the required timestamps before processing and delivering output. +class ImmediateInputStreamHandler : public InputStreamHandler { + public: + ImmediateInputStreamHandler() = delete; + ImmediateInputStreamHandler( + std::shared_ptr tag_map, + CalculatorContextManager* calculator_context_manager, + const MediaPipeOptions& options, bool calculator_run_in_parallel); + + protected: + // Reinitializes this InputStreamHandler before each CalculatorGraph run. + void PrepareForRun(std::function headers_ready_callback, + std::function notification_callback, + std::function schedule_callback, + std::function error_callback) override; + + // Returns kReadyForProcess whenever a Packet is available at any of + // the input streams, or any input stream becomes done. + NodeReadiness GetNodeReadiness(Timestamp* min_stream_timestamp) override; + + // Selects a packet on each stream with an available packet with the + // specified timestamp, leaving other input streams unaffected. + void FillInputSet(Timestamp input_timestamp, + InputStreamShardSet* input_set) override; + + // Returns the number of sync-sets maintained by this input-handler. + int SyncSetCount() override; + + absl::Mutex mutex_; + // The packet-set builder for each input stream. + std::vector sync_sets_ ABSL_GUARDED_BY(mutex_); + // The input timestamp for each kReadyForProcess input stream. + std::vector ready_timestamps_ ABSL_GUARDED_BY(mutex_); +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_IMMEDIATE_INPUT_STREAM_HANDLER_H_ diff --git a/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc b/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc index e5de7f0c9..04b1c490b 100644 --- a/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc +++ b/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc @@ -18,6 +18,7 @@ #include #include "absl/base/macros.h" +#include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_context_manager.h" @@ -104,7 +105,7 @@ class ImmediateInputStreamHandlerTest : public ::testing::Test { void NotifyNoOp() {} void Schedule(CalculatorContext* cc) { - CHECK(cc); + ABSL_CHECK(cc); cc_ = cc; } @@ -132,7 +133,7 @@ class ImmediateInputStreamHandlerTest : public ::testing::Test { } const InputStream& Input(const CollectionItemId& id) { - CHECK(cc_); + ABSL_CHECK(cc_); return cc_->Inputs().Get(id); } diff --git a/mediapipe/framework/stream_handler/in_order_output_stream_handler.cc b/mediapipe/framework/stream_handler/in_order_output_stream_handler.cc index 9af38ecdd..8faaacebe 100644 --- a/mediapipe/framework/stream_handler/in_order_output_stream_handler.cc +++ b/mediapipe/framework/stream_handler/in_order_output_stream_handler.cc @@ -14,6 +14,7 @@ #include "mediapipe/framework/stream_handler/in_order_output_stream_handler.h" +#include "absl/log/absl_check.h" #include "mediapipe/framework/collection.h" #include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/output_stream_shard.h" @@ -23,7 +24,7 @@ namespace mediapipe { REGISTER_OUTPUT_STREAM_HANDLER(InOrderOutputStreamHandler); void InOrderOutputStreamHandler::PropagationLoop() { - CHECK_EQ(propagation_state_, kIdle); + ABSL_CHECK_EQ(propagation_state_, kIdle); Timestamp context_timestamp; CalculatorContext* calculator_context; if (!calculator_context_manager_->HasActiveContexts()) { @@ -34,7 +35,7 @@ void InOrderOutputStreamHandler::PropagationLoop() { if (!completed_input_timestamps_.empty()) { Timestamp completed_timestamp = *completed_input_timestamps_.begin(); if (context_timestamp != completed_timestamp) { - CHECK_LT(context_timestamp, completed_timestamp); + ABSL_CHECK_LT(context_timestamp, completed_timestamp); return; } propagation_state_ = kPropagatingPackets; @@ -45,7 +46,7 @@ void InOrderOutputStreamHandler::PropagationLoop() { if (propagation_state_ == kPropagatingPackets) { PropagatePackets(&calculator_context, &context_timestamp); } else { - CHECK_EQ(kPropagatingBound, propagation_state_); + ABSL_CHECK_EQ(kPropagatingBound, propagation_state_); PropagationBound(&calculator_context, &context_timestamp); } } @@ -105,12 +106,12 @@ void InOrderOutputStreamHandler::PropagationBound( } // Some recent changes require the propagation thread to recheck if any // new packets can be propagated. - CHECK_EQ(propagation_state_, kPropagationPending); + ABSL_CHECK_EQ(propagation_state_, kPropagationPending); // task_timestamp_bound_ was updated while the propagation thread was // doing timestamp propagation. This thread will redo timestamp // propagation for the new task_timestamp_bound_. if (!calculator_context_manager_->HasActiveContexts()) { - CHECK_LT(bound_to_propagate, task_timestamp_bound_); + ABSL_CHECK_LT(bound_to_propagate, task_timestamp_bound_); propagation_state_ = kPropagatingBound; return; } diff --git a/mediapipe/framework/stream_handler/mux_input_stream_handler.cc b/mediapipe/framework/stream_handler/mux_input_stream_handler.cc index 0303a5778..a0253b9cd 100644 --- a/mediapipe/framework/stream_handler/mux_input_stream_handler.cc +++ b/mediapipe/framework/stream_handler/mux_input_stream_handler.cc @@ -11,151 +11,124 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/stream_handler/mux_input_stream_handler.h" +#include + +#include "absl/log/absl_check.h" #include "absl/strings/substitute.h" #include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/input_stream_handler.h" -#include "mediapipe/framework/port/logging.h" namespace mediapipe { -// Implementation of the input stream handler for the MuxCalculator. -// -// One of the input streams is the control stream; all the other input streams -// are data streams. To make MuxInputStreamHandler work properly, the tag of the -// input streams must obey the following rules: -// Let N be the number of input streams. Data streams must use tag "INPUT" with -// index 0, ..., N - 2; the control stream must use tag "SELECT". -// -// The control stream carries packets of type 'int'. The 'int' value in a -// control stream packet must be a valid index in the range 0, ..., N - 2 and -// select the data stream at that index. The selected data stream must have a -// packet with the same timestamp as the control stream packet. -// -// When the control stream is done, GetNodeReadiness() returns -// NodeReadiness::kReadyForClose. -// -// TODO: pass the input stream tags to the MuxInputStreamHandler -// constructor so that it can refer to input streams by tag. See b/30125118. -class MuxInputStreamHandler : public InputStreamHandler { - public: - MuxInputStreamHandler() = delete; - MuxInputStreamHandler(std::shared_ptr tag_map, - CalculatorContextManager* cc_manager, - const MediaPipeOptions& options, - bool calculator_run_in_parallel) - : InputStreamHandler(std::move(tag_map), cc_manager, options, - calculator_run_in_parallel) {} +CollectionItemId MuxInputStreamHandler::GetControlStreamId() const { + return input_stream_managers_.EndId() - 1; +} +void MuxInputStreamHandler::RemoveOutdatedDataPackets(Timestamp timestamp) { + const CollectionItemId control_stream_id = GetControlStreamId(); + for (CollectionItemId id = input_stream_managers_.BeginId(); + id < control_stream_id; ++id) { + input_stream_managers_.Get(id)->ErasePacketsEarlierThan(timestamp); + } +} - protected: - // In MuxInputStreamHandler, a node is "ready" if: - // - the control stream is done (need to call Close() in this case), or - // - we have received the packets on the control stream and the selected data - // stream at the next timestamp. - NodeReadiness GetNodeReadiness(Timestamp* min_stream_timestamp) override { - DCHECK(min_stream_timestamp); - absl::MutexLock lock(&input_streams_mutex_); +// In MuxInputStreamHandler, a node is "ready" if: +// - the control stream is done (need to call Close() in this case), or +// - we have received the packets on the control stream and the selected data +// stream at the next timestamp. +NodeReadiness MuxInputStreamHandler::GetNodeReadiness( + Timestamp* min_stream_timestamp) { + ABSL_DCHECK(min_stream_timestamp); + absl::MutexLock lock(&input_streams_mutex_); - const auto& control_stream = - input_stream_managers_.Get(input_stream_managers_.EndId() - 1); - bool empty; - *min_stream_timestamp = control_stream->MinTimestampOrBound(&empty); - if (empty) { - if (*min_stream_timestamp == Timestamp::Done()) { - // Calculator is done if the control input stream is done. - return NodeReadiness::kReadyForClose; - } - // Calculator is not ready to run if the control input stream is empty. + const auto& control_stream = input_stream_managers_.Get(GetControlStreamId()); + bool empty; + *min_stream_timestamp = control_stream->MinTimestampOrBound(&empty); + + // Data streams may contain some outdated packets which failed to be popped + // out during "FillInputSet". (This handler doesn't sync input streams, + // hence "FillInputSet" can be triggered before every input stream is + // filled with packets corresponding to the same timestamp.) + RemoveOutdatedDataPackets(*min_stream_timestamp); + if (empty) { + if (*min_stream_timestamp == Timestamp::Done()) { + // Calculator is done if the control input stream is done. + return NodeReadiness::kReadyForClose; + } + // Calculator is not ready to run if the control input stream is empty. + return NodeReadiness::kNotReady; + } + + Packet control_packet = control_stream->QueueHead(); + ABSL_CHECK(!control_packet.IsEmpty()); + int control_value = control_packet.Get(); + ABSL_CHECK_LE(0, control_value); + ABSL_CHECK_LT(control_value, input_stream_managers_.NumEntries() - 1); + const auto& data_stream = input_stream_managers_.Get( + input_stream_managers_.BeginId() + control_value); + + Timestamp stream_timestamp = data_stream->MinTimestampOrBound(&empty); + if (empty) { + if (stream_timestamp <= *min_stream_timestamp) { + // "data_stream" didn't receive a packet corresponding to the current + // "control_stream" packet yet. return NodeReadiness::kNotReady; } - - Packet control_packet = control_stream->QueueHead(); - CHECK(!control_packet.IsEmpty()); - int control_value = control_packet.Get(); - CHECK_LE(0, control_value); - CHECK_LT(control_value, input_stream_managers_.NumEntries() - 1); - const auto& data_stream = input_stream_managers_.Get( - input_stream_managers_.BeginId() + control_value); - - // Data stream may contain some outdated packets which failed to be popped - // out during "FillInputSet". (This handler doesn't sync input streams, - // hence "FillInputSet" can be triggerred before every input stream is - // filled with packets corresponding to the same timestamp.) - data_stream->ErasePacketsEarlierThan(*min_stream_timestamp); - Timestamp stream_timestamp = data_stream->MinTimestampOrBound(&empty); - if (empty) { - if (stream_timestamp <= *min_stream_timestamp) { - // "data_stream" didn't receive a packet corresponding to the current - // "control_stream" packet yet. - return NodeReadiness::kNotReady; - } - // "data_stream" timestamp bound update detected. - return NodeReadiness::kReadyForProcess; - } - if (stream_timestamp > *min_stream_timestamp) { - // The earliest packet "data_stream" holds corresponds to a control packet - // yet to arrive, which means there won't be a "data_stream" packet - // corresponding to the current "control_stream" packet, which should be - // indicated as timestamp boun update. - return NodeReadiness::kReadyForProcess; - } - CHECK_EQ(stream_timestamp, *min_stream_timestamp); + // "data_stream" timestamp bound update detected. return NodeReadiness::kReadyForProcess; } - - // Only invoked when associated GetNodeReadiness() returned kReadyForProcess. - void FillInputSet(Timestamp input_timestamp, - InputStreamShardSet* input_set) override { - CHECK(input_timestamp.IsAllowedInStream()); - CHECK(input_set); - absl::MutexLock lock(&input_streams_mutex_); - - const CollectionItemId control_stream_id = - input_stream_managers_.EndId() - 1; - auto& control_stream = input_stream_managers_.Get(control_stream_id); - int num_packets_dropped = 0; - bool stream_is_done = false; - Packet control_packet = control_stream->PopPacketAtTimestamp( - input_timestamp, &num_packets_dropped, &stream_is_done); - CHECK_EQ(num_packets_dropped, 0) - << absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".", - num_packets_dropped, control_stream->Name()); - CHECK(!control_packet.IsEmpty()); - int control_value = control_packet.Get(); - AddPacketToShard(&input_set->Get(control_stream_id), - std::move(control_packet), stream_is_done); - - const CollectionItemId data_stream_id = - input_stream_managers_.BeginId() + control_value; - CHECK_LE(input_stream_managers_.BeginId(), data_stream_id); - CHECK_LT(data_stream_id, control_stream_id); - auto& data_stream = input_stream_managers_.Get(data_stream_id); - stream_is_done = false; - Packet data_packet = data_stream->PopPacketAtTimestamp( - input_timestamp, &num_packets_dropped, &stream_is_done); - CHECK_EQ(num_packets_dropped, 0) - << absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".", - num_packets_dropped, data_stream->Name()); - AddPacketToShard(&input_set->Get(data_stream_id), std::move(data_packet), - stream_is_done); - - // Discard old packets on other streams. - // Note that control_stream_id is the last valid id. - auto next_timestamp = input_timestamp.NextAllowedInStream(); - for (CollectionItemId id = input_stream_managers_.BeginId(); - id < control_stream_id; ++id) { - if (id == data_stream_id) continue; - auto& other_stream = input_stream_managers_.Get(id); - other_stream->ErasePacketsEarlierThan(next_timestamp); - } + if (stream_timestamp > *min_stream_timestamp) { + // The earliest packet "data_stream" holds corresponds to a control packet + // yet to arrive, which means there won't be a "data_stream" packet + // corresponding to the current "control_stream" packet, which should be + // indicated as timestamp boun update. + return NodeReadiness::kReadyForProcess; } + ABSL_CHECK_EQ(stream_timestamp, *min_stream_timestamp); + return NodeReadiness::kReadyForProcess; +} - private: - // Must be acquired when manipulating the control and data streams to ensure - // we have a consistent view of the two streams. - absl::Mutex input_streams_mutex_; -}; +// Only invoked when associated GetNodeReadiness() returned kReadyForProcess. +void MuxInputStreamHandler::FillInputSet(Timestamp input_timestamp, + InputStreamShardSet* input_set) { + ABSL_CHECK(input_timestamp.IsAllowedInStream()); + ABSL_CHECK(input_set); + absl::MutexLock lock(&input_streams_mutex_); + + const CollectionItemId control_stream_id = GetControlStreamId(); + auto& control_stream = input_stream_managers_.Get(control_stream_id); + int num_packets_dropped = 0; + bool stream_is_done = false; + Packet control_packet = control_stream->PopPacketAtTimestamp( + input_timestamp, &num_packets_dropped, &stream_is_done); + ABSL_CHECK_EQ(num_packets_dropped, 0) + << absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".", + num_packets_dropped, control_stream->Name()); + ABSL_CHECK(!control_packet.IsEmpty()); + int control_value = control_packet.Get(); + AddPacketToShard(&input_set->Get(control_stream_id), + std::move(control_packet), stream_is_done); + + const CollectionItemId data_stream_id = + input_stream_managers_.BeginId() + control_value; + ABSL_CHECK_LE(input_stream_managers_.BeginId(), data_stream_id); + ABSL_CHECK_LT(data_stream_id, control_stream_id); + auto& data_stream = input_stream_managers_.Get(data_stream_id); + stream_is_done = false; + Packet data_packet = data_stream->PopPacketAtTimestamp( + input_timestamp, &num_packets_dropped, &stream_is_done); + ABSL_CHECK_EQ(num_packets_dropped, 0) + << absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".", + num_packets_dropped, data_stream->Name()); + AddPacketToShard(&input_set->Get(data_stream_id), std::move(data_packet), + stream_is_done); + + // Discard old packets on data streams. + RemoveOutdatedDataPackets(input_timestamp.NextAllowedInStream()); +} REGISTER_INPUT_STREAM_HANDLER(MuxInputStreamHandler); diff --git a/mediapipe/framework/stream_handler/mux_input_stream_handler.h b/mediapipe/framework/stream_handler/mux_input_stream_handler.h new file mode 100644 index 000000000..63fdde0e6 --- /dev/null +++ b/mediapipe/framework/stream_handler/mux_input_stream_handler.h @@ -0,0 +1,80 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_MUX_INPUT_STREAM_HANDLER_H_ +#define MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_MUX_INPUT_STREAM_HANDLER_H_ + +#include +#include + +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator_context_manager.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection_item_id.h" +#include "mediapipe/framework/input_stream_handler.h" + +namespace mediapipe { + +// Implementation of the input stream handler for the MuxCalculator. +// +// One of the input streams is the control stream; all the other input streams +// are data streams. To make MuxInputStreamHandler work properly, the tag of the +// input streams must obey the following rules: +// Let N be the number of input streams. Data streams must use tag "INPUT" with +// index 0, ..., N - 2; the control stream must use tag "SELECT". +// +// The control stream carries packets of type 'int'. The 'int' value in a +// control stream packet must be a valid index in the range 0, ..., N - 2 and +// select the data stream at that index. The selected data stream must have a +// packet with the same timestamp as the control stream packet. +// +// When the control stream is done, GetNodeReadiness() returns +// NodeReadiness::kReadyForClose. +// +// TODO: pass the input stream tags to the MuxInputStreamHandler +// constructor so that it can refer to input streams by tag. See b/30125118. +class MuxInputStreamHandler : public InputStreamHandler { + public: + MuxInputStreamHandler() = delete; + MuxInputStreamHandler(std::shared_ptr tag_map, + CalculatorContextManager* cc_manager, + const MediaPipeOptions& options, + bool calculator_run_in_parallel) + : InputStreamHandler(std::move(tag_map), cc_manager, options, + calculator_run_in_parallel) {} + + private: + CollectionItemId GetControlStreamId() const; + void RemoveOutdatedDataPackets(Timestamp timestamp); + + protected: + // In MuxInputStreamHandler, a node is "ready" if: + // - the control stream is done (need to call Close() in this case), or + // - we have received the packets on the control stream and the selected data + // stream at the next timestamp. + NodeReadiness GetNodeReadiness(Timestamp* min_stream_timestamp) override; + + // Only invoked when associated GetNodeReadiness() returned kReadyForProcess. + void FillInputSet(Timestamp input_timestamp, + InputStreamShardSet* input_set) override; + + private: + // Must be acquired when manipulating the control and data streams to ensure + // we have a consistent view of the two streams. + absl::Mutex input_streams_mutex_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_MUX_INPUT_STREAM_HANDLER_H_ diff --git a/mediapipe/framework/stream_handler/mux_input_stream_handler_test.cc b/mediapipe/framework/stream_handler/mux_input_stream_handler_test.cc index f19a3ddec..78b2bb3f7 100644 --- a/mediapipe/framework/stream_handler/mux_input_stream_handler_test.cc +++ b/mediapipe/framework/stream_handler/mux_input_stream_handler_test.cc @@ -645,5 +645,41 @@ TEST(MuxInputStreamHandlerTest, MP_ASSERT_OK(graph.WaitUntilDone()); } +TEST(MuxInputStreamHandlerTest, RemovesUnusedDataStreamPackets) { + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "input0" + input_stream: "input1" + input_stream: "select" + node { + calculator: "MuxCalculator" + input_stream: "INPUT:0:input0" + input_stream: "INPUT:1:input1" + input_stream: "SELECT:select" + output_stream: "OUTPUT:output" + input_stream_handler { input_stream_handler: "MuxInputStreamHandler" } + } + )pb"); + config.set_max_queue_size(1); + config.set_report_deadlock(true); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "select", MakePacket(0).At(Timestamp(2)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input0", MakePacket(1000).At(Timestamp(2)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Add two delayed packets to the deselected input. They should be discarded + // instead of triggering the deadlock detection (max_queue_size = 1). + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input1", MakePacket(900).At(Timestamp(1)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input1", MakePacket(900).At(Timestamp(2)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/stream_handler/sync_set_input_stream_handler.cc b/mediapipe/framework/stream_handler/sync_set_input_stream_handler.cc index 1001d64f7..f6356c17e 100644 --- a/mediapipe/framework/stream_handler/sync_set_input_stream_handler.cc +++ b/mediapipe/framework/stream_handler/sync_set_input_stream_handler.cc @@ -11,105 +11,51 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/stream_handler/sync_set_input_stream_handler.h" -#include +#include +#include +#include +#include +#include -// TODO: Move protos in another CL after the C++ code migration. -#include "absl/strings/substitute.h" +#include "absl/log/absl_check.h" #include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/input_stream_handler.h" -#include "mediapipe/framework/mediapipe_options.pb.h" #include "mediapipe/framework/packet_set.h" +#include "mediapipe/framework/port/map_util.h" +#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/stream_handler/sync_set_input_stream_handler.pb.h" #include "mediapipe/framework/timestamp.h" -#include "mediapipe/framework/tool/tag_map.h" namespace mediapipe { -// An input stream handler which separates the inputs into sets which -// are each independently synchronized. For example, if 5 inputs are -// present, then the first three can be grouped (and will be synchronized -// as if they were in a calculator with only those three streams) and the -// remaining 2 streams can be independently grouped. The calculator will -// always be called with all the available packets from a single sync set -// (never more than one). The input timestamps seen by the calculator -// will be ordered sequentially for each sync set but may jump around -// between sync sets. -class SyncSetInputStreamHandler : public InputStreamHandler { - public: - SyncSetInputStreamHandler() = delete; - SyncSetInputStreamHandler(std::shared_ptr tag_map, - CalculatorContextManager* cc_manager, - const MediaPipeOptions& extendable_options, - bool calculator_run_in_parallel); - - void PrepareForRun(std::function headers_ready_callback, - std::function notification_callback, - std::function schedule_callback, - std::function error_callback) override; - - protected: - // In SyncSetInputStreamHandler, a node is "ready" if any - // of its sync sets are ready in the traditional sense (See - // DefaultInputStreamHandler). - NodeReadiness GetNodeReadiness(Timestamp* min_stream_timestamp) override; - - // Only invoked when associated GetNodeReadiness() returned kReadyForProcess. - // Populates packets for the ready sync-set, and populates timestamp bounds - // for all sync-sets. - void FillInputSet(Timestamp input_timestamp, - InputStreamShardSet* input_set) override; - - // Populates timestamp bounds for streams outside the ready sync-set. - void FillInputBounds(Timestamp input_timestamp, - InputStreamShardSet* input_set) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Returns the number of sync-sets maintained by this input-handler. - int SyncSetCount() override; - - private: - absl::Mutex mutex_; - // The ids of each set of inputs. - std::vector sync_sets_ ABSL_GUARDED_BY(mutex_); - // The index of the ready sync set. A value of -1 indicates that no - // sync sets are ready. - int ready_sync_set_index_ ABSL_GUARDED_BY(mutex_) = -1; - // The timestamp at which the sync set is ready. If no sync set is - // ready then this variable should be Timestamp::Done() . - Timestamp ready_timestamp_ ABSL_GUARDED_BY(mutex_); -}; - REGISTER_INPUT_STREAM_HANDLER(SyncSetInputStreamHandler); -SyncSetInputStreamHandler::SyncSetInputStreamHandler( - std::shared_ptr tag_map, CalculatorContextManager* cc_manager, - const MediaPipeOptions& extendable_options, bool calculator_run_in_parallel) - : InputStreamHandler(std::move(tag_map), cc_manager, extendable_options, - calculator_run_in_parallel) {} - void SyncSetInputStreamHandler::PrepareForRun( std::function headers_ready_callback, std::function notification_callback, std::function schedule_callback, std::function error_callback) { const auto& handler_options = - options_.GetExtension(SyncSetInputStreamHandlerOptions::ext); + options_.GetExtension(mediapipe::SyncSetInputStreamHandlerOptions::ext); { absl::MutexLock lock(&mutex_); sync_sets_.clear(); std::set used_ids; for (const auto& sync_set : handler_options.sync_set()) { std::vector stream_ids; - CHECK_LT(0, sync_set.tag_index_size()); + ABSL_CHECK_LT(0, sync_set.tag_index_size()); for (const auto& tag_index : sync_set.tag_index()) { std::string tag; int index; MEDIAPIPE_CHECK_OK(tool::ParseTagIndex(tag_index, &tag, &index)); CollectionItemId id = input_stream_managers_.GetId(tag, index); - CHECK(id.IsValid()) << "stream \"" << tag_index << "\" is not found."; - CHECK(!mediapipe::ContainsKey(used_ids, id)) + ABSL_CHECK(id.IsValid()) + << "stream \"" << tag_index << "\" is not found."; + ABSL_CHECK(!mediapipe::ContainsKey(used_ids, id)) << "stream \"" << tag_index << "\" is in more than one sync set."; used_ids.insert(id); stream_ids.push_back(id); @@ -137,7 +83,7 @@ void SyncSetInputStreamHandler::PrepareForRun( NodeReadiness SyncSetInputStreamHandler::GetNodeReadiness( Timestamp* min_stream_timestamp) { - DCHECK(min_stream_timestamp); + ABSL_DCHECK(min_stream_timestamp); absl::MutexLock lock(&mutex_); if (ready_sync_set_index_ >= 0) { *min_stream_timestamp = ready_timestamp_; @@ -185,7 +131,7 @@ void SyncSetInputStreamHandler::FillInputSet(Timestamp input_timestamp, InputStreamShardSet* input_set) { // Assume that all current packets are already cleared. absl::MutexLock lock(&mutex_); - CHECK_LE(0, ready_sync_set_index_); + ABSL_CHECK_LE(0, ready_sync_set_index_); sync_sets_[ready_sync_set_index_].FillInputSet(input_timestamp, input_set); for (int i = 0; i < sync_sets_.size(); ++i) { if (i != ready_sync_set_index_) { diff --git a/mediapipe/framework/stream_handler/sync_set_input_stream_handler.h b/mediapipe/framework/stream_handler/sync_set_input_stream_handler.h new file mode 100644 index 000000000..67f1e49a1 --- /dev/null +++ b/mediapipe/framework/stream_handler/sync_set_input_stream_handler.h @@ -0,0 +1,97 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_SYNC_SET_INPUT_STREAM_HANDLER_H_ +#define MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_SYNC_SET_INPUT_STREAM_HANDLER_H_ + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator_context_manager.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection_item_id.h" +#include "mediapipe/framework/input_stream_handler.h" +#include "mediapipe/framework/mediapipe_options.pb.h" +#include "mediapipe/framework/packet_set.h" +#include "mediapipe/framework/stream_handler/sync_set_input_stream_handler.pb.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/framework/tool/tag_map.h" + +namespace mediapipe { + +// An input stream handler which separates the inputs into sets which +// are each independently synchronized. For example, if 5 inputs are +// present, then the first three can be grouped (and will be synchronized +// as if they were in a calculator with only those three streams) and the +// remaining 2 streams can be independently grouped. The calculator will +// always be called with all the available packets from a single sync set +// (never more than one). The input timestamps seen by the calculator +// will be ordered sequentially for each sync set but may jump around +// between sync sets. +class SyncSetInputStreamHandler : public InputStreamHandler { + public: + SyncSetInputStreamHandler() = delete; + SyncSetInputStreamHandler( + std::shared_ptr tag_map, + CalculatorContextManager* cc_manager, + const mediapipe::MediaPipeOptions& extendable_options, + bool calculator_run_in_parallel) + : InputStreamHandler(std::move(tag_map), cc_manager, extendable_options, + calculator_run_in_parallel) {} + + void PrepareForRun(std::function headers_ready_callback, + std::function notification_callback, + std::function schedule_callback, + std::function error_callback) override; + + protected: + // In SyncSetInputStreamHandler, a node is "ready" if any + // of its sync sets are ready in the traditional sense (See + // DefaultInputStreamHandler). + NodeReadiness GetNodeReadiness(Timestamp* min_stream_timestamp) override; + + // Only invoked when associated GetNodeReadiness() returned kReadyForProcess. + // Populates packets for the ready sync-set, and populates timestamp bounds + // for all sync-sets. + void FillInputSet(Timestamp input_timestamp, + InputStreamShardSet* input_set) override; + + // Populates timestamp bounds for streams outside the ready sync-set. + void FillInputBounds(Timestamp input_timestamp, + InputStreamShardSet* input_set) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Returns the number of sync-sets maintained by this input-handler. + int SyncSetCount() override; + + private: + absl::Mutex mutex_; + // The ids of each set of inputs. + std::vector sync_sets_ ABSL_GUARDED_BY(mutex_); + // The index of the ready sync set. A value of -1 indicates that no + // sync sets are ready. + int ready_sync_set_index_ ABSL_GUARDED_BY(mutex_) = -1; + // The timestamp at which the sync set is ready. If no sync set is + // ready then this variable should be Timestamp::Done() . + Timestamp ready_timestamp_ ABSL_GUARDED_BY(mutex_); +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_SYNC_SET_INPUT_STREAM_HANDLER_H_ diff --git a/mediapipe/framework/stream_handler/sync_set_input_stream_handler_test.cc b/mediapipe/framework/stream_handler/sync_set_input_stream_handler_test.cc index e93f806be..c8cc6a171 100644 --- a/mediapipe/framework/stream_handler/sync_set_input_stream_handler_test.cc +++ b/mediapipe/framework/stream_handler/sync_set_input_stream_handler_test.cc @@ -17,6 +17,7 @@ #include #include +#include "absl/log/absl_log.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/calculator_framework.h" // TODO: Move protos in another CL after the C++ code migration. @@ -215,7 +216,7 @@ TEST(SyncSetInputStreamHandlerTest, OrdinaryOperation) { RandomEngine rng(testing::UnitTest::GetInstance()->random_seed()); for (int iter = 0; iter < 1000; ++iter) { - LOG(INFO) << "Starting command shuffling iteration " << iter; + ABSL_LOG(INFO) << "Starting command shuffling iteration " << iter; // Merge the commands for each sync set together into a serial list. // This is done by randomly choosing which list to grab from next. diff --git a/mediapipe/framework/stream_handler/timestamp_align_input_stream_handler.cc b/mediapipe/framework/stream_handler/timestamp_align_input_stream_handler.cc index ae075d788..1ab5e4e75 100644 --- a/mediapipe/framework/stream_handler/timestamp_align_input_stream_handler.cc +++ b/mediapipe/framework/stream_handler/timestamp_align_input_stream_handler.cc @@ -12,91 +12,45 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/stream_handler/timestamp_align_input_stream_handler.h" + #include +#include +#include #include #include #include +#include "absl/log/absl_check.h" #include "absl/strings/substitute.h" #include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator_context_manager.h" +#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/input_stream_handler.h" +#include "mediapipe/framework/mediapipe_options.pb.h" #include "mediapipe/framework/stream_handler/timestamp_align_input_stream_handler.pb.h" #include "mediapipe/framework/timestamp.h" #include "mediapipe/framework/tool/validate_name.h" namespace mediapipe { -// The input streams must have the same time unit but may have different time -// origins (also called epochs). The timestamp_base_tag_index option -// designates an input stream as the timestamp base. -// -// TimestampAlignInputStreamHandler operates in two phases: -// -// 1. Pre-initialization: In this phase, the input stream handler passes -// through input packets in the timestamp base input stream, but buffers the -// input packets in all other input streams. This phase ends when the input -// stream handler has an input packet in every input stream. It uses the -// the timestamps of these input packets to calculate the timestamp offset of -// each input stream with respect to the timestamp base input stream. The -// timestamp offsets are saved for use in the next phase. -// -// 2. Post-initialization: In this phase, the input stream handler behaves -// like the DefaultInputStreamHandler, except that timestamp offsets are -// applied to the packet timestamps. -class TimestampAlignInputStreamHandler : public InputStreamHandler { - public: - TimestampAlignInputStreamHandler() = delete; - TimestampAlignInputStreamHandler(std::shared_ptr tag_map, - CalculatorContextManager* cc_manager, - const MediaPipeOptions& options, - bool calculator_run_in_parallel); - - void PrepareForRun(std::function headers_ready_callback, - std::function notification_callback, - std::function schedule_callback, - std::function error_callback) override; - - protected: - // In TimestampAlignInputStreamHandler, a node is "ready" if: - // - before the timestamp offsets are initialized: we have received a packet - // in the timestamp base input stream, or - // - after the timestamp offsets are initialized: the minimum bound (over - // all empty streams) is greater than the smallest timestamp of any - // stream, which means we have received all the packets that will be - // available at the next timestamp, or - // - all streams are done (need to call Close() in this case). - // Note that all packet timestamps and timestamp bounds are aligned with the - // timestamp base. - NodeReadiness GetNodeReadiness(Timestamp* min_stream_timestamp) override; - - // Only invoked when associated GetNodeReadiness() returned kReadyForProcess. - void FillInputSet(Timestamp input_timestamp, - InputStreamShardSet* input_set) override; - - private: - CollectionItemId timestamp_base_stream_id_; - - absl::Mutex mutex_; - bool offsets_initialized_ ABSL_GUARDED_BY(mutex_) = false; - std::vector timestamp_offsets_; -}; REGISTER_INPUT_STREAM_HANDLER(TimestampAlignInputStreamHandler); TimestampAlignInputStreamHandler::TimestampAlignInputStreamHandler( std::shared_ptr tag_map, CalculatorContextManager* cc_manager, - const MediaPipeOptions& options, bool calculator_run_in_parallel) + const mediapipe::MediaPipeOptions& options, bool calculator_run_in_parallel) : InputStreamHandler(std::move(tag_map), cc_manager, options, calculator_run_in_parallel), timestamp_offsets_(input_stream_managers_.NumEntries()) { - const auto& handler_options = - options.GetExtension(TimestampAlignInputStreamHandlerOptions::ext); + const auto& handler_options = options.GetExtension( + mediapipe::TimestampAlignInputStreamHandlerOptions::ext); std::string tag; int index; MEDIAPIPE_CHECK_OK(tool::ParseTagIndex( handler_options.timestamp_base_tag_index(), &tag, &index)); timestamp_base_stream_id_ = input_stream_managers_.GetId(tag, index); - CHECK(timestamp_base_stream_id_.IsValid()) + ABSL_CHECK(timestamp_base_stream_id_.IsValid()) << "stream \"" << handler_options.timestamp_base_tag_index() << "\" is not found."; timestamp_offsets_[timestamp_base_stream_id_.value()] = 0; @@ -119,7 +73,7 @@ void TimestampAlignInputStreamHandler::PrepareForRun( NodeReadiness TimestampAlignInputStreamHandler::GetNodeReadiness( Timestamp* min_stream_timestamp) { - DCHECK(min_stream_timestamp); + ABSL_DCHECK(min_stream_timestamp); *min_stream_timestamp = Timestamp::Done(); Timestamp min_bound = Timestamp::Done(); @@ -178,14 +132,14 @@ NodeReadiness TimestampAlignInputStreamHandler::GetNodeReadiness( return NodeReadiness::kReadyForProcess; } - CHECK_EQ(min_bound, *min_stream_timestamp); + ABSL_CHECK_EQ(min_bound, *min_stream_timestamp); return NodeReadiness::kNotReady; } void TimestampAlignInputStreamHandler::FillInputSet( Timestamp input_timestamp, InputStreamShardSet* input_set) { - CHECK(input_timestamp.IsAllowedInStream()); - CHECK(input_set); + ABSL_CHECK(input_timestamp.IsAllowedInStream()); + ABSL_CHECK(input_set); { absl::MutexLock lock(&mutex_); if (!offsets_initialized_) { @@ -198,7 +152,7 @@ void TimestampAlignInputStreamHandler::FillInputSet( if (id == timestamp_base_stream_id_) { current_packet = stream->PopPacketAtTimestamp( input_timestamp, &num_packets_dropped, &stream_is_done); - CHECK_EQ(num_packets_dropped, 0) << absl::Substitute( + ABSL_CHECK_EQ(num_packets_dropped, 0) << absl::Substitute( "Dropped $0 packet(s) on input stream \"$1\".", num_packets_dropped, stream->Name()); } @@ -218,10 +172,10 @@ void TimestampAlignInputStreamHandler::FillInputSet( Packet current_packet = stream->PopPacketAtTimestamp( stream_timestamp, &num_packets_dropped, &stream_is_done); if (!current_packet.IsEmpty()) { - CHECK_EQ(current_packet.Timestamp(), stream_timestamp); + ABSL_CHECK_EQ(current_packet.Timestamp(), stream_timestamp); current_packet = current_packet.At(input_timestamp); } - CHECK_EQ(num_packets_dropped, 0) + ABSL_CHECK_EQ(num_packets_dropped, 0) << absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".", num_packets_dropped, stream->Name()); AddPacketToShard(&input_set->Get(id), std::move(current_packet), diff --git a/mediapipe/framework/stream_handler/timestamp_align_input_stream_handler.h b/mediapipe/framework/stream_handler/timestamp_align_input_stream_handler.h new file mode 100644 index 000000000..dce8fad9b --- /dev/null +++ b/mediapipe/framework/stream_handler/timestamp_align_input_stream_handler.h @@ -0,0 +1,91 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_TIMESTAMP_ALIGN_INPUT_STREAM_HANDLER_H_ +#define MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_TIMESTAMP_ALIGN_INPUT_STREAM_HANDLER_H_ + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/calculator_context_manager.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection_item_id.h" +#include "mediapipe/framework/input_stream_handler.h" +#include "mediapipe/framework/stream_handler/timestamp_align_input_stream_handler.pb.h" +#include "mediapipe/framework/timestamp.h" + +namespace mediapipe { + +// The input streams must have the same time unit but may have different time +// origins (also called epochs). The timestamp_base_tag_index option +// designates an input stream as the timestamp base. +// +// TimestampAlignInputStreamHandler operates in two phases: +// +// 1. Pre-initialization: In this phase, the input stream handler passes +// through input packets in the timestamp base input stream, but buffers the +// input packets in all other input streams. This phase ends when the input +// stream handler has an input packet in every input stream. It uses the +// the timestamps of these input packets to calculate the timestamp offset of +// each input stream with respect to the timestamp base input stream. The +// timestamp offsets are saved for use in the next phase. +// +// 2. Post-initialization: In this phase, the input stream handler behaves +// like the DefaultInputStreamHandler, except that timestamp offsets are +// applied to the packet timestamps. +class TimestampAlignInputStreamHandler : public InputStreamHandler { + public: + TimestampAlignInputStreamHandler() = delete; + TimestampAlignInputStreamHandler(std::shared_ptr tag_map, + CalculatorContextManager* cc_manager, + const mediapipe::MediaPipeOptions& options, + bool calculator_run_in_parallel); + + void PrepareForRun(std::function headers_ready_callback, + std::function notification_callback, + std::function schedule_callback, + std::function error_callback) override; + + protected: + // In TimestampAlignInputStreamHandler, a node is "ready" if: + // - before the timestamp offsets are initialized: we have received a packet + // in the timestamp base input stream, or + // - after the timestamp offsets are initialized: the minimum bound (over + // all empty streams) is greater than the smallest timestamp of any + // stream, which means we have received all the packets that will be + // available at the next timestamp, or + // - all streams are done (need to call Close() in this case). + // Note that all packet timestamps and timestamp bounds are aligned with the + // timestamp base. + NodeReadiness GetNodeReadiness(Timestamp* min_stream_timestamp) override; + + // Only invoked when associated GetNodeReadiness() returned kReadyForProcess. + void FillInputSet(Timestamp input_timestamp, + InputStreamShardSet* input_set) override; + + private: + CollectionItemId timestamp_base_stream_id_; + + absl::Mutex mutex_; + bool offsets_initialized_ ABSL_GUARDED_BY(mutex_) = false; + std::vector timestamp_offsets_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_STREAM_HANDLER_TIMESTAMP_ALIGN_INPUT_STREAM_HANDLER_H_ diff --git a/mediapipe/framework/subgraph.cc b/mediapipe/framework/subgraph.cc index 6c18c9cac..7cbde28bf 100644 --- a/mediapipe/framework/subgraph.cc +++ b/mediapipe/framework/subgraph.cc @@ -64,13 +64,13 @@ GraphRegistry::GraphRegistry( void GraphRegistry::Register( const std::string& type_name, std::function()> factory) { - local_factories_.Register(type_name, factory, __FILE__, __LINE__); + local_factories_.Register(type_name, factory); } // TODO: Remove this convenience function. void GraphRegistry::Register(const std::string& type_name, const CalculatorGraphConfig& config) { - Register(type_name, [config] { + local_factories_.Register(type_name, [config] { auto result = absl::make_unique(config); return std::unique_ptr(result.release()); }); @@ -79,7 +79,7 @@ void GraphRegistry::Register(const std::string& type_name, // TODO: Remove this convenience function. void GraphRegistry::Register(const std::string& type_name, const CalculatorGraphTemplate& templ) { - Register(type_name, [templ] { + local_factories_.Register(type_name, [templ] { auto result = absl::make_unique(templ); return std::unique_ptr(result.release()); }); diff --git a/mediapipe/framework/test_calculators.cc b/mediapipe/framework/test_calculators.cc index 6cb300855..1ed1e61b1 100644 --- a/mediapipe/framework/test_calculators.cc +++ b/mediapipe/framework/test_calculators.cc @@ -20,6 +20,7 @@ #include #include "Eigen/Core" +#include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/calculator_framework.h" @@ -203,7 +204,7 @@ class RangeCalculator : public CalculatorBase { // Initializes this object. void Initialize(CalculatorContext* cc) { - CHECK(!initialized_); + ABSL_CHECK(!initialized_); cc->Options(); // Ensure Options() can be called here. std::tie(n_, k_) = @@ -380,10 +381,10 @@ class RandomMatrixCalculator : public CalculatorBase { absl::Status Open(CalculatorContext* cc) override { auto& options = cc->Options(); - CHECK_LT(0, options.timestamp_step()); - CHECK_LT(0, options.rows()); - CHECK_LT(0, options.cols()); - CHECK_LT(options.start_timestamp(), options.limit_timestamp()); + ABSL_CHECK_LT(0, options.timestamp_step()); + ABSL_CHECK_LT(0, options.rows()); + ABSL_CHECK_LT(0, options.cols()); + ABSL_CHECK_LT(options.start_timestamp(), options.limit_timestamp()); current_timestamp_ = Timestamp(options.start_timestamp()); cc->Outputs().Index(0).SetNextTimestampBound(current_timestamp_); @@ -447,13 +448,13 @@ class MeanAndCovarianceCalculator : public CalculatorBase { absl::Status Process(CalculatorContext* cc) override { const Eigen::MatrixXd sample = cc->Inputs().Index(0).Get().cast(); - CHECK_EQ(1, sample.cols()); + ABSL_CHECK_EQ(1, sample.cols()); if (num_samples_ == 0) { rows_ = sample.rows(); sum_vector_ = Eigen::VectorXd::Zero(rows_); outer_product_sum_ = Eigen::MatrixXd::Zero(rows_, rows_); } else { - CHECK_EQ(sample.rows(), rows_); + ABSL_CHECK_EQ(sample.rows(), rows_); } sum_vector_ += sample; outer_product_sum_ += sample * sample.transpose(); diff --git a/mediapipe/framework/test_service.cc b/mediapipe/framework/test_service.cc index 4bafaf28c..e7233ebf9 100644 --- a/mediapipe/framework/test_service.cc +++ b/mediapipe/framework/test_service.cc @@ -16,15 +16,6 @@ namespace mediapipe { -const GraphService kTestService( - "test_service", GraphServiceBase::kDisallowDefaultInitialization); -const GraphService kAnotherService( - "another_service", GraphServiceBase::kAllowDefaultInitialization); -const GraphService kNoDefaultService( - "no_default_service", GraphServiceBase::kAllowDefaultInitialization); -const GraphService kNeedsCreateService( - "needs_create_service", GraphServiceBase::kAllowDefaultInitialization); - absl::Status TestServiceCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); diff --git a/mediapipe/framework/test_service.h b/mediapipe/framework/test_service.h index 2ff5a384a..42ebd8df8 100644 --- a/mediapipe/framework/test_service.h +++ b/mediapipe/framework/test_service.h @@ -22,14 +22,17 @@ namespace mediapipe { using TestServiceObject = std::map; -extern const GraphService kTestService; -extern const GraphService kAnotherService; +inline constexpr GraphService kTestService( + "test_service", GraphServiceBase::kDisallowDefaultInitialization); +inline constexpr GraphService kAnotherService( + "another_service", GraphServiceBase::kAllowDefaultInitialization); class NoDefaultConstructor { public: NoDefaultConstructor() = delete; }; -extern const GraphService kNoDefaultService; +inline constexpr GraphService kNoDefaultService( + "no_default_service", GraphServiceBase::kAllowDefaultInitialization); class NeedsCreateMethod { public: @@ -40,7 +43,8 @@ class NeedsCreateMethod { private: NeedsCreateMethod() = default; }; -extern const GraphService kNeedsCreateService; +inline constexpr GraphService kNeedsCreateService( + "needs_create_service", GraphServiceBase::kAllowDefaultInitialization); // Use a service. class TestServiceCalculator : public CalculatorBase { diff --git a/mediapipe/framework/testdata/BUILD b/mediapipe/framework/testdata/BUILD index 8720e39ee..93e416eaa 100644 --- a/mediapipe/framework/testdata/BUILD +++ b/mediapipe/framework/testdata/BUILD @@ -35,6 +35,12 @@ mediapipe_proto_library( ], ) +mediapipe_proto_library( + name = "proto3_options_proto", + srcs = ["proto3_options.proto"], + visibility = ["//visibility:public"], +) + mediapipe_proto_library( name = "zoo_mutator_proto", srcs = ["zoo_mutator.proto"], diff --git a/mediapipe/framework/testdata/proto3_options.proto b/mediapipe/framework/testdata/proto3_options.proto new file mode 100644 index 000000000..c76894819 --- /dev/null +++ b/mediapipe/framework/testdata/proto3_options.proto @@ -0,0 +1,25 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Forked from mediapipe/framework/tool/source.proto. +// The forked proto must remain identical to the original proto and should be +// ONLY used by mediapipe open source project. + +syntax = "proto3"; + +package mediapipe; + +message Proto3Options { + double test_value = 1; +} diff --git a/mediapipe/framework/timestamp.cc b/mediapipe/framework/timestamp.cc index 05b69747f..9183b3c81 100644 --- a/mediapipe/framework/timestamp.cc +++ b/mediapipe/framework/timestamp.cc @@ -16,6 +16,8 @@ #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/strings/str_cat.h" namespace mediapipe { @@ -26,7 +28,7 @@ constexpr double Timestamp::kTimestampUnitsPerSecond; // - The safe int type will check for overflow/underflow and other errors. // - The CHECK in the constructor will disallow special values. TimestampDiff Timestamp::operator-(const Timestamp other) const { - CHECK(IsRangeValue() && other.IsRangeValue()) + ABSL_CHECK(IsRangeValue() && other.IsRangeValue()) << "This timestamp is " << DebugString() << " and other was " << other.DebugString(); TimestampBaseType tmp_base = timestamp_ - other.timestamp_; @@ -43,7 +45,7 @@ TimestampDiff TimestampDiff::operator-(const TimestampDiff other) const { // Clamp the addition to the range [Timestamp::Min(), Timestamp::Max()]. Timestamp Timestamp::operator+(const TimestampDiff offset) const { - CHECK(IsRangeValue()) << "Timestamp is: " << DebugString(); + ABSL_CHECK(IsRangeValue()) << "Timestamp is: " << DebugString(); TimestampBaseType offset_base(offset.Value()); if (offset_base >= TimestampBaseType(0)) { if (timestamp_.value() >= Timestamp::Max().Value() - offset_base.value()) { @@ -112,7 +114,7 @@ std::string Timestamp::DebugString() const { } else if (*this == Timestamp::Done()) { return "Timestamp::Done()"; } else { - LOG(FATAL) << "Unknown special type."; + ABSL_LOG(FATAL) << "Unknown special type."; } } return absl::StrCat(timestamp_.value()); @@ -131,6 +133,13 @@ Timestamp Timestamp::NextAllowedInStream() const { return *this + 1; } +bool Timestamp::HasNextAllowedInStream() const { + if (*this >= Max() || *this == PreStream()) { + return false; + } + return true; +} + Timestamp Timestamp::PreviousAllowedInStream() const { if (*this <= Min() || *this == PostStream()) { // Indicates that no previous timestamps may occur. diff --git a/mediapipe/framework/timestamp.h b/mediapipe/framework/timestamp.h index b8c3a69a2..8949dcc80 100644 --- a/mediapipe/framework/timestamp.h +++ b/mediapipe/framework/timestamp.h @@ -47,6 +47,7 @@ #include #include +#include "absl/log/absl_check.h" #include "mediapipe/framework/deps/safe_int.h" #include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/logging.h" @@ -57,7 +58,7 @@ namespace mediapipe { // have underflow/overflow etc. This type is used internally by Timestamp // and TimestampDiff. MEDIAPIPE_DEFINE_SAFE_INT_TYPE(TimestampBaseType, int64, - mediapipe::intops::LogFatalOnError); + mediapipe::intops::LogFatalOnError) class TimestampDiff; @@ -186,6 +187,10 @@ class Timestamp { // CHECKs that this->IsAllowedInStream(). Timestamp NextAllowedInStream() const; + // Returns true if there's a next timestamp in the range [Min .. Max] after + // this one. + bool HasNextAllowedInStream() const; + // Returns the previous timestamp in the range [Min .. Max], or // Unstarted() if no Packets may preceed one with this timestamp. Timestamp PreviousAllowedInStream() const; @@ -266,14 +271,14 @@ std::ostream& operator<<(std::ostream& os, TimestampDiff arg); inline Timestamp::Timestamp() : timestamp_(kint64min) {} inline Timestamp::Timestamp(int64 timestamp) : timestamp_(timestamp) { - CHECK(!IsSpecialValue()) + ABSL_CHECK(!IsSpecialValue()) << "Cannot directly create a Timestamp with a special value: " << CreateNoErrorChecking(timestamp); } inline Timestamp::Timestamp(TimestampBaseType timestamp) : timestamp_(timestamp) { - CHECK(!IsSpecialValue()) + ABSL_CHECK(!IsSpecialValue()) << "Cannot directly create a Timestamp with a special value: " << CreateNoErrorChecking(timestamp.value()); } diff --git a/mediapipe/framework/timestamp_test.cc b/mediapipe/framework/timestamp_test.cc index 5f5cc3428..3ba0b5c36 100644 --- a/mediapipe/framework/timestamp_test.cc +++ b/mediapipe/framework/timestamp_test.cc @@ -125,6 +125,22 @@ TEST(TimestampTest, NextAllowedInStream) { Timestamp::PostStream().NextAllowedInStream()); } +TEST(TimestampTest, HasNextAllowedInStream) { + EXPECT_TRUE(Timestamp::Min().HasNextAllowedInStream()); + EXPECT_TRUE((Timestamp::Min() + 1).HasNextAllowedInStream()); + EXPECT_TRUE(Timestamp(-1000).HasNextAllowedInStream()); + EXPECT_TRUE(Timestamp(0).HasNextAllowedInStream()); + EXPECT_TRUE(Timestamp(1000).HasNextAllowedInStream()); + EXPECT_TRUE((Timestamp::Max() - 2).HasNextAllowedInStream()); + EXPECT_TRUE((Timestamp::Max() - 1).HasNextAllowedInStream()); + + EXPECT_FALSE(Timestamp::PreStream().HasNextAllowedInStream()); + EXPECT_FALSE(Timestamp::Max().HasNextAllowedInStream()); + EXPECT_FALSE(Timestamp::PostStream().HasNextAllowedInStream()); + EXPECT_FALSE(Timestamp::OneOverPostStream().HasNextAllowedInStream()); + EXPECT_FALSE(Timestamp::Done().HasNextAllowedInStream()); +} + TEST(TimestampTest, SpecialValueDifferences) { { // Lower range const std::vector timestamps = { diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index fbdcf8c9e..77e3ab16d 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -66,10 +66,12 @@ cc_library( deps = [ "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:advanced_proto", + "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/log:absl_log", ], ) @@ -140,6 +142,7 @@ cc_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:map_util", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], @@ -165,6 +168,7 @@ cc_test( ":executor_util", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/log:absl_check", ], ) @@ -188,9 +192,8 @@ cc_test( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/port:status", "//mediapipe/framework/testdata:night_light_calculator_cc_proto", - "//mediapipe/framework/testdata:night_light_calculator_options_lib", + "//mediapipe/framework/testdata:proto3_options_cc_proto", ], ) @@ -281,6 +284,7 @@ cc_binary( "//mediapipe/framework/port:logging", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", ], ) @@ -335,6 +339,7 @@ mediapipe_cc_test( cc_library( name = "packet_generator_wrapper_calculator", srcs = ["packet_generator_wrapper_calculator.cc"], + hdrs = ["packet_generator_wrapper_calculator.h"], visibility = ["//mediapipe/framework:__subpackages__"], deps = [ ":packet_generator_wrapper_calculator_cc_proto", @@ -342,6 +347,9 @@ cc_library( "//mediapipe/framework:calculator_registry", "//mediapipe/framework:output_side_packet", "//mediapipe/framework:packet_generator", + "//mediapipe/framework:packet_set", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/status", ], alwayslink = 1, ) @@ -360,6 +368,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", ], ) @@ -386,21 +395,23 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":name_util", + ":status_util", "//mediapipe/calculators/internal:callback_packet_calculator", "//mediapipe/calculators/internal:callback_packet_calculator_cc_proto", "//mediapipe/framework:calculator_base", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_graph", "//mediapipe/framework:calculator_registry", - "//mediapipe/framework:input_stream", "//mediapipe/framework:packet", "//mediapipe/framework:packet_type", - "//mediapipe/framework/port:logging", + "//mediapipe/framework:timestamp", "//mediapipe/framework/port:source_location", "//mediapipe/framework/port:status", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", ], alwayslink = 1, ) @@ -452,6 +463,7 @@ cc_library( deps = [ "//mediapipe/framework/port:status", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", ], ) @@ -501,10 +513,11 @@ cc_library( ":calculator_graph_template_cc_proto", ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:numbers", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", ], ) @@ -525,11 +538,13 @@ cc_library( "//mediapipe/framework/deps:proto_descriptor_cc_proto", "//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:map_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -625,6 +640,7 @@ cc_test( ":tag_map_helper", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:map_util", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", ], ) @@ -663,6 +679,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:threadpool", "//mediapipe/util:cpu_util", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", ], @@ -782,15 +799,17 @@ cc_library( "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:file_helpers", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@stblib//:stb_image", "@stblib//:stb_image_write", ], @@ -919,6 +938,7 @@ cc_library( "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -939,11 +959,11 @@ cc_test( "//mediapipe/framework:subgraph", "//mediapipe/framework:test_calculators", "//mediapipe/framework/port:gtest_main", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/framework/tool/executor_util.h b/mediapipe/framework/tool/executor_util.h index 3167cdd04..5fb25da74 100644 --- a/mediapipe/framework/tool/executor_util.h +++ b/mediapipe/framework/tool/executor_util.h @@ -22,6 +22,10 @@ namespace mediapipe { namespace tool { // Ensures the default executor's stack size is at least min_stack_size. +// +// Note that this will also initialize the default executor; any configuration +// changes, such as num_threads, should be done to the config before calling +// this. void EnsureMinimumDefaultExecutorStackSize(int32 min_stack_size, CalculatorGraphConfig* config); } // namespace tool diff --git a/mediapipe/framework/tool/ios.bzl b/mediapipe/framework/tool/ios.bzl new file mode 100644 index 000000000..a0fe0be55 --- /dev/null +++ b/mediapipe/framework/tool/ios.bzl @@ -0,0 +1,53 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MediaPipe Task Library Helper Rules for iOS""" + +MPP_TASK_MINIMUM_OS_VERSION = "12.0" + +# When the static framework is built with bazel, the all header files are moved +# to the "Headers" directory with no header path prefixes. This auxiliary rule +# is used for stripping the path prefix to the C/iOS API header files included by +# other C/iOS API header files. +# In case of C header files, includes start with a keyword of "#include'. +# Imports in iOS header files start with a keyword of '#import'. +def strip_api_include_path_prefix(name, hdr_labels, prefix = ""): + """Create modified header files with the import path stripped out. + + Args: + name: The name to be used as a prefix to the generated genrules. + hdr_labels: List of header labels to strip out the include path. Each + label must end with a colon followed by the header file name. + prefix: Optional prefix path to prepend to the header inclusion path. + """ + for hdr_label in hdr_labels: + hdr_filename = hdr_label.split(":")[-1] + + # The last path component of iOS header files is sources/some_file.h + # Hence it wiill contain a '/'. So the string can be split at '/' to get + # the header file name. + if "/" in hdr_filename: + hdr_filename = hdr_filename.split("/")[-1] + + hdr_basename = hdr_filename.split(".")[0] + native.genrule( + name = "{}_{}".format(name, hdr_basename), + srcs = [hdr_label], + outs = [hdr_filename], + cmd = """ + sed 's|#\\([a-z]*\\) ".*/\\([^/]\\{{1,\\}}\\.h\\)"|#\\1 "{}\\2"|'\ + "$(location {})"\ + > "$@" + """.format(prefix, hdr_label), + ) diff --git a/mediapipe/framework/tool/mediapipe_proto.bzl b/mediapipe/framework/tool/mediapipe_proto.bzl index 527774ff3..142560ce5 100644 --- a/mediapipe/framework/tool/mediapipe_proto.bzl +++ b/mediapipe/framework/tool/mediapipe_proto.bzl @@ -50,10 +50,12 @@ def mediapipe_proto_library_impl( def_cc_proto = True, def_py_proto = True, def_java_lite_proto = True, + def_kt_lite_proto = True, def_objc_proto = True, def_java_proto = True, def_jspb_proto = True, def_go_proto = True, + def_dart_proto = True, def_options_lib = True): """Defines the proto_library targets needed for all mediapipe platforms. @@ -71,10 +73,12 @@ def mediapipe_proto_library_impl( def_cc_proto: define the cc_proto_library target def_py_proto: define the py_proto_library target def_java_lite_proto: define the java_lite_proto_library target + def_kt_lite_proto: define the kt_lite_proto_library target def_objc_proto: define the objc_proto_library target def_java_proto: define the java_proto_library target def_jspb_proto: define the jspb_proto_library target def_go_proto: define the go_proto_library target + def_dart_proto: define the dart_proto_library target def_options_lib: define the mediapipe_options_library target """ @@ -253,11 +257,13 @@ def mediapipe_proto_library( def_cc_proto = True, def_py_proto = True, def_java_lite_proto = True, + def_kt_lite_proto = True, def_portable_proto = True, # @unused def_objc_proto = True, def_java_proto = True, def_jspb_proto = True, def_go_proto = True, + def_dart_proto = True, def_options_lib = True, def_rewrite = True, portable_deps = None): # @unused @@ -278,11 +284,13 @@ def mediapipe_proto_library( def_cc_proto: define the cc_proto_library target def_py_proto: define the py_proto_library target def_java_lite_proto: define the java_lite_proto_library target + def_kt_lite_proto: define the kt_lite_proto_library target def_portable_proto: ignored since portable protos are gone def_objc_proto: define the objc_proto_library target def_java_proto: define the java_proto_library target def_jspb_proto: define the jspb_proto_library target def_go_proto: define the go_proto_library target + def_dart_proto: define the dart_proto_library target def_options_lib: define the mediapipe_options_library target def_rewrite: define a sibling mediapipe_proto_library with package "mediapipe" """ @@ -300,10 +308,12 @@ def mediapipe_proto_library( def_cc_proto = def_cc_proto, def_py_proto = def_py_proto, def_java_lite_proto = def_java_lite_proto, + def_kt_lite_proto = def_kt_lite_proto, def_objc_proto = def_objc_proto, def_java_proto = def_java_proto, def_jspb_proto = def_jspb_proto, def_go_proto = def_go_proto, + def_dart_proto = def_dart_proto, def_options_lib = def_options_lib, ) @@ -329,10 +339,12 @@ def mediapipe_proto_library( def_cc_proto = def_cc_proto, def_py_proto = def_py_proto, def_java_lite_proto = def_java_lite_proto, + def_kt_lite_proto = def_kt_lite_proto, def_objc_proto = def_objc_proto, def_java_proto = def_java_proto, def_jspb_proto = def_jspb_proto, def_go_proto = def_go_proto, + def_dart_proto = def_dart_proto, # A clone of mediapipe_options_library() will redefine some classes. def_options_lib = False, ) diff --git a/mediapipe/framework/tool/message_type_util.cc b/mediapipe/framework/tool/message_type_util.cc index fe505ee0f..3bc5ea8d3 100644 --- a/mediapipe/framework/tool/message_type_util.cc +++ b/mediapipe/framework/tool/message_type_util.cc @@ -4,6 +4,7 @@ #include "absl/flags/flag.h" #include "absl/flags/parse.h" +#include "absl/log/absl_check.h" #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" @@ -118,14 +119,14 @@ class DescriptorReader { static FileDescriptorSet ReadFileDescriptorSet(const std::string& path) { std::string contents; - CHECK_OK(file::GetContents(path, &contents)); + ABSL_CHECK_OK(file::GetContents(path, &contents)); proto_ns::FileDescriptorSet result; result.ParseFromString(contents); return result; } static void WriteFile(const std::string& path, const std::string& contents) { - CHECK_OK(file::SetContents(path, contents)); + ABSL_CHECK_OK(file::SetContents(path, contents)); } static void WriteMessageTypeName(const std::string& path, diff --git a/mediapipe/framework/tool/name_util.cc b/mediapipe/framework/tool/name_util.cc index 4784441d7..d42cbb622 100644 --- a/mediapipe/framework/tool/name_util.cc +++ b/mediapipe/framework/tool/name_util.cc @@ -60,7 +60,7 @@ std::string GetUnusedSidePacketName( } std::string candidate = input_side_packet_name_base; int iter = 2; - while (mediapipe::ContainsKey(input_side_packets, candidate)) { + while (input_side_packets.contains(candidate)) { candidate = absl::StrCat(input_side_packet_name_base, "_", absl::StrFormat("%02d", iter)); ++iter; diff --git a/mediapipe/framework/tool/options_field_util.cc b/mediapipe/framework/tool/options_field_util.cc index 308932d4f..248028c25 100644 --- a/mediapipe/framework/tool/options_field_util.cc +++ b/mediapipe/framework/tool/options_field_util.cc @@ -27,10 +27,6 @@ namespace options_field_util { using ::mediapipe::proto_ns::internal::WireFormatLite; using FieldType = WireFormatLite::FieldType; -using ::mediapipe::proto_ns::io::ArrayInputStream; -using ::mediapipe::proto_ns::io::CodedInputStream; -using ::mediapipe::proto_ns::io::CodedOutputStream; -using ::mediapipe::proto_ns::io::StringOutputStream; // Utility functions for OptionsFieldUtil. namespace { diff --git a/mediapipe/framework/tool/options_map.h b/mediapipe/framework/tool/options_map.h index 2b69f4fb6..4950669c6 100644 --- a/mediapipe/framework/tool/options_map.h +++ b/mediapipe/framework/tool/options_map.h @@ -128,7 +128,8 @@ class OptionsMap { return *options_.Get(); } T* result = options_.Get(); - if (node_config_->has_options()) { + if (node_config_->has_options() && + HasExtension(node_config_->options())) { GetExtension(node_config_->options(), result); } else { GetNodeOptions(*node_config_, result); @@ -141,8 +142,9 @@ class OptionsMap { if (options_.Has()) { return true; } - if (node_config_->has_options()) { - return HasExtension(node_config_->options()); + if (node_config_->has_options() && + HasExtension(node_config_->options())) { + return true; } #if defined(MEDIAPIPE_PROTO_LITE) && defined(MEDIAPIPE_PROTO_THIRD_PARTY) // protobuf::Any is unavailable with third_party/protobuf:protobuf-lite. @@ -170,7 +172,8 @@ class MutableOptionsMap : public OptionsMap { template void Set(const T& value) const { *options_.Get() = value; - if (node_config_->has_options()) { + if (node_config_->has_options() && + HasExtension(node_config_->options())) { *GetExtension(*node_config_->mutable_options()) = value; } else { SetNodeOptions(*node_config_, value); @@ -182,7 +185,8 @@ class MutableOptionsMap : public OptionsMap { if (options_.Has()) { return options_.Get(); } - if (node_config_->has_options()) { + if (node_config_->has_options() && + HasExtension(node_config_->options())) { return GetExtension(*node_config_->mutable_options()); } T* result = options_.Get(); diff --git a/mediapipe/framework/tool/options_map_test.cc b/mediapipe/framework/tool/options_map_test.cc index 529fd5770..8efd1cb94 100644 --- a/mediapipe/framework/tool/options_map_test.cc +++ b/mediapipe/framework/tool/options_map_test.cc @@ -17,14 +17,11 @@ #include -#include - #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" -#include "mediapipe/framework/port/status.h" -#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/testdata/night_light_calculator.pb.h" +#include "mediapipe/framework/testdata/proto3_options.pb.h" namespace mediapipe { namespace tool { @@ -40,9 +37,10 @@ TEST(OptionsMapTest, QueryNotFound) { OptionsMap options; options.Initialize(node); EXPECT_FALSE(options.Has()); + EXPECT_FALSE(options.Has()); } -TEST(OptionsMapTest, QueryFound) { +TEST(OptionsMapTest, Proto2QueryFound) { CalculatorGraphConfig::Node node = ParseTextProtoOrDie(R"pb( calculator: "NightLightCalculator" @@ -64,7 +62,7 @@ TEST(OptionsMapTest, QueryFound) { 123); } -TEST(MutableOptionsMapTest, InsertAndQueryFound) { +TEST(MutableOptionsMapTest, InsertProto2AndQueryFound) { CalculatorGraphConfig::Node node = ParseTextProtoOrDie(R"pb( calculator: "NightLightCalculator" @@ -83,6 +81,83 @@ TEST(MutableOptionsMapTest, InsertAndQueryFound) { 123); } +TEST(OptionsMapTest, Proto3QueryFound) { + CalculatorGraphConfig::Node node = + ParseTextProtoOrDie(R"pb( + calculator: "NightLightCalculator" + input_side_packet: "input_value" + output_stream: "values" + node_options { + [type.googleapis.com/mediapipe.Proto3Options] { test_value: 123 } + } + )pb"); + OptionsMap options; + options.Initialize(node); + EXPECT_TRUE(options.Has()); + EXPECT_EQ(options.Get().test_value(), 123); +} + +TEST(MutableOptionsMapTest, InsertProto3AndQueryFound) { + CalculatorGraphConfig::Node node = + ParseTextProtoOrDie(R"pb( + calculator: "NightLightCalculator" + input_side_packet: "input_value" + output_stream: "values" + )pb"); + MutableOptionsMap options; + options.Initialize(node); + EXPECT_FALSE(options.Has()); + mediapipe::Proto3Options proto3_options; + proto3_options.set_test_value(123); + options.Set(proto3_options); + EXPECT_TRUE(options.Has()); + EXPECT_EQ(options.Get().test_value(), 123); +} + +TEST(OptionsMapTest, BothProto2AndProto3QueriesFound) { + CalculatorGraphConfig::Node node = + ParseTextProtoOrDie(R"pb( + calculator: "NightLightCalculator" + input_side_packet: "input_value" + output_stream: "values" + options { + [mediapipe.NightLightCalculatorOptions.ext] { jitter: 321 } + } + node_options { + [type.googleapis.com/mediapipe.Proto3Options] { test_value: 123 } + } + )pb"); + OptionsMap options; + options.Initialize(node); + EXPECT_TRUE(options.Has()); + EXPECT_EQ(options.Get().test_value(), 123); + EXPECT_TRUE(options.Has()); + EXPECT_EQ(options.Get().jitter(), + 321); +} + +TEST(OptionsMapTest, PrefersOptionsOverNodeOptions) { + CalculatorGraphConfig::Node node = + ParseTextProtoOrDie(R"pb( + calculator: "NightLightCalculator" + input_side_packet: "input_value" + output_stream: "values" + options { + [mediapipe.NightLightCalculatorOptions.ext] { jitter: 111 } + } + node_options { + [type.googleapis.com/mediapipe.NightLightCalculatorOptions] { + jitter: 222 + } + } + )pb"); + OptionsMap options; + options.Initialize(node); + EXPECT_TRUE(options.Has()); + EXPECT_EQ(options.Get().jitter(), + 111); +} + } // namespace } // namespace tool } // namespace mediapipe diff --git a/mediapipe/framework/tool/packet_generator_wrapper_calculator.cc b/mediapipe/framework/tool/packet_generator_wrapper_calculator.cc index 831918dfa..07eae6f26 100644 --- a/mediapipe/framework/tool/packet_generator_wrapper_calculator.cc +++ b/mediapipe/framework/tool/packet_generator_wrapper_calculator.cc @@ -1,52 +1,55 @@ +#include "mediapipe/framework/tool/packet_generator_wrapper_calculator.h" + +#include "absl/status/status.h" #include "mediapipe/framework/calculator_base.h" #include "mediapipe/framework/calculator_registry.h" #include "mediapipe/framework/output_side_packet.h" #include "mediapipe/framework/packet_generator.h" +#include "mediapipe/framework/packet_set.h" +#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/tool/packet_generator_wrapper_calculator.pb.h" namespace mediapipe { -class PacketGeneratorWrapperCalculator : public CalculatorBase { - public: - static absl::Status GetContract(CalculatorContract* cc) { - const auto& options = - cc->Options<::mediapipe::PacketGeneratorWrapperCalculatorOptions>(); - ASSIGN_OR_RETURN(auto static_access, - mediapipe::internal::StaticAccessToGeneratorRegistry:: - CreateByNameInNamespace(options.package(), - options.packet_generator())); - MP_RETURN_IF_ERROR(static_access->FillExpectations( - options.options(), &cc->InputSidePackets(), - &cc->OutputSidePackets())) - .SetPrepend() - << options.packet_generator() << "::FillExpectations() failed: "; - return absl::OkStatus(); - } +absl::Status PacketGeneratorWrapperCalculator::GetContract( + CalculatorContract* cc) { + const auto& options = + cc->Options<::mediapipe::PacketGeneratorWrapperCalculatorOptions>(); + ASSIGN_OR_RETURN(auto static_access, + mediapipe::internal::StaticAccessToGeneratorRegistry:: + CreateByNameInNamespace(options.package(), + options.packet_generator())); + MP_RETURN_IF_ERROR(static_access->FillExpectations(options.options(), + &cc->InputSidePackets(), + &cc->OutputSidePackets())) + .SetPrepend() + << options.packet_generator() << "::FillExpectations() failed: "; + return absl::OkStatus(); +} - absl::Status Open(CalculatorContext* cc) override { - const auto& options = - cc->Options<::mediapipe::PacketGeneratorWrapperCalculatorOptions>(); - ASSIGN_OR_RETURN(auto static_access, - mediapipe::internal::StaticAccessToGeneratorRegistry:: - CreateByNameInNamespace(options.package(), - options.packet_generator())); - mediapipe::PacketSet output_packets(cc->OutputSidePackets().TagMap()); - MP_RETURN_IF_ERROR(static_access->Generate(options.options(), - cc->InputSidePackets(), - &output_packets)) - .SetPrepend() - << options.packet_generator() << "::Generate() failed: "; - for (auto id = output_packets.BeginId(); id < output_packets.EndId(); - ++id) { - cc->OutputSidePackets().Get(id).Set(output_packets.Get(id)); - } - return absl::OkStatus(); +absl::Status PacketGeneratorWrapperCalculator::Open(CalculatorContext* cc) { + const auto& options = + cc->Options<::mediapipe::PacketGeneratorWrapperCalculatorOptions>(); + ASSIGN_OR_RETURN(auto static_access, + mediapipe::internal::StaticAccessToGeneratorRegistry:: + CreateByNameInNamespace(options.package(), + options.packet_generator())); + mediapipe::PacketSet output_packets(cc->OutputSidePackets().TagMap()); + MP_RETURN_IF_ERROR(static_access->Generate(options.options(), + cc->InputSidePackets(), + &output_packets)) + .SetPrepend() + << options.packet_generator() << "::Generate() failed: "; + for (auto id = output_packets.BeginId(); id < output_packets.EndId(); ++id) { + cc->OutputSidePackets().Get(id).Set(output_packets.Get(id)); } + return absl::OkStatus(); +} + +absl::Status PacketGeneratorWrapperCalculator::Process(CalculatorContext* cc) { + return absl::OkStatus(); +} - absl::Status Process(CalculatorContext* cc) override { - return absl::OkStatus(); - } -}; REGISTER_CALCULATOR(PacketGeneratorWrapperCalculator); } // namespace mediapipe diff --git a/mediapipe/framework/tool/packet_generator_wrapper_calculator.h b/mediapipe/framework/tool/packet_generator_wrapper_calculator.h new file mode 100644 index 000000000..012281ca0 --- /dev/null +++ b/mediapipe/framework/tool/packet_generator_wrapper_calculator.h @@ -0,0 +1,32 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_TOOL_PACKET_GENERATOR_WRAPPER_CALCULATOR_H_ +#define MEDIAPIPE_FRAMEWORK_TOOL_PACKET_GENERATOR_WRAPPER_CALCULATOR_H_ + +#include "absl/status/status.h" +#include "mediapipe/framework/calculator_base.h" + +namespace mediapipe { + +class PacketGeneratorWrapperCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_TOOL_PACKET_GENERATOR_WRAPPER_CALCULATOR_H_ diff --git a/mediapipe/framework/tool/proto_util_lite.cc b/mediapipe/framework/tool/proto_util_lite.cc index 745f4a13b..285aa2205 100644 --- a/mediapipe/framework/tool/proto_util_lite.cc +++ b/mediapipe/framework/tool/proto_util_lite.cc @@ -16,6 +16,7 @@ #include +#include "absl/log/absl_check.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" @@ -411,7 +412,7 @@ static absl::Status DeserializeValue(const FieldValue& bytes, } case W::TYPE_GROUP: case W::TYPE_MESSAGE: - CHECK(false) << "DeserializeValue cannot deserialize a Message."; + ABSL_CHECK(false) << "DeserializeValue cannot deserialize a Message."; case W::TYPE_UINT32: return ReadPrimitive(&input, result); case W::TYPE_ENUM: diff --git a/mediapipe/framework/tool/sink.cc b/mediapipe/framework/tool/sink.cc index f8abf4925..b97d27ea7 100644 --- a/mediapipe/framework/tool/sink.cc +++ b/mediapipe/framework/tool/sink.cc @@ -18,60 +18,65 @@ #include "mediapipe/framework/tool/sink.h" +#include + +#include +#include #include +#include #include #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "mediapipe/calculators/internal/callback_packet_calculator.pb.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_base.h" #include "mediapipe/framework/calculator_graph.h" #include "mediapipe/framework/calculator_registry.h" -#include "mediapipe/framework/input_stream.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_type.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/source_location.h" #include "mediapipe/framework/port/status_builder.h" +#include "mediapipe/framework/timestamp.h" #include "mediapipe/framework/tool/name_util.h" +#include "mediapipe/framework/tool/status_util.h" namespace mediapipe { namespace tool { -namespace { -// Produces an output packet with the PostStream timestamp containing the -// input side packet. -class MediaPipeInternalSidePacketToPacketStreamCalculator - : public CalculatorBase { - public: - static absl::Status GetContract(CalculatorContract* cc) { - cc->InputSidePackets().Index(0).SetAny(); - cc->Outputs().Index(0).SetSameAs(&cc->InputSidePackets().Index(0)); - return absl::OkStatus(); - } - absl::Status Open(CalculatorContext* cc) final { - cc->Outputs().Index(0).AddPacket( - cc->InputSidePackets().Index(0).At(Timestamp::PostStream())); - cc->Outputs().Index(0).Close(); - return absl::OkStatus(); - } +absl::Status MediaPipeInternalSidePacketToPacketStreamCalculator::GetContract( + CalculatorContract* cc) { + cc->InputSidePackets().Index(0).SetAny(); + cc->Outputs().Index(0).SetSameAs(&cc->InputSidePackets().Index(0)); + return absl::OkStatus(); +} + +absl::Status MediaPipeInternalSidePacketToPacketStreamCalculator::Open( + CalculatorContext* cc) { + cc->Outputs().Index(0).AddPacket( + cc->InputSidePackets().Index(0).At(Timestamp::PostStream())); + cc->Outputs().Index(0).Close(); + return absl::OkStatus(); +} + +absl::Status MediaPipeInternalSidePacketToPacketStreamCalculator::Process( + CalculatorContext* cc) { + // The framework treats this calculator as a source calculator. + return mediapipe::tool::StatusStop(); +} - absl::Status Process(CalculatorContext* cc) final { - // The framework treats this calculator as a source calculator. - return mediapipe::tool::StatusStop(); - } -}; REGISTER_CALCULATOR(MediaPipeInternalSidePacketToPacketStreamCalculator); -} // namespace void AddVectorSink(const std::string& stream_name, // CalculatorGraphConfig* config, // std::vector* dumped_data) { - CHECK(config); - CHECK(dumped_data); + ABSL_CHECK(config); + ABSL_CHECK(dumped_data); std::string input_side_packet_name; tool::AddCallbackCalculator(stream_name, config, &input_side_packet_name, @@ -90,15 +95,15 @@ void AddVectorSink(const std::string& stream_name, // // Up to 64-bit pointer in hex (16 characters) and an optional "0x" prepended. char address[19]; int written = snprintf(address, sizeof(address), "%p", dumped_data); - CHECK(written > 0 && written < sizeof(address)); + ABSL_CHECK(written > 0 && written < sizeof(address)); options->set_pointer(address); } void AddPostStreamPacketSink(const std::string& stream_name, CalculatorGraphConfig* config, Packet* post_stream_packet) { - CHECK(config); - CHECK(post_stream_packet); + ABSL_CHECK(config); + ABSL_CHECK(post_stream_packet); std::string input_side_packet_name; tool::AddCallbackCalculator(stream_name, config, &input_side_packet_name, @@ -116,14 +121,14 @@ void AddPostStreamPacketSink(const std::string& stream_name, // Up to 64-bit pointer in hex (16 characters) and an optional "0x" prepended. char address[19]; int written = snprintf(address, sizeof(address), "%p", post_stream_packet); - CHECK(written > 0 && written < sizeof(address)); + ABSL_CHECK(written > 0 && written < sizeof(address)); options->set_pointer(address); } void AddSidePacketSink(const std::string& side_packet_name, CalculatorGraphConfig* config, Packet* dumped_packet) { - CHECK(config); - CHECK(dumped_packet); + ABSL_CHECK(config); + ABSL_CHECK(dumped_packet); CalculatorGraphConfig::Node* conversion_node = config->add_node(); const std::string node_name = GetUnusedNodeName( @@ -145,8 +150,8 @@ void AddCallbackCalculator(const std::string& stream_name, CalculatorGraphConfig* config, std::string* callback_side_packet_name, bool use_std_function) { - CHECK(config); - CHECK(callback_side_packet_name); + ABSL_CHECK(config); + ABSL_CHECK(callback_side_packet_name); CalculatorGraphConfig::Node* sink_node = config->add_node(); sink_node->set_name(GetUnusedNodeName( *config, @@ -162,7 +167,7 @@ void AddCallbackCalculator(const std::string& stream_name, sink_node->add_input_side_packet( absl::StrCat("CALLBACK:", input_side_packet_name)); } else { - LOG(FATAL) << "AddCallbackCalculator must use std::function"; + ABSL_LOG(FATAL) << "AddCallbackCalculator must use std::function"; } } @@ -182,8 +187,8 @@ void AddMultiStreamCallback( std::function&)> callback, CalculatorGraphConfig* config, std::map* side_packets, bool observe_timestamp_bounds) { - CHECK(config); - CHECK(side_packets); + ABSL_CHECK(config); + ABSL_CHECK(side_packets); CalculatorGraphConfig::Node* sink_node = config->add_node(); const std::string name = GetUnusedNodeName( *config, absl::StrCat("multi_callback_", absl::StrJoin(streams, "_"))); @@ -217,8 +222,8 @@ void AddCallbackWithHeaderCalculator(const std::string& stream_name, CalculatorGraphConfig* config, std::string* callback_side_packet_name, bool use_std_function) { - CHECK(config); - CHECK(callback_side_packet_name); + ABSL_CHECK(config); + ABSL_CHECK(callback_side_packet_name); CalculatorGraphConfig::Node* sink_node = config->add_node(); sink_node->set_name(GetUnusedNodeName( *config, @@ -237,7 +242,7 @@ void AddCallbackWithHeaderCalculator(const std::string& stream_name, sink_node->add_input_side_packet( absl::StrCat("CALLBACK:", input_side_packet_name)); } else { - LOG(FATAL) << "AddCallbackWithHeaderCalculator must use std::function"; + ABSL_LOG(FATAL) << "AddCallbackWithHeaderCalculator must use std::function"; } } @@ -286,7 +291,7 @@ absl::Status CallbackCalculator::Open(CalculatorContext* cc) { .Tag("VECTOR_CALLBACK") .Get&)>>(); } else { - LOG(FATAL) << "InputSidePackets must use tags."; + ABSL_LOG(FATAL) << "InputSidePackets must use tags."; } if (callback_ == nullptr && vector_callback_ == nullptr) { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) @@ -326,7 +331,7 @@ absl::Status CallbackWithHeaderCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag("HEADER").SetAny(); if (cc->InputSidePackets().UsesTags()) { - CHECK(cc->InputSidePackets().HasTag("CALLBACK")); + ABSL_CHECK(cc->InputSidePackets().HasTag("CALLBACK")); cc->InputSidePackets() .Tag("CALLBACK") .Set>(); @@ -343,7 +348,7 @@ absl::Status CallbackWithHeaderCalculator::Open(CalculatorContext* cc) { .Tag("CALLBACK") .Get>(); } else { - LOG(FATAL) << "InputSidePackets must use tags."; + ABSL_LOG(FATAL) << "InputSidePackets must use tags."; } if (callback_ == nullptr) { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) diff --git a/mediapipe/framework/tool/sink.h b/mediapipe/framework/tool/sink.h index f786e60a7..4d00b6e6d 100644 --- a/mediapipe/framework/tool/sink.h +++ b/mediapipe/framework/tool/sink.h @@ -28,10 +28,12 @@ #ifndef MEDIAPIPE_FRAMEWORK_TOOL_SINK_H_ #define MEDIAPIPE_FRAMEWORK_TOOL_SINK_H_ +#include #include #include #include "absl/base/macros.h" +#include "absl/status/status.h" #include "mediapipe/framework/calculator_base.h" #include "mediapipe/framework/packet_type.h" #include "mediapipe/framework/port/status.h" @@ -66,9 +68,9 @@ namespace tool { // // Call tool::AddVectorSink() more times if you wish. Note that each stream // // needs to get its own packet vector. // CalculatorGraph graph; -// CHECK_OK(graph.Initialize(config)); +// ABSL_CHECK_OK(graph.Initialize(config)); // // Set other input side packets. -// CHECK_OK(graph.Run()); +// ABSL_CHECK_OK(graph.Run()); // for (const Packet& packet : packet_dump) { // // Do something. // } @@ -158,7 +160,7 @@ void AddCallbackWithHeaderCalculator(const std::string& stream_name, // tool::AddCallbackCalculator("the_output_stream", &config, // &input_side_packet_name, true); // CalculatorGraph graph(config); -// CHECK_OK(graph.Run( +// ABSL_CHECK_OK(graph.Run( // {{input_side_packet_name, // MakePacket>( // std::bind(&MyClass::MyFunction, this, std::placeholders::_1))}} @@ -205,6 +207,16 @@ class CallbackWithHeaderCalculator : public CalculatorBase { Packet header_packet_; }; +// Produces an output packet with the PostStream timestamp containing the +// input side packet. +class MediaPipeInternalSidePacketToPacketStreamCalculator + : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) final; + absl::Status Process(CalculatorContext* cc) final; +}; + } // namespace tool } // namespace mediapipe diff --git a/mediapipe/framework/tool/status_util.cc b/mediapipe/framework/tool/status_util.cc index 401a1b63c..19f3fc6b7 100644 --- a/mediapipe/framework/tool/status_util.cc +++ b/mediapipe/framework/tool/status_util.cc @@ -16,6 +16,7 @@ #include +#include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" @@ -59,8 +60,8 @@ absl::Status CombinedStatus(absl::string_view general_comment, } } } - if (error_code == StatusCode::kOk) return OkStatus(); - Status combined; + if (error_code == absl::StatusCode::kOk) return absl::OkStatus(); + absl::Status combined; combined = absl::Status( error_code, absl::StrCat(general_comment, "\n", absl::StrJoin(errors, "\n"))); diff --git a/mediapipe/framework/tool/switch_container.cc b/mediapipe/framework/tool/switch_container.cc index daa129928..29307c4f9 100644 --- a/mediapipe/framework/tool/switch_container.cc +++ b/mediapipe/framework/tool/switch_container.cc @@ -20,6 +20,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" @@ -148,7 +149,7 @@ void ClearContainerOptions(CalculatorGraphConfig::Node* dest) { // Returns an unused name similar to a specified name. std::string UniqueName(std::string name, std::set* names) { - CHECK(names != nullptr); + ABSL_CHECK(names != nullptr); std::string result = name; int suffix = 2; while (names->count(result) > 0) { @@ -161,7 +162,7 @@ std::string UniqueName(std::string name, std::set* names) { // Parses tag, index, and name from a list of stream identifiers. void ParseTags(const proto_ns::RepeatedPtrField& streams, std::map* result) { - CHECK(result != nullptr); + ABSL_CHECK(result != nullptr); std::set used_names; int used_index = -1; for (const std::string& stream : streams) { @@ -177,14 +178,14 @@ void ParseTags(const proto_ns::RepeatedPtrField& streams, // Removes the entry for a tag and index from a map. void EraseTag(const std::string& stream, std::map* streams) { - CHECK(streams != nullptr); + ABSL_CHECK(streams != nullptr); streams->erase(ParseTagIndexFromStream(absl::StrCat(stream, ":u"))); } // Removes the entry for a tag and index from a list. void EraseTag(const std::string& stream, proto_ns::RepeatedPtrField* streams) { - CHECK(streams != nullptr); + ABSL_CHECK(streams != nullptr); TagIndex stream_tag = ParseTagIndexFromStream(absl::StrCat(stream, ":u")); for (int i = streams->size() - 1; i >= 0; --i) { TagIndex tag = ParseTagIndexFromStream(streams->at(i)); @@ -197,7 +198,7 @@ void EraseTag(const std::string& stream, // Returns the stream names for the container node. void GetContainerNodeStreams(const CalculatorGraphConfig::Node& node, CalculatorGraphConfig::Node* result) { - CHECK(result != nullptr); + ABSL_CHECK(result != nullptr); *result->mutable_input_stream() = node.input_stream(); *result->mutable_output_stream() = node.output_stream(); *result->mutable_input_side_packet() = node.input_side_packet(); diff --git a/mediapipe/framework/tool/switch_container_test.cc b/mediapipe/framework/tool/switch_container_test.cc index 08cc4ab5a..5ffd26e03 100644 --- a/mediapipe/framework/tool/switch_container_test.cc +++ b/mediapipe/framework/tool/switch_container_test.cc @@ -17,13 +17,13 @@ #include #include +#include "absl/log/absl_log.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/proto_ns.h" #include "mediapipe/framework/port/ret_check.h" @@ -385,7 +385,7 @@ TEST(SwitchContainerTest, RunsWithInputStreamHandler) { CalculatorGraphConfig supergraph = SubnodeContainerExample(R"pb(synchronize_io: true)pb"); MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph)); - LOG(INFO) << supergraph.DebugString(); + ABSL_LOG(INFO) << supergraph.DebugString(); RunTestContainer(supergraph, true); } diff --git a/mediapipe/framework/tool/tag_map_test.cc b/mediapipe/framework/tool/tag_map_test.cc index 20a9be966..68ee94ae7 100644 --- a/mediapipe/framework/tool/tag_map_test.cc +++ b/mediapipe/framework/tool/tag_map_test.cc @@ -14,6 +14,7 @@ #include "mediapipe/framework/tool/tag_map.h" +#include "absl/log/absl_log.h" #include "absl/strings/str_join.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -101,7 +102,7 @@ void TestSuccessTagMap(const std::vector& tag_index_names, EXPECT_EQ(tags.size(), tag_map->Mapping().size()) << "Parameters: in " << tag_map->DebugString(); for (int i = 0; i < tags.size(); ++i) { - EXPECT_TRUE(mediapipe::ContainsKey(tag_map->Mapping(), tags[i])) + EXPECT_TRUE(tag_map->Mapping().contains(tags[i])) << "Parameters: Trying to find \"" << tags[i] << "\" in\n" << tag_map->DebugString(); } @@ -329,8 +330,8 @@ void TestDebugString( tool::TagMap& tag_map = *statusor_tag_map.value(); std::string debug_string = tag_map.DebugString(); std::string short_string = tag_map.ShortDebugString(); - LOG(INFO) << "ShortDebugString:\n" << short_string << "\n"; - LOG(INFO) << "DebugString:\n" << debug_string << "\n\n"; + ABSL_LOG(INFO) << "ShortDebugString:\n" << short_string << "\n"; + ABSL_LOG(INFO) << "DebugString:\n" << debug_string << "\n\n"; std::vector actual_entries; for (const auto& field : tag_map.CanonicalEntries()) { diff --git a/mediapipe/framework/tool/template_expander.cc b/mediapipe/framework/tool/template_expander.cc index a91ea5adc..8f9ef6866 100644 --- a/mediapipe/framework/tool/template_expander.cc +++ b/mediapipe/framework/tool/template_expander.cc @@ -15,20 +15,16 @@ #include "mediapipe/framework/tool/template_expander.h" #include -#include #include #include -#include #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/strings/ascii.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" -#include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" #include "mediapipe/framework/calculator.pb.h" -#include "mediapipe/framework/port/canonical_errors.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/numbers.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" @@ -183,8 +179,8 @@ FieldType GetFieldType(const TemplateExpression& rule) { int FieldCount(const FieldValue& base, ProtoPath field_path, FieldType field_type) { int result = 0; - CHECK( - ProtoUtilLite::GetFieldCount(base, field_path, field_type, &result).ok()); + ABSL_CHECK_OK( + ProtoUtilLite::GetFieldCount(base, field_path, field_type, &result)); return result; } @@ -647,7 +643,7 @@ class TemplateExpanderImpl { for (int i = 0; i < args.size(); ++i) { if (args[i].has_dict()) { FieldValue dict_bytes; - CHECK(args[i].dict().SerializePartialToString(&dict_bytes)); + ABSL_CHECK(args[i].dict().SerializePartialToString(&dict_bytes)); result->push_back(dict_bytes); } else if (args[i].has_num() || args[i].has_str()) { std::string text_value = args[i].has_num() @@ -694,7 +690,7 @@ absl::Status TemplateExpander::ExpandTemplates( } absl::Status status; for (const absl::Status& error : errors_) { - LOG(ERROR) << error; + ABSL_LOG(ERROR) << error; status.Update(error); } return status; diff --git a/mediapipe/framework/tool/template_parser.cc b/mediapipe/framework/tool/template_parser.cc index f012ac418..d97ec0c2c 100644 --- a/mediapipe/framework/tool/template_parser.cc +++ b/mediapipe/framework/tool/template_parser.cc @@ -20,6 +20,9 @@ #include #include +#include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/strings/ascii.h" #include "absl/strings/numbers.h" @@ -30,7 +33,6 @@ #include "mediapipe/framework/deps/proto_descriptor.pb.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/integral_types.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/map_util.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" @@ -180,11 +182,11 @@ void CheckFieldIndex(const FieldDescriptor* field, int index) { } if (field->is_repeated() && index == -1) { - LOG(DFATAL) << "Index must be in range of repeated field values. " - << "Field: " << field->name(); + ABSL_LOG(ERROR) << "Index must be in range of repeated field values. " + << "Field: " << field->name(); } else if (!field->is_repeated() && index != -1) { - LOG(DFATAL) << "Index must be -1 for singular fields." - << "Field: " << field->name(); + ABSL_LOG(ERROR) << "Index must be -1 for singular fields." + << "Field: " << field->name(); } } @@ -304,7 +306,7 @@ class TemplateParser::Parser::ParserImpl { // Parses the ASCII representation specified in input and saves the // information into the output pointer (a Message). Returns // false if an error occurs (an error will also be logged to - // LOG(ERROR)). + // ABSL_LOG(ERROR)). virtual bool Parse(Message* output) { // Consume fields until we cannot do so anymore. while (true) { @@ -334,12 +336,12 @@ class TemplateParser::Parser::ParserImpl { had_errors_ = true; if (error_collector_ == NULL) { if (line >= 0) { - LOG(ERROR) << "Error parsing text-format " - << root_message_type_->full_name() << ": " << (line + 1) - << ":" << (col + 1) << ": " << message; + ABSL_LOG(ERROR) << "Error parsing text-format " + << root_message_type_->full_name() << ": " << (line + 1) + << ":" << (col + 1) << ": " << message; } else { - LOG(ERROR) << "Error parsing text-format " - << root_message_type_->full_name() << ": " << message; + ABSL_LOG(ERROR) << "Error parsing text-format " + << root_message_type_->full_name() << ": " << message; } } else { error_collector_->AddError(line, col, std::string(message)); @@ -349,12 +351,12 @@ class TemplateParser::Parser::ParserImpl { void ReportWarning(int line, int col, absl::string_view message) { if (error_collector_ == NULL) { if (line >= 0) { - LOG(WARNING) << "Warning parsing text-format " - << root_message_type_->full_name() << ": " << (line + 1) - << ":" << (col + 1) << ": " << message; + ABSL_LOG(WARNING) << "Warning parsing text-format " + << root_message_type_->full_name() << ": " + << (line + 1) << ":" << (col + 1) << ": " << message; } else { - LOG(WARNING) << "Warning parsing text-format " - << root_message_type_->full_name() << ": " << message; + ABSL_LOG(WARNING) << "Warning parsing text-format " + << root_message_type_->full_name() << ": " << message; } } else { error_collector_->AddWarning(line, col, std::string(message)); @@ -470,7 +472,7 @@ class TemplateParser::Parser::ParserImpl { "\" stored in google.protobuf.Any."); return false; } - DO(ConsumeAnyValue(value_descriptor, &serialized_value)); + DO(ConsumeAnyValue(any_value_field, value_descriptor, &serialized_value)); if (singular_overwrite_policy_ == FORBID_SINGULAR_OVERWRITES) { // Fail if any_type_url_field has already been specified. if ((!any_type_url_field->is_repeated() && @@ -564,7 +566,8 @@ class TemplateParser::Parser::ParserImpl { // Skips unknown or reserved fields. if (field == NULL) { - CHECK(allow_unknown_field_ || allow_unknown_extension_ || reserved_field); + ABSL_CHECK(allow_unknown_field_ || allow_unknown_extension_ || + reserved_field); // Try to guess the type of this field. // If this field is not a message, there should be a ":" between the @@ -708,7 +711,7 @@ class TemplateParser::Parser::ParserImpl { // If the parse information tree is not NULL, create a nested one // for the nested message. ParseInfoTree* parent = parse_info_tree_; - if (parent != NULL) { + if (parent) { parse_info_tree_ = parent->CreateNested(field); } @@ -883,7 +886,7 @@ class TemplateParser::Parser::ParserImpl { case FieldDescriptor::CPPTYPE_MESSAGE: { // We should never get here. Put here instead of a default // so that if new types are added, we get a nice compiler warning. - LOG(FATAL) << "Reached an unintended state: CPPTYPE_MESSAGE"; + ABSL_LOG(FATAL) << "Reached an unintended state: CPPTYPE_MESSAGE"; break; } } @@ -974,7 +977,7 @@ class TemplateParser::Parser::ParserImpl { } // Consumes an identifier and saves its value in the identifier parameter. - // Returns false if the token is not of type IDENTFIER. + // Returns false if the token is not of type IDENTIFIER. bool ConsumeIdentifier(std::string* identifier) { if (LookingAtType(io::Tokenizer::TYPE_IDENTIFIER)) { *identifier = tokenizer_.current().text; @@ -1190,8 +1193,20 @@ class TemplateParser::Parser::ParserImpl { // A helper function for reconstructing Any::value. Consumes a text of // full_type_name, then serializes it into serialized_value. - bool ConsumeAnyValue(const Descriptor* value_descriptor, + bool ConsumeAnyValue(const FieldDescriptor* field, + const Descriptor* value_descriptor, std::string* serialized_value) { + if (--recursion_limit_ < 0) { + ReportError("Message is too deep"); + return false; + } + // If the parse information tree is not NULL, create a nested one + // for the nested message. + ParseInfoTree* parent = parse_info_tree_; + if (parent) { + parse_info_tree_ = parent->CreateNested(field); + } + DynamicMessageFactory factory; const Message* value_prototype = factory.GetPrototype(value_descriptor); if (value_prototype == NULL) { @@ -1213,6 +1228,11 @@ class TemplateParser::Parser::ParserImpl { } value->AppendToString(serialized_value); } + + ++recursion_limit_; + + // Reset the parse information tree. + parse_info_tree_ = parent; return true; } @@ -1379,7 +1399,7 @@ bool DeterministicallySerialize(const Message& proto, std::string* result) { void SerializeField(const Message* message, const FieldDescriptor* field, std::vector* result) { ProtoUtilLite::FieldValue message_bytes; - CHECK(DeterministicallySerialize(*message, &message_bytes)); + ABSL_CHECK(DeterministicallySerialize(*message, &message_bytes)); ProtoUtilLite::FieldAccess access( field->number(), static_cast(field->type())); MEDIAPIPE_CHECK_OK(access.SetMessage(message_bytes)); @@ -1430,10 +1450,10 @@ std::vector GetFields(const Message* src) { // Orders map entries in dst to match src. void OrderMapEntries(const Message* src, Message* dst, - std::set* seen = nullptr) { - std::unique_ptr> seen_owner; + absl::flat_hash_set* seen = nullptr) { + std::unique_ptr> seen_owner; if (!seen) { - seen_owner = std::make_unique>(); + seen_owner = std::make_unique>(); seen = seen_owner.get(); } if (seen->count(src) > 0) { @@ -1672,7 +1692,9 @@ class TemplateParser::Parser::MediaPipeParserImpl if (field_type == ProtoUtilLite::FieldType::TYPE_MESSAGE) { *args = {""}; } else { - MEDIAPIPE_CHECK_OK(ProtoUtilLite::Serialize({"1"}, field_type, args)); + constexpr char kPlaceholderValue[] = "1"; + MEDIAPIPE_CHECK_OK( + ProtoUtilLite::Serialize({kPlaceholderValue}, field_type, args)); } } @@ -1682,13 +1704,13 @@ class TemplateParser::Parser::MediaPipeParserImpl const std::vector& args) { auto field_type = static_cast(field->type()); ProtoUtilLite::FieldValue message_bytes; - CHECK(message->SerializePartialToString(&message_bytes)); + ABSL_CHECK(message->SerializePartialToString(&message_bytes)); int count; MEDIAPIPE_CHECK_OK(ProtoUtilLite::GetFieldCount( message_bytes, {{field->number(), 0}}, field_type, &count)); MEDIAPIPE_CHECK_OK(ProtoUtilLite::ReplaceFieldRange( &message_bytes, {{field->number(), count}}, 0, field_type, args)); - CHECK(message->ParsePartialFromString(message_bytes)); + ABSL_CHECK(message->ParsePartialFromString(message_bytes)); } // Parse and record a template definition for the current field path. diff --git a/mediapipe/framework/tool/test_util.cc b/mediapipe/framework/tool/test_util.cc index 5642941e9..e5fac11ae 100644 --- a/mediapipe/framework/tool/test_util.cc +++ b/mediapipe/framework/tool/test_util.cc @@ -22,10 +22,13 @@ #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/substitute.h" #include "mediapipe/framework/calculator.pb.h" @@ -34,7 +37,6 @@ #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/port/advanced_proto_inc.h" #include "mediapipe/framework/port/file_helpers.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/proto_ns.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status_macros.h" @@ -58,8 +60,8 @@ absl::Status CompareDiff(const ImageFrame& image1, const ImageFrame& image2, const float max_avg_diff, std::unique_ptr& diff_image) { // Verify image byte depth matches expected byte depth. - CHECK_EQ(sizeof(T), image1.ByteDepth()); - CHECK_EQ(sizeof(T), image2.ByteDepth()); + ABSL_CHECK_EQ(sizeof(T), image1.ByteDepth()); + ABSL_CHECK_EQ(sizeof(T), image2.ByteDepth()); const int width = image1.Width(); const int height = image1.Height(); @@ -70,8 +72,8 @@ absl::Status CompareDiff(const ImageFrame& image1, const ImageFrame& image2, const int num_channels = std::min(channels1, channels2); // Verify the width steps are multiples of byte depth. - CHECK_EQ(image1.WidthStep() % image1.ByteDepth(), 0); - CHECK_EQ(image2.WidthStep() % image2.ByteDepth(), 0); + ABSL_CHECK_EQ(image1.WidthStep() % image1.ByteDepth(), 0); + ABSL_CHECK_EQ(image2.WidthStep() % image2.ByteDepth(), 0); const int width_padding1 = image1.WidthStep() / image1.ByteDepth() - width * channels1; const int width_padding2 = @@ -142,7 +144,7 @@ absl::Status CompareDiff(const ImageFrame& image1, const ImageFrame& image2, std::string GetBinaryDirectory() { char full_path[PATH_MAX + 1]; int length = readlink("/proc/self/exe", full_path, PATH_MAX + 1); - CHECK_GT(length, 0); + ABSL_CHECK_GT(length, 0); return std::string( ::mediapipe::file::Dirname(absl::string_view(full_path, length))); } @@ -195,7 +197,7 @@ absl::Status CompareImageFrames(const ImageFrame& image1, return CompareDiff(image1, image2, max_color_diff, max_alpha_diff, max_avg_diff, diff_image); default: - LOG(FATAL) << ImageFrame::InvalidFormatString(image1.Format()); + ABSL_LOG(FATAL) << ImageFrame::InvalidFormatString(image1.Format()); } } @@ -227,7 +229,9 @@ absl::Status CompareAndSaveImageOutput( auto status = CompareImageFrames(**expected, actual, options.max_color_diff, options.max_alpha_diff, options.max_avg_diff, diff_img); - ASSIGN_OR_RETURN(auto diff_img_path, SavePngTestOutput(*diff_img, "diff")); + if (diff_img) { + ASSIGN_OR_RETURN(auto diff_img_path, SavePngTestOutput(*diff_img, "diff")); + } return status; } @@ -311,6 +315,13 @@ std::unique_ptr LoadTestPng(absl::string_view path, // Returns the path to the output if successful. absl::StatusOr SavePngTestOutput( const mediapipe::ImageFrame& image, absl::string_view prefix) { + absl::flat_hash_set supported_formats = { + ImageFormat::GRAY8, ImageFormat::SRGB, ImageFormat::SRGBA, + ImageFormat::LAB8, ImageFormat::SBGRA}; + if (!supported_formats.contains(image.Format())) { + return absl::CancelledError( + absl::StrFormat("Format %d can not be saved to PNG.", image.Format())); + } std::string now_string = absl::FormatTime(absl::Now()); std::string output_relative_path = absl::StrCat(prefix, "_", now_string, ".png"); @@ -326,15 +337,15 @@ absl::StatusOr SavePngTestOutput( bool LoadTestGraph(CalculatorGraphConfig* proto, const std::string& path) { int fd = open(path.c_str(), O_RDONLY); if (fd == -1) { - LOG(ERROR) << "could not open test graph: " << path - << ", error: " << strerror(errno); + ABSL_LOG(ERROR) << "could not open test graph: " << path + << ", error: " << strerror(errno); return false; } proto_ns::io::FileInputStream input(fd); bool success = proto->ParseFromZeroCopyStream(&input); close(fd); if (!success) { - LOG(ERROR) << "could not parse test graph: " << path; + ABSL_LOG(ERROR) << "could not parse test graph: " << path; } return success; } @@ -345,7 +356,7 @@ std::unique_ptr GenerateLuminanceImage( const int height = original_image.Height(); const int channels = original_image.NumberOfChannels(); if (channels != 3 && channels != 4) { - LOG(ERROR) << "Invalid number of image channels: " << channels; + ABSL_LOG(ERROR) << "Invalid number of image channels: " << channels; return nullptr; } auto luminance_image = diff --git a/mediapipe/framework/tool/text_to_binary_graph.cc b/mediapipe/framework/tool/text_to_binary_graph.cc index b6b38dea7..046f07518 100644 --- a/mediapipe/framework/tool/text_to_binary_graph.cc +++ b/mediapipe/framework/tool/text_to_binary_graph.cc @@ -21,9 +21,11 @@ #include "absl/flags/flag.h" #include "absl/flags/parse.h" +#include "absl/log/absl_log.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/port/advanced_proto_inc.h" #include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" @@ -33,10 +35,10 @@ ABSL_FLAG(std::string, proto_source, "", ABSL_FLAG(std::string, proto_output, "", "An output template file in binary CalculatorGraphTemplate form."); -#define EXIT_IF_ERROR(status) \ - if (!status.ok()) { \ - LOG(ERROR) << status; \ - return EXIT_FAILURE; \ +#define EXIT_IF_ERROR(status) \ + if (!status.ok()) { \ + ABSL_LOG(ERROR) << status; \ + return EXIT_FAILURE; \ } namespace mediapipe { diff --git a/mediapipe/framework/tool/validate_type.cc b/mediapipe/framework/tool/validate_type.cc index 4c97a310a..38c04fa87 100644 --- a/mediapipe/framework/tool/validate_type.cc +++ b/mediapipe/framework/tool/validate_type.cc @@ -18,6 +18,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_framework.h" @@ -78,7 +79,7 @@ absl::Status RunGenerateAndValidateTypes( const PacketGeneratorOptions& extendable_options, const PacketSet& input_side_packets, PacketSet* output_side_packets, const std::string& package) { - CHECK(output_side_packets); + ABSL_CHECK(output_side_packets); // Get static access to functions. ASSIGN_OR_RETURN( auto static_access, diff --git a/mediapipe/framework/type_map.h b/mediapipe/framework/type_map.h index e26efa039..f03f48ce7 100644 --- a/mediapipe/framework/type_map.h +++ b/mediapipe/framework/type_map.h @@ -64,6 +64,8 @@ #include #include "absl/base/macros.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/demangle.h" #include "mediapipe/framework/port/status.h" @@ -127,7 +129,7 @@ class StaticMap { } static void GetKeys(std::vector* keys) { - CHECK(keys); + ABSL_CHECK(keys); keys->clear(); const MapType& internal_map = GetMap()->internal_map_; for (typename MapType::const_iterator i = internal_map.begin(); @@ -158,12 +160,12 @@ class StaticMap { // Type has been already registered. const MediaPipeTypeData& existing_data = it->second.second; - CHECK_EQ(existing_data.type_id, value.type_id) + ABSL_CHECK_EQ(existing_data.type_id, value.type_id) << "Found inconsistent type ids (" << existing_data.type_id << " vs " << value.type_id << ") during mediapipe type registration. Previous definition at " << it->second.first << " and current definition at " << file_and_line; - CHECK_EQ(existing_data.type_string, value.type_string) + ABSL_CHECK_EQ(existing_data.type_string, value.type_string) << "Found inconsistent type strings (" << existing_data.type_string << " vs " << value.type_string << ") during mediapipe type registration. Previous registration at " @@ -171,29 +173,31 @@ class StaticMap { << file_and_line; if (value.serialize_fn && value.deserialize_fn) { // Doesn't allow to redefine the existing type serialization functions. - CHECK(!existing_data.serialize_fn && !existing_data.deserialize_fn) + ABSL_CHECK(!existing_data.serialize_fn && !existing_data.deserialize_fn) << "Attempting to redefine serialization functions of type " << value.type_string << ", that have been defined at " << it->second.first << ", at " << file_and_line; const std::string previous_file_and_line = it->second.first; it->second.first = file_and_line; it->second.second = value; - LOG(WARNING) << "Redo mediapipe type registration of type " - << value.type_string << " with serialization function at " - << file_and_line << ". It was registered at " - << previous_file_and_line; + ABSL_LOG(WARNING) << "Redo mediapipe type registration of type " + << value.type_string + << " with serialization function at " << file_and_line + << ". It was registered at " + << previous_file_and_line; } else if (!value.serialize_fn && !value.deserialize_fn) { // Prefers type registration with serialization functions. If type has // been registered with some serialization functions, the // non-serialization version will be ignored. - LOG(WARNING) << "Ignore mediapipe type registration of type " - << value.type_string << " at " << file_and_line - << ", since type has been registered with serialization " - "functions at " - << it->second.first; + ABSL_LOG(WARNING) + << "Ignore mediapipe type registration of type " + << value.type_string << " at " << file_and_line + << ", since type has been registered with serialization " + "functions at " + << it->second.first; } else { // Doesn't allow to only have one of serialize_fn and deserialize_fn. - LOG(FATAL) + ABSL_LOG(FATAL) << "Invalid mediapipe type registration at " << file_and_line << ". Serialization functions should be provided at the same time."; } @@ -241,9 +245,9 @@ class StaticMap { #define DEFINE_MEDIAPIPE_TYPE_MAP(MapName, KeyType) \ class MapName : public type_map_internal::StaticMap {}; // Defines a map from unique typeid number to MediaPipeTypeData. -DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeIdToMediaPipeTypeData, size_t); +DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeIdToMediaPipeTypeData, size_t) // Defines a map from unique type string to MediaPipeTypeData. -DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string); +DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string) // MEDIAPIPE_REGISTER_TYPE can be used to register a type. // Convention: @@ -272,17 +276,20 @@ DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string); #define MEDIAPIPE_REGISTER_TYPE(type, type_name, serialize_fn, deserialize_fn) \ SET_MEDIAPIPE_TYPE_MAP_VALUE( \ mediapipe::PacketTypeIdToMediaPipeTypeData, \ - mediapipe::tool::GetTypeHash< \ - mediapipe::type_map_internal::ReflectType::Type>(), \ + mediapipe::TypeId::Of< \ + mediapipe::type_map_internal::ReflectType::Type>() \ + .hash_code(), \ (mediapipe::MediaPipeTypeData{ \ - mediapipe::tool::GetTypeHash< \ - mediapipe::type_map_internal::ReflectType::Type>(), \ + mediapipe::TypeId::Of< \ + mediapipe::type_map_internal::ReflectType::Type>() \ + .hash_code(), \ type_name, serialize_fn, deserialize_fn})); \ SET_MEDIAPIPE_TYPE_MAP_VALUE( \ mediapipe::PacketTypeStringToMediaPipeTypeData, type_name, \ (mediapipe::MediaPipeTypeData{ \ - mediapipe::tool::GetTypeHash< \ - mediapipe::type_map_internal::ReflectType::Type>(), \ + mediapipe::TypeId::Of< \ + mediapipe::type_map_internal::ReflectType::Type>() \ + .hash_code(), \ type_name, serialize_fn, deserialize_fn})); // End define MEDIAPIPE_REGISTER_TYPE. diff --git a/mediapipe/framework/validated_graph_config.cc b/mediapipe/framework/validated_graph_config.cc index 15eac3209..4f9182474 100644 --- a/mediapipe/framework/validated_graph_config.cc +++ b/mediapipe/framework/validated_graph_config.cc @@ -15,8 +15,11 @@ #include "mediapipe/framework/validated_graph_config.h" #include +#include #include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -33,6 +36,7 @@ #include "mediapipe/framework/port/core_proto_inc.h" #include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/proto_ns.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/source_location.h" #include "mediapipe/framework/port/status.h" @@ -49,8 +53,6 @@ namespace mediapipe { -namespace { - // Create a debug string name for a set of edge. An edge can be either // a stream or a side packet. std::string DebugEdgeNames( @@ -78,6 +80,8 @@ std::string DebugName(const CalculatorGraphConfig::Node& node_config) { return name; } +namespace { + std::string DebugName(const PacketGeneratorConfig& node_config) { return absl::StrCat( "[", node_config.packet_generator(), ", ", @@ -98,7 +102,7 @@ std::string DebugName(const CalculatorGraphConfig& config, NodeTypeInfo::NodeType node_type, int node_index) { switch (node_type) { case NodeTypeInfo::NodeType::CALCULATOR: - return DebugName(config.node(node_index)); + return mediapipe::DebugName(config.node(node_index)); case NodeTypeInfo::NodeType::PACKET_GENERATOR: return DebugName(config.packet_generator(node_index)); case NodeTypeInfo::NodeType::GRAPH_INPUT_STREAM: @@ -108,8 +112,8 @@ std::string DebugName(const CalculatorGraphConfig& config, case NodeTypeInfo::NodeType::UNKNOWN: /* Fall through. */ {} } - LOG(FATAL) << "Unknown NodeTypeInfo::NodeType: " - << NodeTypeInfo::NodeTypeToString(node_type); + ABSL_LOG(FATAL) << "Unknown NodeTypeInfo::NodeType: " + << NodeTypeInfo::NodeTypeToString(node_type); } // Adds the ExecutorConfigs for predefined executors, if they are not in @@ -158,8 +162,8 @@ std::string NodeTypeInfo::NodeTypeToString(NodeType node_type) { case NodeTypeInfo::NodeType::UNKNOWN: return "Unknown Node"; } - LOG(FATAL) << "Unknown NodeTypeInfo::NodeType: " - << static_cast(node_type); + ABSL_LOG(FATAL) << "Unknown NodeTypeInfo::NodeType: " + << static_cast(node_type); } absl::Status NodeTypeInfo::Initialize( @@ -692,12 +696,13 @@ absl::Status ValidatedGraphConfig::AddInputStreamsForNode( if (edge_info.back_edge) { // A back edge was specified, but its output side was already seen. if (!need_sorting_ptr) { - LOG(WARNING) << "Input Stream \"" << name - << "\" for node with sorted index " << node_index - << " name " << node_type_info->Contract().GetNodeName() - << " is marked as a back edge, but its output stream is " - "already available. This means it was not necessary " - "to mark it as a back edge."; + ABSL_LOG(WARNING) + << "Input Stream \"" << name << "\" for node with sorted index " + << node_index << " name " + << node_type_info->Contract().GetNodeName() + << " is marked as a back edge, but its output stream is " + "already available. This means it was not necessary " + "to mark it as a back edge."; } } else { edge_info.upstream = iter->second; @@ -744,7 +749,7 @@ int ValidatedGraphConfig::SorterIndexForNode(NodeTypeInfo::NodeRef node) const { case NodeTypeInfo::NodeType::CALCULATOR: return generators_.size() + node.index; default: - CHECK(false); + ABSL_CHECK(false); } } @@ -900,8 +905,8 @@ absl::Status ValidatedGraphConfig::ValidateSidePacketTypes() { "\"$3\" but the connected output side packet will be of type \"$4\"", side_packet.name, NodeTypeInfo::NodeTypeToString(side_packet.parent_node.type), - mediapipe::DebugName(config_, side_packet.parent_node.type, - side_packet.parent_node.index), + DebugName(config_, side_packet.parent_node.type, + side_packet.parent_node.index), side_packet.packet_type->DebugTypeName(), output_side_packets_[side_packet.upstream] .packet_type->DebugTypeName())); diff --git a/mediapipe/framework/validated_graph_config.h b/mediapipe/framework/validated_graph_config.h index 95ecccbb4..ec46b62b4 100644 --- a/mediapipe/framework/validated_graph_config.h +++ b/mediapipe/framework/validated_graph_config.h @@ -16,15 +16,18 @@ #define MEDIAPIPE_FRAMEWORK_VALIDATED_GRAPH_CONFIG_H_ #include +#include #include #include "absl/container/flat_hash_set.h" +#include "google/protobuf/repeated_ptr_field.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/graph_service_manager.h" #include "mediapipe/framework/packet_generator.pb.h" #include "mediapipe/framework/packet_type.h" #include "mediapipe/framework/port/map_util.h" +#include "mediapipe/framework/port/proto_ns.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_builder.h" #include "mediapipe/framework/status_handler.pb.h" @@ -34,6 +37,12 @@ namespace mediapipe { class ValidatedGraphConfig; +std::string DebugEdgeNames( + const std::string& edge_type, + const proto_ns::RepeatedPtrField& edges); + +std::string DebugName(const CalculatorGraphConfig::Node& node_config); + // Type information for a graph node (Calculator, Generator, etc). class NodeTypeInfo { public: diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index c785e5624..b7c1a27d8 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -38,7 +38,10 @@ cc_library( srcs = ["gpu_service.cc"], hdrs = ["gpu_service.h"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:graph_service"] + select({ + deps = [ + "//mediapipe/framework:graph_service", + "@com_google_absl//absl/base:core_headers", + ] + select({ "//conditions:default": [ ":gpu_shared_data_internal", ], @@ -201,6 +204,8 @@ cc_library( "//mediapipe/framework/port:threadpool", "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/debugging:leak_check", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -232,6 +237,8 @@ cc_library( ":gpu_buffer_format", ":gpu_buffer_storage", ":gpu_buffer_storage_image_frame", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", # TODO: remove this dependency. Some other teams' tests # depend on having an indirect image_frame dependency, need to be @@ -292,6 +299,7 @@ cc_library( "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:logging", "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", ] + select({ @@ -328,6 +336,7 @@ cc_library( "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/port:logging", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", ] + select({ "//conditions:default": [ ":gl_base", @@ -364,6 +373,8 @@ cc_library( ":image_frame_view", "//mediapipe/objc:CFHolder", "//mediapipe/objc:util", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", ], ) @@ -394,6 +405,7 @@ cc_library( ":pixel_buffer_pool_util", "//mediapipe/framework/port:logging", "//mediapipe/objc:CFHolder", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/synchronization", ], ) @@ -417,6 +429,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/objc:CFHolder", "//mediapipe/objc:util", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/synchronization", ], ) @@ -433,6 +446,7 @@ cc_library( ":image_frame_view", "//mediapipe/framework/formats:frame_buffer", "//mediapipe/framework/formats:image_frame", + "@com_google_absl//absl/log:absl_check", ], ) @@ -472,8 +486,8 @@ cc_library( "//mediapipe/framework/formats:yuv_image", "//mediapipe/util/frame_buffer:frame_buffer_util", "//third_party/libyuv", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", ], ) @@ -630,6 +644,8 @@ cc_library( "//mediapipe/framework:executor", "//mediapipe/framework/deps:no_destructor", "//mediapipe/framework/port:ret_check", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", ] + select({ "//conditions:default": [], "//mediapipe:apple": [ @@ -751,6 +767,10 @@ cc_library( deps = [ ":gl_base", "//mediapipe/framework/port:logging", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -813,11 +833,12 @@ cc_library( "//mediapipe/framework/deps:registration", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image_frame", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:map_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", ] + select({ @@ -843,6 +864,8 @@ objc_library( "//mediapipe/objc:mediapipe_framework_ios", "//third_party/apple_frameworks:CoreVideo", "//third_party/apple_frameworks:Metal", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@google_toolbox_for_mac//:GTM_Defines", ], ) @@ -985,6 +1008,7 @@ cc_library( "//mediapipe/framework/api2:node", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/synchronization", ], alwayslink = 1, @@ -1116,7 +1140,7 @@ objc_library( alwayslink = 1, ) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" test_suite( name = "ios", @@ -1203,5 +1227,6 @@ mediapipe_cc_test( "//mediapipe/framework/formats:yuv_image", "//mediapipe/framework/port:gtest_main", "//third_party/libyuv", + "@com_google_absl//absl/log:absl_check", ], ) diff --git a/mediapipe/gpu/MPPMetalHelper.mm b/mediapipe/gpu/MPPMetalHelper.mm index c0703e6ee..c66483698 100644 --- a/mediapipe/gpu/MPPMetalHelper.mm +++ b/mediapipe/gpu/MPPMetalHelper.mm @@ -14,9 +14,11 @@ #import "mediapipe/gpu/MPPMetalHelper.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #import "mediapipe/gpu/gpu_buffer.h" -#import "mediapipe/gpu/graph_support.h" #import "mediapipe/gpu/gpu_service.h" +#import "mediapipe/gpu/graph_support.h" #import "mediapipe/gpu/metal_shared_resources.h" #import "GTMDefines.h" @@ -78,14 +80,13 @@ class MetalHelperLegacySupport { - (instancetype)initWithSidePackets:(const mediapipe::PacketSet&)inputSidePackets { auto cc = mediapipe::MetalHelperLegacySupport::GetCalculatorContext(); if (cc) { - CHECK_EQ(&inputSidePackets, &cc->InputSidePackets()); + ABSL_CHECK_EQ(&inputSidePackets, &cc->InputSidePackets()); return [self initWithCalculatorContext:cc]; } // TODO: remove when we can. - LOG(WARNING) - << "CalculatorContext not available. If this calculator uses " - "CalculatorBase, call initWithCalculatorContext instead."; + ABSL_LOG(WARNING) << "CalculatorContext not available. If this calculator uses " + "CalculatorBase, call initWithCalculatorContext instead."; mediapipe::GpuSharedData* gpu_shared = inputSidePackets.Tag(mediapipe::kGpuSharedTagName).Get(); @@ -96,14 +97,13 @@ class MetalHelperLegacySupport { + (absl::Status)setupInputSidePackets:(mediapipe::PacketTypeSet*)inputSidePackets { auto cc = mediapipe::MetalHelperLegacySupport::GetCalculatorContract(); if (cc) { - CHECK_EQ(inputSidePackets, &cc->InputSidePackets()); + ABSL_CHECK_EQ(inputSidePackets, &cc->InputSidePackets()); return [self updateContract:cc]; } // TODO: remove when we can. - LOG(WARNING) - << "CalculatorContract not available. If you're calling this " - "from a GetContract method, call updateContract instead."; + ABSL_LOG(WARNING) << "CalculatorContract not available. If you're calling this " + "from a GetContract method, call updateContract instead."; auto id = inputSidePackets->GetId(mediapipe::kGpuSharedTagName, 0); RET_CHECK(id.IsValid()) << "A " << mediapipe::kGpuSharedTagName @@ -180,7 +180,7 @@ class MetalHelperLegacySupport { NULL, _gpuResources->metal_shared().resources().mtlTextureCache, mediapipe::GetCVPixelBufferRef(gpuBuffer), NULL, metalPixelFormat, width, height, plane, &texture); - CHECK_EQ(err, kCVReturnSuccess); + ABSL_CHECK_EQ(err, kCVReturnSuccess); return texture; } diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc index 6e077ae6e..07ac7373a 100644 --- a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc @@ -17,6 +17,7 @@ #include #include "CoreFoundation/CFBase.h" +#include "absl/log/absl_check.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/objc/CFHolder.h" #include "mediapipe/objc/util.h" @@ -27,7 +28,7 @@ CvPixelBufferPoolWrapper::CvPixelBufferPoolWrapper( int width, int height, GpuBufferFormat format, CFTimeInterval maxAge, CvTextureCacheManager* texture_caches) { OSType cv_format = CVPixelFormatForGpuBufferFormat(format); - CHECK_NE(cv_format, -1) << "unsupported pixel format"; + ABSL_CHECK_NE(cv_format, -1) << "unsupported pixel format"; pool_ = MakeCFHolderAdopting( /* keep count is 0 because the age param keeps buffers around anyway */ CreateCVPixelBufferPool(width, height, cv_format, 0, maxAge)); @@ -58,7 +59,7 @@ CFHolder CvPixelBufferPoolWrapper::GetBuffer() { ++threshold; } } - CHECK(!err) << "Error creating pixel buffer: " << err; + ABSL_CHECK(!err) << "Error creating pixel buffer: " << err; count_ = threshold; return MakeCFHolderAdopting(buffer); } @@ -73,11 +74,11 @@ void CvPixelBufferPoolWrapper::Flush() { CVPixelBufferPoolFlush(*pool_, 0); } CFHolder CvPixelBufferPoolWrapper::CreateBufferWithoutPool( const internal::GpuBufferSpec& spec) { OSType cv_format = CVPixelFormatForGpuBufferFormat(spec.format); - CHECK_NE(cv_format, -1) << "unsupported pixel format"; + ABSL_CHECK_NE(cv_format, -1) << "unsupported pixel format"; CVPixelBufferRef buffer; CVReturn err = CreateCVPixelBufferWithoutPool(spec.width, spec.height, cv_format, &buffer); - CHECK(!err) << "Error creating pixel buffer: " << err; + ABSL_CHECK(!err) << "Error creating pixel buffer: " << err; return MakeCFHolderAdopting(buffer); } diff --git a/mediapipe/gpu/cv_texture_cache_manager.cc b/mediapipe/gpu/cv_texture_cache_manager.cc index b977a8993..0c4d2306c 100644 --- a/mediapipe/gpu/cv_texture_cache_manager.cc +++ b/mediapipe/gpu/cv_texture_cache_manager.cc @@ -14,6 +14,7 @@ #include "mediapipe/gpu/cv_texture_cache_manager.h" +#include "absl/log/absl_check.h" #include "mediapipe/framework/port/logging.h" namespace mediapipe { @@ -32,8 +33,8 @@ void CvTextureCacheManager::FlushTextureCaches() { void CvTextureCacheManager::RegisterTextureCache(CVTextureCacheType cache) { absl::MutexLock lock(&mutex_); - CHECK(std::find(texture_caches_.begin(), texture_caches_.end(), cache) == - texture_caches_.end()) + ABSL_CHECK(std::find(texture_caches_.begin(), texture_caches_.end(), cache) == + texture_caches_.end()) << "Attempting to register a texture cache twice"; texture_caches_.emplace_back(cache); } @@ -42,13 +43,13 @@ void CvTextureCacheManager::UnregisterTextureCache(CVTextureCacheType cache) { absl::MutexLock lock(&mutex_); auto it = std::find(texture_caches_.begin(), texture_caches_.end(), cache); - CHECK(it != texture_caches_.end()) + ABSL_CHECK(it != texture_caches_.end()) << "Attempting to unregister an unknown texture cache"; texture_caches_.erase(it); } CvTextureCacheManager::~CvTextureCacheManager() { - CHECK_EQ(texture_caches_.size(), 0) + ABSL_CHECK_EQ(texture_caches_.size(), 0) << "Failed to unregister texture caches before deleting manager"; } diff --git a/mediapipe/gpu/frame_buffer_view.h b/mediapipe/gpu/frame_buffer_view.h index 76d773a5e..a6192e521 100644 --- a/mediapipe/gpu/frame_buffer_view.h +++ b/mediapipe/gpu/frame_buffer_view.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/gpu/gl_calculator_helper.cc b/mediapipe/gpu/gl_calculator_helper.cc index 9b217ddfd..763ac387a 100644 --- a/mediapipe/gpu/gl_calculator_helper.cc +++ b/mediapipe/gpu/gl_calculator_helper.cc @@ -14,6 +14,8 @@ #include "mediapipe/gpu/gl_calculator_helper.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/legacy_calculator_support.h" @@ -36,7 +38,7 @@ void GlCalculatorHelper::InitializeInternal(CalculatorContext* cc, } absl::Status GlCalculatorHelper::Open(CalculatorContext* cc) { - CHECK(cc); + ABSL_CHECK(cc); auto gpu_service = cc->Service(kGpuService); RET_CHECK(gpu_service.IsAvailable()) << "GPU service not available. Did you forget to call " @@ -71,12 +73,12 @@ absl::Status GlCalculatorHelper::SetupInputSidePackets( PacketTypeSet* input_side_packets) { auto cc = LegacyCalculatorSupport::Scoped::current(); if (cc) { - CHECK_EQ(input_side_packets, &cc->InputSidePackets()); + ABSL_CHECK_EQ(input_side_packets, &cc->InputSidePackets()); return UpdateContract(cc); } // TODO: remove when we can. - LOG(WARNING) + ABSL_LOG(WARNING) << "CalculatorContract not available. If you're calling this " "from a GetContract method, call GlCalculatorHelper::UpdateContract " "instead."; @@ -183,9 +185,9 @@ GpuBuffer GlCalculatorHelper::GpuBufferCopyingImageFrame( const ImageFrame& image_frame) { #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER auto maybe_buffer = CreateCVPixelBufferCopyingImageFrame(image_frame); - // Converts absl::StatusOr to absl::Status since CHECK_OK() currently only - // deals with absl::Status in MediaPipe OSS. - CHECK_OK(maybe_buffer.status()); + // Converts absl::StatusOr to absl::Status since ABSL_CHECK_OK() currently + // only deals with absl::Status in MediaPipe OSS. + ABSL_CHECK_OK(maybe_buffer.status()); return GpuBuffer(std::move(maybe_buffer).value()); #else return GpuBuffer(GlTextureBuffer::Create(image_frame)); @@ -194,8 +196,8 @@ GpuBuffer GlCalculatorHelper::GpuBufferCopyingImageFrame( void GlCalculatorHelper::GetGpuBufferDimensions(const GpuBuffer& pixel_buffer, int* width, int* height) { - CHECK(width); - CHECK(height); + ABSL_CHECK(width); + ABSL_CHECK(height); *width = pixel_buffer.width(); *height = pixel_buffer.height(); } @@ -211,6 +213,18 @@ GlTexture GlCalculatorHelper::CreateDestinationTexture(int width, int height, return MapGpuBuffer(gpu_buffer, gpu_buffer.GetWriteView(0)); } +GlTexture GlCalculatorHelper::CreateDestinationTexture( + const ImageFrame& image_frame) { + // TODO: ensure buffer pool is used when creating textures out of + // ImageFrame. + GpuBuffer gpu_buffer = GpuBufferCopyingImageFrame(image_frame); + return MapGpuBuffer(gpu_buffer, gpu_buffer.GetWriteView(0)); +} + +GlTexture GlCalculatorHelper::CreateDestinationTexture(GpuBuffer& gpu_buffer) { + return MapGpuBuffer(gpu_buffer, gpu_buffer.GetWriteView(0)); +} + GlTexture GlCalculatorHelper::CreateSourceTexture( const mediapipe::Image& image) { return CreateSourceTexture(image.GetGpuBuffer()); diff --git a/mediapipe/gpu/gl_calculator_helper.h b/mediapipe/gpu/gl_calculator_helper.h index af897bbe9..b6430860f 100644 --- a/mediapipe/gpu/gl_calculator_helper.h +++ b/mediapipe/gpu/gl_calculator_helper.h @@ -135,6 +135,12 @@ class GlCalculatorHelper { // This is deprecated because: 1) it encourages the use of GlTexture as a // long-lived object; 2) it requires copying the ImageFrame's contents, // which may not always be necessary. + // + // WARNING: do NOT use as a destination texture which will be sent to + // downstream calculators as it may lead to synchronization issues. The result + // is meant to be a short-lived object, local to a single calculator and + // single GL thread. Use `CreateDestinationTexture` instead, if you need a + // destination texture. ABSL_DEPRECATED("Use `GpuBufferWithImageFrame`.") GlTexture CreateSourceTexture(const ImageFrame& image_frame); @@ -156,6 +162,17 @@ class GlCalculatorHelper { int output_width, int output_height, GpuBufferFormat format = GpuBufferFormat::kBGRA32); + // Allows user provided buffers to be used as rendering destinations. + GlTexture CreateDestinationTexture(GpuBuffer& buffer); + + // Creates a destination texture copying and uploading passed image frame. + // + // WARNING: mind that this functions creates a new texture every time and + // doesn't use MediaPipe's gpu buffer pool. + // TODO: ensure buffer pool is used when creating textures out of + // ImageFrame. + GlTexture CreateDestinationTexture(const ImageFrame& image_frame); + // The OpenGL name of the output framebuffer. GLuint framebuffer() const; @@ -196,7 +213,7 @@ class GlCalculatorHelper { // This class should be the main way to interface with GL memory within a single // calculator. This is the preferred way to utilize the memory pool inside of // the helper, because GlTexture manages efficiently releasing memory back into -// the pool. A GPU backed Image can be extracted from the unerlying +// the pool. A GPU backed Image can be extracted from the underlying // memory. class GlTexture { public: diff --git a/mediapipe/gpu/gl_context.cc b/mediapipe/gpu/gl_context.cc index 99b995dda..5eff88b92 100644 --- a/mediapipe/gpu/gl_context.cc +++ b/mediapipe/gpu/gl_context.cc @@ -22,10 +22,11 @@ #include #include "absl/base/dynamic_annotations.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_builder.h" @@ -59,27 +60,27 @@ static void SetThreadName(const char* name) { thread_name[sizeof(thread_name) - 1] = '\0'; int res = pthread_setname_np(pthread_self(), thread_name); if (res != 0) { - LOG_FIRST_N(INFO, 1) << "Can't set pthread names: name: \"" << name - << "\"; error: " << res; + ABSL_LOG_FIRST_N(INFO, 1) + << "Can't set pthread names: name: \"" << name << "\"; error: " << res; } #elif __APPLE__ pthread_setname_np(name); #endif - ANNOTATE_THREAD_NAME(name); + ABSL_ANNOTATE_THREAD_NAME(name); } GlContext::DedicatedThread::DedicatedThread() { - CHECK_EQ(pthread_create(&gl_thread_id_, nullptr, ThreadBody, this), 0); + ABSL_CHECK_EQ(pthread_create(&gl_thread_id_, nullptr, ThreadBody, this), 0); } GlContext::DedicatedThread::~DedicatedThread() { if (IsCurrentThread()) { - CHECK(self_destruct_); - CHECK_EQ(pthread_detach(gl_thread_id_), 0); + ABSL_CHECK(self_destruct_); + ABSL_CHECK_EQ(pthread_detach(gl_thread_id_), 0); } else { // Give an invalid job to signal termination. PutJob({}); - CHECK_EQ(pthread_join(gl_thread_id_, nullptr), 0); + ABSL_CHECK_EQ(pthread_join(gl_thread_id_, nullptr), 0); } } @@ -168,7 +169,7 @@ void GlContext::DedicatedThread::RunWithoutWaiting(GlVoidFunction gl_func) { // non-calculator tasks in the presence of GL source calculators, calculator // tasks must always be scheduled as new tasks, or another solution needs to // be set up to avoid starvation. See b/78522434. - CHECK(gl_func); + ABSL_CHECK(gl_func); PutJob(std::move(gl_func)); } @@ -236,9 +237,10 @@ absl::Status GlContext::GetGlExtensions() { // platforms to avoid possible undefined symbol or runtime errors. #if (GL_VERSION_3_0 || GL_ES_VERSION_3_0) && !defined(__EMSCRIPTEN__) if (!SymbolAvailable(&glGetStringi)) { - LOG(ERROR) << "GL major version > 3.0 indicated, but glGetStringi not " - << "defined. Falling back to deprecated GL extensions querying " - << "method."; + ABSL_LOG(ERROR) + << "GL major version > 3.0 indicated, but glGetStringi not " + << "defined. Falling back to deprecated GL extensions querying " + << "method."; return absl::InternalError("glGetStringi not defined, but queried"); } int num_extensions = 0; @@ -269,7 +271,7 @@ absl::Status GlContext::GetGlExtensionsCompat() { const GLubyte* res = glGetString(GL_EXTENSIONS); if (glGetError() != 0 || res == nullptr) { - LOG(ERROR) << "Error querying for GL extensions"; + ABSL_LOG(ERROR) << "Error querying for GL extensions"; return absl::InternalError("Error querying for GL extensions"); } const char* signed_res = reinterpret_cast(res); @@ -297,7 +299,7 @@ absl::Status GlContext::FinishInitialization(bool create_thread) { } else { // This may happen when using SwiftShader, but the numeric versions are // available and will be used instead. - LOG(WARNING) << "failed to get GL_VERSION string"; + ABSL_LOG(WARNING) << "failed to get GL_VERSION string"; } // We will decide later whether we want to use the version numbers we query @@ -315,8 +317,8 @@ absl::Status GlContext::FinishInitialization(bool create_thread) { // parse the version string. if (!ParseGlVersion(version_string, &gl_major_version_, &gl_minor_version_)) { - LOG(WARNING) << "invalid GL_VERSION format: '" << version_string - << "'; assuming 2.0"; + ABSL_LOG(WARNING) << "invalid GL_VERSION format: '" << version_string + << "'; assuming 2.0"; gl_major_version_ = 2; gl_minor_version_ = 0; } @@ -330,17 +332,19 @@ absl::Status GlContext::FinishInitialization(bool create_thread) { // for more details. if (gl_major_version_from_context_creation > 0 && gl_major_version_ != gl_major_version_from_context_creation) { - LOG(WARNING) << "Requested a context with major GL version " - << gl_major_version_from_context_creation - << " but context reports major version " << gl_major_version_ - << ". Setting to " << gl_major_version_from_context_creation - << ".0"; + ABSL_LOG(WARNING) << "Requested a context with major GL version " + << gl_major_version_from_context_creation + << " but context reports major version " + << gl_major_version_ << ". Setting to " + << gl_major_version_from_context_creation << ".0"; gl_major_version_ = gl_major_version_from_context_creation; gl_minor_version_ = 0; } - LOG(INFO) << "GL version: " << gl_major_version_ << "." << gl_minor_version_ - << " (" << version_string << ")"; + ABSL_LOG(INFO) << "GL version: " << gl_major_version_ << "." + << gl_minor_version_ << " (" << version_string + << "), renderer: " << glGetString(GL_RENDERER); + { auto status = GetGlExtensions(); if (!status.ok()) { @@ -387,7 +391,7 @@ GlContext::~GlContext() { clear_attachments(); return ExitContext(nullptr); }); - LOG_IF(ERROR, !status.ok()) + ABSL_LOG_IF(ERROR, !status.ok()) << "Failed to deactivate context on thread: " << status; if (thread_->IsCurrentThread()) { thread_.release()->SelfDestruct(); @@ -401,7 +405,7 @@ GlContext::~GlContext() { clear_attachments(); return absl::OkStatus(); }); - LOG_IF(ERROR, !status.ok()) << status; + ABSL_LOG_IF(ERROR, !status.ok()) << status; } } DestroyContext(); @@ -466,7 +470,7 @@ void GlContext::RunWithoutWaiting(GlVoidFunction gl_func) { return absl::OkStatus(); }); if (!status.ok()) { - LOG(ERROR) << "Error in RunWithoutWaiting: " << status; + ABSL_LOG(ERROR) << "Error in RunWithoutWaiting: " << status; } } } @@ -492,10 +496,10 @@ absl::Status GlContext::SwitchContext(ContextBinding* saved_context, } // Check that the context object is consistent with the native context. if (old_context_obj && saved_context) { - DCHECK(old_context_obj->context_ == saved_context->context); + ABSL_DCHECK(old_context_obj->context_ == saved_context->context); } if (new_context_obj) { - DCHECK(new_context_obj->context_ == new_context.context); + ABSL_DCHECK(new_context_obj->context_ == new_context.context); } if (new_context_obj && (old_context_obj == new_context_obj)) { @@ -535,7 +539,7 @@ GlContext::ContextBinding GlContext::ThisContextBinding() { } absl::Status GlContext::EnterContext(ContextBinding* saved_context) { - DCHECK(HasContext()); + ABSL_DCHECK(HasContext()); return SwitchContext(saved_context, ThisContextBinding()); } @@ -846,7 +850,7 @@ bool GlContext::IsAnyContextCurrent() { std::shared_ptr GlContext::CreateSyncTokenForCurrentExternalContext( const std::shared_ptr& delegate_graph_context) { - CHECK(delegate_graph_context); + ABSL_CHECK(delegate_graph_context); if (!IsAnyContextCurrent()) return nullptr; if (delegate_graph_context->ShouldUseFenceSync()) { return std::shared_ptr( @@ -897,7 +901,7 @@ void GlContext::WaitForGlFinishCountPast(int64_t count_to_pass) { // from the GlContext, and we must wait for gl_finish_count_ to pass it. // Therefore, we need to do at most one more glFinish call. This DCHECK // is used for documentation and sanity-checking purposes. - DCHECK(gl_finish_count_ >= count_to_pass); + ABSL_DCHECK(gl_finish_count_ >= count_to_pass); if (gl_finish_count_ == count_to_pass) { glFinish(); GlFinishCalled(); @@ -918,7 +922,7 @@ void GlContext::WaitForGlFinishCountPast(int64_t count_to_pass) { // it can signal the right condition variable if it is asked to do a // glFinish. absl::MutexLock other_lock(&other->mutex_); - DCHECK(!other->context_waiting_on_); + ABSL_DCHECK(!other->context_waiting_on_); other->context_waiting_on_ = this; } // We do not schedule this action using Run because we don't necessarily @@ -962,12 +966,12 @@ void GlContext::WaitForGlFinishCountPast(int64_t count_to_pass) { } void GlContext::WaitSyncToken(const std::shared_ptr& token) { - CHECK(token); + ABSL_CHECK(token); token->Wait(); } bool GlContext::SyncTokenIsReady(const std::shared_ptr& token) { - CHECK(token); + ABSL_CHECK(token); return token->IsReady(); } @@ -980,7 +984,7 @@ bool GlContext::CheckForGlErrors() { return CheckForGlErrors(false); } bool GlContext::CheckForGlErrors(bool force) { #if UNSAFE_EMSCRIPTEN_SKIP_GL_ERROR_HANDLING if (!force) { - LOG_FIRST_N(WARNING, 1) << "OpenGL error checking is disabled"; + ABSL_LOG_FIRST_N(WARNING, 1) << "OpenGL error checking is disabled"; return false; } #endif @@ -992,23 +996,23 @@ bool GlContext::CheckForGlErrors(bool force) { had_error = true; switch (error) { case GL_INVALID_ENUM: - LOG(INFO) << "Found unchecked GL error: GL_INVALID_ENUM"; + ABSL_LOG(INFO) << "Found unchecked GL error: GL_INVALID_ENUM"; break; case GL_INVALID_VALUE: - LOG(INFO) << "Found unchecked GL error: GL_INVALID_VALUE"; + ABSL_LOG(INFO) << "Found unchecked GL error: GL_INVALID_VALUE"; break; case GL_INVALID_OPERATION: - LOG(INFO) << "Found unchecked GL error: GL_INVALID_OPERATION"; + ABSL_LOG(INFO) << "Found unchecked GL error: GL_INVALID_OPERATION"; break; case GL_INVALID_FRAMEBUFFER_OPERATION: - LOG(INFO) + ABSL_LOG(INFO) << "Found unchecked GL error: GL_INVALID_FRAMEBUFFER_OPERATION"; break; case GL_OUT_OF_MEMORY: - LOG(INFO) << "Found unchecked GL error: GL_OUT_OF_MEMORY"; + ABSL_LOG(INFO) << "Found unchecked GL error: GL_OUT_OF_MEMORY"; break; default: - LOG(INFO) << "Found unchecked GL error: UNKNOWN ERROR"; + ABSL_LOG(INFO) << "Found unchecked GL error: UNKNOWN ERROR"; break; } } @@ -1020,16 +1024,16 @@ void GlContext::LogUncheckedGlErrors(bool had_gl_errors) { // TODO: ideally we would print a backtrace here, or at least // the name of the current calculator, to make it easier to find the // culprit. In practice, getting a backtrace from Android without crashing - // is nearly impossible, so screw it. Just change this to LOG(FATAL) when - // you want to debug. - LOG(WARNING) << "Ignoring unchecked GL error."; + // is nearly impossible, so screw it. Just change this to ABSL_LOG(FATAL) + // when you want to debug. + ABSL_LOG(WARNING) << "Ignoring unchecked GL error."; } } const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, int plane) { std::shared_ptr ctx = GlContext::GetCurrent(); - CHECK(ctx != nullptr); + ABSL_CHECK(ctx != nullptr); return GlTextureInfoForGpuBufferFormat(format, plane, ctx->GetGlVersion()); } diff --git a/mediapipe/gpu/gl_context.h b/mediapipe/gpu/gl_context.h index 4f2390404..bb3e6a597 100644 --- a/mediapipe/gpu/gl_context.h +++ b/mediapipe/gpu/gl_context.h @@ -22,6 +22,7 @@ #include #include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/executor.h" #include "mediapipe/framework/mediapipe_profiling.h" @@ -295,7 +296,7 @@ class GlContext : public std::enable_shared_from_this { // TOOD: const result? template T& GetCachedAttachment(const Attachment& attachment) { - DCHECK(IsCurrent()); + ABSL_DCHECK(IsCurrent()); internal::AttachmentPtr& entry = attachments_[&attachment]; if (entry == nullptr) { entry = attachment.factory()(*this); @@ -454,8 +455,8 @@ class GlContext : public std::enable_shared_from_this { // Number of glFinish calls completed on the GL thread. // Changes should be guarded by mutex_. However, we use simple atomic // loads for efficiency on the fast path. - std::atomic gl_finish_count_ = ATOMIC_VAR_INIT(0); - std::atomic gl_finish_count_target_ = ATOMIC_VAR_INIT(0); + std::atomic gl_finish_count_ = 0; + std::atomic gl_finish_count_target_ = 0; GlContext* context_waiting_on_ ABSL_GUARDED_BY(mutex_) = nullptr; diff --git a/mediapipe/gpu/gl_context_eagl.cc b/mediapipe/gpu/gl_context_eagl.cc index 865813c21..5beb9d49f 100644 --- a/mediapipe/gpu/gl_context_eagl.cc +++ b/mediapipe/gpu/gl_context_eagl.cc @@ -15,7 +15,6 @@ #include #include "absl/memory/memory.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_builder.h" diff --git a/mediapipe/gpu/gl_context_egl.cc b/mediapipe/gpu/gl_context_egl.cc index f8784bbbc..d573b6978 100644 --- a/mediapipe/gpu/gl_context_egl.cc +++ b/mediapipe/gpu/gl_context_egl.cc @@ -14,10 +14,11 @@ #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_builder.h" @@ -58,7 +59,7 @@ static void EglThreadExitCallback(void* key_value) { static void MakeEglReleaseThreadKey() { int err = pthread_key_create(&egl_release_thread_key, EglThreadExitCallback); if (err) { - LOG(ERROR) << "cannot create pthread key: " << err; + ABSL_LOG(ERROR) << "cannot create pthread key: " << err; } } @@ -81,8 +82,8 @@ static absl::StatusOr GetInitializedDefaultEglDisplay() { EGLint minor = 0; EGLBoolean egl_initialized = eglInitialize(display, &major, &minor); RET_CHECK(egl_initialized) << "Unable to initialize EGL"; - LOG(INFO) << "Successfully initialized EGL. Major : " << major - << " Minor: " << minor; + ABSL_LOG(INFO) << "Successfully initialized EGL. Major : " << major + << " Minor: " << minor; return display; } @@ -114,7 +115,7 @@ GlContext::StatusOrGlContext GlContext::Create(EGLContext share_context, absl::Status GlContext::CreateContextInternal(EGLContext share_context, int gl_version) { - CHECK(gl_version == 2 || gl_version == 3); + ABSL_CHECK(gl_version == 2 || gl_version == 3); const EGLint config_attr[] = { // clang-format off @@ -180,8 +181,9 @@ absl::Status GlContext::CreateContext(EGLContext share_context) { auto status = CreateContextInternal(share_context, 3); if (!status.ok()) { - LOG(WARNING) << "Creating a context with OpenGL ES 3 failed: " << status; - LOG(WARNING) << "Fall back on OpenGL ES 2."; + ABSL_LOG(WARNING) << "Creating a context with OpenGL ES 3 failed: " + << status; + ABSL_LOG(WARNING) << "Fall back on OpenGL ES 2."; status = CreateContextInternal(share_context, 2); } MP_RETURN_IF_ERROR(status); @@ -208,13 +210,13 @@ void GlContext::DestroyContext() { if (eglMakeCurrent(display_, surface_, surface_, context_)) { glUseProgram(0); } else { - LOG(ERROR) << "eglMakeCurrent() returned error " << std::showbase - << std::hex << eglGetError(); + ABSL_LOG(ERROR) << "eglMakeCurrent() returned error " << std::showbase + << std::hex << eglGetError(); } return SetCurrentContextBinding(saved_context); }; auto status = thread_ ? thread_->Run(detach_program) : detach_program(); - LOG_IF(ERROR, !status.ok()) << status; + ABSL_LOG_IF(ERROR, !status.ok()) << status; } #endif // __ANDROID__ @@ -236,21 +238,21 @@ void GlContext::DestroyContext() { if (IsCurrent()) { if (!eglMakeCurrent(display_, EGL_NO_SURFACE, EGL_NO_SURFACE, EGL_NO_CONTEXT)) { - LOG(ERROR) << "eglMakeCurrent() returned error " << std::showbase - << std::hex << eglGetError(); + ABSL_LOG(ERROR) << "eglMakeCurrent() returned error " << std::showbase + << std::hex << eglGetError(); } } if (surface_ != EGL_NO_SURFACE) { if (!eglDestroySurface(display_, surface_)) { - LOG(ERROR) << "eglDestroySurface() returned error " << std::showbase - << std::hex << eglGetError(); + ABSL_LOG(ERROR) << "eglDestroySurface() returned error " << std::showbase + << std::hex << eglGetError(); } surface_ = EGL_NO_SURFACE; } if (context_ != EGL_NO_CONTEXT) { if (!eglDestroyContext(display_, context_)) { - LOG(ERROR) << "eglDestroyContext() returned error " << std::showbase - << std::hex << eglGetError(); + ABSL_LOG(ERROR) << "eglDestroyContext() returned error " << std::showbase + << std::hex << eglGetError(); } context_ = EGL_NO_CONTEXT; } diff --git a/mediapipe/gpu/gl_context_nsgl.cc b/mediapipe/gpu/gl_context_nsgl.cc index 561474ad8..82d92a00a 100644 --- a/mediapipe/gpu/gl_context_nsgl.cc +++ b/mediapipe/gpu/gl_context_nsgl.cc @@ -14,8 +14,8 @@ #include +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_builder.h" @@ -83,7 +83,7 @@ absl::Status GlContext::CreateContext(NSOpenGLContext* share_context) { if (!pixel_format_) { // On several Forge machines, the default config fails. For now let's do // this. - LOG(WARNING) + ABSL_LOG(WARNING) << "failed to create pixel format; trying without acceleration"; NSOpenGLPixelFormatAttribute attrs_no_accel[] = {NSOpenGLPFAColorSize, 24, @@ -102,7 +102,8 @@ absl::Status GlContext::CreateContext(NSOpenGLContext* share_context) { // Try to query pixel format from shared context. if (!context_) { - LOG(WARNING) << "Requested context not created, using queried context."; + ABSL_LOG(WARNING) + << "Requested context not created, using queried context."; CGLContextObj cgl_ctx = static_cast([share_context CGLContextObj]); CGLPixelFormatObj cgl_fmt = diff --git a/mediapipe/gpu/gl_context_webgl.cc b/mediapipe/gpu/gl_context_webgl.cc index 25cbed83d..0f14581b6 100644 --- a/mediapipe/gpu/gl_context_webgl.cc +++ b/mediapipe/gpu/gl_context_webgl.cc @@ -14,6 +14,8 @@ #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" @@ -48,7 +50,7 @@ GlContext::StatusOrGlContext GlContext::Create( absl::Status GlContext::CreateContextInternal( EMSCRIPTEN_WEBGL_CONTEXT_HANDLE external_context, int webgl_version) { - CHECK(webgl_version == 1 || webgl_version == 2); + ABSL_CHECK(webgl_version == 1 || webgl_version == 2); EmscriptenWebGLContextAttributes attrs; emscripten_webgl_init_context_attributes(&attrs); @@ -78,7 +80,7 @@ absl::Status GlContext::CreateContextInternal( // Check for failure if (context_handle <= 0) { - LOG(INFO) << "Couldn't create webGL " << webgl_version << " context."; + ABSL_LOG(INFO) << "Couldn't create webGL " << webgl_version << " context."; return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "emscripten_webgl_create_context() returned error " << context_handle; @@ -103,32 +105,32 @@ absl::Status GlContext::CreateContext( auto status = CreateContextInternal(external_context, 2); if (!status.ok()) { - LOG(WARNING) << "Creating a context with WebGL 2 failed: " << status; - LOG(WARNING) << "Fall back on WebGL 1."; + ABSL_LOG(WARNING) << "Creating a context with WebGL 2 failed: " << status; + ABSL_LOG(WARNING) << "Fall back on WebGL 1."; status = CreateContextInternal(external_context, 1); } MP_RETURN_IF_ERROR(status); - LOG(INFO) << "Successfully created a WebGL context with major version " - << gl_major_version_ << " and handle " << context_; - + VLOG(1) << "Successfully created a WebGL context with major version " + << gl_major_version_ << " and handle " << context_; return absl::OkStatus(); } void GlContext::DestroyContext() { if (thread_) { // For now, we force web MediaPipe to be single-threaded, so error here. - LOG(ERROR) << "thread_ should not exist in DestroyContext() on web."; + ABSL_LOG(ERROR) << "thread_ should not exist in DestroyContext() on web."; } // Destroy the context and surface. if (context_ != 0) { EMSCRIPTEN_RESULT res = emscripten_webgl_destroy_context(context_); if (res != EMSCRIPTEN_RESULT_SUCCESS) { - LOG(ERROR) << "emscripten_webgl_destroy_context() returned error " << res; + ABSL_LOG(ERROR) << "emscripten_webgl_destroy_context() returned error " + << res; } else { - LOG(INFO) << "Successfully destroyed WebGL context with handle " - << context_; + ABSL_LOG(INFO) << "Successfully destroyed WebGL context with handle " + << context_; } context_ = 0; } diff --git a/mediapipe/gpu/gl_scaler_calculator.cc b/mediapipe/gpu/gl_scaler_calculator.cc index fa06c8854..14540b52d 100644 --- a/mediapipe/gpu/gl_scaler_calculator.cc +++ b/mediapipe/gpu/gl_scaler_calculator.cc @@ -104,6 +104,7 @@ class GlScalerCalculator : public CalculatorBase { bool vertical_flip_output_; bool horizontal_flip_output_; FrameScaleMode scale_mode_ = FrameScaleMode::kStretch; + bool use_nearest_neighbor_interpolation_ = false; }; REGISTER_CALCULATOR(GlScalerCalculator); @@ -186,7 +187,8 @@ absl::Status GlScalerCalculator::Open(CalculatorContext* cc) { scale_mode_ = FrameScaleModeFromProto(options.scale_mode(), FrameScaleMode::kStretch); } - + use_nearest_neighbor_interpolation_ = + options.use_nearest_neighbor_interpolation(); if (HasTagOrIndex(cc->InputSidePackets(), "OUTPUT_DIMENSIONS", 1)) { const auto& dimensions = TagOrIndex(cc->InputSidePackets(), "OUTPUT_DIMENSIONS", 1) @@ -297,6 +299,11 @@ absl::Status GlScalerCalculator::Process(CalculatorContext* cc) { glBindTexture(src2.target(), src2.name()); } + if (use_nearest_neighbor_interpolation_) { + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST); + } + MP_RETURN_IF_ERROR(renderer->GlRender( src1.width(), src1.height(), dst.width(), dst.height(), scale_mode_, rotation_, horizontal_flip_output_, vertical_flip_output_, diff --git a/mediapipe/gpu/gl_scaler_calculator.proto b/mediapipe/gpu/gl_scaler_calculator.proto index 99c0d439a..f746a30f8 100644 --- a/mediapipe/gpu/gl_scaler_calculator.proto +++ b/mediapipe/gpu/gl_scaler_calculator.proto @@ -19,7 +19,7 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; import "mediapipe/gpu/scale_mode.proto"; -// Next id: 8. +// Next id: 9. message GlScalerCalculatorOptions { extend CalculatorOptions { optional GlScalerCalculatorOptions ext = 166373014; @@ -39,4 +39,7 @@ message GlScalerCalculatorOptions { // Flip the output texture horizontally. This is applied after rotation. optional bool flip_horizontal = 5; optional ScaleMode.Mode scale_mode = 6; + // Whether to use nearest neighbor interpolation. Default to use linear + // interpolation. + optional bool use_nearest_neighbor_interpolation = 8 [default = false]; } diff --git a/mediapipe/gpu/gl_surface_sink_calculator.cc b/mediapipe/gpu/gl_surface_sink_calculator.cc index ad867c2be..f56dbc849 100644 --- a/mediapipe/gpu/gl_surface_sink_calculator.cc +++ b/mediapipe/gpu/gl_surface_sink_calculator.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/log/absl_log.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" @@ -39,6 +40,10 @@ enum { kAttribVertex, kAttribTexturePosition, kNumberOfAttributes }; // SURFACE: unique_ptr to an EglSurfaceHolder to draw to. // // See GlSurfaceSinkCalculatorOptions for options. +// +// NOTE: all GlSurfaceSinkCalculators use a common dedicated shared GL context +// thread by default, which is different from the main GL context thread used by +// the graph. (If MediaPipe uses multithreading and multiple OpenGL contexts.) class GlSurfaceSinkCalculator : public Node { public: static constexpr Input< @@ -95,7 +100,7 @@ absl::Status GlSurfaceSinkCalculator::Process(CalculatorContext* cc) { absl::MutexLock lock(&surface_holder_->mutex); EGLSurface surface = surface_holder_->surface; if (surface == EGL_NO_SURFACE) { - LOG_EVERY_N(INFO, 300) << "GlSurfaceSinkCalculator: no surface"; + ABSL_LOG_EVERY_N(INFO, 300) << "GlSurfaceSinkCalculator: no surface"; return absl::OkStatus(); } diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index 69b9889c7..48afbd219 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -14,6 +14,8 @@ #include "mediapipe/gpu/gl_texture_buffer.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/gpu/gl_context.h" #include "mediapipe/gpu/gl_texture_view.h" @@ -47,6 +49,7 @@ std::unique_ptr GlTextureBuffer::Create(int width, int height, auto buf = absl::make_unique(GL_TEXTURE_2D, 0, width, height, format, nullptr); if (!buf->CreateInternal(data, alignment)) { + ABSL_LOG(WARNING) << "Failed to create a GL texture"; return nullptr; } return buf; @@ -64,7 +67,7 @@ std::unique_ptr GlTextureBuffer::Create( int actual_ws = image_frame.WidthStep(); int alignment = 0; std::unique_ptr temp; - const uint8* data = image_frame.PixelData(); + const uint8_t* data = image_frame.PixelData(); // Let's see if the pixel data is tightly aligned to one of the alignments // supported by OpenGL, preferring 4 if possible since it's the default. @@ -106,7 +109,10 @@ GlTextureBuffer::GlTextureBuffer(GLenum target, GLuint name, int width, bool GlTextureBuffer::CreateInternal(const void* data, int alignment) { auto context = GlContext::GetCurrent(); - if (!context) return false; + if (!context) { + ABSL_LOG(WARNING) << "Cannot create a GL texture without a valid context"; + return false; + } producer_context_ = context; // Save creation GL context. @@ -123,7 +129,7 @@ bool GlTextureBuffer::CreateInternal(const void* data, int alignment) { if (info.gl_internal_format == GL_RGBA16F && context->GetGlVersion() != GlVersion::kGLES2 && SymbolAvailable(&glTexStorage2D)) { - CHECK(data == nullptr) << "unimplemented"; + ABSL_CHECK(data == nullptr) << "unimplemented"; glTexStorage2D(target_, 1, info.gl_internal_format, width_, height_); } else { glTexImage2D(target_, 0 /* level */, info.gl_internal_format, width_, @@ -145,10 +151,10 @@ bool GlTextureBuffer::CreateInternal(const void* data, int alignment) { // Use the deletion callback to delete the texture on the context // that created it. - CHECK(!deletion_callback_); + ABSL_CHECK(!deletion_callback_); deletion_callback_ = [this, context](std::shared_ptr sync_token) { - CHECK_NE(name_, 0); + ABSL_CHECK_NE(name_, 0); GLuint name_to_delete = name_; context->RunWithoutWaiting([name_to_delete]() { // Note that we do not wait for consumers to be done before deleting the @@ -167,7 +173,7 @@ bool GlTextureBuffer::CreateInternal(const void* data, int alignment) { // normal single-context behavior. E.g. if you do bind, delete, render, // unbind, the object is not deleted until the unbind, and it waits for // the render to finish. - DLOG_IF(ERROR, !glIsTexture(name_to_delete)) + ABSL_DLOG_IF(ERROR, !glIsTexture(name_to_delete)) << "Deleting invalid texture id: " << name_to_delete; glDeleteTextures(1, &name_to_delete); }); @@ -196,9 +202,9 @@ void GlTextureBuffer::Reuse() { } void GlTextureBuffer::Updated(std::shared_ptr prod_token) { - CHECK(!producer_sync_) + ABSL_CHECK(!producer_sync_) << "Updated existing texture which had not been marked for reuse!"; - CHECK(prod_token); + ABSL_CHECK(prod_token); producer_sync_ = std::move(prod_token); const auto& synced_context = producer_sync_->GetContext(); if (synced_context) { @@ -212,7 +218,7 @@ void GlTextureBuffer::DidRead(std::shared_ptr cons_token) const { consumer_multi_sync_->Add(std::move(cons_token)); } else { // TODO: change to a CHECK. - LOG_FIRST_N(WARNING, 5) << "unexpected null sync in DidRead"; + ABSL_LOG_FIRST_N(WARNING, 5) << "unexpected null sync in DidRead"; } } @@ -259,11 +265,11 @@ void GlTextureBuffer::WaitForConsumersOnGpu() { GlTextureView GlTextureBuffer::GetReadView(internal::types, int plane) const { auto gl_context = GlContext::GetCurrent(); - CHECK(gl_context); - CHECK_EQ(plane, 0); + ABSL_CHECK(gl_context); + ABSL_CHECK_EQ(plane, 0); // Note that this method is only supposed to be called by GpuBuffer, which // ensures this condition is satisfied. - DCHECK(!weak_from_this().expired()) + ABSL_DCHECK(!weak_from_this().expired()) << "GlTextureBuffer must be held in shared_ptr to get a GlTextureView"; // Insert wait call to sync with the producer. WaitOnGpu(); @@ -280,11 +286,11 @@ GlTextureView GlTextureBuffer::GetReadView(internal::types, GlTextureView GlTextureBuffer::GetWriteView(internal::types, int plane) { auto gl_context = GlContext::GetCurrent(); - CHECK(gl_context); - CHECK_EQ(plane, 0); + ABSL_CHECK(gl_context); + ABSL_CHECK_EQ(plane, 0); // Note that this method is only supposed to be called by GpuBuffer, which // ensures this condition is satisfied. - DCHECK(!weak_from_this().expired()) + ABSL_DCHECK(!weak_from_this().expired()) << "GlTextureBuffer must be held in shared_ptr to get a GlTextureView"; // Insert wait call to sync with the producer. WaitOnGpu(); @@ -341,7 +347,7 @@ static void ReadTexture(GlContext& ctx, const GlTextureView& view, // won't overflow the buffer with glReadPixels, we'd also need to check or // reset several glPixelStore parameters (e.g. what if someone had the // ill-advised idea of setting GL_PACK_SKIP_PIXELS?). - CHECK(view.gl_context()); + ABSL_CHECK(view.gl_context()); GlTextureInfo info = GlTextureInfoForGpuBufferFormat( format, view.plane(), view.gl_context()->GetGlVersion()); diff --git a/mediapipe/gpu/gl_texture_buffer.h b/mediapipe/gpu/gl_texture_buffer.h index f785571a1..7b9140646 100644 --- a/mediapipe/gpu/gl_texture_buffer.h +++ b/mediapipe/gpu/gl_texture_buffer.h @@ -91,9 +91,9 @@ class GlTextureBuffer // TODO: turn into a single call? GLuint name() const { return name_; } GLenum target() const { return target_; } - int width() const { return width_; } - int height() const { return height_; } - GpuBufferFormat format() const { return format_; } + int width() const override { return width_; } + int height() const override { return height_; } + GpuBufferFormat format() const override { return format_; } GlTextureView GetReadView(internal::types, int plane) const override; diff --git a/mediapipe/gpu/gpu_buffer.cc b/mediapipe/gpu/gpu_buffer.cc index 628e86099..0eb7a1c5d 100644 --- a/mediapipe/gpu/gpu_buffer.cc +++ b/mediapipe/gpu/gpu_buffer.cc @@ -4,6 +4,7 @@ #include #include "absl/functional/bind_front.h" +#include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "mediapipe/framework/port/logging.h" @@ -127,10 +128,11 @@ internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie( TypeId view_provider_type, bool for_writing) const { auto* chosen_storage = GpuBuffer::GetStorageForView(view_provider_type, for_writing); - CHECK(chosen_storage) << "no view provider found for requested view " - << view_provider_type.name() << "; storages available: " - << (holder_ ? holder_->DebugString() : "invalid"); - DCHECK(chosen_storage->can_down_cast_to(view_provider_type)); + ABSL_CHECK(chosen_storage) + << "no view provider found for requested view " + << view_provider_type.name() << "; storages available: " + << (holder_ ? holder_->DebugString() : "invalid"); + ABSL_DCHECK(chosen_storage->can_down_cast_to(view_provider_type)); return *chosen_storage; } diff --git a/mediapipe/gpu/gpu_buffer.h b/mediapipe/gpu/gpu_buffer.h index b9a88aa53..20cc05ead 100644 --- a/mediapipe/gpu/gpu_buffer.h +++ b/mediapipe/gpu/gpu_buffer.h @@ -20,6 +20,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/gpu/gpu_buffer_format.h" @@ -72,8 +73,10 @@ class GpuBuffer { // are not portable. Applications and calculators should normally obtain // GpuBuffers in a portable way from the framework, e.g. using // GpuBufferMultiPool. - explicit GpuBuffer(std::shared_ptr storage) - : holder_(std::make_shared(std::move(storage))) {} + explicit GpuBuffer(std::shared_ptr storage) { + ABSL_CHECK(storage) << "Cannot construct GpuBuffer with null storage"; + holder_ = std::make_shared(std::move(storage)); + } #if !MEDIAPIPE_DISABLE_GPU && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER // This is used to support backward-compatible construction of GpuBuffer from diff --git a/mediapipe/gpu/gpu_buffer_format.cc b/mediapipe/gpu/gpu_buffer_format.cc index a820f04d6..646fb383f 100644 --- a/mediapipe/gpu/gpu_buffer_format.cc +++ b/mediapipe/gpu/gpu_buffer_format.cc @@ -15,6 +15,7 @@ #include "mediapipe/gpu/gpu_buffer_format.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" #include "mediapipe/framework/deps/no_destructor.h" #include "mediapipe/framework/port/logging.h" @@ -28,6 +29,12 @@ namespace mediapipe { #define GL_HALF_FLOAT 0x140B #endif // GL_HALF_FLOAT +#ifdef __EMSCRIPTEN__ +#ifndef GL_HALF_FLOAT_OES +#define GL_HALF_FLOAT_OES 0x8D61 +#endif // GL_HALF_FLOAT_OES +#endif // __EMSCRIPTEN__ + #if !MEDIAPIPE_DISABLE_GPU #ifdef GL_ES_VERSION_2_0 static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) { @@ -48,6 +55,12 @@ static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) { case GL_RG8: info->gl_internal_format = info->gl_format = GL_RG_EXT; return; +#ifdef __EMSCRIPTEN__ + case GL_RGBA16F: + info->gl_internal_format = GL_RGBA; + info->gl_type = GL_HALF_FLOAT_OES; + return; +#endif // __EMSCRIPTEN__ default: return; } @@ -88,6 +101,10 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, {GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1}, #endif // TARGET_OS_OSX }}, + {GpuBufferFormat::kOneComponent8Alpha, + { + {GL_ALPHA, GL_ALPHA, GL_UNSIGNED_BYTE, 1}, + }}, {GpuBufferFormat::kOneComponent8Red, { {GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1}, @@ -173,16 +190,16 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, } auto iter = format_info->find(format); - CHECK(iter != format_info->end()) + ABSL_CHECK(iter != format_info->end()) << "unsupported format: " << static_cast>(format); const auto& planes = iter->second; #ifndef __APPLE__ - CHECK_EQ(planes.size(), 1) + ABSL_CHECK_EQ(planes.size(), 1) << "multiplanar formats are not supported on this platform"; #endif - CHECK_GE(plane, 0) << "invalid plane number"; - CHECK_LT(plane, planes.size()) << "invalid plane number"; + ABSL_CHECK_GE(plane, 0) << "invalid plane number"; + ABSL_CHECK_LT(plane, planes.size()) << "invalid plane number"; return planes[plane]; } #endif // MEDIAPIPE_DISABLE_GPU @@ -209,6 +226,7 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) { case GpuBufferFormat::kRGBA32: // TODO: this likely maps to ImageFormat::SRGBA case GpuBufferFormat::kGrayHalf16: + case GpuBufferFormat::kOneComponent8Alpha: case GpuBufferFormat::kOneComponent8Red: case GpuBufferFormat::kTwoComponent8: case GpuBufferFormat::kTwoComponentHalf16: diff --git a/mediapipe/gpu/gpu_buffer_format.h b/mediapipe/gpu/gpu_buffer_format.h index 5d77afeb6..06eabda77 100644 --- a/mediapipe/gpu/gpu_buffer_format.h +++ b/mediapipe/gpu/gpu_buffer_format.h @@ -43,6 +43,7 @@ enum class GpuBufferFormat : uint32_t { kGrayFloat32 = MEDIAPIPE_FOURCC('L', '0', '0', 'f'), kGrayHalf16 = MEDIAPIPE_FOURCC('L', '0', '0', 'h'), kOneComponent8 = MEDIAPIPE_FOURCC('L', '0', '0', '8'), + kOneComponent8Alpha = MEDIAPIPE_FOURCC('A', '0', '0', '8'), kOneComponent8Red = MEDIAPIPE_FOURCC('R', '0', '0', '8'), kTwoComponent8 = MEDIAPIPE_FOURCC('2', 'C', '0', '8'), kTwoComponentHalf16 = MEDIAPIPE_FOURCC('2', 'C', '0', 'h'), @@ -101,6 +102,7 @@ inline OSType CVPixelFormatForGpuBufferFormat(GpuBufferFormat format) { return kCVPixelFormatType_OneComponent32Float; case GpuBufferFormat::kOneComponent8: return kCVPixelFormatType_OneComponent8; + case GpuBufferFormat::kOneComponent8Alpha: case GpuBufferFormat::kOneComponent8Red: return -1; case GpuBufferFormat::kTwoComponent8: diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc index 7cac32b7f..ba048351b 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc @@ -2,6 +2,8 @@ #include +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "mediapipe/gpu/gl_context.h" #include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" #include "mediapipe/objc/util.h" @@ -17,11 +19,11 @@ typedef CVOpenGLESTextureRef CVTextureType; GpuBufferStorageCvPixelBuffer::GpuBufferStorageCvPixelBuffer( int width, int height, GpuBufferFormat format) { OSType cv_format = CVPixelFormatForGpuBufferFormat(format); - CHECK_NE(cv_format, -1) << "unsupported pixel format"; + ABSL_CHECK_NE(cv_format, -1) << "unsupported pixel format"; CVPixelBufferRef buffer; CVReturn err = CreateCVPixelBufferWithoutPool(width, height, cv_format, &buffer); - CHECK(!err) << "Error creating pixel buffer: " << err; + ABSL_CHECK(!err) << "Error creating pixel buffer: " << err; adopt(buffer); } @@ -29,13 +31,13 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetTexture( int plane, GlTextureView::DoneWritingFn done_writing) const { CVReturn err; auto gl_context = GlContext::GetCurrent(); - CHECK(gl_context); + ABSL_CHECK(gl_context); #if TARGET_OS_OSX CVTextureType cv_texture_temp; err = CVOpenGLTextureCacheCreateTextureFromImage( kCFAllocatorDefault, gl_context->cv_texture_cache(), **this, NULL, &cv_texture_temp); - CHECK(cv_texture_temp && !err) + ABSL_CHECK(cv_texture_temp && !err) << "CVOpenGLTextureCacheCreateTextureFromImage failed: " << err; CFHolder cv_texture; cv_texture.adopt(cv_texture_temp); @@ -53,7 +55,7 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetTexture( GL_TEXTURE_2D, info.gl_internal_format, width() / info.downscale, height() / info.downscale, info.gl_format, info.gl_type, plane, &cv_texture_temp); - CHECK(cv_texture_temp && !err) + ABSL_CHECK(cv_texture_temp && !err) << "CVOpenGLESTextureCacheCreateTextureFromImage failed: " << err; CFHolder cv_texture; cv_texture.adopt(cv_texture_temp); @@ -73,12 +75,12 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView( #if TARGET_IPHONE_SIMULATOR static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer, const GlTextureView& view) { - CHECK(pixel_buffer); + ABSL_CHECK(pixel_buffer); auto ctx = GlContext::GetCurrent().get(); if (!ctx) ctx = view.gl_context(); ctx->Run([pixel_buffer, &view, ctx] { CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); - CHECK(err == kCVReturnSuccess) + ABSL_CHECK(err == kCVReturnSuccess) << "CVPixelBufferLockBaseAddress failed: " << err; OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer); @@ -113,10 +115,10 @@ static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer, view.target(), 0, 0); glBindFramebuffer(GL_FRAMEBUFFER, 0); } else { - LOG(ERROR) << "unsupported pixel format: " << pixel_format; + ABSL_LOG(ERROR) << "unsupported pixel format: " << pixel_format; } err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); - CHECK(err == kCVReturnSuccess) + ABSL_CHECK(err == kCVReturnSuccess) << "CVPixelBufferUnlockBaseAddress failed: " << err; }); } @@ -149,7 +151,7 @@ static std::shared_ptr ConvertFromImageFrame( std::shared_ptr frame) { auto status_or_buffer = CreateCVPixelBufferForImageFrame(frame->image_frame()); - CHECK(status_or_buffer.ok()); + ABSL_CHECK(status_or_buffer.ok()); return std::make_shared( std::move(status_or_buffer).value()); } diff --git a/mediapipe/gpu/gpu_buffer_storage_image_frame.cc b/mediapipe/gpu/gpu_buffer_storage_image_frame.cc index 1cd661d37..7f46e2975 100644 --- a/mediapipe/gpu/gpu_buffer_storage_image_frame.cc +++ b/mediapipe/gpu/gpu_buffer_storage_image_frame.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/log/absl_check.h" #include "mediapipe/framework/formats/frame_buffer.h" #include "mediapipe/framework/formats/image_frame.h" @@ -43,7 +44,7 @@ std::shared_ptr ImageFrameToFrameBuffer( std::shared_ptr image_frame) { FrameBuffer::Format format = FrameBufferFormatForImageFrameFormat(image_frame->Format()); - CHECK(format != FrameBuffer::Format::kUNKNOWN) + ABSL_CHECK(format != FrameBuffer::Format::kUNKNOWN) << "Invalid format. Only SRGB, SRGBA and GRAY8 are supported."; const FrameBuffer::Dimension dimension{/*width=*/image_frame->Width(), /*height=*/image_frame->Height()}; diff --git a/mediapipe/gpu/gpu_buffer_storage_image_frame.h b/mediapipe/gpu/gpu_buffer_storage_image_frame.h index 542791f98..3b805e8f2 100644 --- a/mediapipe/gpu/gpu_buffer_storage_image_frame.h +++ b/mediapipe/gpu/gpu_buffer_storage_image_frame.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/gpu/gpu_buffer_storage_yuv_image.cc b/mediapipe/gpu/gpu_buffer_storage_yuv_image.cc index 4b0913b96..87fb8957d 100644 --- a/mediapipe/gpu/gpu_buffer_storage_yuv_image.cc +++ b/mediapipe/gpu/gpu_buffer_storage_yuv_image.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "absl/log/check.h" -#include "absl/log/log.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "libyuv/video_common.h" #include "mediapipe/framework/formats/frame_buffer.h" #include "mediapipe/framework/formats/image_frame.h" @@ -87,7 +87,7 @@ std::shared_ptr YuvImageToFrameBuffer( FrameBuffer::Dimension dimension{/*width=*/yuv_image->width(), /*height=*/yuv_image->height()}; std::vector planes; - CHECK(yuv_image->mutable_data(0) != nullptr && yuv_image->stride(0) > 0) + ABSL_CHECK(yuv_image->mutable_data(0) != nullptr && yuv_image->stride(0) > 0) << "Invalid YuvImage. Expected plane at index 0 to be non-null and have " "stride > 0."; planes.emplace_back( @@ -97,7 +97,8 @@ std::shared_ptr YuvImageToFrameBuffer( switch (format) { case FrameBuffer::Format::kNV12: case FrameBuffer::Format::kNV21: { - CHECK(yuv_image->mutable_data(1) != nullptr && yuv_image->stride(1) > 0) + ABSL_CHECK(yuv_image->mutable_data(1) != nullptr && + yuv_image->stride(1) > 0) << "Invalid YuvImage. Expected plane at index 1 to be non-null and " "have stride > 0."; planes.emplace_back( @@ -108,8 +109,9 @@ std::shared_ptr YuvImageToFrameBuffer( } case FrameBuffer::Format::kYV12: case FrameBuffer::Format::kYV21: { - CHECK(yuv_image->mutable_data(1) != nullptr && yuv_image->stride(1) > 0 && - yuv_image->mutable_data(2) != nullptr && yuv_image->stride(2) > 0) + ABSL_CHECK( + yuv_image->mutable_data(1) != nullptr && yuv_image->stride(1) > 0 && + yuv_image->mutable_data(2) != nullptr && yuv_image->stride(2) > 0) << "Invalid YuvImage. Expected planes at indices 1 and 2 to be " "non-null and have stride > 0."; planes.emplace_back( @@ -123,7 +125,7 @@ std::shared_ptr YuvImageToFrameBuffer( break; } default: - LOG(FATAL) + ABSL_LOG(FATAL) << "Invalid format. Only FOURCC_NV12, FOURCC_NV21, FOURCC_YV12 and " "FOURCC_I420 are supported."; } @@ -148,7 +150,7 @@ std::shared_ptr YuvImageToImageFrame( auto rgb_buffer = FrameBuffer(planes, yuv_buffer->dimension(), FrameBuffer::Format::kRGB); // Convert. - CHECK_OK(frame_buffer::Convert(*yuv_buffer, &rgb_buffer)); + ABSL_CHECK_OK(frame_buffer::Convert(*yuv_buffer, &rgb_buffer)); return image_frame; } @@ -156,8 +158,8 @@ std::shared_ptr YuvImageToImageFrame( GpuBufferStorageYuvImage::GpuBufferStorageYuvImage( std::shared_ptr yuv_image) { - CHECK(GpuBufferFormatForFourCC(yuv_image->fourcc()) != - GpuBufferFormat::kUnknown) + ABSL_CHECK(GpuBufferFormatForFourCC(yuv_image->fourcc()) != + GpuBufferFormat::kUnknown) << "Invalid format. Only FOURCC_NV12, FOURCC_NV21, FOURCC_YV12 and " "FOURCC_I420 are supported."; yuv_image_ = yuv_image; @@ -195,7 +197,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height, break; } default: - LOG(FATAL) + ABSL_LOG(FATAL) << "Invalid format. Only kNV12, kNV21, kYV12 and kYV21 are supported"; } } @@ -223,6 +225,6 @@ std::shared_ptr GpuBufferStorageYuvImage::GetWriteView( internal::types) { // Not supported on purpose: writes into the resulting ImageFrame cannot // easily be ported back to the original YUV image. - LOG(FATAL) << "GetWriteView is not supported."; + ABSL_LOG(FATAL) << "GetWriteView is not supported."; } } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_storage_yuv_image.h b/mediapipe/gpu/gpu_buffer_storage_yuv_image.h index 6b34f4948..cf6ffcd0e 100644 --- a/mediapipe/gpu/gpu_buffer_storage_yuv_image.h +++ b/mediapipe/gpu/gpu_buffer_storage_yuv_image.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/gpu/gpu_buffer_to_image_frame_calculator.cc b/mediapipe/gpu/gpu_buffer_to_image_frame_calculator.cc index c9527880a..7222273e6 100644 --- a/mediapipe/gpu/gpu_buffer_to_image_frame_calculator.cc +++ b/mediapipe/gpu/gpu_buffer_to_image_frame_calculator.cc @@ -27,6 +27,22 @@ namespace mediapipe { // Convert an input image (GpuBuffer or ImageFrame) to ImageFrame. +// +// NOTE: all GpuBufferToImageFrameCalculators use a common dedicated shared GL +// context thread by default, which is different from the main GL context thread +// used by the graph. (If MediaPipe uses multithreading and multiple OpenGL +// contexts.) +// +// IMPORTANT: graph writer must make sure input GpuBuffer backed OpenGL texture +// is not in use before the calculator starts processing and not used by any +// other code until the calculator returns: +// - pixel transfer involves attaching GpuBuffer backing texture as a logical +// buffer to a particular bound framebuffer. +// - and if texture is already bound and enabled for texturing, this may lead +// to a "feedback loop" and undefined results. +// See, OpenGL ES 3.0 Spec 4.4.3 "Feedback Loops between Textures and the +// Framebuffer" +// class GpuBufferToImageFrameCalculator : public CalculatorBase { public: GpuBufferToImageFrameCalculator() {} diff --git a/mediapipe/gpu/gpu_origin.proto b/mediapipe/gpu/gpu_origin.proto index f4db83537..9d4ae2aa1 100644 --- a/mediapipe/gpu/gpu_origin.proto +++ b/mediapipe/gpu/gpu_origin.proto @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -syntax = "proto2"; +syntax = "proto3"; package mediapipe; diff --git a/mediapipe/gpu/gpu_service.h b/mediapipe/gpu/gpu_service.h index 65fecd0b8..dd3bd3bf5 100644 --- a/mediapipe/gpu/gpu_service.h +++ b/mediapipe/gpu/gpu_service.h @@ -15,6 +15,7 @@ #ifndef MEDIAPIPE_GPU_GPU_SERVICE_H_ #define MEDIAPIPE_GPU_GPU_SERVICE_H_ +#include "absl/base/attributes.h" #include "mediapipe/framework/graph_service.h" #if !MEDIAPIPE_DISABLE_GPU @@ -29,7 +30,7 @@ class GpuResources { }; #endif // MEDIAPIPE_DISABLE_GPU -extern const GraphService kGpuService; +ABSL_CONST_INIT extern const GraphService kGpuService; } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index f542f0bb2..b9b9c26f0 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -14,6 +14,8 @@ #include "mediapipe/gpu/gpu_shared_data_internal.h" +#include "absl/base/attributes.h" +#include "absl/log/absl_check.h" #include "mediapipe/framework/deps/no_destructor.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/gpu/gl_context.h" @@ -116,10 +118,10 @@ GpuResources::~GpuResources() { #endif // __APPLE__ } -extern const GraphService kGpuService; +ABSL_CONST_INIT extern const GraphService kGpuService; absl::Status GpuResources::PrepareGpuNode(CalculatorNode* node) { - CHECK(node->Contract().ServiceRequests().contains(kGpuService.key)); + ABSL_CHECK(node->Contract().ServiceRequests().contains(kGpuService.key)); std::string node_id = node->GetCalculatorState().NodeName(); std::string node_type = node->GetCalculatorState().CalculatorType(); std::string context_key; diff --git a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc index 2a8331db8..8b56ee1c5 100644 --- a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc +++ b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc @@ -24,6 +24,11 @@ namespace mediapipe { // Convert ImageFrame to GpuBuffer. +// +// NOTE: all ImageFrameToGpuBufferCalculators use a common dedicated shared GL +// context thread by default, which is different from the main GL context thread +// used by the graph. (If MediaPipe uses multithreading and multiple OpenGL +// contexts.) class ImageFrameToGpuBufferCalculator : public CalculatorBase { public: ImageFrameToGpuBufferCalculator() {} @@ -71,11 +76,10 @@ absl::Status ImageFrameToGpuBufferCalculator::Process(CalculatorContext* cc) { #else const auto& input = cc->Inputs().Index(0).Get(); helper_.RunInGlContext([this, &input, &cc]() { - auto src = helper_.CreateSourceTexture(input); - auto output = src.GetFrame(); - glFlush(); + GlTexture dst = helper_.CreateDestinationTexture(input); + std::unique_ptr output = dst.GetFrame(); cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - src.Release(); + dst.Release(); }); #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER return absl::OkStatus(); diff --git a/mediapipe/gpu/shader_util.cc b/mediapipe/gpu/shader_util.cc index 5de7e24f5..f50b5e6c9 100644 --- a/mediapipe/gpu/shader_util.cc +++ b/mediapipe/gpu/shader_util.cc @@ -16,6 +16,15 @@ #include +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "mediapipe/framework/port/logging.h" #if DEBUG @@ -26,7 +35,7 @@ if (log_length > 0) { \ GLchar* log = static_cast(malloc(log_length)); \ glGet##type##InfoLog(object, log_length, &log_length, log); \ - LOG(INFO) << #type " " action " log:\n" << log; \ + ABSL_LOG(INFO) << #type " " action " log:\n" << log; \ free(log); \ } \ } while (0) @@ -41,15 +50,32 @@ if (log_length > 0) { \ GLchar* log = static_cast(malloc(log_length)); \ glGet##type##InfoLog(object, log_length, &log_length, log); \ - LOG(ERROR) << #type " " action " log:\n" << log; \ + ABSL_LOG(ERROR) << #type " " action " log:\n" << log; \ free(log); \ } \ } while (0) namespace mediapipe { +namespace { constexpr int kMaxShaderInfoLength = 1024; +std::string AddLineNumbers(const GLchar* source) { + // Use format "%ni %s", with n=1 for 1..9 lines, n=2 for 10..99 lines etc. + // Note that StrFormat needs either a constexpr format or a ParsedFormat. + std::vector lines = absl::StrSplit(source, '\n'); + std::string format = absl::StrFormat( + "%%%ii %%s", static_cast(ceilf(log10(1 + lines.size())))); + auto parsed_format = absl::ParsedFormat<'i', 's'>::New(format); + ABSL_CHECK(parsed_format); + for (int n = 0; n < lines.size(); n++) { + lines[n] = absl::StrFormat(*parsed_format, n + 1, lines[n]); + } + return absl::StrJoin(lines, "\n"); +} + +} // namespace + GLint GlhCompileShader(GLenum target, const GLchar* source, GLuint* shader, bool force_log_errors) { *shader = glCreateShader(target); @@ -70,13 +96,14 @@ GLint GlhCompileShader(GLenum target, const GLchar* source, GLuint* shader, GLint status; glGetShaderiv(*shader, GL_COMPILE_STATUS, &status); - LOG_IF(ERROR, status == GL_FALSE) << "Failed to compile shader:\n" << source; + ABSL_LOG_IF(ERROR, status == GL_FALSE) << "Failed to compile shader:\n" + << AddLineNumbers(source); if (status == GL_FALSE) { int length = 0; GLchar cmessage[kMaxShaderInfoLength]; glGetShaderInfoLog(*shader, kMaxShaderInfoLength, &length, cmessage); - LOG(ERROR) << "Error message: " << std::string(cmessage, length); + ABSL_LOG(ERROR) << "Error message: " << std::string(cmessage, length); } return status; } @@ -95,7 +122,8 @@ GLint GlhLinkProgram(GLuint program, bool force_log_errors) { GL_DEBUG_LOG(Program, program, "link"); glGetProgramiv(program, GL_LINK_STATUS, &status); - LOG_IF(ERROR, status == GL_FALSE) << "Failed to link program " << program; + ABSL_LOG_IF(ERROR, status == GL_FALSE) + << "Failed to link program " << program; return status; } @@ -108,7 +136,8 @@ GLint GlhValidateProgram(GLuint program) { GL_DEBUG_LOG(Program, program, "validate"); glGetProgramiv(program, GL_VALIDATE_STATUS, &status); - LOG_IF(ERROR, status == GL_FALSE) << "Failed to validate program " << program; + ABSL_LOG_IF(ERROR, status == GL_FALSE) + << "Failed to validate program " << program; return status; } @@ -141,6 +170,9 @@ GLint GlhCreateProgram(const GLchar* vert_src, const GLchar* frag_src, } ok = GlhLinkProgram(*program, force_log_errors); + + glDetachShader(*program, frag_shader); + glDetachShader(*program, vert_shader); } if (vert_shader) glDeleteShader(vert_shader); @@ -168,7 +200,8 @@ bool CompileShader(GLenum shader_type, const std::string& shader_source, GLint compiled; glGetShaderiv(*shader, GL_COMPILE_STATUS, &compiled); if (!compiled) { - VLOG(2) << "Unable to compile shader:\n" << shader_source; + VLOG(2) << "Unable to compile shader:\n" + << AddLineNumbers(shader_source_cstr); GL_ERROR_LOG(Shader, *shader, "compile"); glDeleteShader(*shader); *shader = 0; diff --git a/mediapipe/graphs/instant_motion_tracking/calculators/BUILD b/mediapipe/graphs/instant_motion_tracking/calculators/BUILD index 93af68c21..cdfd911d4 100644 --- a/mediapipe/graphs/instant_motion_tracking/calculators/BUILD +++ b/mediapipe/graphs/instant_motion_tracking/calculators/BUILD @@ -63,6 +63,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/graphs/object_detection_3d/calculators:model_matrix_cc_proto", "//mediapipe/modules/objectron/calculators:box", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@eigen_archive//:eigen3", diff --git a/mediapipe/graphs/instant_motion_tracking/calculators/matrices_manager_calculator.cc b/mediapipe/graphs/instant_motion_tracking/calculators/matrices_manager_calculator.cc index c003135bd..a73589a8c 100644 --- a/mediapipe/graphs/instant_motion_tracking/calculators/matrices_manager_calculator.cc +++ b/mediapipe/graphs/instant_motion_tracking/calculators/matrices_manager_calculator.cc @@ -18,6 +18,7 @@ #include "Eigen/Core" #include "Eigen/Dense" #include "Eigen/Geometry" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -116,8 +117,8 @@ class MatricesManagerCalculator : public CalculatorBase { return user_scaling.scale_factor; } } - LOG(WARNING) << "Cannot find sticker_id: " << sticker_id - << ", returning 1.0f scaling"; + ABSL_LOG(WARNING) << "Cannot find sticker_id: " << sticker_id + << ", returning 1.0f scaling"; return 1.0f; } @@ -129,8 +130,8 @@ class MatricesManagerCalculator : public CalculatorBase { return rotation.rotation_radians; } } - LOG(WARNING) << "Cannot find sticker_id: " << sticker_id - << ", returning 0.0f rotation"; + ABSL_LOG(WARNING) << "Cannot find sticker_id: " << sticker_id + << ", returning 0.0f rotation"; return 0.0f; } }; @@ -221,8 +222,9 @@ absl::Status MatricesManagerCalculator::Process(CalculatorContext* cc) { model_matrix = asset_matrices_gif->add_model_matrix(); } else { // Asset 3D if (render_data[render_idx] != 1) { - LOG(ERROR) << "render id: " << render_data[render_idx] - << " is not supported. Fall back to using render_id = 1."; + ABSL_LOG(ERROR) + << "render id: " << render_data[render_idx] + << " is not supported. Fall back to using render_id = 1."; } model_matrix = asset_matrices_1->add_model_matrix(); } @@ -379,8 +381,8 @@ DiagonalMatrix3f MatricesManagerCalculator::GetDefaultRenderScaleDiagonal( break; } default: { - LOG(INFO) << "Unsupported render_id: " << render_id - << ", returning default render_scale"; + ABSL_LOG(INFO) << "Unsupported render_id: " << render_id + << ", returning default render_scale"; break; } } diff --git a/mediapipe/graphs/object_detection_3d/calculators/BUILD b/mediapipe/graphs/object_detection_3d/calculators/BUILD index d4c5c496b..c491baf28 100644 --- a/mediapipe/graphs/object_detection_3d/calculators/BUILD +++ b/mediapipe/graphs/object_detection_3d/calculators/BUILD @@ -74,6 +74,8 @@ cc_library( "//mediapipe/gpu:shader_util", "//mediapipe/modules/objectron/calculators:camera_parameters_cc_proto", "//mediapipe/util/android:asset_manager_util", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", ], alwayslink = 1, ) diff --git a/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc b/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc index a92020ff0..5dee74a25 100644 --- a/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc +++ b/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc @@ -19,6 +19,8 @@ #include #endif +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" @@ -35,7 +37,7 @@ namespace { #if defined(GL_DEBUG) #define GLCHECK(command) \ command; \ - if (int err = glGetError()) LOG(ERROR) << "GL error detected: " << err; + if (int err = glGetError()) ABSL_LOG(ERROR) << "GL error detected: " << err; #else #define GLCHECK(command) command #endif @@ -355,12 +357,13 @@ bool GlAnimationOverlayCalculator::ReadBytesFromAsset(AAsset *asset, } // At least log any I/O errors encountered. if (bytes_read < 0) { - LOG(ERROR) << "Error reading from AAsset: " << bytes_read; + ABSL_LOG(ERROR) << "Error reading from AAsset: " << bytes_read; return false; } if (bytes_left > 0) { // Reached EOF before reading in specified number of bytes. - LOG(WARNING) << "Reached EOF before reading in specified number of bytes."; + ABSL_LOG(WARNING) + << "Reached EOF before reading in specified number of bytes."; return false; } return true; @@ -374,7 +377,7 @@ bool GlAnimationOverlayCalculator::LoadAnimationAndroid( Singleton::get(); AAssetManager *asset_manager = mediapipe_asset_manager->GetAssetManager(); if (!asset_manager) { - LOG(ERROR) << "Failed to access Android asset manager."; + ABSL_LOG(ERROR) << "Failed to access Android asset manager."; return false; } @@ -382,7 +385,7 @@ bool GlAnimationOverlayCalculator::LoadAnimationAndroid( AAsset *asset = AAssetManager_open(asset_manager, filename.c_str(), AASSET_MODE_STREAMING); if (!asset) { - LOG(ERROR) << "Failed to open animation asset: " << filename; + ABSL_LOG(ERROR) << "Failed to open animation asset: " << filename; return false; } @@ -400,14 +403,14 @@ bool GlAnimationOverlayCalculator::LoadAnimationAndroid( triangle_mesh.vertices.reset(new float[lengths[0]]); if (!ReadBytesFromAsset(asset, (void *)triangle_mesh.vertices.get(), sizeof(float) * lengths[0])) { - LOG(ERROR) << "Failed to read vertices for frame " << frame_count_; + ABSL_LOG(ERROR) << "Failed to read vertices for frame " << frame_count_; return false; } // Try to read in texture coordinates (4-byte floats) triangle_mesh.texture_coords.reset(new float[lengths[1]]); if (!ReadBytesFromAsset(asset, (void *)triangle_mesh.texture_coords.get(), sizeof(float) * lengths[1])) { - LOG(ERROR) << "Failed to read tex-coords for frame " << frame_count_; + ABSL_LOG(ERROR) << "Failed to read tex-coords for frame " << frame_count_; return false; } // Try to read in indices (2-byte shorts) @@ -415,7 +418,7 @@ bool GlAnimationOverlayCalculator::LoadAnimationAndroid( triangle_mesh.triangle_indices.reset(new int16[lengths[2]]); if (!ReadBytesFromAsset(asset, (void *)triangle_mesh.triangle_indices.get(), sizeof(int16) * lengths[2])) { - LOG(ERROR) << "Failed to read indices for frame " << frame_count_; + ABSL_LOG(ERROR) << "Failed to read indices for frame " << frame_count_; return false; } @@ -426,9 +429,10 @@ bool GlAnimationOverlayCalculator::LoadAnimationAndroid( } AAsset_close(asset); - LOG(INFO) << "Finished parsing " << frame_count_ << " animation frames."; + ABSL_LOG(INFO) << "Finished parsing " << frame_count_ << " animation frames."; if (meshes->empty()) { - LOG(ERROR) << "No animation frames were parsed! Erroring out calculator."; + ABSL_LOG(ERROR) + << "No animation frames were parsed! Erroring out calculator."; return false; } return true; @@ -439,7 +443,7 @@ bool GlAnimationOverlayCalculator::LoadAnimationAndroid( bool GlAnimationOverlayCalculator::LoadAnimation(const std::string &filename) { std::ifstream infile(filename.c_str(), std::ifstream::binary); if (!infile) { - LOG(ERROR) << "Error opening asset with filename: " << filename; + ABSL_LOG(ERROR) << "Error opening asset with filename: " << filename; return false; } @@ -462,7 +466,7 @@ bool GlAnimationOverlayCalculator::LoadAnimation(const std::string &filename) { infile.read((char *)(triangle_mesh.vertices.get()), sizeof(float) * lengths[0]); if (!infile) { - LOG(ERROR) << "Failed to read vertices for frame " << frame_count_; + ABSL_LOG(ERROR) << "Failed to read vertices for frame " << frame_count_; return false; } @@ -471,8 +475,8 @@ bool GlAnimationOverlayCalculator::LoadAnimation(const std::string &filename) { infile.read((char *)(triangle_mesh.texture_coords.get()), sizeof(float) * lengths[1]); if (!infile) { - LOG(ERROR) << "Failed to read texture coordinates for frame " - << frame_count_; + ABSL_LOG(ERROR) << "Failed to read texture coordinates for frame " + << frame_count_; return false; } @@ -482,8 +486,8 @@ bool GlAnimationOverlayCalculator::LoadAnimation(const std::string &filename) { infile.read((char *)(triangle_mesh.triangle_indices.get()), sizeof(int16_t) * lengths[2]); if (!infile) { - LOG(ERROR) << "Failed to read triangle indices for frame " - << frame_count_; + ABSL_LOG(ERROR) << "Failed to read triangle indices for frame " + << frame_count_; return false; } @@ -493,9 +497,10 @@ bool GlAnimationOverlayCalculator::LoadAnimation(const std::string &filename) { frame_count_++; } - LOG(INFO) << "Finished parsing " << frame_count_ << " animation frames."; + ABSL_LOG(INFO) << "Finished parsing " << frame_count_ << " animation frames."; if (triangle_meshes_.empty()) { - LOG(ERROR) << "No animation frames were parsed! Erroring out calculator."; + ABSL_LOG(ERROR) + << "No animation frames were parsed! Erroring out calculator."; return false; } return true; @@ -506,8 +511,8 @@ bool GlAnimationOverlayCalculator::LoadAnimation(const std::string &filename) { void GlAnimationOverlayCalculator::ComputeAspectRatioAndFovFromCameraParameters( const CameraParametersProto &camera_parameters, float *aspect_ratio, float *vertical_fov_degrees) { - CHECK(aspect_ratio != nullptr); - CHECK(vertical_fov_degrees != nullptr); + ABSL_CHECK(aspect_ratio != nullptr); + ABSL_CHECK(vertical_fov_degrees != nullptr); *aspect_ratio = camera_parameters.portrait_width() / camera_parameters.portrait_height(); *vertical_fov_degrees = @@ -560,7 +565,7 @@ absl::Status GlAnimationOverlayCalculator::Open(CalculatorContext *cc) { cc->InputSidePackets().Tag("MASK_ASSET").Get(); loaded_animation = LoadAnimationAndroid(mask_asset_name, &mask_meshes_); if (!loaded_animation) { - LOG(ERROR) << "Failed to load mask asset."; + ABSL_LOG(ERROR) << "Failed to load mask asset."; return absl::UnknownError("Failed to load mask asset."); } } @@ -569,7 +574,7 @@ absl::Status GlAnimationOverlayCalculator::Open(CalculatorContext *cc) { loaded_animation = LoadAnimation(asset_name); #endif if (!loaded_animation) { - LOG(ERROR) << "Failed to load animation asset."; + ABSL_LOG(ERROR) << "Failed to load animation asset."; return absl::UnknownError("Failed to load animation asset."); } @@ -608,7 +613,7 @@ void GlAnimationOverlayCalculator::LoadModelMatrices( current_model_matrices->clear(); for (int i = 0; i < model_matrices.model_matrix_size(); ++i) { const auto &model_matrix = model_matrices.model_matrix(i); - CHECK(model_matrix.matrix_entries_size() == kNumMatrixEntries) + ABSL_CHECK(model_matrix.matrix_entries_size() == kNumMatrixEntries) << "Invalid Model Matrix"; current_model_matrices->emplace_back(); ModelMatrix &new_matrix = current_model_matrices->back(); @@ -669,8 +674,8 @@ absl::Status GlAnimationOverlayCalculator::Process(CalculatorContext *cc) { height = input_frame->height(); dst = helper_.CreateSourceTexture(*input_frame); } else { - LOG(ERROR) << "Unable to consume input video frame for overlay!"; - LOG(ERROR) << "Status returned was: " << result.status(); + ABSL_LOG(ERROR) << "Unable to consume input video frame for overlay!"; + ABSL_LOG(ERROR) << "Status returned was: " << result.status(); dst = helper_.CreateDestinationTexture(width, height); } } else if (!has_video_stream_) { @@ -699,7 +704,7 @@ absl::Status GlAnimationOverlayCalculator::Process(CalculatorContext *cc) { GL_RENDERBUFFER, renderbuffer_)); GLenum status = GLCHECK(glCheckFramebufferStatus(GL_FRAMEBUFFER)); if (status != GL_FRAMEBUFFER_COMPLETE) { - LOG(ERROR) << "Incomplete framebuffer with status: " << status; + ABSL_LOG(ERROR) << "Incomplete framebuffer with status: " << status; } GLCHECK(glClear(GL_DEPTH_BUFFER_BIT)); diff --git a/mediapipe/java/com/google/mediapipe/components/AudioDataProducer.java b/mediapipe/java/com/google/mediapipe/components/AudioDataProducer.java index 4f18f4706..5d042562e 100644 --- a/mediapipe/java/com/google/mediapipe/components/AudioDataProducer.java +++ b/mediapipe/java/com/google/mediapipe/components/AudioDataProducer.java @@ -14,8 +14,10 @@ package com.google.mediapipe.components; +import javax.annotation.Nullable; + /** Lightweight abstraction for an object that can produce audio data. */ public interface AudioDataProducer { /** Set the consumer that receives the audio data from this producer. */ - void setAudioConsumer(AudioDataConsumer consumer); + void setAudioConsumer(@Nullable AudioDataConsumer consumer); } diff --git a/mediapipe/java/com/google/mediapipe/components/BUILD b/mediapipe/java/com/google/mediapipe/components/BUILD index a1ec17548..630bc94c3 100644 --- a/mediapipe/java/com/google/mediapipe/components/BUILD +++ b/mediapipe/java/com/google/mediapipe/components/BUILD @@ -71,7 +71,10 @@ android_library( "AudioDataProducer.java", ], visibility = ["//visibility:public"], - deps = ["@maven//:com_google_guava_guava"], + deps = [ + "@maven//:com_google_code_findbugs_jsr305", + "@maven//:com_google_guava_guava", + ], ) # MicrophoneHelper that provides access to audio data from a microphone diff --git a/mediapipe/java/com/google/mediapipe/components/GlSurfaceViewRenderer.java b/mediapipe/java/com/google/mediapipe/components/GlSurfaceViewRenderer.java index 7732ed17d..591b6c987 100644 --- a/mediapipe/java/com/google/mediapipe/components/GlSurfaceViewRenderer.java +++ b/mediapipe/java/com/google/mediapipe/components/GlSurfaceViewRenderer.java @@ -34,6 +34,7 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import javax.annotation.Nullable; import javax.microedition.khronos.egl.EGLConfig; import javax.microedition.khronos.opengles.GL10; @@ -231,7 +232,7 @@ public class GlSurfaceViewRenderer implements GLSurfaceView.Renderer { } /** Returns the texture left, right, bottom, and top visible boundaries. */ - protected float[] calculateTextureBoundary() { + public float[] calculateTextureBoundary() { // TODO: compute scale from surfaceTexture size. float scaleWidth = frameWidth > 0 ? (float) surfaceWidth / (float) frameWidth : 1.0f; float scaleHeight = frameHeight > 0 ? (float) surfaceHeight / (float) frameHeight : 1.0f; @@ -303,7 +304,7 @@ public class GlSurfaceViewRenderer implements GLSurfaceView.Renderer { } // Use this when the texture is not a SurfaceTexture. - public void setNextFrame(TextureFrame frame) { + public void setNextFrame(@Nullable TextureFrame frame) { if (surfaceTexture != null) { Matrix.setIdentityM(textureTransformMatrix, 0 /* offset */); } diff --git a/mediapipe/java/com/google/mediapipe/framework/AppTextureFrame.java b/mediapipe/java/com/google/mediapipe/framework/AppTextureFrame.java index 20c63c069..242cd616a 100644 --- a/mediapipe/java/com/google/mediapipe/framework/AppTextureFrame.java +++ b/mediapipe/java/com/google/mediapipe/framework/AppTextureFrame.java @@ -78,17 +78,21 @@ public class AppTextureFrame implements TextureFrame { * Use {@link waitUntilReleasedWithGpuSync} whenever possible. */ public void waitUntilReleased() throws InterruptedException { + GlSyncToken tokenToRelease = null; synchronized (this) { while (inUse && releaseSyncToken == null) { wait(); } if (releaseSyncToken != null) { - releaseSyncToken.waitOnCpu(); - releaseSyncToken.release(); + tokenToRelease = releaseSyncToken; inUse = false; releaseSyncToken = null; } } + if (tokenToRelease != null) { + tokenToRelease.waitOnCpu(); + tokenToRelease.release(); + } } /** @@ -98,17 +102,21 @@ public class AppTextureFrame implements TextureFrame { * TextureFrame. */ public void waitUntilReleasedWithGpuSync() throws InterruptedException { + GlSyncToken tokenToRelease = null; synchronized (this) { while (inUse && releaseSyncToken == null) { wait(); } if (releaseSyncToken != null) { - releaseSyncToken.waitOnGpu(); - releaseSyncToken.release(); + tokenToRelease = releaseSyncToken; inUse = false; releaseSyncToken = null; } } + if (tokenToRelease != null) { + tokenToRelease.waitOnGpu(); + tokenToRelease.release(); + } } /** diff --git a/mediapipe/java/com/google/mediapipe/framework/BUILD b/mediapipe/java/com/google/mediapipe/framework/BUILD index dd5f8f1da..78ae61d06 100644 --- a/mediapipe/java/com/google/mediapipe/framework/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/BUILD @@ -50,7 +50,6 @@ android_library( "MediaPipeRunner.java", ], visibility = [ - "//java/com/google/android/libraries/camera/effects:__subpackages__", "//mediapipe/java/com/google/mediapipe:__subpackages__", ], exports = [ diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java b/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java index 04265cab5..e71749d09 100644 --- a/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java +++ b/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java @@ -237,6 +237,10 @@ public class PacketCreator { return Packet.create(nativeCreateInt32Array(mediapipeGraph.getNativeHandle(), data)); } + public Packet createInt32Pair(int first, int second) { + return Packet.create(nativeCreateInt32Pair(mediapipeGraph.getNativeHandle(), first, second)); + } + public Packet createFloat32Array(float[] data) { return Packet.create(nativeCreateFloat32Array(mediapipeGraph.getNativeHandle(), data)); } @@ -449,6 +453,8 @@ public class PacketCreator { private native long nativeCreateInt32Array(long context, int[] data); + private native long nativeCreateInt32Pair(long context, int first, int second); + private native long nativeCreateFloat32Array(long context, float[] data); private native long nativeCreateFloat32Vector(long context, float[] data); diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java index 1c1daadcc..5ea12872a 100644 --- a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java +++ b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java @@ -239,7 +239,7 @@ public final class PacketGetter { /** * Assign the native image buffer array in given ByteBuffer array. It assumes given ByteBuffer - * array has the the same size of image list packet, and assumes the output buffer stores pixels + * array has the same size of image list packet, and assumes the output buffer stores pixels * contiguously. It returns false if this assumption does not hold. * *

If deepCopy is true, it assumes the given buffersArray has allocated the required size of diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BUILD b/mediapipe/java/com/google/mediapipe/framework/image/BUILD index d9508c1f7..a34e97954 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/image/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BitmapExtractor.java b/mediapipe/java/com/google/mediapipe/framework/image/BitmapExtractor.java index d6f50bf30..0211c808d 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/BitmapExtractor.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/BitmapExtractor.java @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageBuilder.java b/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageBuilder.java index 988cdf542..0b11c6de0 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageBuilder.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageBuilder.java @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageContainer.java b/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageContainer.java index 6fbcac214..264668575 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageContainer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageContainer.java @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java index 68c53b0c4..242404ad0 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageBuilder.java b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageBuilder.java index a650e4c33..a8bd90d2a 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageBuilder.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageBuilder.java @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageContainer.java b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageContainer.java index 82dbe32ca..a631a93de 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageContainer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageContainer.java @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MPImage.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImage.java index 946beae37..4622189a6 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/MPImage.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImage.java @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MPImageConsumer.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImageConsumer.java index f9f343e93..eb9f3ecb9 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/MPImageConsumer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImageConsumer.java @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MPImageContainer.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImageContainer.java index 674073b5b..7002b6f80 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/MPImageContainer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImageContainer.java @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MPImageProducer.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImageProducer.java index 9783935d4..48b8c33c9 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/MPImageProducer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImageProducer.java @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MPImageProperties.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImageProperties.java index 6005ce77b..dff6481e9 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/MPImageProperties.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImageProperties.java @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageBuilder.java b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageBuilder.java index 9e719715d..af3372fea 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageBuilder.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageBuilder.java @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageContainer.java b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageContainer.java index 864c76df2..d9b85af70 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageContainer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageContainer.java @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageExtractor.java b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageExtractor.java index 76bb5a5ec..5fca757c5 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageExtractor.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageExtractor.java @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD index 778790b1c..0a985f87c 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD @@ -95,13 +95,14 @@ cc_library( "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/port:core_proto", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:singleton", "//mediapipe/framework/port:status", "//mediapipe/framework/port:threadpool", "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", "//mediapipe/framework/tool:executor_util", "//mediapipe/framework/tool:name_util", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", @@ -138,8 +139,8 @@ cc_library( hdrs = ["jni_util.h"], deps = [ ":class_registry", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/synchronization", ] + select({ "//conditions:default": [ @@ -173,8 +174,8 @@ cc_library( ":class_registry", ":loose_headers", ":mediapipe_framework_jni", - "//mediapipe/framework/port:logging", "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ] + select({ diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/android_packet_creator_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/android_packet_creator_jni.cc index cda84ac16..a40112b2a 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/android_packet_creator_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/android_packet_creator_jni.cc @@ -19,11 +19,11 @@ #include #include +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/formats/image_frame.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/graph.h" @@ -49,26 +49,26 @@ std::unique_ptr CreateImageFrameFromBitmap( void* pixel_addr = nullptr; int result = AndroidBitmap_lockPixels(env, bitmap, &pixel_addr); if (result != ANDROID_BITMAP_RESULT_SUCCESS) { - LOG(ERROR) << "AndroidBitmap_lockPixels() failed with result code " - << result; + ABSL_LOG(ERROR) << "AndroidBitmap_lockPixels() failed with result code " + << result; return nullptr; } if (format == mediapipe::ImageFormat::SRGBA) { const int64_t buffer_size = stride * height; if (buffer_size != image_frame->PixelDataSize()) { - LOG(ERROR) << "Bitmap stride: " << stride - << " times bitmap height: " << height - << " is not equal to the expected size: " - << image_frame->PixelDataSize(); + ABSL_LOG(ERROR) << "Bitmap stride: " << stride + << " times bitmap height: " << height + << " is not equal to the expected size: " + << image_frame->PixelDataSize(); return nullptr; } std::memcpy(image_frame->MutablePixelData(), pixel_addr, image_frame->PixelDataSize()); } else if (format == mediapipe::ImageFormat::SRGB) { if (stride != width * 4) { - LOG(ERROR) << "Bitmap stride: " << stride - << "is not equal to 4 times bitmap width: " << width; + ABSL_LOG(ERROR) << "Bitmap stride: " << stride + << "is not equal to 4 times bitmap width: " << width; return nullptr; } const uint8_t* rgba_data = static_cast(pixel_addr); @@ -76,14 +76,14 @@ std::unique_ptr CreateImageFrameFromBitmap( image_frame->MutablePixelData(), image_frame->WidthStep()); } else { - LOG(ERROR) << "unsupported image format: " << format; + ABSL_LOG(ERROR) << "unsupported image format: " << format; return nullptr; } result = AndroidBitmap_unlockPixels(env, bitmap); if (result != ANDROID_BITMAP_RESULT_SUCCESS) { - LOG(ERROR) << "AndroidBitmap_unlockPixels() failed with result code " - << result; + ABSL_LOG(ERROR) << "AndroidBitmap_unlockPixels() failed with result code " + << result; return nullptr; } @@ -98,7 +98,8 @@ JNIEXPORT jlong JNICALL ANDROID_PACKET_CREATOR_METHOD( AndroidBitmapInfo info; int result = AndroidBitmap_getInfo(env, bitmap, &info); if (result != ANDROID_BITMAP_RESULT_SUCCESS) { - LOG(ERROR) << "AndroidBitmap_getInfo() failed with result code " << result; + ABSL_LOG(ERROR) << "AndroidBitmap_getInfo() failed with result code " + << result; return 0L; } @@ -117,7 +118,8 @@ JNIEXPORT jlong JNICALL ANDROID_PACKET_CREATOR_METHOD( AndroidBitmapInfo info; int result = AndroidBitmap_getInfo(env, bitmap, &info); if (result != ANDROID_BITMAP_RESULT_SUCCESS) { - LOG(ERROR) << "AndroidBitmap_getInfo() failed with result code " << result; + ABSL_LOG(ERROR) << "AndroidBitmap_getInfo() failed with result code " + << result; return 0L; } @@ -135,7 +137,8 @@ JNIEXPORT jlong JNICALL ANDROID_PACKET_CREATOR_METHOD(nativeCreateRgbaImage)( AndroidBitmapInfo info; int result = AndroidBitmap_getInfo(env, bitmap, &info); if (result != ANDROID_BITMAP_RESULT_SUCCESS) { - LOG(ERROR) << "AndroidBitmap_getInfo() failed with result code " << result; + ABSL_LOG(ERROR) << "AndroidBitmap_getInfo() failed with result code " + << result; return 0L; } diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc index d565187d9..f129b1a7c 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc @@ -18,6 +18,7 @@ #include +#include "absl/log/absl_log.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" @@ -75,7 +76,7 @@ class CallbackHandler { // The jobject global reference is managed by the Graph directly. // So no-op here. if (java_callback_) { - LOG(ERROR) << "Java callback global reference is not released."; + ABSL_LOG(ERROR) << "Java callback global reference is not released."; } } @@ -135,7 +136,8 @@ Graph::~Graph() { // Cleans up the jni objects. JNIEnv* env = mediapipe::java::GetJNIEnv(); if (env == nullptr) { - LOG(ERROR) << "Can't attach to java thread, no jni clean up performed."; + ABSL_LOG(ERROR) + << "Can't attach to java thread, no jni clean up performed."; return; } for (const auto& handler : callback_handlers_) { @@ -219,12 +221,12 @@ absl::Status Graph::AddMultiStreamCallbackHandler( int64_t Graph::AddSurfaceOutput(const std::string& output_stream_name) { if (!graph_config()) { - LOG(ERROR) << "Graph is not loaded!"; + ABSL_LOG(ERROR) << "Graph is not loaded!"; return 0; } #if MEDIAPIPE_DISABLE_GPU - LOG(FATAL) << "GPU support has been disabled in this build!"; + ABSL_LOG(FATAL) << "GPU support has been disabled in this build!"; #else CalculatorGraphConfig::Node* sink_node = graph_config()->add_node(); sink_node->set_name(mediapipe::tool::GetUnusedNodeName( @@ -291,7 +293,7 @@ CalculatorGraphConfig Graph::GetCalculatorGraphConfig() { CalculatorGraph temp_graph; absl::Status status = InitializeGraph(&temp_graph); if (!status.ok()) { - LOG(ERROR) << "GetCalculatorGraphConfig failed:\n" << status.message(); + ABSL_LOG(ERROR) << "GetCalculatorGraphConfig failed:\n" << status.message(); } return temp_graph.Config(); } @@ -416,13 +418,13 @@ absl::Status Graph::RunGraphUntilClose(JNIEnv* env) { CalculatorGraph calculator_graph; absl::Status status = InitializeGraph(&calculator_graph); if (!status.ok()) { - LOG(ERROR) << status.message(); + ABSL_LOG(ERROR) << status.message(); running_graph_.reset(nullptr); return status; } // TODO: gpu & services set up! status = calculator_graph.Run(CreateCombinedSidePackets()); - LOG(INFO) << "Graph run finished."; + ABSL_LOG(INFO) << "Graph run finished."; return status; } @@ -440,9 +442,9 @@ absl::Status Graph::StartRunningGraph(JNIEnv* env) { // Set the mode for adding packets to graph input streams. running_graph_->SetGraphInputStreamAddMode(graph_input_stream_add_mode_); if (VLOG_IS_ON(2)) { - LOG(INFO) << "input packet streams:"; + ABSL_LOG(INFO) << "input packet streams:"; for (auto& name : graph_config()->input_stream()) { - LOG(INFO) << name; + ABSL_LOG(INFO) << name; } } absl::Status status; @@ -450,7 +452,7 @@ absl::Status Graph::StartRunningGraph(JNIEnv* env) { if (gpu_resources_) { status = running_graph_->SetGpuResources(gpu_resources_); if (!status.ok()) { - LOG(ERROR) << status.message(); + ABSL_LOG(ERROR) << status.message(); running_graph_.reset(nullptr); return status; } @@ -461,7 +463,7 @@ absl::Status Graph::StartRunningGraph(JNIEnv* env) { status = running_graph_->SetServicePacket(*service_packet.first, service_packet.second); if (!status.ok()) { - LOG(ERROR) << status.message(); + ABSL_LOG(ERROR) << status.message(); running_graph_.reset(nullptr); return status; } @@ -469,15 +471,15 @@ absl::Status Graph::StartRunningGraph(JNIEnv* env) { status = InitializeGraph(running_graph_.get()); if (!status.ok()) { - LOG(ERROR) << status.message(); + ABSL_LOG(ERROR) << status.message(); running_graph_.reset(nullptr); return status; } - LOG(INFO) << "Start running the graph, waiting for inputs."; + ABSL_LOG(INFO) << "Start running the graph, waiting for inputs."; status = running_graph_->StartRun(CreateCombinedSidePackets(), stream_headers_); if (!status.ok()) { - LOG(ERROR) << status; + ABSL_LOG(ERROR) << status; running_graph_.reset(nullptr); return status; } @@ -520,12 +522,12 @@ absl::Status Graph::CloseInputStream(std::string stream_name) { if (!running_graph_) { return absl::FailedPreconditionError("Graph must be running."); } - LOG(INFO) << "Close input stream: " << stream_name; + ABSL_LOG(INFO) << "Close input stream: " << stream_name; return running_graph_->CloseInputStream(stream_name); } absl::Status Graph::CloseAllInputStreams() { - LOG(INFO) << "Close all input streams."; + ABSL_LOG(INFO) << "Close all input streams."; if (!running_graph_) { return absl::FailedPreconditionError("Graph must be running."); } @@ -533,7 +535,7 @@ absl::Status Graph::CloseAllInputStreams() { } absl::Status Graph::CloseAllPacketSources() { - LOG(INFO) << "Close all input streams."; + ABSL_LOG(INFO) << "Close all input streams."; if (!running_graph_) { return absl::FailedPreconditionError("Graph must be running."); } @@ -564,7 +566,7 @@ void Graph::SetInputSidePacket(const std::string& stream_name, void Graph::SetStreamHeader(const std::string& stream_name, const Packet& packet) { stream_headers_[stream_name] = packet; - LOG(INFO) << stream_name << " stream header being set."; + ABSL_LOG(INFO) << stream_name << " stream header being set."; } void Graph::SetGraphInputStreamAddMode( @@ -580,7 +582,7 @@ mediapipe::GpuResources* Graph::GetGpuResources() const { absl::Status Graph::SetParentGlContext(int64_t java_gl_context) { #if MEDIAPIPE_DISABLE_GPU - LOG(FATAL) << "GPU support has been disabled in this build!"; + ABSL_LOG(FATAL) << "GPU support has been disabled in this build!"; #else if (gpu_resources_) { return absl::AlreadyExistsError( diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc index dd99cccd4..a658d01cc 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc @@ -14,6 +14,8 @@ #include "mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h" +#include "absl/log/absl_log.h" +#include "absl/strings/str_format.h" #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_context.h" #include "mediapipe/gpu/gl_texture_buffer.h" @@ -89,9 +91,20 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeDidRead)( JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken) { + if (!consumerSyncToken) return; + GlTextureBufferSharedPtr* buffer = reinterpret_cast(nativeHandle); mediapipe::GlSyncToken& token = *reinterpret_cast(consumerSyncToken); + // The below check attempts to detect when an invalid or already deleted + // `consumerSyncToken` is passed. (That results in undefined behavior. + // However, `DidRead` may succeed resulting in a later crash and masking the + // actual problem.) + if (token.use_count() == 0) { + ABSL_LOG_FIRST_N(ERROR, 5) + << absl::StrFormat("invalid sync token ref: %d", consumerSyncToken); + return; + } (*buffer)->DidRead(token); } diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.cc b/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.cc index 88a1366b9..6ccf8d7e9 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.cc @@ -16,8 +16,8 @@ #include +#include "absl/log/absl_log.h" #include "absl/synchronization/mutex.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/class_registry.h" namespace { @@ -38,7 +38,7 @@ class JvmThread { case JNI_OK: break; case JNI_EDETACHED: - LOG(INFO) << "GetEnv: not attached"; + ABSL_LOG(INFO) << "GetEnv: not attached"; if (jvm_->AttachCurrentThread( #ifdef __ANDROID__ &jni_env_, @@ -46,16 +46,16 @@ class JvmThread { reinterpret_cast(&jni_env_), #endif // __ANDROID__ nullptr) != 0) { - LOG(ERROR) << "Failed to attach to java thread."; + ABSL_LOG(ERROR) << "Failed to attach to java thread."; break; } attached_ = true; break; case JNI_EVERSION: - LOG(ERROR) << "GetEnv: jni version not supported."; + ABSL_LOG(ERROR) << "GetEnv: jni version not supported."; break; default: - LOG(ERROR) << "GetEnv: unknown status."; + ABSL_LOG(ERROR) << "GetEnv: unknown status."; break; } } @@ -83,7 +83,7 @@ static pthread_once_t key_once = PTHREAD_ONCE_INIT; static void ThreadExitCallback(void* key_value) { JvmThread* jvm_thread = reinterpret_cast(key_value); // Detach the thread when thread exits. - LOG(INFO) << "Exiting thread. Detach thread."; + ABSL_LOG(INFO) << "Exiting thread. Detach thread."; delete jvm_thread; } @@ -187,7 +187,7 @@ bool SetJavaVM(JNIEnv* env) { absl::MutexLock lock(&g_jvm_mutex); if (!g_jvm) { if (env->GetJavaVM(&g_jvm) != JNI_OK) { - LOG(ERROR) << "Can not get the Java VM instance!"; + ABSL_LOG(ERROR) << "Can not get the Java VM instance!"; g_jvm = nullptr; return false; } diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc index f7430e6e8..56ddd5e09 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc @@ -16,6 +16,7 @@ #include #include +#include #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -27,6 +28,7 @@ #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/time_series_header.pb.h" #include "mediapipe/framework/formats/video_stream_header.h" +#include "mediapipe/framework/packet.h" #include "mediapipe/framework/port/core_proto_inc.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h" @@ -481,6 +483,15 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateInt32Array)( return CreatePacketWithContext(context, packet); } +JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateInt32Pair)( + JNIEnv* env, jobject thiz, jlong context, jint first, jint second) { + static_assert(std::is_same::value, "jint must be int32_t"); + + mediapipe::Packet packet = mediapipe::MakePacket>( + std::make_pair(first, second)); + return CreatePacketWithContext(context, packet); +} + JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateStringFromByteArray)( JNIEnv* env, jobject thiz, jlong context, jbyteArray data) { jsize count = env->GetArrayLength(data); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h index b3b1043fb..92f48261c 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h @@ -118,6 +118,9 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloat32Vector)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateInt32Array)( JNIEnv* env, jobject thiz, jlong context, jintArray data); +JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateInt32Pair)( + JNIEnv* env, jobject thiz, jlong context, jint first, jint second); + JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateStringFromByteArray)( JNIEnv* env, jobject thiz, jlong context, jbyteArray data); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc b/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc index bef275b40..3f96a404d 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc @@ -14,8 +14,8 @@ #include "mediapipe/java/com/google/mediapipe/framework/jni/register_natives.h" +#include "absl/log/absl_log.h" #include "absl/strings/str_format.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/class_registry.h" #if defined(__ANDROID__) @@ -65,9 +65,10 @@ void RegisterNativesVector(JNIEnv *env, jclass cls, // in exchange for flexibility to list out all registrations without worrying // about usage subset by client Java projects. if (!cls || methods.empty()) { - LOG(INFO) << "Skipping registration and clearing exception. Class or " - "native methods not found, may be unused and/or trimmed by " - "Proguard."; + ABSL_LOG(INFO) + << "Skipping registration and clearing exception. Class or " + "native methods not found, may be unused and/or trimmed by " + "Proguard."; env->ExceptionClear(); return; } @@ -81,7 +82,7 @@ void RegisterNativesVector(JNIEnv *env, jclass cls, } // Fatal crash if registration fails. if (env->RegisterNatives(cls, methods_array, methods.size()) < 0) { - LOG(FATAL) + ABSL_LOG(FATAL) << "Failed during native method registration, so likely the " "signature of a method is incorrect. Make sure there are no typos " "and " diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/surface_output_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/surface_output_jni.cc index 51d693b20..2ac43e57e 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/surface_output_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/surface_output_jni.cc @@ -17,6 +17,8 @@ #include #endif // __ANDROID__ +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/gpu/egl_surface_holder.h" @@ -51,7 +53,7 @@ JNIEXPORT void JNICALL MEDIAPIPE_SURFACE_OUTPUT_METHOD(nativeSetSurface)( JNIEnv* env, jobject thiz, jlong context, jlong packet, jobject surface) { #ifdef __ANDROID__ mediapipe::GlContext* gl_context = GetGlContext(context); - CHECK(gl_context) << "GPU shared data not created"; + ABSL_CHECK(gl_context) << "GPU shared data not created"; mediapipe::EglSurfaceHolder* surface_holder = GetSurfaceHolder(packet); // ANativeWindow_fromSurface must not be called on the GL thread, it is a @@ -99,14 +101,14 @@ JNIEXPORT void JNICALL MEDIAPIPE_SURFACE_OUTPUT_METHOD(nativeSetSurface)( ANativeWindow_release(window); } #else - LOG(FATAL) << "setSurface is only supported on Android"; + ABSL_LOG(FATAL) << "setSurface is only supported on Android"; #endif // __ANDROID__ } JNIEXPORT void JNICALL MEDIAPIPE_SURFACE_OUTPUT_METHOD(nativeSetEglSurface)( JNIEnv* env, jobject thiz, jlong context, jlong packet, jlong surface) { mediapipe::GlContext* gl_context = GetGlContext(context); - CHECK(gl_context) << "GPU shared data not created"; + ABSL_CHECK(gl_context) << "GPU shared data not created"; auto egl_surface = reinterpret_cast(surface); mediapipe::EglSurfaceHolder* surface_holder = GetSurfaceHolder(packet); EGLSurface old_surface = EGL_NO_SURFACE; diff --git a/mediapipe/java/com/google/mediapipe/glutil/ExternalTextureRenderer.java b/mediapipe/java/com/google/mediapipe/glutil/ExternalTextureRenderer.java index 4dd35f865..381864484 100644 --- a/mediapipe/java/com/google/mediapipe/glutil/ExternalTextureRenderer.java +++ b/mediapipe/java/com/google/mediapipe/glutil/ExternalTextureRenderer.java @@ -67,6 +67,7 @@ public class ExternalTextureRenderer { private float[] textureTransformMatrix = new float[16]; private boolean flipY; private int rotation = Surface.ROTATION_0; + private boolean doExplicitCpuSync = true; /** Call this to setup the shader program before rendering. */ public void setup() { @@ -101,6 +102,14 @@ public class ExternalTextureRenderer { this.rotation = rotation; } + /** + * Configures whether the renderer should do an explicit CPU synchronization using glFinish upon + * each {@link #render} call. Defaults to true. + */ + public void setDoExplicitCpuSync(boolean doExplicitCpuSync) { + this.doExplicitCpuSync = doExplicitCpuSync; + } + /** * Renders the surfaceTexture to the framebuffer with optional vertical flip. * @@ -150,8 +159,11 @@ public class ExternalTextureRenderer { GLES20.glBindTexture(GLES11Ext.GL_TEXTURE_EXTERNAL_OES, 0); ShaderUtil.checkGlError("glBindTexture"); - // TODO: add sync and go back to glFlush() - GLES20.glFinish(); + if (doExplicitCpuSync) { + + // TODO: add sync and go back to glFlush() + GLES20.glFinish(); + } } /** diff --git a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl index 03fc41757..8817f2835 100644 --- a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl +++ b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl @@ -1,4 +1,4 @@ -# Copyright 2019-2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2019-2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -197,7 +197,6 @@ def _mediapipe_jni(name, gen_libmediapipe, calculators = []): name = name + "_opencv_cc_lib", srcs = select({ "//mediapipe:android_arm64": ["@android_opencv//:libopencv_java3_so_arm64-v8a"], - "//mediapipe:android_armeabi": ["@android_opencv//:libopencv_java3_so_armeabi-v7a"], "//mediapipe:android_arm": ["@android_opencv//:libopencv_java3_so_armeabi-v7a"], "//mediapipe:android_x86": ["@android_opencv//:libopencv_java3_so_x86"], "//mediapipe:android_x86_64": ["@android_opencv//:libopencv_java3_so_x86_64"], diff --git a/mediapipe/model_maker/BUILD b/mediapipe/model_maker/BUILD index cb312072f..e3995e134 100644 --- a/mediapipe/model_maker/BUILD +++ b/mediapipe/model_maker/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/__init__.py b/mediapipe/model_maker/__init__.py index 533edebf7..8c87c12df 100644 --- a/mediapipe/model_maker/__init__.py +++ b/mediapipe/model_maker/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,11 +13,15 @@ # limitations under the License. +from mediapipe.model_maker.python.vision.core import image_utils from mediapipe.model_maker.python.core.utils import quantization +from mediapipe.model_maker.python.core.utils import model_util + from mediapipe.model_maker.python.vision import image_classifier from mediapipe.model_maker.python.vision import gesture_recognizer from mediapipe.model_maker.python.text import text_classifier from mediapipe.model_maker.python.vision import object_detector +from mediapipe.model_maker.python.vision import face_stylizer # Remove duplicated and non-public API del python diff --git a/mediapipe/model_maker/models/gesture_recognizer/BUILD b/mediapipe/model_maker/models/gesture_recognizer/BUILD index 5ead0e618..c57d7a2c9 100644 --- a/mediapipe/model_maker/models/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/models/gesture_recognizer/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,9 +19,7 @@ load( licenses(["notice"]) -package( - default_visibility = ["//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__"], -) +package(default_visibility = ["//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__"]) mediapipe_files( srcs = [ diff --git a/mediapipe/model_maker/models/text_classifier/BUILD b/mediapipe/model_maker/models/text_classifier/BUILD index dc6210a7d..460d6cfd1 100644 --- a/mediapipe/model_maker/models/text_classifier/BUILD +++ b/mediapipe/model_maker/models/text_classifier/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,9 +19,7 @@ load( licenses(["notice"]) -package( - default_visibility = ["//mediapipe/model_maker/python/text/text_classifier:__subpackages__"], -) +package(default_visibility = ["//mediapipe/model_maker/python/text/text_classifier:__subpackages__"]) mediapipe_files( srcs = [ diff --git a/mediapipe/model_maker/python/BUILD b/mediapipe/model_maker/python/BUILD index fe101f293..42681fadb 100644 --- a/mediapipe/model_maker/python/BUILD +++ b/mediapipe/model_maker/python/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ package_group( package_group( name = "1p_client", packages = [ + "//cloud/ml/applications/vision/model_garden/model_oss/mediapipe/...", "//research/privacy/learning/fl_eval/pcvr/...", ], ) diff --git a/mediapipe/model_maker/python/__init__.py b/mediapipe/model_maker/python/__init__.py index 7ca2f9216..5b1a4244c 100644 --- a/mediapipe/model_maker/python/__init__.py +++ b/mediapipe/model_maker/python/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/core/BUILD b/mediapipe/model_maker/python/core/BUILD index 636a1a720..a73e545d3 100644 --- a/mediapipe/model_maker/python/core/BUILD +++ b/mediapipe/model_maker/python/core/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,9 +14,10 @@ # Placeholder for internal Python strict library and test compatibility macro. -package( - default_visibility = ["//mediapipe:__subpackages__"], -) +package(default_visibility = [ + "//cloud/ml/applications/vision/model_garden/model_oss/mediapipe:__subpackages__", + "//mediapipe:__subpackages__", +]) licenses(["notice"]) diff --git a/mediapipe/model_maker/python/core/__init__.py b/mediapipe/model_maker/python/core/__init__.py index 7ca2f9216..5b1a4244c 100644 --- a/mediapipe/model_maker/python/core/__init__.py +++ b/mediapipe/model_maker/python/core/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/core/data/BUILD b/mediapipe/model_maker/python/core/data/BUILD index 70a62e8f7..4364b7744 100644 --- a/mediapipe/model_maker/python/core/data/BUILD +++ b/mediapipe/model_maker/python/core/data/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,9 +17,7 @@ licenses(["notice"]) -package( - default_visibility = ["//mediapipe:__subpackages__"], -) +package(default_visibility = ["//mediapipe:__subpackages__"]) py_library( name = "data_util", @@ -59,3 +57,14 @@ py_test( srcs = ["classification_dataset_test.py"], deps = [":classification_dataset"], ) + +py_library( + name = "cache_files", + srcs = ["cache_files.py"], +) + +py_test( + name = "cache_files_test", + srcs = ["cache_files_test.py"], + deps = [":cache_files"], +) diff --git a/mediapipe/model_maker/python/core/data/__init__.py b/mediapipe/model_maker/python/core/data/__init__.py index 7ca2f9216..5b1a4244c 100644 --- a/mediapipe/model_maker/python/core/data/__init__.py +++ b/mediapipe/model_maker/python/core/data/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/core/data/cache_files.py b/mediapipe/model_maker/python/core/data/cache_files.py new file mode 100644 index 000000000..13d3d5b61 --- /dev/null +++ b/mediapipe/model_maker/python/core/data/cache_files.py @@ -0,0 +1,112 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common TFRecord cache files library.""" + +import dataclasses +import os +import tempfile +from typing import Any, Mapping, Sequence + +import tensorflow as tf +import yaml + + +# Suffix of the meta data file name. +METADATA_FILE_SUFFIX = '_metadata.yaml' + + +@dataclasses.dataclass(frozen=True) +class TFRecordCacheFiles: + """TFRecordCacheFiles dataclass to store and load cached TFRecord files. + + Attributes: + cache_prefix_filename: The cache prefix filename. This is usually provided + as a hash of the original data source to avoid different data sources + resulting in the same cache file. + cache_dir: The cache directory to save TFRecord and metadata file. When + cache_dir is None, a temporary folder will be created and will not be + removed automatically after training which makes it can be used later. + num_shards: Number of shards for output tfrecord files. + """ + + cache_prefix_filename: str = 'cache_prefix' + cache_dir: str = dataclasses.field(default_factory=tempfile.mkdtemp) + num_shards: int = 1 + + def __post_init__(self): + if not tf.io.gfile.exists(self.cache_dir): + tf.io.gfile.makedirs(self.cache_dir) + if not self.cache_prefix_filename: + raise ValueError('cache_prefix_filename cannot be empty.') + if self.num_shards <= 0: + raise ValueError( + f'num_shards must be greater than 0, got {self.num_shards}' + ) + + @property + def cache_prefix(self) -> str: + """The cache prefix including the cache directory and the cache prefix filename.""" + return os.path.join(self.cache_dir, self.cache_prefix_filename) + + @property + def tfrecord_files(self) -> Sequence[str]: + """The TFRecord files.""" + tfrecord_files = [ + self.cache_prefix + '-%05d-of-%05d.tfrecord' % (i, self.num_shards) + for i in range(self.num_shards) + ] + return tfrecord_files + + @property + def metadata_file(self) -> str: + """The metadata file.""" + return self.cache_prefix + METADATA_FILE_SUFFIX + + def get_writers(self) -> Sequence[tf.io.TFRecordWriter]: + """Gets an array of TFRecordWriter objects. + + Note that these writers should each be closed using .close() when done. + + Returns: + Array of TFRecordWriter objects + """ + return [tf.io.TFRecordWriter(path) for path in self.tfrecord_files] + + def save_metadata(self, metadata): + """Writes metadata to file. + + Args: + metadata: A dictionary of metadata content to write. Exact format is + dependent on the specific dataset, but typically includes a 'size' and + 'label_names' entry. + """ + with tf.io.gfile.GFile(self.metadata_file, 'w') as f: + yaml.dump(metadata, f) + + def load_metadata(self) -> Mapping[Any, Any]: + """Reads metadata from file. + + Returns: + Dictionary object containing metadata + """ + if not tf.io.gfile.exists(self.metadata_file): + return {} + with tf.io.gfile.GFile(self.metadata_file, 'r') as f: + metadata = yaml.load(f, Loader=yaml.FullLoader) + return metadata + + def is_cached(self) -> bool: + """Checks whether this CacheFiles is already cached.""" + all_cached_files = list(self.tfrecord_files) + [self.metadata_file] + return all(tf.io.gfile.exists(f) for f in all_cached_files) diff --git a/mediapipe/model_maker/python/core/data/cache_files_test.py b/mediapipe/model_maker/python/core/data/cache_files_test.py new file mode 100644 index 000000000..ac727b3fe --- /dev/null +++ b/mediapipe/model_maker/python/core/data/cache_files_test.py @@ -0,0 +1,77 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import cache_files + + +class CacheFilesTest(tf.test.TestCase): + + def test_tfrecord_cache_files(self): + cf = cache_files.TFRecordCacheFiles( + cache_prefix_filename='tfrecord', + cache_dir='/tmp/cache_dir', + num_shards=2, + ) + self.assertEqual(cf.cache_prefix, '/tmp/cache_dir/tfrecord') + self.assertEqual( + cf.metadata_file, + '/tmp/cache_dir/tfrecord' + cache_files.METADATA_FILE_SUFFIX, + ) + expected_tfrecord_files = [ + '/tmp/cache_dir/tfrecord-%05d-of-%05d.tfrecord' % (i, 2) + for i in range(2) + ] + self.assertEqual(cf.tfrecord_files, expected_tfrecord_files) + + # Writing TFRecord Files + self.assertFalse(cf.is_cached()) + for tfrecord_file in cf.tfrecord_files: + self.assertFalse(tf.io.gfile.exists(tfrecord_file)) + writers = cf.get_writers() + for writer in writers: + writer.close() + for tfrecord_file in cf.tfrecord_files: + self.assertTrue(tf.io.gfile.exists(tfrecord_file)) + self.assertFalse(cf.is_cached()) + + # Writing Metadata Files + original_metadata = {'size': 10, 'label_names': ['label1', 'label2']} + cf.save_metadata(original_metadata) + self.assertTrue(cf.is_cached()) + metadata = cf.load_metadata() + self.assertEqual(metadata, original_metadata) + + def test_recordio_cache_files_error(self): + with self.assertRaisesRegex( + ValueError, 'cache_prefix_filename cannot be empty' + ): + cache_files.TFRecordCacheFiles( + cache_prefix_filename='', + cache_dir='/tmp/cache_dir', + num_shards=2, + ) + with self.assertRaisesRegex( + ValueError, 'num_shards must be greater than 0, got 0' + ): + cache_files.TFRecordCacheFiles( + cache_prefix_filename='tfrecord', + cache_dir='/tmp/cache_dir', + num_shards=0, + ) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/core/data/classification_dataset.py b/mediapipe/model_maker/python/core/data/classification_dataset.py index 073e79638..352caca6f 100644 --- a/mediapipe/model_maker/python/core/data/classification_dataset.py +++ b/mediapipe/model_maker/python/core/data/classification_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ # limitations under the License. """Common classification dataset library.""" -from typing import List, Tuple +from typing import List, Optional, Tuple import tensorflow as tf @@ -23,8 +23,12 @@ from mediapipe.model_maker.python.core.data import dataset as ds class ClassificationDataset(ds.Dataset): """Dataset Loader for classification models.""" - def __init__(self, dataset: tf.data.Dataset, size: int, - label_names: List[str]): + def __init__( + self, + dataset: tf.data.Dataset, + label_names: List[str], + size: Optional[int] = None, + ): super().__init__(dataset, size) self._label_names = label_names diff --git a/mediapipe/model_maker/python/core/data/classification_dataset_test.py b/mediapipe/model_maker/python/core/data/classification_dataset_test.py index 82e74b04e..dfcea7da6 100644 --- a/mediapipe/model_maker/python/core/data/classification_dataset_test.py +++ b/mediapipe/model_maker/python/core/data/classification_dataset_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,9 +36,14 @@ class ClassificationDatasetTest(tf.test.TestCase): value: A value variable stored by the mock dataset class for testing. """ - def __init__(self, dataset: tf.data.Dataset, size: int, - label_names: List[str], value: Any): - super().__init__(dataset=dataset, size=size, label_names=label_names) + def __init__( + self, + dataset: tf.data.Dataset, + label_names: List[str], + value: Any, + size: int, + ): + super().__init__(dataset=dataset, label_names=label_names, size=size) self.value = value def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]: @@ -52,7 +57,8 @@ class ClassificationDatasetTest(tf.test.TestCase): # Create data loader from sample data. ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) data = MagicClassificationDataset( - dataset=ds, size=len(ds), label_names=label_names, value=magic_value) + dataset=ds, label_names=label_names, value=magic_value, size=len(ds) + ) # Train/Test data split. fraction = .25 diff --git a/mediapipe/model_maker/python/core/data/data_util.py b/mediapipe/model_maker/python/core/data/data_util.py index 8c6b9145f..88efa896c 100644 --- a/mediapipe/model_maker/python/core/data/data_util.py +++ b/mediapipe/model_maker/python/core/data/data_util.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/core/data/data_util_test.py b/mediapipe/model_maker/python/core/data/data_util_test.py index 56ac832c3..8bed8ef7c 100644 --- a/mediapipe/model_maker/python/core/data/data_util_test.py +++ b/mediapipe/model_maker/python/core/data/data_util_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/core/data/dataset.py b/mediapipe/model_maker/python/core/data/dataset.py index 3b4182c14..0cfccb149 100644 --- a/mediapipe/model_maker/python/core/data/dataset.py +++ b/mediapipe/model_maker/python/core/data/dataset.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -56,15 +56,14 @@ class Dataset(object): def size(self) -> Optional[int]: """Returns the size of the dataset. - Note that this function may return None becuase the exact size of the - dataset isn't a necessary parameter to create an instance of this class, - and tf.data.Dataset donesn't support a function to get the length directly - since it's lazy-loaded and may be infinite. - In most cases, however, when an instance of this class is created by helper - functions like 'from_folder', the size of the dataset will be preprocessed, - and this function can return an int representing the size of the dataset. + Same functionality as calling __len__. See the __len__ method definition for + more information. + + Raises: + TypeError if self._size is not set and the cardinality of self._dataset + is INFINITE_CARDINALITY or UNKNOWN_CARDINALITY. """ - return self._size + return self.__len__() def gen_tf_dataset( self, @@ -116,8 +115,22 @@ class Dataset(object): # here. return dataset - def __len__(self): - """Returns the number of element of the dataset.""" + def __len__(self) -> int: + """Returns the number of element of the dataset. + + If size is not set, this method will fallback to using the __len__ method + of the tf.data.Dataset in self._dataset. Calling __len__ on a + tf.data.Dataset instance may throw a TypeError because the dataset may + be lazy-loaded with an unknown size or have infinite size. + + In most cases, however, when an instance of this class is created by helper + functions like 'from_folder', the size of the dataset will be preprocessed, + and the _size instance variable will be already set. + + Raises: + TypeError if self._size is not set and the cardinality of self._dataset + is INFINITE_CARDINALITY or UNKNOWN_CARDINALITY. + """ if self._size is not None: return self._size else: @@ -152,15 +165,25 @@ class Dataset(object): Returns: The splitted two sub datasets. + + Raises: + ValueError: if the provided fraction is not between 0 and 1. + ValueError: if this dataset does not have a set size. """ - assert (fraction > 0 and fraction < 1) + if not (fraction > 0 and fraction < 1): + raise ValueError(f'Fraction must be between 0 and 1. Got:{fraction}') + if not self._size: + raise ValueError( + 'Dataset size unknown. Cannot split the dataset when ' + 'the size is unknown.' + ) dataset = self._dataset train_size = int(self._size * fraction) - trainset = self.__class__(dataset.take(train_size), train_size, *args) + trainset = self.__class__(dataset.take(train_size), *args, size=train_size) test_size = self._size - train_size - testset = self.__class__(dataset.skip(train_size), test_size, *args) + testset = self.__class__(dataset.skip(train_size), *args, size=test_size) return trainset, testset diff --git a/mediapipe/model_maker/python/core/data/dataset_test.py b/mediapipe/model_maker/python/core/data/dataset_test.py index 9adff127d..7a3f75388 100644 --- a/mediapipe/model_maker/python/core/data/dataset_test.py +++ b/mediapipe/model_maker/python/core/data/dataset_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/core/data/testdata/BUILD b/mediapipe/model_maker/python/core/data/testdata/BUILD index 54e562d41..b799c3cee 100644 --- a/mediapipe/model_maker/python/core/data/testdata/BUILD +++ b/mediapipe/model_maker/python/core/data/testdata/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/core/hyperparameters.py b/mediapipe/model_maker/python/core/hyperparameters.py index e6848e0de..92e1856cc 100644 --- a/mediapipe/model_maker/python/core/hyperparameters.py +++ b/mediapipe/model_maker/python/core/hyperparameters.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,8 +15,11 @@ import dataclasses import tempfile +from typing import Mapping, Optional -from typing import Optional +import tensorflow as tf + +from official.common import distribute_utils @dataclasses.dataclass @@ -33,6 +36,8 @@ class BaseHParams: steps_per_epoch: An optional integer indicate the number of training steps per epoch. If not set, the training pipeline calculates the default steps per epoch as the training dataset size divided by batch size. + class_weights: An optional mapping of indices to weights for weighting the + loss function during training. shuffle: True if the dataset is shuffled before training. export_dir: The location of the model checkpoint files. distribution_strategy: A string specifying which Distribution Strategy to @@ -43,10 +48,10 @@ class BaseHParams: documentation for more details: https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy. num_gpus: How many GPUs to use at each worker with the - DistributionStrategies API. The default is -1, which means utilize all - available GPUs. - tpu: The Cloud TPU to use for training. This should be either the name used - when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url. + DistributionStrategies API. The default is 0. + tpu: The TPU resource to be used for training. This should be either the + name used when creating the Cloud TPU, a grpc://ip.address.of.tpu:8470 + url, or an empty string if using a local TPU. """ # Parameters for train configuration @@ -54,6 +59,7 @@ class BaseHParams: batch_size: int epochs: int steps_per_epoch: Optional[int] = None + class_weights: Optional[Mapping[int, float]] = None # Dataset-related parameters shuffle: bool = False @@ -63,5 +69,16 @@ class BaseHParams: # Parameters for hardware acceleration distribution_strategy: str = 'off' - num_gpus: int = -1 # default value of -1 means use all available GPUs + num_gpus: int = 0 tpu: str = '' + _strategy: tf.distribute.Strategy = dataclasses.field(init=False) + + def __post_init__(self): + self._strategy = distribute_utils.get_distribution_strategy( + distribution_strategy=self.distribution_strategy, + num_gpus=self.num_gpus, + tpu_address=self.tpu, + ) + + def get_strategy(self): + return self._strategy diff --git a/mediapipe/model_maker/python/core/tasks/BUILD b/mediapipe/model_maker/python/core/tasks/BUILD index 8c5448556..818d78feb 100644 --- a/mediapipe/model_maker/python/core/tasks/BUILD +++ b/mediapipe/model_maker/python/core/tasks/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,9 +15,7 @@ # Placeholder for internal Python strict library and test compatibility macro. # Placeholder for internal Python strict test compatibility macro. -package( - default_visibility = ["//mediapipe:__subpackages__"], -) +package(default_visibility = ["//mediapipe:__subpackages__"]) licenses(["notice"]) diff --git a/mediapipe/model_maker/python/core/tasks/__init__.py b/mediapipe/model_maker/python/core/tasks/__init__.py index 7ca2f9216..5b1a4244c 100644 --- a/mediapipe/model_maker/python/core/tasks/__init__.py +++ b/mediapipe/model_maker/python/core/tasks/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py index bfe0f027f..d504defbe 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier.py +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,7 +43,7 @@ class Classifier(custom_model.CustomModel): self._model: tf.keras.Model = None self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None self._loss_function: Union[str, tf.keras.losses.Loss] = None - self._metric_function: Union[str, tf.keras.metrics.Metric] = None + self._metric_functions: Sequence[Union[str, tf.keras.metrics.Metric]] = None self._callbacks: Sequence[tf.keras.callbacks.Callback] = None self._hparams: hp.BaseHParams = None self._history: tf.keras.callbacks.History = None @@ -92,7 +92,8 @@ class Classifier(custom_model.CustomModel): self._model.compile( optimizer=self._optimizer, loss=self._loss_function, - metrics=[self._metric_function]) + metrics=self._metric_functions, + ) latest_checkpoint = ( tf.train.latest_checkpoint(checkpoint_path) @@ -109,7 +110,9 @@ class Classifier(custom_model.CustomModel): # dataset is exhausted even if there are epochs remaining. steps_per_epoch=None, validation_data=validation_dataset, - callbacks=self._callbacks) + callbacks=self._callbacks, + class_weight=self._hparams.class_weights, + ) def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any: """Evaluates the classifier with the provided evaluation dataset. diff --git a/mediapipe/model_maker/python/core/tasks/classifier_test.py b/mediapipe/model_maker/python/core/tasks/classifier_test.py index 6bf3b7a2e..2943825ac 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier_test.py +++ b/mediapipe/model_maker/python/core/tasks/classifier_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/core/tasks/custom_model.py b/mediapipe/model_maker/python/core/tasks/custom_model.py index 188bf62cc..55f5a6db3 100644 --- a/mediapipe/model_maker/python/core/tasks/custom_model.py +++ b/mediapipe/model_maker/python/core/tasks/custom_model.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/core/tasks/custom_model_test.py b/mediapipe/model_maker/python/core/tasks/custom_model_test.py index ad77d4ecd..afb418c44 100644 --- a/mediapipe/model_maker/python/core/tasks/custom_model_test.py +++ b/mediapipe/model_maker/python/core/tasks/custom_model_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD index 43c3d42f9..c5e031245 100644 --- a/mediapipe/model_maker/python/core/utils/BUILD +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,8 +17,13 @@ licenses(["notice"]) -package( - default_visibility = ["//mediapipe:__subpackages__"], +package(default_visibility = ["//mediapipe:__subpackages__"]) + +filegroup( + name = "testdata", + srcs = glob([ + "testdata/**", + ]), ) py_library( @@ -58,27 +63,69 @@ py_library( py_test( name = "file_util_test", srcs = ["file_util_test.py"], - data = ["//mediapipe/model_maker/python/core/utils/testdata"], + data = [":testdata"], tags = ["requires-net:external"], deps = [":file_util"], ) +py_library( + name = "hub_loader", + srcs = ["hub_loader.py"], +) + +py_test( + name = "hub_loader_test", + srcs = ["hub_loader_test.py"], + data = [":testdata"], + deps = [ + ":hub_loader", + "//mediapipe/tasks/python/test:test_utils", + ], +) + py_library( name = "loss_functions", srcs = ["loss_functions.py"], srcs_version = "PY3", + deps = [ + ":file_util", + ":model_util", + ], ) py_test( name = "loss_functions_test", srcs = ["loss_functions_test.py"], + tags = [ + "requires-net:external", + ], deps = [":loss_functions"], ) +###################################################################### +# Public target of the MediaPipe Model Maker Quantization Config. + +# Quantization Config is used to export a quantized model. Please refer +# to the specific task documentations such as: +# https://developers.google.com/mediapipe/solutions/vision/image_classifier/customize +# for usage information. +###################################################################### +py_library( + name = "metrics", + srcs = ["metrics.py"], +) + +py_test( + name = "metrics_test", + srcs = ["metrics_test.py"], + deps = [":metrics"], +) + py_library( name = "quantization", srcs = ["quantization.py"], srcs_version = "PY3", + visibility = ["//visibility:public"], deps = ["//mediapipe/model_maker/python/core/data:dataset"], ) diff --git a/mediapipe/model_maker/python/core/utils/__init__.py b/mediapipe/model_maker/python/core/utils/__init__.py index 7ca2f9216..5b1a4244c 100644 --- a/mediapipe/model_maker/python/core/utils/__init__.py +++ b/mediapipe/model_maker/python/core/utils/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/core/utils/file_util.py b/mediapipe/model_maker/python/core/utils/file_util.py index 221df94fd..71b5a0a7b 100644 --- a/mediapipe/model_maker/python/core/utils/file_util.py +++ b/mediapipe/model_maker/python/core/utils/file_util.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/core/utils/file_util_test.py b/mediapipe/model_maker/python/core/utils/file_util_test.py index 027756ff0..d5b983929 100644 --- a/mediapipe/model_maker/python/core/utils/file_util_test.py +++ b/mediapipe/model_maker/python/core/utils/file_util_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/core/utils/hub_loader.py b/mediapipe/model_maker/python/core/utils/hub_loader.py new file mode 100644 index 000000000..a52099884 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/hub_loader.py @@ -0,0 +1,97 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Handles both V1 and V2 modules.""" + +import tensorflow_hub as hub + + +class HubKerasLayerV1V2(hub.KerasLayer): + """Class to loads TF v1 and TF v2 hub modules that could be fine-tuned. + + Since TF v1 modules couldn't be retrained in hub.KerasLayer. This class + provides a workaround for retraining the whole tf1 model in tf2. In + particular, it extract self._func._self_unconditional_checkpoint_dependencies + into trainable variable in tf1. + + Doesn't update moving-mean/moving-variance for BatchNormalization during + fine-tuning. + """ + + def _setup_layer(self, trainable=False, **kwargs): + if self._is_hub_module_v1: + self._setup_layer_v1(trainable, **kwargs) + else: + # call _setup_layer from the base class for v2. + super(HubKerasLayerV1V2, self)._setup_layer(trainable, **kwargs) + + def _check_trainability(self): + if self._is_hub_module_v1: + self._check_trainability_v1() + else: + # call _check_trainability from the base class for v2. + super(HubKerasLayerV1V2, self)._check_trainability() + + def _setup_layer_v1(self, trainable=False, **kwargs): + """Constructs keras layer with relevant weights and losses.""" + # Initialize an empty layer, then add_weight() etc. as needed. + super(hub.KerasLayer, self).__init__(trainable=trainable, **kwargs) + + if not self._is_hub_module_v1: + raise ValueError( + 'Only supports to set up v1 hub module in this function.' + ) + + # v2 trainable_variable: + if hasattr(self._func, 'trainable_variables'): + for v in self._func.trainable_variables: + self._add_existing_weight(v, trainable=True) + trainable_variables = {id(v) for v in self._func.trainable_variables} + else: + trainable_variables = set() + + if not hasattr(self._func, '_self_unconditional_checkpoint_dependencies'): + raise ValueError( + "_func doesn't contains attribute " + '_self_unconditional_checkpoint_dependencies.' + ) + dependencies = self._func._self_unconditional_checkpoint_dependencies # pylint: disable=protected-access + + # Adds trainable variables. + for dep in dependencies: + if dep.name == 'variables': + for v in dep.ref: + if id(v) not in trainable_variables: + self._add_existing_weight(v, trainable=True) + trainable_variables.add(id(v)) + + # Adds non-trainable variables. + if hasattr(self._func, 'variables'): + for v in self._func.variables: + if id(v) not in trainable_variables: + self._add_existing_weight(v, trainable=False) + + # Forward the callable's regularization losses (if any). + if hasattr(self._func, 'regularization_losses'): + for l in self._func.regularization_losses: + if not callable(l): + raise ValueError( + 'hub.KerasLayer(obj) expects obj.regularization_losses to be an ' + 'iterable of callables, each returning a scalar loss term.' + ) + self.add_loss(self._call_loss_if_trainable(l)) # Supports callables. + + def _check_trainability_v1(self): + """Ignores trainability checks for V1.""" + if self._is_hub_module_v1: + return # Nothing to do. diff --git a/mediapipe/model_maker/python/core/utils/hub_loader_test.py b/mediapipe/model_maker/python/core/utils/hub_loader_test.py new file mode 100644 index 000000000..8ea15b5d1 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/hub_loader_test.py @@ -0,0 +1,59 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +import tensorflow as tf + +from mediapipe.model_maker.python.core.utils import hub_loader +from mediapipe.tasks.python.test import test_utils + + +class HubKerasLayerV1V2Test(tf.test.TestCase, parameterized.TestCase): + + @parameterized.parameters( + ("hub_module_v1_mini", True), + ("saved_model_v2_mini", True), + ("hub_module_v1_mini", False), + ("saved_model_v2_mini", False), + ) + def test_load_with_defaults(self, module_name, trainable): + inputs, expected_outputs = 10.0, 11.0 # Test modules perform increment op. + path = test_utils.get_test_data_path(module_name) + layer = hub_loader.HubKerasLayerV1V2(path, trainable=trainable) + output = layer(inputs) + self.assertEqual(output, expected_outputs) + + def test_trainable_variable(self): + path = test_utils.get_test_data_path("hub_module_v1_mini_train") + layer = hub_loader.HubKerasLayerV1V2(path, trainable=True) + # Checks trainable variables. + self.assertLen(layer.trainable_variables, 2) + self.assertEqual(layer.trainable_variables[0].name, "a:0") + self.assertEqual(layer.trainable_variables[1].name, "b:0") + self.assertEqual(layer.variables, layer.trainable_variables) + # Checks non-trainable variables. + self.assertEmpty(layer.non_trainable_variables) + + layer = hub_loader.HubKerasLayerV1V2(path, trainable=False) + # Checks trainable variables. + self.assertEmpty(layer.trainable_variables) + # Checks non-trainable variables. + self.assertLen(layer.non_trainable_variables, 2) + self.assertEqual(layer.non_trainable_variables[0].name, "a:0") + self.assertEqual(layer.non_trainable_variables[1].name, "b:0") + self.assertEqual(layer.variables, layer.non_trainable_variables) + + +if __name__ == "__main__": + tf.test.main() diff --git a/mediapipe/model_maker/python/core/utils/loss_functions.py b/mediapipe/model_maker/python/core/utils/loss_functions.py index 5b0aa32bf..c741e4282 100644 --- a/mediapipe/model_maker/python/core/utils/loss_functions.py +++ b/mediapipe/model_maker/python/core/utils/loss_functions.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,10 +13,21 @@ # limitations under the License. """Loss function utility library.""" -from typing import Optional, Sequence +import abc +from typing import Mapping, Sequence +import dataclasses +from typing import Any, Optional +import numpy as np import tensorflow as tf +from mediapipe.model_maker.python.core.utils import file_util +from mediapipe.model_maker.python.core.utils import model_util +from official.modeling import tf_utils + + +_VGG_IMAGENET_PERCEPTUAL_MODEL_URL = 'https://storage.googleapis.com/mediapipe-assets/vgg_feature_extractor.tar.gz' + class FocalLoss(tf.keras.losses.Loss): """Implementation of focal loss (https://arxiv.org/pdf/1708.02002.pdf). @@ -45,11 +56,10 @@ class FocalLoss(tf.keras.losses.Loss): ```python model.compile(optimizer='sgd', loss=FocalLoss(gamma)) ``` - """ def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None): - """Constructor. + """Initializes FocalLoss. Args: gamma: Focal loss gamma, as described in class docs. @@ -103,3 +113,297 @@ class FocalLoss(tf.keras.losses.Loss): # By default, this function uses "sum_over_batch_size" reduction for the # loss per batch. return tf.reduce_sum(losses) / batch_size + + +class SparseFocalLoss(FocalLoss): + """Sparse implementation of Focal Loss. + + This is the same as FocalLoss, except the labels are expected to be class ids + instead of 1-hot encoded vectors. See FocalLoss class documentation defined + in this same file for more details. + + Example usage: + >>> y_true = [1, 2] + >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] + >>> gamma = 2 + >>> focal_loss = SparseFocalLoss(gamma, 3) + >>> focal_loss(y_true, y_pred).numpy() + 0.9326 + + >>> # Calling with 'sample_weight'. + >>> focal_loss(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy() + 0.6528 + """ + + def __init__( + self, gamma, num_classes, class_weight: Optional[Sequence[float]] = None + ): + """Initializes SparseFocalLoss. + + Args: + gamma: Focal loss gamma, as described in class docs. + num_classes: Number of classes. + class_weight: A weight to apply to the loss, one for each class. The + weight is applied for each input where the ground truth label matches. + """ + super().__init__(gamma, class_weight=class_weight) + self._num_classes = num_classes + + def __call__( + self, + y_true: tf.Tensor, + y_pred: tf.Tensor, + sample_weight: Optional[tf.Tensor] = None, + ) -> tf.Tensor: + y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32) + y_true_one_hot = tf.one_hot(y_true, self._num_classes) + return super().__call__(y_true_one_hot, y_pred, sample_weight=sample_weight) + + +@dataclasses.dataclass +class PerceptualLossWeight: + """The weight for each perceptual loss. + + Attributes: + l1: weight for L1 loss. + content: weight for content loss. + style: weight for style loss. + """ + + l1: float = 1.0 + content: float = 1.0 + style: float = 1.0 + + +class ImagePerceptualQualityLoss(tf.keras.losses.Loss): + """Image perceptual quality loss. + + It obtains a weighted loss between the VGGPerceptualLoss and L1 loss. + """ + + def __init__( + self, + loss_weight: Optional[PerceptualLossWeight] = None, + reduction: tf.keras.losses.Reduction = tf.keras.losses.Reduction.NONE, + ): + """Initializes ImagePerceptualQualityLoss.""" + self._loss_weight = loss_weight + self._losses = {} + self._vgg_loss = VGGPerceptualLoss(self._loss_weight) + self._reduction = reduction + + def _l1_loss( + self, + reduction: tf.keras.losses.Reduction = tf.keras.losses.Reduction.NONE, + ) -> Any: + """Calculates L1 loss.""" + return tf.keras.losses.MeanAbsoluteError(reduction) + + def __call__( + self, + img1: tf.Tensor, + img2: tf.Tensor, + ) -> tf.Tensor: + """Computes image perceptual quality loss.""" + loss_value = [] + if self._loss_weight is None: + self._loss_weight = PerceptualLossWeight() + + if self._loss_weight.content > 0 or self._loss_weight.style > 0: + vgg_loss = self._vgg_loss(img1, img2) + vgg_loss_value = tf.math.add_n(vgg_loss.values()) + loss_value.append(vgg_loss_value) + if self._loss_weight.l1 > 0: + l1_loss = self._l1_loss(reduction=self._reduction)(img1, img2) + l1_loss_value = tf_utils.safe_mean(l1_loss * self._loss_weight.l1) + loss_value.append(l1_loss_value) + total_loss = tf.math.add_n(loss_value) + return total_loss + + +class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta): + """Base class for perceptual loss model.""" + + def __init__( + self, + feature_weight: Optional[Sequence[float]] = None, + loss_weight: Optional[PerceptualLossWeight] = None, + ): + """Instantiates perceptual loss. + + Args: + feature_weight: The weight coefficients of multiple model extracted + features used for calculating the perceptual loss. + loss_weight: The weight coefficients between `style_loss` and + `content_loss`. + """ + super().__init__() + self._loss_op = lambda x, y: tf.math.reduce_mean(tf.abs(x - y)) + self._loss_style = tf.constant(0.0) + self._loss_content = tf.constant(0.0) + self._feature_weight = feature_weight + self._loss_weight = loss_weight + + def __call__( + self, + img1: tf.Tensor, + img2: tf.Tensor, + ) -> Mapping[str, tf.Tensor]: + """Computes perceptual loss between two images. + + Args: + img1: First batch of images. The pixel values should be normalized to [-1, + 1]. + img2: Second batch of images. The pixel values should be normalized to + [-1, 1]. + + Returns: + A mapping between loss name and loss tensors. + """ + x_features = self._compute_features(img1) + y_features = self._compute_features(img2) + + if self._loss_weight is None: + self._loss_weight = PerceptualLossWeight() + + # If the _feature_weight is not initialized, then initialize it as a list of + # all the element equals to 1.0. + if self._feature_weight is None: + self._feature_weight = [1.0] * len(x_features) + + # If the length of _feature_weight smallert than the length of the feature, + # raise a ValueError. Otherwise, only use the first len(x_features) weight + # for computing the loss. + if len(self._feature_weight) < len(x_features): + raise ValueError( + f'Input feature weight length {len(self._feature_weight)} is smaller' + f' than feature length {len(x_features)}' + ) + + if self._loss_weight.style > 0.0: + self._loss_style = tf_utils.safe_mean( + self._loss_weight.style + * self._get_style_loss(x_feats=x_features, y_feats=y_features) + ) + if self._loss_weight.content > 0.0: + self._loss_content = tf_utils.safe_mean( + self._loss_weight.content + * self._get_content_loss(x_feats=x_features, y_feats=y_features) + ) + + return {'style_loss': self._loss_style, 'content_loss': self._loss_content} + + @abc.abstractmethod + def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]: + """Computes features from the given image tensor. + + Args: + img: Image tensor. + + Returns: + A list of multi-scale feature maps. + """ + + def _get_content_loss( + self, x_feats: Sequence[tf.Tensor], y_feats: Sequence[tf.Tensor] + ) -> tf.Tensor: + """Gets weighted multi-scale content loss. + + Args: + x_feats: Reconstructed face image. + y_feats: Target face image. + + Returns: + A scalar tensor for the content loss. + """ + content_losses = [] + for coef, x_feat, y_feat in zip(self._feature_weight, x_feats, y_feats): + content_loss = self._loss_op(x_feat, y_feat) * coef + content_losses.append(content_loss) + return tf.math.reduce_sum(content_losses) + + def _get_style_loss( + self, x_feats: Sequence[tf.Tensor], y_feats: Sequence[tf.Tensor] + ) -> tf.Tensor: + """Gets weighted multi-scale style loss. + + Args: + x_feats: Reconstructed face image. + y_feats: Target face image. + + Returns: + A scalar tensor for the style loss. + """ + style_losses = [] + i = 0 + for coef, x_feat, y_feat in zip(self._feature_weight, x_feats, y_feats): + x_feat_g = _compute_gram_matrix(x_feat) + y_feat_g = _compute_gram_matrix(y_feat) + style_loss = self._loss_op(x_feat_g, y_feat_g) * coef + style_losses.append(style_loss) + i = i + 1 + + return tf.math.reduce_sum(style_loss) + + +class VGGPerceptualLoss(PerceptualLoss): + """Perceptual loss based on VGG19 pretrained on the ImageNet dataset. + + Reference: + - [Perceptual Losses for Real-Time Style Transfer and Super-Resolution]( + https://arxiv.org/abs/1603.08155) (ECCV 2016) + + Perceptual loss measures high-level perceptual and semantic differences + between images. + """ + + def __init__( + self, + loss_weight: Optional[PerceptualLossWeight] = None, + ): + """Initializes image quality loss essentials. + + Args: + loss_weight: Loss weight coefficients. + """ + super().__init__( + feature_weight=np.array([0.1, 0.1, 1.0, 1.0, 1.0]), + loss_weight=loss_weight, + ) + + rgb_mean = tf.constant([0.485, 0.456, 0.406]) + rgb_std = tf.constant([0.229, 0.224, 0.225]) + + self._rgb_mean = tf.reshape(rgb_mean, (1, 1, 1, 3)) + self._rgb_std = tf.reshape(rgb_std, (1, 1, 1, 3)) + + model_path = file_util.DownloadedFiles( + 'vgg_feature_extractor', + _VGG_IMAGENET_PERCEPTUAL_MODEL_URL, + is_folder=True, + ) + self._vgg19 = model_util.load_keras_model(model_path.get_path()) + + def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]: + """Computes VGG19 features.""" + img = (img + 1) / 2.0 + norm_img = (img - self._rgb_mean) / self._rgb_std + # no grad, as it only serves as a frozen feature extractor. + return self._vgg19(norm_img) + + +def _compute_gram_matrix(feature: tf.Tensor) -> tf.Tensor: + """Computes gram matrix for the feature map. + + Args: + feature: [B, H, W, C] feature map. + + Returns: + [B, C, C] gram matrix. + """ + h, w, c = feature.shape[1:].as_list() + feat_reshaped = tf.reshape(feature, shape=(-1, h * w, c)) + feat_gram = tf.matmul( + tf.transpose(feat_reshaped, perm=[0, 2, 1]), feat_reshaped + ) + return feat_gram / (c * h * w) diff --git a/mediapipe/model_maker/python/core/utils/loss_functions_test.py b/mediapipe/model_maker/python/core/utils/loss_functions_test.py index 716c329ef..3a14567ed 100644 --- a/mediapipe/model_maker/python/core/utils/loss_functions_test.py +++ b/mediapipe/model_maker/python/core/utils/loss_functions_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,9 @@ # limitations under the License. import math -from typing import Optional +import tempfile +from typing import Dict, Optional, Sequence +from unittest import mock as unittest_mock from absl.testing import parameterized import tensorflow as tf @@ -21,7 +23,7 @@ import tensorflow as tf from mediapipe.model_maker.python.core.utils import loss_functions -class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase): +class FocalLossTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters( dict(testcase_name='no_sample_weight', sample_weight=None), @@ -99,5 +101,245 @@ class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase): self.assertNear(loss, expected_loss, 1e-4) +class SparseFocalLossTest(tf.test.TestCase): + + def test_sparse_focal_loss_matches_focal_loss(self): + num_classes = 2 + y_pred = tf.constant([[0.8, 0.2], [0.3, 0.7]]) + y_true = tf.constant([1, 0]) + y_true_one_hot = tf.one_hot(y_true, num_classes) + for gamma in [0.0, 0.5, 1.0]: + expected_loss_fn = loss_functions.FocalLoss(gamma=gamma) + loss_fn = loss_functions.SparseFocalLoss( + gamma=gamma, num_classes=num_classes + ) + expected_loss = expected_loss_fn(y_true_one_hot, y_pred) + loss = loss_fn(y_true, y_pred) + self.assertNear(loss, expected_loss, 1e-4) + + +class MockPerceptualLoss(loss_functions.PerceptualLoss): + """A mock class with implementation of abstract methods for testing.""" + + def __init__( + self, + use_mock_loss_op: bool = False, + feature_weight: Optional[Sequence[float]] = None, + loss_weight: Optional[loss_functions.PerceptualLossWeight] = None, + ): + super().__init__(feature_weight=feature_weight, loss_weight=loss_weight) + if use_mock_loss_op: + self._loss_op = lambda x, y: tf.math.reduce_mean(x - y) + + def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]: + return [tf.random.normal(shape=(1, 8, 8, 3))] * 5 + + +class PerceptualLossTest(tf.test.TestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + self._img1 = tf.fill(dims=(8, 8), value=0.2) + self._img2 = tf.fill(dims=(8, 8), value=0.8) + + def test_invalid_feature_weight_raise_value_error(self): + with self.assertRaisesRegex( + ValueError, + 'Input feature weight length 2 is smaller than feature length 5', + ): + MockPerceptualLoss(feature_weight=[1.0, 2.0])( + img1=self._img1, img2=self._img2 + ) + + @parameterized.named_parameters( + dict( + testcase_name='default_loss_weight_and_loss_op', + use_mock_loss_op=False, + feature_weight=None, + loss_weight=None, + loss_values={ + 'style_loss': 0.032839, + 'content_loss': 5.639870, + }, + ), + dict( + testcase_name='style_loss_weight_is_0_default_loss_op', + use_mock_loss_op=False, + feature_weight=None, + loss_weight=loss_functions.PerceptualLossWeight(style=0), + loss_values={ + 'style_loss': 0, + 'content_loss': 5.639870, + }, + ), + dict( + testcase_name='content_loss_weight_is_0_default_loss_op', + use_mock_loss_op=False, + feature_weight=None, + loss_weight=loss_functions.PerceptualLossWeight(content=0), + loss_values={ + 'style_loss': 0.032839, + 'content_loss': 0, + }, + ), + dict( + testcase_name='customized_loss_weight_default_loss_op', + use_mock_loss_op=False, + feature_weight=None, + loss_weight=loss_functions.PerceptualLossWeight( + style=1.0, content=2.0 + ), + loss_values={'style_loss': 0.032839, 'content_loss': 11.279739}, + ), + dict( + testcase_name=( + 'customized_feature_weight_and_loss_weight_default_loss_op' + ), + use_mock_loss_op=False, + feature_weight=[1.0, 2.0, 3.0, 4.0, 5.0], + loss_weight=loss_functions.PerceptualLossWeight( + style=1.0, content=2.0 + ), + loss_values={'style_loss': 0.164193, 'content_loss': 33.839218}, + ), + dict( + testcase_name='no_loss_change_if_extra_feature_weight_provided', + use_mock_loss_op=False, + feature_weight=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + loss_weight=loss_functions.PerceptualLossWeight( + style=1.0, content=2.0 + ), + loss_values={ + 'style_loss': 0.164193, + 'content_loss': 33.839218, + }, + ), + dict( + testcase_name='customized_loss_weight_custom_loss_op', + use_mock_loss_op=True, + feature_weight=None, + loss_weight=loss_functions.PerceptualLossWeight( + style=1.0, content=2.0 + ), + loss_values={'style_loss': 0.000395, 'content_loss': -1.533469}, + ), + ) + def test_weighted_perceptul_loss( + self, + use_mock_loss_op: bool, + feature_weight: Sequence[float], + loss_weight: loss_functions.PerceptualLossWeight, + loss_values: Dict[str, float], + ): + perceptual_loss = MockPerceptualLoss( + use_mock_loss_op=use_mock_loss_op, + feature_weight=feature_weight, + loss_weight=loss_weight, + ) + loss = perceptual_loss(img1=self._img1, img2=self._img2) + self.assertEqual(list(loss.keys()), ['style_loss', 'content_loss']) + self.assertNear(loss['style_loss'], loss_values['style_loss'], 1e-4) + self.assertNear(loss['content_loss'], loss_values['content_loss'], 1e-4) + + +class VGGPerceptualLossTest(tf.test.TestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + # Mock tempfile.gettempdir() to be unique for each test to avoid race + # condition when downloading model since these tests may run in parallel. + mock_gettempdir = unittest_mock.patch.object( + tempfile, + 'gettempdir', + return_value=self.create_tempdir(), + autospec=True, + ) + self.mock_gettempdir = mock_gettempdir.start() + self.addCleanup(mock_gettempdir.stop) + self._img1 = tf.fill(dims=(1, 256, 256, 3), value=0.1) + self._img2 = tf.fill(dims=(1, 256, 256, 3), value=0.9) + + @parameterized.named_parameters( + dict( + testcase_name='default_loss_weight', + loss_weight=None, + loss_values={ + 'style_loss': 5.8363257e-06, + 'content_loss': 1.7016045, + }, + ), + dict( + testcase_name='customized_loss_weight', + loss_weight=loss_functions.PerceptualLossWeight( + style=10.0, content=20.0 + ), + loss_values={ + 'style_loss': 5.8363257e-05, + 'content_loss': 34.03208, + }, + ), + ) + def test_vgg_perceptual_loss(self, loss_weight, loss_values): + vgg_loss = loss_functions.VGGPerceptualLoss(loss_weight=loss_weight) + loss = vgg_loss(img1=self._img1, img2=self._img2) + self.assertEqual(list(loss.keys()), ['style_loss', 'content_loss']) + self.assertNear( + loss['style_loss'], + loss_values['style_loss'], + loss_values['style_loss'] / 1e5, + ) + self.assertNear( + loss['content_loss'], + loss_values['content_loss'], + loss_values['content_loss'] / 1e5, + ) + + +class ImagePerceptualQualityLossTest(tf.test.TestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + # Mock tempfile.gettempdir() to be unique for each test to avoid race + # condition when downloading model since these tests may run in parallel. + mock_gettempdir = unittest_mock.patch.object( + tempfile, + 'gettempdir', + return_value=self.create_tempdir(), + autospec=True, + ) + self.mock_gettempdir = mock_gettempdir.start() + self.addCleanup(mock_gettempdir.stop) + self._img1 = tf.fill(dims=(1, 256, 256, 3), value=0.1) + self._img2 = tf.fill(dims=(1, 256, 256, 3), value=0.9) + + @parameterized.named_parameters( + dict( + testcase_name='default_loss_weight', + loss_weight=None, + loss_value=2.501612, + ), + dict( + testcase_name='customized_loss_weight_zero_l1', + loss_weight=loss_functions.PerceptualLossWeight( + l1=0.0, style=10.0, content=20.0 + ), + loss_value=34.032139, + ), + dict( + testcase_name='customized_loss_weight_nonzero_l1', + loss_weight=loss_functions.PerceptualLossWeight( + l1=10.0, style=10.0, content=20.0 + ), + loss_value=42.032139, + ), + ) + def test_image_perceptual_quality_loss(self, loss_weight, loss_value): + image_quality_loss = loss_functions.ImagePerceptualQualityLoss( + loss_weight=loss_weight + ) + loss = image_quality_loss(img1=self._img1, img2=self._img2) + self.assertNear(loss, loss_value, 1e-4) + + if __name__ == '__main__': tf.test.main() diff --git a/mediapipe/model_maker/python/core/utils/metrics.py b/mediapipe/model_maker/python/core/utils/metrics.py new file mode 100644 index 000000000..310146168 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/metrics.py @@ -0,0 +1,104 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Metrics utility library.""" + +import tensorflow as tf + + +def _get_binary_sparse_metric(metric: tf.metrics.Metric): + """Helper method to create a BinarySparse version of a tf.keras.Metric. + + BinarySparse is an implementation where the update_state(y_true, y_pred) takes + in shapes y_true=(batch_size, 1) y_pred=(batch_size, 2). Note that this only + supports the binary classification case, and that class_id=0 is the negative + class and class_id=1 is the positive class. + + Currently supported tf.metric.Metric classes + 1. BinarySparseRecallAtPrecision + 2. BinarySparsePrecisionAtRecall + + Args: + metric: A tf.metric.Metric class for which we want to generate a + BinarySparse version of this metric. + + Returns: + A class for the BinarySparse version of the specified tf.metrics.Metric + """ + + class BinarySparseMetric(metric): + """A BinarySparse wrapper class for a tf.keras.Metric. + + This class has the same parameters and functions as the underlying + metric class. For example, the parameters for BinarySparseRecallAtPrecision + is the same as tf.keras.metrics.RecallAtPrecision. The only new constraint + is that class_id must be set to 1 (or not specified) for the Binary metric. + """ + + def __init__(self, *args, **kwargs): + if 'class_id' in kwargs and kwargs['class_id'] != 1: + raise ValueError( + f'Custom BinarySparseMetric for class:{metric.__name__} is ' + 'only supported for class_id=1, got class_id=' + f'{kwargs["class_id"]} instead' + ) + else: + kwargs['class_id'] = 1 + super().__init__(*args, **kwargs) + + def update_state(self, y_true, y_pred, sample_weight=None): + y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32) + y_true_one_hot = tf.one_hot(y_true, 2) + return super().update_state( + y_true_one_hot, y_pred, sample_weight=sample_weight + ) + + return BinarySparseMetric + + +def _get_sparse_metric(metric: tf.metrics.Metric): + """Helper method to create a Sparse version of a tf.keras.Metric. + + Sparse is an implementation where the update_state(y_true, y_pred) takes in + shapes y_true=(batch_size, 1) and y_pred=(batch_size, num_classes). + + Currently supported tf.metrics.Metric classes: + 1. tf.metrics.Recall + 2. tf.metrics.Precision + + Args: + metric: A tf.metric.Metric class for which we want to generate a Sparse + version of this metric. + + Returns: + A class for the Sparse version of the specified tf.keras.Metric. + """ + + class SparseMetric(metric): + """A Sparse wrapper class for a tf.keras.Metric.""" + + def update_state(self, y_true, y_pred, sample_weight=None): + y_pred = tf.math.argmax(y_pred, axis=-1) + return super().update_state(y_true, y_pred, sample_weight=sample_weight) + + return SparseMetric + + +SparseRecall = _get_sparse_metric(tf.metrics.Recall) +SparsePrecision = _get_sparse_metric(tf.metrics.Precision) +BinarySparseRecallAtPrecision = _get_binary_sparse_metric( + tf.metrics.RecallAtPrecision +) +BinarySparsePrecisionAtRecall = _get_binary_sparse_metric( + tf.metrics.PrecisionAtRecall +) diff --git a/mediapipe/model_maker/python/core/utils/metrics_test.py b/mediapipe/model_maker/python/core/utils/metrics_test.py new file mode 100644 index 000000000..842335273 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/metrics_test.py @@ -0,0 +1,74 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from absl.testing import parameterized +import tensorflow as tf + +from mediapipe.model_maker.python.core.utils import metrics + + +class SparseMetricTest(tf.test.TestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + self.y_true = [0, 0, 1, 1, 0, 1] + self.y_pred = [ + [0.9, 0.1], # 0, 0 y + [0.8, 0.2], # 0, 0 y + [0.7, 0.3], # 0, 1 n + [0.6, 0.4], # 0, 1 n + [0.3, 0.7], # 1, 0 y + [0.3, 0.7], # 1, 1 y + ] + self.num_classes = 3 + + def _assert_metric_equals(self, metric, value): + metric.update_state(self.y_true, self.y_pred) + self.assertEqual(metric.result(), value) + + def test_sparse_recall(self): + metric = metrics.SparseRecall() + self._assert_metric_equals(metric, 1 / 3) + + def test_sparse_precision(self): + metric = metrics.SparsePrecision() + self._assert_metric_equals(metric, 1 / 2) + + def test_binary_sparse_recall_at_precision(self): + metric = metrics.BinarySparseRecallAtPrecision(1.0) + self._assert_metric_equals(metric, 0.0) # impossible to achieve precision=1 + metric = metrics.BinarySparseRecallAtPrecision(0.4) + self._assert_metric_equals(metric, 1.0) + + def test_binary_sparse_precision_at_recall(self): + metric = metrics.BinarySparsePrecisionAtRecall(1.0) + self._assert_metric_equals(metric, 3 / 4) + metric = metrics.BinarySparsePrecisionAtRecall(0.7) + self._assert_metric_equals(metric, 3 / 4) + + def test_binary_sparse_precision_at_recall_class_id_error(self): + # class_id=1 case should not error + _ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=1) + # class_id=2 case should error + with self.assertRaisesRegex( + ValueError, + 'Custom BinarySparseMetric for class:PrecisionAtRecall is only' + ' supported for class_id=1, got class_id=2 instead', + ): + _ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=2) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/core/utils/model_util.py b/mediapipe/model_maker/python/core/utils/model_util.py index 7a0b8fcf0..2b1eebf9f 100644 --- a/mediapipe/model_maker/python/core/utils/model_util.py +++ b/mediapipe/model_maker/python/core/utils/model_util.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,7 +34,8 @@ ESTIMITED_STEPS_PER_EPOCH = 1000 def get_default_callbacks( - export_dir: str) -> Sequence[tf.keras.callbacks.Callback]: + export_dir: str, +) -> Sequence[tf.keras.callbacks.Callback]: """Gets default callbacks.""" summary_dir = os.path.join(export_dir, 'summaries') summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) @@ -43,12 +44,14 @@ def get_default_callbacks( checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( os.path.join(checkpoint_path, 'model-{epoch:04d}'), save_weights_only=True, - period=5) + period=5, + ) return [summary_callback, checkpoint_callback] -def load_keras_model(model_path: str, - compile_on_load: bool = False) -> tf.keras.Model: +def load_keras_model( + model_path: str, compile_on_load: bool = False +) -> tf.keras.Model: """Loads a tensorflow Keras model from file and returns the Keras model. Args: @@ -82,9 +85,11 @@ def load_tflite_model_buffer(model_path: str) -> bytearray: return tflite_model_buffer -def get_steps_per_epoch(steps_per_epoch: Optional[int] = None, - batch_size: Optional[int] = None, - train_data: Optional[dataset.Dataset] = None) -> int: +def get_steps_per_epoch( + steps_per_epoch: Optional[int] = None, + batch_size: Optional[int] = None, + train_data: Optional[dataset.Dataset] = None, +) -> int: """Gets the estimated training steps per epoch. 1. If `steps_per_epoch` is set, returns `steps_per_epoch` directly. @@ -112,6 +117,43 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None, return len(train_data) // batch_size +def convert_to_tflite_from_file( + saved_model_file: str, + quantization_config: Optional[quantization.QuantizationConfig] = None, + supported_ops: Tuple[tf.lite.OpsSet, ...] = ( + tf.lite.OpsSet.TFLITE_BUILTINS, + ), + preprocess: Optional[Callable[..., Any]] = None, + allow_custom_ops: bool = False, +) -> bytearray: + """Converts the input Keras model to TFLite format. + + Args: + saved_model_file: Keras model to be converted to TFLite. + quantization_config: Configuration for post-training quantization. + supported_ops: A list of supported ops in the converted TFLite file. + preprocess: A callable to preprocess the representative dataset for + quantization. The callable takes three arguments in order: feature, label, + and is_training. + allow_custom_ops: A boolean flag to enable custom ops in model convsion. + Default to False. + + Returns: + bytearray of TFLite model + """ + converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_file) + + if quantization_config: + converter = quantization_config.set_converter_with_quantization( + converter, preprocess=preprocess + ) + + converter.allow_custom_ops = allow_custom_ops + converter.target_spec.supported_ops = supported_ops + tflite_model = converter.convert() + return tflite_model + + def convert_to_tflite( model: tf.keras.Model, quantization_config: Optional[quantization.QuantizationConfig] = None, @@ -119,6 +161,7 @@ def convert_to_tflite( tf.lite.OpsSet.TFLITE_BUILTINS, ), preprocess: Optional[Callable[..., Any]] = None, + allow_custom_ops: bool = False, ) -> bytearray: """Converts the input Keras model to TFLite format. @@ -129,22 +172,26 @@ def convert_to_tflite( preprocess: A callable to preprocess the representative dataset for quantization. The callable takes three arguments in order: feature, label, and is_training. + allow_custom_ops: A boolean flag to enable custom ops in model conversion. + Default to False. Returns: bytearray of TFLite model """ with tempfile.TemporaryDirectory() as temp_dir: save_path = os.path.join(temp_dir, 'saved_model') - model.save(save_path, include_optimizer=False, save_format='tf') - converter = tf.lite.TFLiteConverter.from_saved_model(save_path) - - if quantization_config: - converter = quantization_config.set_converter_with_quantization( - converter, preprocess=preprocess) - - converter.target_spec.supported_ops = supported_ops - tflite_model = converter.convert() - return tflite_model + model.save( + save_path, + include_optimizer=False, + save_format='tf', + ) + return convert_to_tflite_from_file( + save_path, + quantization_config, + supported_ops, + preprocess, + allow_custom_ops, + ) def save_tflite(tflite_model: bytearray, tflite_file: str) -> None: @@ -159,17 +206,20 @@ def save_tflite(tflite_model: bytearray, tflite_file: str) -> None: with tf.io.gfile.GFile(tflite_file, 'wb') as f: f.write(tflite_model) tf.compat.v1.logging.info( - 'TensorFlow Lite model exported successfully to: %s' % tflite_file) + 'TensorFlow Lite model exported successfully to: %s' % tflite_file + ) class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): """Applies a warmup schedule on a given learning rate decay schedule.""" - def __init__(self, - initial_learning_rate: float, - decay_schedule_fn: Callable[[Any], Any], - warmup_steps: int, - name: Optional[str] = None): + def __init__( + self, + initial_learning_rate: float, + decay_schedule_fn: Callable[[Any], Any], + warmup_steps: int, + name: Optional[str] = None, + ): """Initializes a new instance of the `WarmUp` class. Args: @@ -197,14 +247,15 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): global_step_float < warmup_steps_float, lambda: warmup_learning_rate, lambda: self.decay_schedule_fn(step), - name=name) + name=name, + ) def get_config(self) -> Dict[str, Any]: return { 'initial_learning_rate': self.initial_learning_rate, 'decay_schedule_fn': self.decay_schedule_fn, 'warmup_steps': self.warmup_steps, - 'name': self.name + 'name': self.name, } @@ -238,7 +289,8 @@ class LiteRunner(object): """ if not isinstance(input_tensors, list) and not isinstance( - input_tensors, dict): + input_tensors, dict + ): input_tensors = [input_tensors] interpreter = self.interpreter @@ -246,19 +298,18 @@ class LiteRunner(object): # Reshape inputs for i, input_detail in enumerate(self.input_details): input_tensor = _get_input_tensor( - input_tensors=input_tensors, - input_details=self.input_details, - index=i) + input_tensors=input_tensors, input_details=self.input_details, index=i + ) interpreter.resize_tensor_input( - input_index=input_detail['index'], tensor_size=input_tensor.shape) + input_index=input_detail['index'], tensor_size=input_tensor.shape + ) interpreter.allocate_tensors() # Feed input to the interpreter for i, input_detail in enumerate(self.input_details): input_tensor = _get_input_tensor( - input_tensors=input_tensors, - input_details=self.input_details, - index=i) + input_tensors=input_tensors, input_details=self.input_details, index=i + ) if input_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT): # Quantize the input scale, zero_point = input_detail['quantization'] @@ -289,9 +340,11 @@ def get_lite_runner(tflite_buffer: bytearray) -> 'LiteRunner': return lite_runner -def _get_input_tensor(input_tensors: Union[List[tf.Tensor], Dict[str, - tf.Tensor]], - input_details: Dict[str, Any], index: int) -> tf.Tensor: +def _get_input_tensor( + input_tensors: Union[List[tf.Tensor], Dict[str, tf.Tensor]], + input_details: Dict[str, Any], + index: int, +) -> tf.Tensor: """Returns input tensor in `input_tensors` that maps `input_detail[i]`.""" if isinstance(input_tensors, dict): # Gets the mapped input tensor. @@ -299,7 +352,9 @@ def _get_input_tensor(input_tensors: Union[List[tf.Tensor], Dict[str, for input_tensor_name, input_tensor in input_tensors.items(): if input_tensor_name in input_detail['name']: return input_tensor - raise ValueError('Input tensors don\'t contains a tensor that mapped the ' - 'input detail %s' % str(input_detail)) + raise ValueError( + "Input tensors don't contains a tensor that mapped the input detail %s" + % str(input_detail) + ) else: return input_tensors[index] diff --git a/mediapipe/model_maker/python/core/utils/model_util_test.py b/mediapipe/model_maker/python/core/utils/model_util_test.py index 6961a5fc7..57750624f 100644 --- a/mediapipe/model_maker/python/core/utils/model_util_test.py +++ b/mediapipe/model_maker/python/core/utils/model_util_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/core/utils/quantization.py b/mediapipe/model_maker/python/core/utils/quantization.py index a1a38cc64..2a8d92244 100644 --- a/mediapipe/model_maker/python/core/utils/quantization.py +++ b/mediapipe/model_maker/python/core/utils/quantization.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/core/utils/quantization_test.py b/mediapipe/model_maker/python/core/utils/quantization_test.py index 9d27d34ac..57523d405 100644 --- a/mediapipe/model_maker/python/core/utils/quantization_test.py +++ b/mediapipe/model_maker/python/core/utils/quantization_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/core/utils/test_util.py b/mediapipe/model_maker/python/core/utils/test_util.py index 14d02814e..72fb229c7 100644 --- a/mediapipe/model_maker/python/core/utils/test_util.py +++ b/mediapipe/model_maker/python/core/utils/test_util.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from typing import List, Union +from typing import Sequence +from typing import Dict, List, Union # Dependency imports @@ -94,6 +95,17 @@ def is_same_output(tflite_model: bytearray, return np.allclose(lite_output, keras_output, atol=atol) +def run_tflite( + tflite_filename: str, + input_tensors: Union[List[tf.Tensor], Dict[str, tf.Tensor]], +) -> Union[Sequence[tf.Tensor], tf.Tensor]: + """Runs TFLite model inference.""" + with tf.io.gfile.GFile(tflite_filename, "rb") as f: + tflite_model = f.read() + lite_runner = model_util.get_lite_runner(tflite_model) + return lite_runner.run(input_tensors) + + def test_tflite(keras_model: tf.keras.Model, tflite_model: bytearray, size: Union[int, List[int]], diff --git a/mediapipe/model_maker/python/core/utils/testdata/BUILD b/mediapipe/model_maker/python/core/utils/testdata/BUILD deleted file mode 100644 index 8eed72f78..000000000 --- a/mediapipe/model_maker/python/core/utils/testdata/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - default_visibility = ["//mediapipe/model_maker/python/core/utils:__subpackages__"], - licenses = ["notice"], # Apache 2.0 -) - -filegroup( - name = "testdata", - srcs = ["test.txt"], -) diff --git a/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini/saved_model.pb b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini/saved_model.pb new file mode 100644 index 000000000..e60e04a24 Binary files /dev/null and b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini/saved_model.pb differ diff --git a/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini/tfhub_module.pb b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini/tfhub_module.pb new file mode 100644 index 000000000..d65dd8f1d --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini/tfhub_module.pb @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/saved_model.pb b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/saved_model.pb new file mode 100644 index 000000000..69519fef7 Binary files /dev/null and b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/saved_model.pb differ diff --git a/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/tfhub_module.pb b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/tfhub_module.pb new file mode 100644 index 000000000..d65dd8f1d --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/tfhub_module.pb @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.data-00000-of-00001 b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.data-00000-of-00001 new file mode 100644 index 000000000..3474955ee --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.data-00000-of-00001 @@ -0,0 +1,2 @@ +¿ + diff --git a/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.index b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.index new file mode 100644 index 000000000..d0e35ab87 Binary files /dev/null and b/mediapipe/model_maker/python/core/utils/testdata/hub_module_v1_mini_train/variables/variables.index differ diff --git a/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/saved_model.pb b/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/saved_model.pb new file mode 100644 index 000000000..314ea74fc Binary files /dev/null and b/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/saved_model.pb differ diff --git a/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/variables/variables.data-00000-of-00001 b/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/variables/variables.data-00000-of-00001 new file mode 100644 index 000000000..09dbb330d Binary files /dev/null and b/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/variables/variables.data-00000-of-00001 differ diff --git a/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/variables/variables.index b/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/variables/variables.index new file mode 100644 index 000000000..7cfb9ffd4 Binary files /dev/null and b/mediapipe/model_maker/python/core/utils/testdata/saved_model_v2_mini/variables/variables.index differ diff --git a/mediapipe/model_maker/python/text/__init__.py b/mediapipe/model_maker/python/text/__init__.py index 7ca2f9216..5b1a4244c 100644 --- a/mediapipe/model_maker/python/text/__init__.py +++ b/mediapipe/model_maker/python/text/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/text/core/BUILD b/mediapipe/model_maker/python/text/core/BUILD index e0c53491a..d99f46b77 100644 --- a/mediapipe/model_maker/python/text/core/BUILD +++ b/mediapipe/model_maker/python/text/core/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,9 +14,7 @@ # Placeholder for internal Python strict library and test compatibility macro. -package( - default_visibility = ["//mediapipe:__subpackages__"], -) +package(default_visibility = ["//mediapipe:__subpackages__"]) licenses(["notice"]) diff --git a/mediapipe/model_maker/python/text/core/__init__.py b/mediapipe/model_maker/python/text/core/__init__.py index 7ca2f9216..5b1a4244c 100644 --- a/mediapipe/model_maker/python/text/core/__init__.py +++ b/mediapipe/model_maker/python/text/core/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/text/core/bert_model_options.py b/mediapipe/model_maker/python/text/core/bert_model_options.py index ce5ef6af4..bb8aca963 100644 --- a/mediapipe/model_maker/python/text/core/bert_model_options.py +++ b/mediapipe/model_maker/python/text/core/bert_model_options.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/text/core/bert_model_spec.py b/mediapipe/model_maker/python/text/core/bert_model_spec.py index 605435df0..4a847ac33 100644 --- a/mediapipe/model_maker/python/text/core/bert_model_spec.py +++ b/mediapipe/model_maker/python/text/core/bert_model_spec.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ """Specification for a BERT model.""" import dataclasses -from typing import Dict +from typing import Dict, Union from mediapipe.model_maker.python.core import hyperparameters as hp from mediapipe.model_maker.python.core.utils import file_util @@ -35,7 +35,9 @@ class BertModelSpec: Transformers for Language Understanding) for more details. Attributes: - downloaded_files: A DownloadedFiles object of the model files + files: Either a TFHub url string which can be passed directly to + hub.KerasLayer or a DownloadedFiles object of the model files. + is_tf2: If True, the checkpoint is TF2 format. Else use TF1 format. hparams: Hyperparameters used for training. model_options: Configurable options for a BERT model. do_lower_case: boolean, whether to lower case the input text. Should be @@ -45,15 +47,28 @@ class BertModelSpec: name: The name of the object. """ - downloaded_files: file_util.DownloadedFiles - hparams: hp.BaseHParams = hp.BaseHParams( - epochs=3, - batch_size=32, - learning_rate=3e-5, - distribution_strategy='mirrored') - model_options: bert_model_options.BertModelOptions = ( - bert_model_options.BertModelOptions()) + files: Union[str, file_util.DownloadedFiles] + is_tf2: bool = True + hparams: hp.BaseHParams = dataclasses.field( + default_factory=lambda: hp.BaseHParams( + epochs=3, + batch_size=32, + learning_rate=3e-5, + distribution_strategy='mirrored', + ) + ) + model_options: bert_model_options.BertModelOptions = dataclasses.field( + default_factory=bert_model_options.BertModelOptions + ) do_lower_case: bool = True tflite_input_name: Dict[str, str] = dataclasses.field( default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME) name: str = 'Bert' + + def get_path(self) -> str: + if isinstance(self.files, file_util.DownloadedFiles): + return self.files.get_path() + elif isinstance(self.files, str): + return self.files + else: + raise ValueError(f'files has unsupported type: {type(self.files)}') diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD index 2c1e2d3d8..e32733e31 100644 --- a/mediapipe/model_maker/python/text/text_classifier/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,12 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python strict binary and library compatibility macro. # Placeholder for internal Python strict test compatibility macro. -package( - default_visibility = ["//mediapipe:__subpackages__"], -) +package(default_visibility = ["//mediapipe:__subpackages__"]) licenses(["notice"]) @@ -33,11 +31,11 @@ py_library( visibility = ["//visibility:public"], deps = [ ":dataset", + ":hyperparameters", ":model_options", ":model_spec", ":text_classifier", ":text_classifier_options", - "//mediapipe/model_maker/python/core:hyperparameters", ], ) @@ -47,12 +45,18 @@ py_library( deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"], ) +py_library( + name = "hyperparameters", + srcs = ["hyperparameters.py"], + deps = ["//mediapipe/model_maker/python/core:hyperparameters"], +) + py_library( name = "model_spec", srcs = ["model_spec.py"], deps = [ + ":hyperparameters", ":model_options", - "//mediapipe/model_maker/python/core:hyperparameters", "//mediapipe/model_maker/python/core/utils:file_util", "//mediapipe/model_maker/python/text/core:bert_model_spec", ], @@ -63,16 +67,19 @@ py_test( srcs = ["model_spec_test.py"], tags = ["requires-net:external"], deps = [ + ":hyperparameters", ":model_options", ":model_spec", - "//mediapipe/model_maker/python/core:hyperparameters", ], ) py_library( name = "dataset", srcs = ["dataset.py"], - deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"], + deps = [ + "//mediapipe/model_maker/python/core/data:cache_files", + "//mediapipe/model_maker/python/core/data:classification_dataset", + ], ) py_test( @@ -84,7 +91,10 @@ py_test( py_library( name = "preprocessor", srcs = ["preprocessor.py"], - deps = [":dataset"], + deps = [ + ":dataset", + "//mediapipe/model_maker/python/core/data:cache_files", + ], ) py_test( @@ -95,6 +105,7 @@ py_test( ":dataset", ":model_spec", ":preprocessor", + "//mediapipe/model_maker/python/core/data:cache_files", ], ) @@ -102,9 +113,9 @@ py_library( name = "text_classifier_options", srcs = ["text_classifier_options.py"], deps = [ + ":hyperparameters", ":model_options", ":model_spec", - "//mediapipe/model_maker/python/core:hyperparameters", ], ) @@ -113,13 +124,16 @@ py_library( srcs = ["text_classifier.py"], deps = [ ":dataset", + ":hyperparameters", ":model_options", ":model_spec", ":preprocessor", ":text_classifier_options", - "//mediapipe/model_maker/python/core:hyperparameters", "//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/tasks:classifier", + "//mediapipe/model_maker/python/core/utils:hub_loader", + "//mediapipe/model_maker/python/core/utils:loss_functions", + "//mediapipe/model_maker/python/core/utils:metrics", "//mediapipe/model_maker/python/core/utils:model_util", "//mediapipe/model_maker/python/core/utils:quantization", "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", @@ -142,6 +156,7 @@ py_test( ], deps = [ ":text_classifier_import", + "//mediapipe/model_maker/python/core/utils:loss_functions", "//mediapipe/tasks/python/test:test_utils", ], ) diff --git a/mediapipe/model_maker/python/text/text_classifier/__init__.py b/mediapipe/model_maker/python/text/text_classifier/__init__.py index 697461969..7eb0f9259 100644 --- a/mediapipe/model_maker/python/text/text_classifier/__init__.py +++ b/mediapipe/model_maker/python/text/text_classifier/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,19 +13,23 @@ # limitations under the License. """MediaPipe Public Python API for Text Classifier.""" -from mediapipe.model_maker.python.core import hyperparameters from mediapipe.model_maker.python.text.text_classifier import dataset +from mediapipe.model_maker.python.text.text_classifier import hyperparameters from mediapipe.model_maker.python.text.text_classifier import model_options from mediapipe.model_maker.python.text.text_classifier import model_spec from mediapipe.model_maker.python.text.text_classifier import text_classifier from mediapipe.model_maker.python.text.text_classifier import text_classifier_options -HParams = hyperparameters.BaseHParams + +AverageWordEmbeddingHParams = hyperparameters.AverageWordEmbeddingHParams +AverageWordEmbeddingModelOptions = ( + model_options.AverageWordEmbeddingModelOptions +) +BertOptimizer = hyperparameters.BertOptimizer +BertHParams = hyperparameters.BertHParams +BertModelOptions = model_options.BertModelOptions CSVParams = dataset.CSVParameters Dataset = dataset.Dataset -AverageWordEmbeddingModelOptions = ( - model_options.AverageWordEmbeddingModelOptions) -BertModelOptions = model_options.BertModelOptions SupportedModels = model_spec.SupportedModels TextClassifier = text_classifier.TextClassifier TextClassifierOptions = text_classifier_options.TextClassifierOptions diff --git a/mediapipe/model_maker/python/text/text_classifier/dataset.py b/mediapipe/model_maker/python/text/text_classifier/dataset.py index 3679b67ae..1f8798df7 100644 --- a/mediapipe/model_maker/python/text/text_classifier/dataset.py +++ b/mediapipe/model_maker/python/text/text_classifier/dataset.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,11 +15,15 @@ import csv import dataclasses +import hashlib +import os import random +import tempfile +from typing import List, Optional, Sequence -from typing import Optional, Sequence import tensorflow as tf +from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib from mediapipe.model_maker.python.core.data import classification_dataset @@ -46,21 +50,49 @@ class CSVParameters: class Dataset(classification_dataset.ClassificationDataset): """Dataset library for text classifier.""" + def __init__( + self, + dataset: tf.data.Dataset, + label_names: List[str], + tfrecord_cache_files: Optional[cache_files_lib.TFRecordCacheFiles] = None, + size: Optional[int] = None, + ): + super().__init__(dataset, label_names, size) + if not tfrecord_cache_files: + tfrecord_cache_files = cache_files_lib.TFRecordCacheFiles( + cache_prefix_filename="tfrecord", num_shards=1 + ) + self.tfrecord_cache_files = tfrecord_cache_files + @classmethod - def from_csv(cls, - filename: str, - csv_params: CSVParameters, - shuffle: bool = True) -> "Dataset": + def from_csv( + cls, + filename: str, + csv_params: CSVParameters, + shuffle: bool = True, + cache_dir: Optional[str] = None, + num_shards: int = 1, + ) -> "Dataset": """Loads text with labels from a CSV file. Args: filename: Name of the CSV file. csv_params: Parameters used for reading the CSV file. shuffle: If True, randomly shuffle the data. + cache_dir: Optional parameter to specify where to store the preprocessed + dataset. Only used for BERT models. + num_shards: Optional parameter for num shards of the preprocessed dataset. + Note that using more than 1 shard will reorder the dataset. Only used + for BERT models. Returns: Dataset containing (text, label) pairs and other related info. """ + if cache_dir is None: + cache_dir = tempfile.mkdtemp() + # calculate hash for cache based off of files + hasher = hashlib.md5() + hasher.update(os.path.basename(filename).encode("utf-8")) with tf.io.gfile.GFile(filename, "r") as f: reader = csv.DictReader( f, @@ -69,6 +101,9 @@ class Dataset(classification_dataset.ClassificationDataset): quotechar=csv_params.quotechar) lines = list(reader) + for line in lines: + hasher.update(str(line).encode("utf-8")) + if shuffle: random.shuffle(lines) @@ -81,8 +116,18 @@ class Dataset(classification_dataset.ClassificationDataset): index_by_label[line[csv_params.label_column]] for line in lines ] label_index_ds = tf.data.Dataset.from_tensor_slices( - tf.cast(label_indices, tf.int64)) + tf.cast(label_indices, tf.int64) + ) text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds)) + hasher.update(str(num_shards).encode("utf-8")) + cache_prefix_filename = hasher.hexdigest() + tfrecord_cache_files = cache_files_lib.TFRecordCacheFiles( + cache_prefix_filename, cache_dir, num_shards + ) return Dataset( - dataset=text_label_ds, size=len(texts), label_names=label_names) + dataset=text_label_ds, + label_names=label_names, + tfrecord_cache_files=tfrecord_cache_files, + size=len(texts), + ) diff --git a/mediapipe/model_maker/python/text/text_classifier/dataset_test.py b/mediapipe/model_maker/python/text/text_classifier/dataset_test.py index ec9e8fa2d..2fa90b860 100644 --- a/mediapipe/model_maker/python/text/text_classifier/dataset_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/dataset_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -53,7 +53,7 @@ class DatasetTest(tf.test.TestCase): def test_split(self): ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd']) - data = dataset.Dataset(ds, 4, ['pos', 'neg']) + data = dataset.Dataset(ds, ['pos', 'neg'], size=4) train_data, test_data = data.split(0.5) expected_train_data = [b'good', b'bad'] expected_test_data = [b'neutral', b'odd'] diff --git a/mediapipe/model_maker/python/text/text_classifier/hyperparameters.py b/mediapipe/model_maker/python/text/text_classifier/hyperparameters.py new file mode 100644 index 000000000..71470edb3 --- /dev/null +++ b/mediapipe/model_maker/python/text/text_classifier/hyperparameters.py @@ -0,0 +1,72 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Hyperparameters for training object detection models.""" + +import dataclasses +import enum +from typing import Sequence, Union + +from mediapipe.model_maker.python.core import hyperparameters as hp + + +@dataclasses.dataclass +class AverageWordEmbeddingHParams(hp.BaseHParams): + """The hyperparameters for an AverageWordEmbeddingClassifier.""" + + +@enum.unique +class BertOptimizer(enum.Enum): + """Supported Optimizers for Bert Text Classifier.""" + + ADAMW = "adamw" + LAMB = "lamb" + + +@dataclasses.dataclass +class BertHParams(hp.BaseHParams): + """The hyperparameters for a Bert Classifier. + + Attributes: + learning_rate: Learning rate to use for gradient descent training. + end_learning_rate: End learning rate for linear decay. Defaults to 0. + batch_size: Batch size for training. Defaults to 48. + epochs: Number of training iterations over the dataset. Defaults to 2. + optimizer: Optimizer to use for training. Supported values are defined in + BertOptimizer enum: ADAMW and LAMB. + weight_decay: Weight decay of the optimizer. Defaults to 0.01. + desired_precisions: If specified, adds a RecallAtPrecision metric per + desired_precisions[i] entry which tracks the recall given the constraint + on precision. Only supported for binary classification. + desired_recalls: If specified, adds a PrecisionAtRecall metric per + desired_recalls[i] entry which tracks the precision given the constraint + on recall. Only supported for binary classification. + gamma: Gamma parameter for focal loss. To use cross entropy loss, set this + value to 0. Defaults to 2.0. + """ + + learning_rate: float = 3e-5 + end_learning_rate: float = 0.0 + + batch_size: int = 48 + epochs: int = 2 + optimizer: BertOptimizer = BertOptimizer.ADAMW + weight_decay: float = 0.01 + + desired_precisions: Sequence[float] = dataclasses.field(default_factory=list) + desired_recalls: Sequence[float] = dataclasses.field(default_factory=list) + + gamma: float = 2.0 + + +HParams = Union[BertHParams, AverageWordEmbeddingHParams] diff --git a/mediapipe/model_maker/python/text/text_classifier/model_options.py b/mediapipe/model_maker/python/text/text_classifier/model_options.py index a3d94bdf7..a2f45e145 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_options.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_options.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec.py b/mediapipe/model_maker/python/text/text_classifier/model_spec.py index d999f6867..01d1432cb 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,18 +17,14 @@ import dataclasses import enum import functools -from mediapipe.model_maker.python.core import hyperparameters as hp from mediapipe.model_maker.python.core.utils import file_util from mediapipe.model_maker.python.text.core import bert_model_spec +from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp from mediapipe.model_maker.python.text.text_classifier import model_options as mo -# BERT-based text classifier spec inherited from BertModelSpec -BertClassifierSpec = bert_model_spec.BertModelSpec -MOBILEBERT_TINY_FILES = file_util.DownloadedFiles( - 'text_classifier/mobilebert_tiny', - 'https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny.tar.gz', - is_folder=True, +MOBILEBERT_FILES = ( + 'https://tfhub.dev/google/mobilebert/uncased_L-24_H-128_B-512_A-4_F-4_OPT/1' ) @@ -43,28 +39,38 @@ class AverageWordEmbeddingClassifierSpec: """ # `learning_rate` is unused for the average word embedding model - hparams: hp.BaseHParams = hp.BaseHParams( - epochs=10, batch_size=32, learning_rate=0) - model_options: mo.AverageWordEmbeddingModelOptions = ( - mo.AverageWordEmbeddingModelOptions()) + hparams: hp.AverageWordEmbeddingHParams = dataclasses.field( + default_factory=lambda: hp.AverageWordEmbeddingHParams( + epochs=10, batch_size=32, learning_rate=0 + ) + ) + model_options: mo.AverageWordEmbeddingModelOptions = dataclasses.field( + default_factory=mo.AverageWordEmbeddingModelOptions + ) name: str = 'AverageWordEmbedding' - average_word_embedding_classifier_spec = functools.partial( AverageWordEmbeddingClassifierSpec) + +@dataclasses.dataclass +class BertClassifierSpec(bert_model_spec.BertModelSpec): + """Specification for a Bert classifier model. + + Only overrides the hparams attribute since the rest of the attributes are + inherited from the BertModelSpec. + """ + + hparams: hp.BertHParams = dataclasses.field(default_factory=hp.BertHParams) + mobilebert_classifier_spec = functools.partial( BertClassifierSpec, - downloaded_files=MOBILEBERT_TINY_FILES, - hparams=hp.BaseHParams( + files=MOBILEBERT_FILES, + hparams=hp.BertHParams( epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' ), - name='MobileBert', - tflite_input_name={ - 'ids': 'serving_default_input_1:0', - 'mask': 'serving_default_input_3:0', - 'segment_ids': 'serving_default_input_2:0', - }, + name='MobileBERT', + is_tf2=False, ) diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py index c2d96bac4..d1e578b81 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ from unittest import mock as unittest_mock import tensorflow as tf -from mediapipe.model_maker.python.core import hyperparameters as hp +from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options from mediapipe.model_maker.python.text.text_classifier import model_spec as ms @@ -42,26 +42,30 @@ class ModelSpecTest(tf.test.TestCase): def test_predefined_bert_spec(self): model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value() self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec) - self.assertEqual(model_spec_obj.name, 'MobileBert') - self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path())) + self.assertEqual(model_spec_obj.name, 'MobileBERT') + self.assertTrue(model_spec_obj.files) self.assertTrue(model_spec_obj.do_lower_case) self.assertEqual( - model_spec_obj.tflite_input_name, { - 'ids': 'serving_default_input_1:0', - 'mask': 'serving_default_input_3:0', - 'segment_ids': 'serving_default_input_2:0' - }) + model_spec_obj.tflite_input_name, + { + 'ids': 'serving_default_input_word_ids:0', + 'mask': 'serving_default_input_mask:0', + 'segment_ids': 'serving_default_input_type_ids:0', + }, + ) self.assertEqual( model_spec_obj.model_options, classifier_model_options.BertModelOptions( seq_len=128, do_fine_tuning=True, dropout_rate=0.1)) self.assertEqual( model_spec_obj.hparams, - hp.BaseHParams( + hp.BertHParams( epochs=3, batch_size=48, learning_rate=3e-5, - distribution_strategy='off')) + distribution_strategy='off', + ), + ) def test_predefined_average_word_embedding_spec(self): model_spec_obj = ( @@ -78,15 +82,17 @@ class ModelSpecTest(tf.test.TestCase): dropout_rate=0.2)) self.assertEqual( model_spec_obj.hparams, - hp.BaseHParams( + hp.AverageWordEmbeddingHParams( epochs=10, batch_size=32, learning_rate=0, steps_per_epoch=None, shuffle=False, distribution_strategy='off', - num_gpus=-1, - tpu='')) + num_gpus=0, + tpu='', + ), + ) def test_custom_bert_spec(self): custom_bert_classifier_options = ( @@ -99,7 +105,7 @@ class ModelSpecTest(tf.test.TestCase): custom_bert_classifier_options) def test_custom_average_word_embedding_spec(self): - custom_hparams = hp.BaseHParams( + custom_hparams = hp.AverageWordEmbeddingHParams( learning_rate=0.4, batch_size=64, epochs=10, @@ -108,7 +114,8 @@ class ModelSpecTest(tf.test.TestCase): export_dir='foo/bar', distribution_strategy='mirrored', num_gpus=3, - tpu='tpu/address') + tpu='tpu/address', + ) custom_average_word_embedding_model_options = ( classifier_model_options.AverageWordEmbeddingModelOptions( seq_len=512, diff --git a/mediapipe/model_maker/python/text/text_classifier/preprocessor.py b/mediapipe/model_maker/python/text/text_classifier/preprocessor.py index 0a48f459c..68a5df2fd 100644 --- a/mediapipe/model_maker/python/text/text_classifier/preprocessor.py +++ b/mediapipe/model_maker/python/text/text_classifier/preprocessor.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,16 +15,16 @@ """Preprocessors for text classification.""" import collections +import hashlib import os import re -import tempfile from typing import Mapping, Sequence, Tuple, Union import tensorflow as tf import tensorflow_hub +from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds -from official.nlp.data import classifier_data_lib from official.nlp.tools import tokenization @@ -75,19 +75,20 @@ def _decode_record( return bert_features, example["label_ids"] -def _single_file_dataset( - input_file: str, name_to_features: Mapping[str, tf.io.FixedLenFeature] +def _tfrecord_dataset( + tfrecord_files: Sequence[str], + name_to_features: Mapping[str, tf.io.FixedLenFeature], ) -> tf.data.TFRecordDataset: """Creates a single-file dataset to be passed for BERT custom training. Args: - input_file: Filepath for the dataset. + tfrecord_files: Filepaths for the dataset. name_to_features: Maps record keys to feature types. Returns: Dataset containing BERT model input features and labels. """ - d = tf.data.TFRecordDataset(input_file) + d = tf.data.TFRecordDataset(tfrecord_files) d = d.map( lambda record: _decode_record(record, name_to_features), num_parallel_calls=tf.data.AUTOTUNE) @@ -221,15 +222,23 @@ class BertClassifierPreprocessor: seq_len: Length of the input sequence to the model. vocab_file: File containing the BERT vocab. tokenizer: BERT tokenizer. + model_name: Name of the model provided by the model_spec. Used to associate + cached files with specific Bert model vocab. """ - def __init__(self, seq_len: int, do_lower_case: bool, uri: str): + def __init__( + self, seq_len: int, do_lower_case: bool, uri: str, model_name: str + ): self._seq_len = seq_len # Vocab filepath is tied to the BERT module's URI. self._vocab_file = os.path.join( - tensorflow_hub.resolve(uri), "assets", "vocab.txt") - self._tokenizer = tokenization.FullTokenizer(self._vocab_file, - do_lower_case) + tensorflow_hub.resolve(uri), "assets", "vocab.txt" + ) + self._do_lower_case = do_lower_case + self._tokenizer = tokenization.FullTokenizer( + self._vocab_file, self._do_lower_case + ) + self._model_name = model_name def _get_name_to_features(self): """Gets the dictionary mapping record keys to feature types.""" @@ -244,8 +253,62 @@ class BertClassifierPreprocessor: """Returns the vocab file of the BertClassifierPreprocessor.""" return self._vocab_file + def _get_tfrecord_cache_files( + self, ds_cache_files + ) -> cache_files_lib.TFRecordCacheFiles: + """Helper to regenerate cache prefix filename using preprocessor info. + + We need to update the dataset cache_prefix cache because the actual cached + dataset depends on the preprocessor parameters such as model_name, seq_len, + and do_lower_case in addition to the raw dataset parameters which is already + included in the ds_cache_files.cache_prefix_filename + + Specifically, the new cache_prefix_filename used by the preprocessor will + be a hash generated from the following: + 1. cache_prefix_filename of the initial raw dataset + 2. model_name + 3. seq_len + 4. do_lower_case + + Args: + ds_cache_files: TFRecordCacheFiles from the original raw dataset object + + Returns: + A new TFRecordCacheFiles object which incorporates the preprocessor + parameters. + """ + hasher = hashlib.md5() + hasher.update(ds_cache_files.cache_prefix_filename.encode("utf-8")) + hasher.update(self._model_name.encode("utf-8")) + hasher.update(str(self._seq_len).encode("utf-8")) + hasher.update(str(self._do_lower_case).encode("utf-8")) + cache_prefix_filename = hasher.hexdigest() + return cache_files_lib.TFRecordCacheFiles( + cache_prefix_filename, + ds_cache_files.cache_dir, + ds_cache_files.num_shards, + ) + + def _process_bert_features(self, text: str) -> Mapping[str, Sequence[int]]: + tokens = self._tokenizer.tokenize(text) + tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP] + tokens.insert(0, "[CLS]") + tokens.append("[SEP]") + input_ids = self._tokenizer.convert_tokens_to_ids(tokens) + input_mask = [1] * len(input_ids) + while len(input_ids) < self._seq_len: + input_ids.append(0) + input_mask.append(0) + segment_ids = [0] * self._seq_len + return { + "input_ids": input_ids, + "input_mask": input_mask, + "segment_ids": segment_ids, + } + def preprocess( - self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset: + self, dataset: text_classifier_ds.Dataset + ) -> text_classifier_ds.Dataset: """Preprocesses data into input for a BERT-based classifier. Args: @@ -254,32 +317,54 @@ class BertClassifierPreprocessor: Returns: Dataset containing (bert_features, label) data. """ - examples = [] - for index, (text, label) in enumerate(dataset.gen_tf_dataset()): - _validate_text_and_label(text, label) - examples.append( - classifier_data_lib.InputExample( - guid=str(index), - text_a=text.numpy()[0].decode("utf-8"), - text_b=None, - # InputExample expects the label name rather than the int ID - label=dataset.label_names[label.numpy()[0]])) + ds_cache_files = dataset.tfrecord_cache_files + # Get new tfrecord_cache_files by including preprocessor information. + tfrecord_cache_files = self._get_tfrecord_cache_files(ds_cache_files) + if not tfrecord_cache_files.is_cached(): + print(f"Writing new cache files to {tfrecord_cache_files.cache_prefix}") + writers = tfrecord_cache_files.get_writers() + size = 0 + for index, (text, label) in enumerate(dataset.gen_tf_dataset()): + _validate_text_and_label(text, label) + feature = self._process_bert_features(text.numpy()[0].decode("utf-8")) + def create_int_feature(values): + f = tf.train.Feature( + int64_list=tf.train.Int64List(value=list(values)) + ) + return f - tfrecord_file = os.path.join(tempfile.mkdtemp(), "bert_features.tfrecord") - classifier_data_lib.file_based_convert_examples_to_features( - examples=examples, - label_list=dataset.label_names, - max_seq_length=self._seq_len, - tokenizer=self._tokenizer, - output_file=tfrecord_file) - preprocessed_ds = _single_file_dataset(tfrecord_file, - self._get_name_to_features()) + features = collections.OrderedDict() + features["input_ids"] = create_int_feature(feature["input_ids"]) + features["input_mask"] = create_int_feature(feature["input_mask"]) + features["segment_ids"] = create_int_feature(feature["segment_ids"]) + features["label_ids"] = create_int_feature([label.numpy()[0]]) + tf_example = tf.train.Example( + features=tf.train.Features(feature=features) + ) + writers[index % len(writers)].write(tf_example.SerializeToString()) + size = index + 1 + for writer in writers: + writer.close() + metadata = {"size": size, "label_names": dataset.label_names} + tfrecord_cache_files.save_metadata(metadata) + else: + print( + f"Using existing cache files at {tfrecord_cache_files.cache_prefix}" + ) + metadata = tfrecord_cache_files.load_metadata() + size = metadata["size"] + label_names = metadata["label_names"] + preprocessed_ds = _tfrecord_dataset( + tfrecord_cache_files.tfrecord_files, self._get_name_to_features() + ) return text_classifier_ds.Dataset( dataset=preprocessed_ds, - size=dataset.size, - label_names=dataset.label_names) + size=size, + label_names=label_names, + tfrecord_cache_files=tfrecord_cache_files, + ) -TextClassifierPreprocessor = ( - Union[BertClassifierPreprocessor, - AverageWordEmbeddingClassifierPreprocessor]) +TextClassifierPreprocessor = Union[ + BertClassifierPreprocessor, AverageWordEmbeddingClassifierPreprocessor +] diff --git a/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py b/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py index 2ddc4aea9..ff9015498 100644 --- a/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,14 +13,17 @@ # limitations under the License. import csv +import io import os import tempfile from unittest import mock as unittest_mock +import mock import numpy as np import numpy.testing as npt import tensorflow as tf +from mediapipe.model_maker.python.core.data import cache_files from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds from mediapipe.model_maker.python.text.text_classifier import model_spec from mediapipe.model_maker.python.text.text_classifier import preprocessor @@ -88,7 +91,8 @@ class PreprocessorTest(tf.test.TestCase): bert_preprocessor = preprocessor.BertClassifierPreprocessor( seq_len=5, do_lower_case=bert_spec.do_lower_case, - uri=bert_spec.downloaded_files.get_path(), + uri=bert_spec.get_path(), + model_name=bert_spec.name, ) preprocessed_dataset = bert_preprocessor.preprocess(dataset) labels = [] @@ -97,18 +101,87 @@ class PreprocessorTest(tf.test.TestCase): self.assertEqual(label.shape, [1]) labels.append(label.numpy()[0]) self.assertSameElements( - features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids']) + features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids'] + ) for feature in features.values(): self.assertEqual(feature.shape, [1, 5]) input_masks.append(features['input_mask'].numpy()[0]) - npt.assert_array_equal(features['input_type_ids'].numpy()[0], - [0, 0, 0, 0, 0]) + npt.assert_array_equal( + features['input_type_ids'].numpy()[0], [0, 0, 0, 0, 0] + ) npt.assert_array_equal( - np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])) + np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]]) + ) self.assertEqual(labels, [1, 0]) + def test_bert_preprocessor_cache(self): + csv_file = self._get_csv_file() + dataset = text_classifier_ds.Dataset.from_csv( + filename=csv_file, + csv_params=self.CSV_PARAMS_, + cache_dir=self.get_temp_dir(), + ) + bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() + bert_preprocessor = preprocessor.BertClassifierPreprocessor( + seq_len=5, + do_lower_case=bert_spec.do_lower_case, + uri=bert_spec.get_path(), + model_name=bert_spec.name, + ) + ds_cache_files = dataset.tfrecord_cache_files + preprocessed_cache_files = bert_preprocessor._get_tfrecord_cache_files( + ds_cache_files + ) + self.assertFalse(preprocessed_cache_files.is_cached()) + preprocessed_dataset = bert_preprocessor.preprocess(dataset) + self.assertTrue(preprocessed_cache_files.is_cached()) + self.assertEqual( + preprocessed_dataset.tfrecord_cache_files, preprocessed_cache_files + ) + + # The second time running preprocessor, it should load from cache directly + mock_stdout = io.StringIO() + with mock.patch('sys.stdout', mock_stdout): + _ = bert_preprocessor.preprocess(dataset) + self.assertEqual( + mock_stdout.getvalue(), + 'Using existing cache files at' + f' {preprocessed_cache_files.cache_prefix}\n', + ) + + def _get_new_prefix(self, cf, bert_spec, seq_len, do_lower_case): + bert_preprocessor = preprocessor.BertClassifierPreprocessor( + seq_len=seq_len, + do_lower_case=do_lower_case, + uri=bert_spec.get_path(), + model_name=bert_spec.name, + ) + new_cf = bert_preprocessor._get_tfrecord_cache_files(cf) + return new_cf.cache_prefix_filename + + def test_bert_get_tfrecord_cache_files(self): + # Test to ensure regenerated cache_files have different prefixes + all_cf_prefixes = set() + cf = cache_files.TFRecordCacheFiles( + cache_prefix_filename='cache_prefix', + cache_dir=self.get_temp_dir(), + num_shards=1, + ) + mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() + all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, True)) + all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 10, True)) + all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, False)) + new_cf = cache_files.TFRecordCacheFiles( + cache_prefix_filename='new_cache_prefix', + cache_dir=self.get_temp_dir(), + num_shards=1, + ) + all_cf_prefixes.add(self._get_new_prefix(new_cf, mobilebert_spec, 5, True)) + + # Each item of all_cf_prefixes should be unique. + self.assertLen(all_cf_prefixes, 4) + if __name__ == '__main__': # Load compressed models from tensorflow_hub - os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED' tf.test.main() diff --git a/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD b/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD index a581462cf..027bad7e6 100644 --- a/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/testdata/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json b/mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json index 24214a80d..22fb220fb 100644 --- a/mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json +++ b/mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json @@ -16,8 +16,8 @@ } }, { - "name": "mask", - "description": "Mask with 1 for real tokens and 0 for padding tokens.", + "name": "segment_ids", + "description": "0 for the first sequence, 1 for the second sequence if exists.", "content": { "content_properties_type": "FeatureProperties", "content_properties": { @@ -27,8 +27,8 @@ } }, { - "name": "segment_ids", - "description": "0 for the first sequence, 1 for the second sequence if exists.", + "name": "mask", + "description": "Mask with 1 for real tokens and 0 for padding tokens.", "content": { "content_properties_type": "FeatureProperties", "content_properties": { diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index 3d932ce90..752752230 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,14 +19,18 @@ import tempfile from typing import Any, Optional, Sequence, Tuple import tensorflow as tf +from tensorflow_addons import optimizers as tfa_optimizers import tensorflow_hub as hub -from mediapipe.model_maker.python.core import hyperparameters as hp from mediapipe.model_maker.python.core.data import dataset as ds from mediapipe.model_maker.python.core.tasks import classifier +from mediapipe.model_maker.python.core.utils import hub_loader +from mediapipe.model_maker.python.core.utils import loss_functions +from mediapipe.model_maker.python.core.utils import metrics from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds +from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp from mediapipe.model_maker.python.text.text_classifier import model_options as mo from mediapipe.model_maker.python.text.text_classifier import model_spec as ms from mediapipe.model_maker.python.text.text_classifier import preprocessor @@ -49,27 +53,34 @@ def _validate(options: text_classifier_options.TextClassifierOptions): if options.model_options is None: return - if (isinstance(options.model_options, mo.AverageWordEmbeddingModelOptions) and - (options.supported_model != - ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)): - raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER," - f" got {options.supported_model}") - if (isinstance(options.model_options, mo.BertModelOptions) and - (options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)): + if isinstance( + options.model_options, mo.AverageWordEmbeddingModelOptions + ) and ( + options.supported_model + != ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER + ): raise ValueError( - f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}") + "Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER," + f" got {options.supported_model}" + ) + if isinstance(options.model_options, mo.BertModelOptions) and ( + not isinstance(options.supported_model.value(), ms.BertClassifierSpec) + ): + raise ValueError( + f"Expected a Bert Classifier, got {options.supported_model}" + ) class TextClassifier(classifier.Classifier): """API for creating and training a text classification model.""" - def __init__(self, model_spec: Any, hparams: hp.BaseHParams, - label_names: Sequence[str]): + def __init__( + self, model_spec: Any, label_names: Sequence[str], shuffle: bool + ): super().__init__( - model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle) + model_spec=model_spec, label_names=label_names, shuffle=shuffle + ) self._model_spec = model_spec - self._hparams = hparams - self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None @classmethod @@ -106,29 +117,39 @@ class TextClassifier(classifier.Classifier): if options.hparams is None: options.hparams = options.supported_model.value().hparams - if options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER: - text_classifier = ( - _BertClassifier.create_bert_classifier(train_data, validation_data, - options, - train_data.label_names)) - elif (options.supported_model == - ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER): - text_classifier = ( - _AverageWordEmbeddingClassifier - .create_average_word_embedding_classifier(train_data, validation_data, - options, - train_data.label_names)) + if isinstance(options.supported_model.value(), ms.BertClassifierSpec): + text_classifier = _BertClassifier.create_bert_classifier( + train_data, validation_data, options + ) + elif isinstance( + options.supported_model.value(), ms.AverageWordEmbeddingClassifierSpec + ): + text_classifier = _AverageWordEmbeddingClassifier.create_average_word_embedding_classifier( + train_data, validation_data, options + ) else: raise ValueError(f"Unknown model {options.supported_model}") return text_classifier - def evaluate(self, data: ds.Dataset, batch_size: int = 32) -> Any: + def evaluate( + self, + data: ds.Dataset, + batch_size: int = 32, + desired_precisions: Optional[Sequence[float]] = None, + desired_recalls: Optional[Sequence[float]] = None, + ) -> Any: """Overrides Classifier.evaluate(). Args: data: Evaluation dataset. Must be a TextClassifier Dataset. batch_size: Number of samples per evaluation step. + desired_precisions: If specified, adds a RecallAtPrecision metric per + desired_precisions[i] entry which tracks the recall given the constraint + on precision. Only supported for binary classification. + desired_recalls: If specified, adds a PrecisionAtRecall metric per + desired_recalls[i] entry which tracks the precision given the constraint + on recall. Only supported for binary classification. Returns: The loss value and accuracy. @@ -144,7 +165,28 @@ class TextClassifier(classifier.Classifier): processed_data = self._text_preprocessor.preprocess(data) dataset = processed_data.gen_tf_dataset(batch_size, is_training=False) - return self._model.evaluate(dataset) + + with self._hparams.get_strategy().scope(): + return self._model.evaluate(dataset) + + def save_model( + self, + model_name: str = "saved_model", + ): + """Saves the model in SavedModel format. + + For more information, see https://www.tensorflow.org/guide/saved_model. + + Args: + model_name: Name of the saved model. + """ + tf.io.gfile.makedirs(self._hparams.export_dir) + saved_model_file = os.path.join(self._hparams.export_dir, model_name) + self._model.save( + saved_model_file, + include_optimizer=False, + save_format="tf", + ) def export_model( self, @@ -161,20 +203,23 @@ class TextClassifier(classifier.Classifier): path is {self._hparams.export_dir}/{model_name}. quantization_config: The configuration for model quantization. """ - if not tf.io.gfile.exists(self._hparams.export_dir): - tf.io.gfile.makedirs(self._hparams.export_dir) + tf.io.gfile.makedirs(self._hparams.export_dir) tflite_file = os.path.join(self._hparams.export_dir, model_name) metadata_file = os.path.join(self._hparams.export_dir, "metadata.json") - tflite_model = model_util.convert_to_tflite( - model=self._model, quantization_config=quantization_config) + self.save_model(model_name="saved_model") + saved_model_file = os.path.join(self._hparams.export_dir, "saved_model") + + tflite_model = model_util.convert_to_tflite_from_file( + saved_model_file, quantization_config=quantization_config + ) vocab_filepath = os.path.join(tempfile.mkdtemp(), "vocab.txt") self._save_vocab(vocab_filepath) writer = self._get_metadata_writer(tflite_model, vocab_filepath) tflite_model_with_metadata, metadata_json = writer.populate() model_util.save_tflite(tflite_model_with_metadata, tflite_file) - with open(metadata_file, "w") as f: + with tf.io.gfile.GFile(metadata_file, "w") as f: f.write(metadata_json) @abc.abstractmethod @@ -191,28 +236,39 @@ class _AverageWordEmbeddingClassifier(TextClassifier): _DELIM_REGEX_PATTERN = r"[^\w\']+" - def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec, - model_options: mo.AverageWordEmbeddingModelOptions, - hparams: hp.BaseHParams, label_names: Sequence[str]): - super().__init__(model_spec, hparams, label_names) + def __init__( + self, + model_spec: ms.AverageWordEmbeddingClassifierSpec, + model_options: mo.AverageWordEmbeddingModelOptions, + hparams: hp.AverageWordEmbeddingHParams, + label_names: Sequence[str], + ): + super().__init__(model_spec, label_names, hparams.shuffle) self._model_options = model_options + self._hparams = hparams + self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) self._loss_function = "sparse_categorical_crossentropy" - self._metric_function = "accuracy" + self._metric_functions = [ + "accuracy", + metrics.SparsePrecision(name="precision", dtype=tf.float32), + metrics.SparseRecall(name="recall", dtype=tf.float32), + ] self._text_preprocessor: ( preprocessor.AverageWordEmbeddingClassifierPreprocessor) = None @classmethod def create_average_word_embedding_classifier( - cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset, + cls, + train_data: text_ds.Dataset, + validation_data: text_ds.Dataset, options: text_classifier_options.TextClassifierOptions, - label_names: Sequence[str]) -> "_AverageWordEmbeddingClassifier": + ) -> "_AverageWordEmbeddingClassifier": """Creates, trains, and returns an Average Word Embedding classifier. Args: train_data: Training data. validation_data: Validation data. options: Options for creating and training the text classifier. - label_names: Label names used in the data. Returns: An Average Word Embedding classifier. @@ -306,28 +362,37 @@ class _BertClassifier(TextClassifier): _INITIALIZER_RANGE = 0.02 - def __init__(self, model_spec: ms.BertClassifierSpec, - model_options: mo.BertModelOptions, hparams: hp.BaseHParams, - label_names: Sequence[str]): - super().__init__(model_spec, hparams, label_names) + def __init__( + self, + model_spec: ms.BertClassifierSpec, + model_options: mo.BertModelOptions, + hparams: hp.BertHParams, + label_names: Sequence[str], + ): + super().__init__(model_spec, label_names, hparams.shuffle) + self._hparams = hparams + self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) self._model_options = model_options - self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy() - self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy( - "test_accuracy", dtype=tf.float32) self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None + with self._hparams.get_strategy().scope(): + self._loss_function = loss_functions.SparseFocalLoss( + self._hparams.gamma, self._num_classes + ) + self._metric_functions = self._create_metrics() @classmethod def create_bert_classifier( - cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset, + cls, + train_data: text_ds.Dataset, + validation_data: text_ds.Dataset, options: text_classifier_options.TextClassifierOptions, - label_names: Sequence[str]) -> "_BertClassifier": + ) -> "_BertClassifier": """Creates, trains, and returns a BERT-based classifier. Args: train_data: Training data. validation_data: Validation data. options: Options for creating and training the text classifier. - label_names: Label names used in the data. Returns: A BERT-based classifier. @@ -350,8 +415,9 @@ class _BertClassifier(TextClassifier): """ (processed_train_data, processed_validation_data) = ( self._load_and_run_preprocessor(train_data, validation_data)) - self._create_model() - self._create_optimizer(processed_train_data) + with self._hparams.get_strategy().scope(): + self._create_model() + self._create_optimizer(processed_train_data) self._train_model(processed_train_data, processed_validation_data) def _load_and_run_preprocessor( @@ -369,10 +435,60 @@ class _BertClassifier(TextClassifier): self._text_preprocessor = preprocessor.BertClassifierPreprocessor( seq_len=self._model_options.seq_len, do_lower_case=self._model_spec.do_lower_case, - uri=self._model_spec.downloaded_files.get_path(), + uri=self._model_spec.get_path(), + model_name=self._model_spec.name, ) - return (self._text_preprocessor.preprocess(train_data), - self._text_preprocessor.preprocess(validation_data)) + return ( + self._text_preprocessor.preprocess(train_data), + self._text_preprocessor.preprocess(validation_data), + ) + + def _create_metrics(self): + """Creates metrics for training and evaluation. + + The default metrics are accuracy, precision, and recall. + + For binary classification tasks only (num_classes=2): + Users can configure PrecisionAtRecall and RecallAtPrecision metrics using + the desired_presisions and desired_recalls fields in BertHParams. + + Returns: + A list of tf.keras.Metric subclasses which can be used with model.compile + """ + metric_functions = [ + tf.keras.metrics.SparseCategoricalAccuracy( + "accuracy", dtype=tf.float32 + ), + metrics.SparsePrecision(name="precision", dtype=tf.float32), + metrics.SparseRecall(name="recall", dtype=tf.float32), + ] + if self._num_classes == 2: + if self._hparams.desired_precisions: + for desired_precision in self._hparams.desired_precisions: + metric_functions.append( + metrics.BinarySparseRecallAtPrecision( + desired_precision, + name=f"recall_at_precision_{desired_precision}", + num_thresholds=1000, + ) + ) + if self._hparams.desired_recalls: + for desired_recall in self._hparams.desired_recalls: + metric_functions.append( + metrics.BinarySparseRecallAtPrecision( + desired_recall, + name=f"precision_at_recall_{desired_recall}", + num_thresholds=1000, + ) + ) + else: + if self._hparams.desired_precisions or self._hparams.desired_recalls: + raise ValueError( + "desired_recalls and desired_precisions parameters are binary" + " metrics and not supported for num_classes > 2. Found" + f" num_classes: {self._num_classes}" + ) + return metric_functions def _create_model(self): """Creates a BERT-based classifier model. @@ -382,30 +498,58 @@ class _BertClassifier(TextClassifier): """ encoder_inputs = dict( input_word_ids=tf.keras.layers.Input( - shape=(self._model_options.seq_len,), dtype=tf.int32), + shape=(self._model_options.seq_len,), + dtype=tf.int32, + name="input_word_ids", + ), input_mask=tf.keras.layers.Input( - shape=(self._model_options.seq_len,), dtype=tf.int32), + shape=(self._model_options.seq_len,), + dtype=tf.int32, + name="input_mask", + ), input_type_ids=tf.keras.layers.Input( - shape=(self._model_options.seq_len,), dtype=tf.int32), + shape=(self._model_options.seq_len,), + dtype=tf.int32, + name="input_type_ids", + ), ) - encoder = hub.KerasLayer( - self._model_spec.downloaded_files.get_path(), - trainable=self._model_options.do_fine_tuning, - ) - encoder_outputs = encoder(encoder_inputs) - pooled_output = encoder_outputs["pooled_output"] + if self._model_spec.is_tf2: + encoder = hub.KerasLayer( + self._model_spec.get_path(), + trainable=self._model_options.do_fine_tuning, + load_options=tf.saved_model.LoadOptions( + experimental_io_device="/job:localhost" + ), + ) + encoder_outputs = encoder(encoder_inputs) + pooled_output = encoder_outputs["pooled_output"] + else: + renamed_inputs = dict( + input_ids=encoder_inputs["input_word_ids"], + input_mask=encoder_inputs["input_mask"], + segment_ids=encoder_inputs["input_type_ids"], + ) + encoder = hub_loader.HubKerasLayerV1V2( + self._model_spec.get_path(), + signature="tokens", + output_key="pooled_output", + trainable=self._model_options.do_fine_tuning, + ) + pooled_output = encoder(renamed_inputs) output = tf.keras.layers.Dropout(rate=self._model_options.dropout_rate)( - pooled_output) + pooled_output + ) initializer = tf.keras.initializers.TruncatedNormal( - stddev=self._INITIALIZER_RANGE) + stddev=self._INITIALIZER_RANGE + ) output = tf.keras.layers.Dense( self._num_classes, kernel_initializer=initializer, name="output", activation="softmax", - dtype=tf.float32)( - output) + dtype=tf.float32, + )(output) self._model = tf.keras.Model(inputs=encoder_inputs, outputs=output) def _create_optimizer(self, train_data: text_ds.Dataset): @@ -428,18 +572,38 @@ class _BertClassifier(TextClassifier): lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=initial_lr, decay_steps=total_steps, - end_learning_rate=0.0, - power=1.0) + end_learning_rate=self._hparams.end_learning_rate, + power=1.0, + ) if warmup_steps: lr_schedule = model_util.WarmUp( initial_learning_rate=initial_lr, decay_schedule_fn=lr_schedule, - warmup_steps=warmup_steps) - - self._optimizer = tf.keras.optimizers.experimental.AdamW( - lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0) - self._optimizer.exclude_from_weight_decay( - var_names=["LayerNorm", "layer_norm", "bias"]) + warmup_steps=warmup_steps, + ) + if self._hparams.optimizer == hp.BertOptimizer.ADAMW: + self._optimizer = tf.keras.optimizers.experimental.AdamW( + lr_schedule, + weight_decay=self._hparams.weight_decay, + epsilon=1e-6, + global_clipnorm=1.0, + ) + self._optimizer.exclude_from_weight_decay( + var_names=["LayerNorm", "layer_norm", "bias"] + ) + elif self._hparams.optimizer == hp.BertOptimizer.LAMB: + self._optimizer = tfa_optimizers.LAMB( + lr_schedule, + weight_decay_rate=self._hparams.weight_decay, + epsilon=1e-6, + exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], + global_clipnorm=1.0, + ) + else: + raise ValueError( + "BertHParams.optimizer must be set to ADAM or " + f"LAMB. Got {self._hparams.optimizer}." + ) def _save_vocab(self, vocab_filepath: str): tf.io.gfile.copy( diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_demo.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_demo.py index 08f4c2ad3..b646a15ad 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_demo.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_demo.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -66,14 +66,16 @@ def run(data_dir, quantization_config = None if (supported_model == text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER): - hparams = text_classifier.HParams( - epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir) + hparams = text_classifier.AverageWordEmbeddingHParams( + epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir + ) # Warning: This takes extremely long to run on CPU elif ( supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER): quantization_config = quantization.QuantizationConfig.for_dynamic() - hparams = text_classifier.HParams( - epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir) + hparams = text_classifier.BertHParams( + epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir + ) # Fine-tunes the model. options = text_classifier.TextClassifierOptions( @@ -82,8 +84,8 @@ def run(data_dir, options) # Gets evaluation results. - _, acc = model.evaluate(validation_data) - print('Eval accuracy: %f' % acc) + metrics = model.evaluate(validation_data) + print('Eval accuracy: %f' % metrics[1]) model.export_model(quantization_config=quantization_config) model.export_labels(export_dir=options.hparams.export_dir) diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_options.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_options.py index a02f17347..b61731f16 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_options.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_options.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ import dataclasses from typing import Optional -from mediapipe.model_maker.python.core import hyperparameters as hp +from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp from mediapipe.model_maker.python.text.text_classifier import model_options as mo from mediapipe.model_maker.python.text.text_classifier import model_spec as ms @@ -34,5 +34,5 @@ class TextClassifierOptions: architecture of the `supported_model`. """ supported_model: ms.SupportedModels - hparams: Optional[hp.BaseHParams] = None + hparams: Optional[hp.HParams] = None model_options: Optional[mo.TextClassifierModelOptions] = None diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index 6aa68a284..122182ddd 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,17 +16,17 @@ import csv import filecmp import os import tempfile -import unittest from unittest import mock as unittest_mock +from absl.testing import parameterized import tensorflow as tf +from mediapipe.model_maker.python.core.utils import loss_functions from mediapipe.model_maker.python.text import text_classifier from mediapipe.tasks.python.test import test_utils -@unittest.skip('b/275624089') -class TextClassifierTest(tf.test.TestCase): +class TextClassifierTest(tf.test.TestCase, parameterized.TestCase): _AVERAGE_WORD_EMBEDDING_JSON_FILE = ( test_utils.get_test_data_path('average_word_embedding_metadata.json')) @@ -66,18 +66,20 @@ class TextClassifierTest(tf.test.TestCase): def test_create_and_train_average_word_embedding_model(self): train_data, validation_data = self._get_data() - options = ( - text_classifier.TextClassifierOptions( - supported_model=(text_classifier.SupportedModels - .AVERAGE_WORD_EMBEDDING_CLASSIFIER), - hparams=text_classifier.HParams( - epochs=1, batch_size=1, learning_rate=0))) + options = text_classifier.TextClassifierOptions( + supported_model=( + text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER + ), + hparams=text_classifier.AverageWordEmbeddingHParams( + epochs=1, batch_size=1, learning_rate=0 + ), + ) average_word_embedding_classifier = ( text_classifier.TextClassifier.create(train_data, validation_data, options)) - _, accuracy = average_word_embedding_classifier.evaluate(validation_data) - self.assertGreaterEqual(accuracy, 0.0) + metrics = average_word_embedding_classifier.evaluate(validation_data) + self.assertGreaterEqual(metrics[1], 0.0) # metrics[1] is accuracy # Test export_model average_word_embedding_classifier.export_model() @@ -96,24 +98,36 @@ class TextClassifierTest(tf.test.TestCase): filecmp.cmp( output_metadata_file, self._AVERAGE_WORD_EMBEDDING_JSON_FILE, - shallow=False)) + shallow=False, + ) + ) - def test_create_and_train_bert(self): + @parameterized.named_parameters( + # Skipping mobilebert b/c OSS test timeout/flakiness: b/275624089 + dict( + testcase_name='mobilebert', + supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER, + ), + ) + def test_create_and_train_bert(self, supported_model): train_data, validation_data = self._get_data() options = text_classifier.TextClassifierOptions( - supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER, + supported_model=supported_model, model_options=text_classifier.BertModelOptions( - do_fine_tuning=False, seq_len=2), - hparams=text_classifier.HParams( + do_fine_tuning=False, seq_len=2 + ), + hparams=text_classifier.BertHParams( epochs=1, batch_size=1, learning_rate=3e-5, - distribution_strategy='off')) + distribution_strategy='off', + ), + ) bert_classifier = text_classifier.TextClassifier.create( train_data, validation_data, options) - _, accuracy = bert_classifier.evaluate(validation_data) - self.assertGreaterEqual(accuracy, 0.0) + metrics = bert_classifier.evaluate(validation_data) + self.assertGreaterEqual(metrics[1], 0.0) # metrics[1] is accuracy # Test export_model bert_classifier.export_model() @@ -137,45 +151,93 @@ class TextClassifierTest(tf.test.TestCase): ) def test_label_mismatch(self): - options = ( - text_classifier.TextClassifierOptions( - supported_model=( - text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER))) + options = text_classifier.TextClassifierOptions( + supported_model=(text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER) + ) train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) - train_data = text_classifier.Dataset(train_tf_dataset, 1, ['foo']) + train_data = text_classifier.Dataset(train_tf_dataset, ['foo'], 1) validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) - validation_data = text_classifier.Dataset(validation_tf_dataset, 1, ['bar']) + validation_data = text_classifier.Dataset(validation_tf_dataset, ['bar'], 1) with self.assertRaisesRegex( ValueError, - 'Training data label names .* not equal to validation data label names' + 'Training data label names .* not equal to validation data label names', ): - text_classifier.TextClassifier.create(train_data, validation_data, - options) + text_classifier.TextClassifier.create( + train_data, validation_data, options + ) def test_options_mismatch(self): train_data, validation_data = self._get_data() - avg_options = ( - text_classifier.TextClassifierOptions( - supported_model=( - text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER), - model_options=text_classifier.AverageWordEmbeddingModelOptions())) - with self.assertRaisesRegex( - ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got' - ' SupportedModels.MOBILEBERT_CLASSIFIER'): - text_classifier.TextClassifier.create(train_data, validation_data, - avg_options) + avg_options = text_classifier.TextClassifierOptions( + supported_model=(text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER), + model_options=text_classifier.AverageWordEmbeddingModelOptions(), + ) + with self.assertRaisesWithLiteralMatch( + ValueError, + 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got' + ' SupportedModels.MOBILEBERT_CLASSIFIER', + ): + text_classifier.TextClassifier.create( + train_data, validation_data, avg_options + ) - bert_options = ( - text_classifier.TextClassifierOptions( - supported_model=(text_classifier.SupportedModels - .AVERAGE_WORD_EMBEDDING_CLASSIFIER), - model_options=text_classifier.BertModelOptions())) - with self.assertRaisesRegex( - ValueError, 'Expected MOBILEBERT_CLASSIFIER, got' - ' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'): - text_classifier.TextClassifier.create(train_data, validation_data, - bert_options) + bert_options = text_classifier.TextClassifierOptions( + supported_model=( + text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER + ), + model_options=text_classifier.BertModelOptions(), + ) + with self.assertRaisesWithLiteralMatch( + ValueError, + 'Expected a Bert Classifier, got' + ' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER', + ): + text_classifier.TextClassifier.create( + train_data, validation_data, bert_options + ) + + def test_bert_loss_and_metrics_creation(self): + train_data, validation_data = self._get_data() + supported_model = text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER + hparams = text_classifier.BertHParams( + desired_recalls=[0.2], + desired_precisions=[0.9], + epochs=1, + batch_size=1, + learning_rate=3e-5, + distribution_strategy='off', + gamma=3.5, + ) + options = text_classifier.TextClassifierOptions( + supported_model=supported_model, hparams=hparams + ) + bert_classifier = text_classifier.TextClassifier.create( + train_data, validation_data, options + ) + loss_fn = bert_classifier._loss_function + self.assertIsInstance(loss_fn, loss_functions.SparseFocalLoss) + self.assertEqual(loss_fn._gamma, 3.5) + self.assertEqual(loss_fn._num_classes, 2) + metric_names = [m.name for m in bert_classifier._metric_functions] + expected_metric_names = [ + 'accuracy', + 'recall', + 'precision', + 'precision_at_recall_0.2', + 'recall_at_precision_0.9', + ] + self.assertCountEqual(metric_names, expected_metric_names) + + # Non-binary data + tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) + data = text_classifier.Dataset(tf_dataset, ['foo', 'bar', 'baz'], 1) + with self.assertRaisesWithLiteralMatch( + ValueError, + 'desired_recalls and desired_precisions parameters are binary metrics' + ' and not supported for num_classes > 2. Found num_classes: 3', + ): + text_classifier.TextClassifier.create(data, data, options) if __name__ == '__main__': diff --git a/mediapipe/model_maker/python/vision/BUILD b/mediapipe/model_maker/python/vision/BUILD index 10aef8c33..b7d0d13a6 100644 --- a/mediapipe/model_maker/python/vision/BUILD +++ b/mediapipe/model_maker/python/vision/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -package( - default_visibility = ["//mediapipe:__subpackages__"], -) +package(default_visibility = ["//mediapipe:__subpackages__"]) licenses(["notice"]) diff --git a/mediapipe/model_maker/python/vision/__init__.py b/mediapipe/model_maker/python/vision/__init__.py index 7ca2f9216..5b1a4244c 100644 --- a/mediapipe/model_maker/python/vision/__init__.py +++ b/mediapipe/model_maker/python/vision/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/core/BUILD b/mediapipe/model_maker/python/vision/core/BUILD index 6dd547ff1..a5d4b08f8 100644 --- a/mediapipe/model_maker/python/vision/core/BUILD +++ b/mediapipe/model_maker/python/vision/core/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,9 +17,9 @@ licenses(["notice"]) -package( - default_visibility = ["//mediapipe:__subpackages__"], -) +package(default_visibility = [ + "//mediapipe:__subpackages__", +]) py_library( name = "image_preprocessing", diff --git a/mediapipe/model_maker/python/vision/core/__init__.py b/mediapipe/model_maker/python/vision/core/__init__.py index 7ca2f9216..5b1a4244c 100644 --- a/mediapipe/model_maker/python/vision/core/__init__.py +++ b/mediapipe/model_maker/python/vision/core/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/core/image_preprocessing.py b/mediapipe/model_maker/python/vision/core/image_preprocessing.py index 104ccd9ca..84d486347 100644 --- a/mediapipe/model_maker/python/vision/core/image_preprocessing.py +++ b/mediapipe/model_maker/python/vision/core/image_preprocessing.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/core/image_preprocessing_test.py b/mediapipe/model_maker/python/vision/core/image_preprocessing_test.py index 0594b4376..563b19cfc 100644 --- a/mediapipe/model_maker/python/vision/core/image_preprocessing_test.py +++ b/mediapipe/model_maker/python/vision/core/image_preprocessing_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/core/image_utils.py b/mediapipe/model_maker/python/vision/core/image_utils.py index 80d0616e5..0562da44e 100644 --- a/mediapipe/model_maker/python/vision/core/image_utils.py +++ b/mediapipe/model_maker/python/vision/core/image_utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/core/image_utils_test.py b/mediapipe/model_maker/python/vision/core/image_utils_test.py index 84101113c..a89ff20a2 100644 --- a/mediapipe/model_maker/python/vision/core/image_utils_test.py +++ b/mediapipe/model_maker/python/vision/core/image_utils_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/core/test_utils.py b/mediapipe/model_maker/python/vision/core/test_utils.py index 528b2ca7b..8cfe30811 100644 --- a/mediapipe/model_maker/python/vision/core/test_utils.py +++ b/mediapipe/model_maker/python/vision/core/test_utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/face_stylizer/BUILD b/mediapipe/model_maker/python/vision/face_stylizer/BUILD index b5e0399d1..5e4c22454 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/BUILD +++ b/mediapipe/model_maker/python/vision/face_stylizer/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,16 +14,16 @@ # Placeholder for internal Python strict test compatibility macro. # Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python GPU test rule. licenses(["notice"]) package(default_visibility = ["//mediapipe:__subpackages__"]) -filegroup( - name = "testdata", - srcs = glob([ - "testdata/**", - ]), +py_library( + name = "constants", + srcs = ["constants.py"], + deps = ["//mediapipe/model_maker/python/core/utils:file_util"], ) py_library( @@ -37,6 +37,7 @@ py_library( py_library( name = "model_options", srcs = ["model_options.py"], + deps = ["//mediapipe/model_maker/python/core/utils:loss_functions"], ) py_library( @@ -64,19 +65,41 @@ py_library( name = "dataset", srcs = ["dataset.py"], deps = [ + ":constants", "//mediapipe/model_maker/python/core/data:classification_dataset", - "//mediapipe/model_maker/python/vision/core:image_utils", + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/vision:face_aligner", ], ) -py_test( - name = "dataset_test", - srcs = ["dataset_test.py"], - data = [ - ":testdata", - ], +py_library( + name = "face_stylizer", + srcs = ["face_stylizer.py"], deps = [ - ":dataset", - "//mediapipe/tasks/python/test:test_utils", + ":constants", + ":face_stylizer_options", + ":hyperparameters", + ":model_options", + ":model_spec", + "//mediapipe/model_maker/python/core/data:classification_dataset", + "//mediapipe/model_maker/python/core/utils:loss_functions", + "//mediapipe/model_maker/python/core/utils:model_util", + "//mediapipe/model_maker/python/vision/core:image_preprocessing", + "//mediapipe/tasks/python/metadata/metadata_writers:face_stylizer", + ], +) + +py_library( + name = "face_stylizer_import", + srcs = ["__init__.py"], + visibility = ["//visibility:public"], + deps = [ + ":dataset", + ":face_stylizer", + ":face_stylizer_options", + ":hyperparameters", + ":model_options", + ":model_spec", ], ) diff --git a/mediapipe/model_maker/python/vision/face_stylizer/__init__.py b/mediapipe/model_maker/python/vision/face_stylizer/__init__.py index e935c0c76..cc5bd2ae8 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/__init__.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,3 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. """MediaPipe Model Maker Python Public API For Face Stylization.""" + +from mediapipe.model_maker.python.vision.face_stylizer import dataset +from mediapipe.model_maker.python.vision.face_stylizer import face_stylizer +from mediapipe.model_maker.python.vision.face_stylizer import face_stylizer_options +from mediapipe.model_maker.python.vision.face_stylizer import hyperparameters +from mediapipe.model_maker.python.vision.face_stylizer import model_options +from mediapipe.model_maker.python.vision.face_stylizer import model_spec + +FaceStylizer = face_stylizer.FaceStylizer +SupportedModels = model_spec.SupportedModels +ModelOptions = model_options.FaceStylizerModelOptions +HParams = hyperparameters.HParams +Dataset = dataset.Dataset +FaceStylizerOptions = face_stylizer_options.FaceStylizerOptions diff --git a/mediapipe/model_maker/python/vision/face_stylizer/constants.py b/mediapipe/model_maker/python/vision/face_stylizer/constants.py new file mode 100644 index 000000000..ac7675232 --- /dev/null +++ b/mediapipe/model_maker/python/vision/face_stylizer/constants.py @@ -0,0 +1,51 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Face stylizer model constants.""" + +from mediapipe.model_maker.python.core.utils import file_util + +# TODO: Move model files to GCS for downloading. +FACE_STYLIZER_ENCODER_MODEL_FILES = file_util.DownloadedFiles( + 'face_stylizer/encoder', + 'https://storage.googleapis.com/mediapipe-assets/face_stylizer_encoder.tar.gz', + is_folder=True, +) +FACE_STYLIZER_DECODER_MODEL_FILES = file_util.DownloadedFiles( + 'face_stylizer/decoder', + 'https://storage.googleapis.com/mediapipe-assets/face_stylizer_decoder.tar.gz', + is_folder=True, +) +FACE_STYLIZER_MAPPING_MODEL_FILES = file_util.DownloadedFiles( + 'face_stylizer/mapping', + 'https://storage.googleapis.com/mediapipe-assets/face_stylizer_mapping.tar.gz', + is_folder=True, +) +FACE_STYLIZER_DISCRIMINATOR_MODEL_FILES = file_util.DownloadedFiles( + 'face_stylizer/discriminator', + 'https://storage.googleapis.com/mediapipe-assets/face_stylizer_discriminator.tar.gz', + is_folder=True, +) +FACE_STYLIZER_W_FILES = file_util.DownloadedFiles( + 'face_stylizer/w_avg.npy', + 'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy', +) + +FACE_ALIGNER_TASK_FILES = file_util.DownloadedFiles( + 'face_stylizer/face_landmarker_v2.task', + 'https://storage.googleapis.com/mediapipe-assets/face_landmarker_v2.task', + is_folder=False, +) + +# Dimension of the input style vector to the decoder +STYLE_DIM = 512 diff --git a/mediapipe/model_maker/python/vision/face_stylizer/dataset.py b/mediapipe/model_maker/python/vision/face_stylizer/dataset.py index b6c85d6f3..eb324028a 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/dataset.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/dataset.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,13 +13,43 @@ # limitations under the License. """Face stylizer dataset library.""" +from typing import Sequence import logging import os import tensorflow as tf from mediapipe.model_maker.python.core.data import classification_dataset -from mediapipe.model_maker.python.vision.core import image_utils +from mediapipe.model_maker.python.vision.face_stylizer import constants +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.vision import face_aligner + + +def _preprocess_face_dataset( + all_image_paths: Sequence[str], +) -> Sequence[tf.Tensor]: + """Preprocess face image dataset by aligning the face.""" + path = constants.FACE_ALIGNER_TASK_FILES.get_path() + base_options = base_options_module.BaseOptions(model_asset_path=path) + options = face_aligner.FaceAlignerOptions(base_options=base_options) + aligner = face_aligner.FaceAligner.create_from_options(options) + + preprocessed_images = [] + for path in all_image_paths: + tf.compat.v1.logging.info('Preprocess image %s', path) + image = image_module.Image.create_from_file(path) + aligned_image = aligner.align(image) + if aligned_image is None: + raise ValueError( + 'ERROR: Invalid image. No face is detected and aligned. Please make' + ' sure the image has a single face that is facing straightforward and' + ' not significantly rotated.' + ) + aligned_image_tensor = tf.convert_to_tensor(aligned_image.numpy_view()) + preprocessed_images.append(aligned_image_tensor) + + return preprocessed_images # TODO: Change to a unlabeled dataset if it makes sense. @@ -27,72 +57,41 @@ class Dataset(classification_dataset.ClassificationDataset): """Dataset library for face stylizer fine tuning.""" @classmethod - def from_folder( - cls, dirname: str + def from_image( + cls, filename: str ) -> classification_dataset.ClassificationDataset: - """Loads images from the given directory. + """Creates a dataset from single image. - The style image dataset directory is expected to contain one subdirectory - whose name represents the label of the style. There can be one or multiple - images of the same style in that subdirectory. Supported input image formats - include 'jpg', 'jpeg', 'png'. + Supported input image formats include 'jpg', 'jpeg', 'png'. Args: - dirname: Name of the directory containing the image files. + filename: Name of the image file. Returns: - Dataset containing images and labels and other related info. - Raises: - ValueError: if the input data directory is empty. + Dataset containing image and label and other related info. """ - data_root = os.path.abspath(dirname) + file_path = os.path.abspath(filename) + image_filename = os.path.basename(filename) + image_name, ext_name = os.path.splitext(image_filename) - # Assumes the image data of the same label are in the same subdirectory, - # gets image path and label names. - all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*')) - all_image_size = len(all_image_paths) - if all_image_size == 0: - raise ValueError('Invalid input data directory') - if not any( - fname.endswith(('.jpg', '.jpeg', '.png')) for fname in all_image_paths - ): - raise ValueError('No images found under given directory') + if not ext_name.endswith(('.jpg', '.jpeg', '.png')): + raise ValueError('Unsupported image formats: %s' % ext_name) - label_names = sorted( - name - for name in os.listdir(data_root) - if os.path.isdir(os.path.join(data_root, name)) - ) - all_label_size = len(label_names) - index_by_label = dict( - (name, index) for index, name in enumerate(label_names) - ) - # Get the style label from the subdirectory name. - all_image_labels = [ - index_by_label[os.path.basename(os.path.dirname(path))] - for path in all_image_paths - ] + image_data = _preprocess_face_dataset([file_path]) + label_names = [image_name] - path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths) - - image_ds = path_ds.map( - image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE - ) + image_ds = tf.data.Dataset.from_tensor_slices(image_data) # Load label - label_ds = tf.data.Dataset.from_tensor_slices( - tf.cast(all_image_labels, tf.int64) - ) + label_ds = tf.data.Dataset.from_tensor_slices(tf.cast([0], tf.int64)) # Create a dataset of (image, label) pairs image_label_ds = tf.data.Dataset.zip((image_ds, label_ds)) - logging.info( - 'Load images dataset with size: %d, num_label: %d, labels: %s.', - all_image_size, - all_label_size, - ', '.join(label_names), - ) + logging.info('Create dataset for style: %s.', image_name) + return Dataset( - dataset=image_label_ds, size=all_image_size, label_names=label_names + dataset=image_label_ds, + label_names=label_names, + size=1, ) diff --git a/mediapipe/model_maker/python/vision/face_stylizer/dataset_test.py b/mediapipe/model_maker/python/vision/face_stylizer/dataset_test.py index a8af222d4..242140811 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/dataset_test.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/dataset_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import tensorflow as tf +from mediapipe.model_maker.python.vision.core import image_utils from mediapipe.model_maker.python.vision.face_stylizer import dataset from mediapipe.tasks.python.test import test_utils @@ -22,26 +24,24 @@ class DatasetTest(tf.test.TestCase): def setUp(self): super().setUp() - # TODO: Replace the stylize image dataset with licensed images. - self._test_data_dirname = 'testdata' - def test_from_folder(self): - input_data_dir = test_utils.get_test_data_path(self._test_data_dirname) - data = dataset.Dataset.from_folder(dirname=input_data_dir) - self.assertEqual(data.num_classes, 2) - self.assertEqual(data.label_names, ['cartoon', 'sketch']) - self.assertLen(data, 2) + def test_from_image(self): + test_image_file = 'input/style/cartoon/cartoon.jpg' + input_image_path = test_utils.get_test_data_path(test_image_file) + data = dataset.Dataset.from_image(filename=input_image_path) + self.assertEqual(data.num_classes, 1) + self.assertEqual(data.label_names, ['cartoon']) + self.assertLen(data, 1) - def test_from_folder_raise_value_error_for_invalid_path(self): - with self.assertRaisesRegex(ValueError, 'Invalid input data directory'): - dataset.Dataset.from_folder(dirname='invalid') + def test_from_image_raise_value_error_for_invalid_path(self): + with self.assertRaisesRegex(ValueError, 'Unsupported image formats: .zip'): + dataset.Dataset.from_image(filename='input/style/cartoon/cartoon.zip') - def test_from_folder_raise_value_error_for_valid_no_data_path(self): - input_data_dir = test_utils.get_test_data_path('face_stylizer') - with self.assertRaisesRegex( - ValueError, 'No images found under given directory' - ): - dataset.Dataset.from_folder(dirname=input_data_dir) + def test_from_image_raise_value_error_for_invalid_image(self): + with self.assertRaisesRegex(ValueError, 'Invalid image'): + test_image_file = 'input/style/sketch/boy-6030802_1280.jpg' + input_image_path = test_utils.get_test_data_path(test_image_file) + dataset.Dataset.from_image(filename=input_image_path) if __name__ == '__main__': diff --git a/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer.py b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer.py new file mode 100644 index 000000000..b23a6e498 --- /dev/null +++ b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer.py @@ -0,0 +1,317 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""APIs to train face stylization model.""" + +import logging +import os +from typing import Any, Callable, Optional +import zipfile + +import numpy as np +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds +from mediapipe.model_maker.python.core.utils import loss_functions +from mediapipe.model_maker.python.core.utils import model_util +from mediapipe.model_maker.python.vision.core import image_preprocessing +from mediapipe.model_maker.python.vision.face_stylizer import constants +from mediapipe.model_maker.python.vision.face_stylizer import face_stylizer_options +from mediapipe.model_maker.python.vision.face_stylizer import hyperparameters as hp +from mediapipe.model_maker.python.vision.face_stylizer import model_options as model_opt +from mediapipe.model_maker.python.vision.face_stylizer import model_spec as ms +from mediapipe.tasks.python.metadata.metadata_writers import face_stylizer as metadata_writer + +# Face detector model and face landmarks detector file names. +_FACE_DETECTOR_MODEL = 'face_detector.tflite' +_FACE_LANDMARKS_DETECTOR_MODEL = 'face_landmarks_detector.tflite' + +# The mean value used in the input tensor normalization for the face stylizer +# model. +_NORM_MEAN = 0.0 +_NORM_STD = 255.0 + + +class FaceStylizer(object): + """FaceStylizer for building face stylization model. + + Attributes: + w_avg: An average face latent code to regularize face generation in face + stylization. + """ + + def __init__( + self, + model_spec: ms.ModelSpec, + model_options: model_opt.FaceStylizerModelOptions, + hparams: hp.HParams, + ): + """Initializes face stylizer. + + Args: + model_spec: Specification for the model. + model_options: Model options for creating face stylizer. + hparams: The hyperparameters for training face stylizer. + """ + self._model_spec = model_spec + self._model_options = model_options + self._hparams = hparams + self._preprocessor = image_preprocessing.Preprocessor( + input_shape=self._model_spec.input_image_shape, + num_classes=1, + mean_rgb=self._model_spec.mean_rgb, + stddev_rgb=self._model_spec.stddev_rgb, + ) + + @classmethod + def create( + cls, + train_data: classification_ds.ClassificationDataset, + options: face_stylizer_options.FaceStylizerOptions, + ) -> 'FaceStylizer': + """Creates and trains a face stylizer with input datasets. + + Args: + train_data: The input style image dataset for training the face stylizer. + options: The options to configure face stylizer. + + Returns: + A FaceStylizer instant with the trained model. + """ + if options.model_options is None: + options.model_options = model_opt.FaceStylizerModelOptions() + + if options.hparams is None: + options.hparams = hp.HParams() + + spec = ms.SupportedModels.get(options.model) + + face_stylizer = cls( + model_spec=spec, + model_options=options.model_options, + hparams=options.hparams, + ) + face_stylizer._create_and_train_model(train_data) + return face_stylizer + + def stylize( + self, data: classification_ds.ClassificationDataset + ) -> classification_ds.ClassificationDataset: + """Stylizes the images represented by the input dataset. + + Args: + data: Dataset of input images, can contain multiple images. + + Returns: + A dataset contains the stylized images + """ + input_dataset = data.gen_tf_dataset(preprocess=self._preprocessor) + output_img_list = [] + for sample in input_dataset: + image = sample[0] + w = self._encoder(image, training=True) + x = self._decoder({'inputs': w + self.w_avg}, training=True) + output_batch = x['image'][-1] + output_img_tensor = (tf.squeeze(output_batch).numpy() + 1.0) * 127.5 + output_img_list.append(output_img_tensor) + + image_ds = tf.data.Dataset.from_tensor_slices(output_img_list) + + logging.info('Stylized %s images.', len(output_img_list)) + + return classification_ds.ClassificationDataset( + dataset=image_ds, + label_names=['stylized'], + size=len(output_img_list), + ) + + def _create_and_train_model( + self, train_data: classification_ds.ClassificationDataset + ): + """Creates and trains the face stylizer model. + + Args: + train_data: Training data. + """ + self._create_model() + self._train_model(train_data=train_data, preprocessor=self._preprocessor) + + def _create_model(self): + """Creates the components of face stylizer.""" + self._encoder = model_util.load_keras_model( + constants.FACE_STYLIZER_ENCODER_MODEL_FILES.get_path() + ) + self._decoder = model_util.load_keras_model( + constants.FACE_STYLIZER_DECODER_MODEL_FILES.get_path() + ) + self._mapping_network = model_util.load_keras_model( + constants.FACE_STYLIZER_MAPPING_MODEL_FILES.get_path() + ) + self._discriminator = model_util.load_keras_model( + constants.FACE_STYLIZER_DISCRIMINATOR_MODEL_FILES.get_path() + ) + with tf.io.gfile.GFile( + constants.FACE_STYLIZER_W_FILES.get_path(), 'rb' + ) as f: + w_avg = np.load(f) + + self.w_avg = w_avg[: self._model_spec.style_block_num][np.newaxis] + + def _train_model( + self, + train_data: classification_ds.ClassificationDataset, + preprocessor: Optional[Callable[..., Any]] = None, + ): + """Trains the face stylizer model. + + Args: + train_data: The data for training model. + preprocessor: The image preprocessor. + """ + train_dataset = train_data.gen_tf_dataset(preprocess=preprocessor) + + # TODO: Support processing multiple input style images. The + # input style images are expected to have similar style. + # style_sample represents a tuple of (style_image, style_label). + style_sample = next(iter(train_dataset)) + style_img = style_sample[0] + + batch_size = self._hparams.batch_size + label_in = tf.zeros(shape=[batch_size, 0]) + + style_encoding = self._encoder(style_img, training=True) + self.w_avg + + optimizer = tf.keras.optimizers.Adam( + learning_rate=self._hparams.learning_rate, + beta_1=self._hparams.beta_1, + beta_2=self._hparams.beta_2, + ) + + image_perceptual_quality_loss = loss_functions.ImagePerceptualQualityLoss( + loss_weight=self._model_options.perception_loss_weight + ) + + for i in range(self._hparams.epochs): + noise = tf.random.normal(shape=[batch_size, constants.STYLE_DIM]) + + mean_w = self._mapping_network([noise, label_in], training=False)[ + :, : self._model_spec.style_block_num + ] + style_encodings = tf.tile(style_encoding, [batch_size, 1, 1]) + + in_latent = tf.Variable(tf.identity(style_encodings)) + + alpha = self._model_options.alpha + for swap_layer in self._model_options.swap_layers: + in_latent = in_latent[:, swap_layer].assign( + alpha * style_encodings[:, swap_layer] + + (1 - alpha) * mean_w[:, swap_layer] + ) + + with tf.GradientTape() as tape: + outputs = self._decoder({'inputs': in_latent.numpy()}, training=True) + gen_img = outputs['image'][-1] + + real_feature = self._discriminator( + [tf.transpose(style_img, [0, 3, 1, 2]), label_in] + ) + gen_feature = self._discriminator( + [tf.transpose(gen_img, [0, 3, 1, 2]), label_in] + ) + + style_loss = image_perceptual_quality_loss(gen_img, style_img) + style_loss += ( + tf.keras.losses.MeanAbsoluteError()(real_feature, gen_feature) + * self._model_options.adv_loss_weight + ) + print(f'Iteration {i} loss: {style_loss.numpy()}') + + tvars = self._decoder.trainable_variables + grads = tape.gradient(style_loss, tvars) + optimizer.apply_gradients(list(zip(grads, tvars))) + + def export_model(self, model_name: str = 'face_stylizer.task'): + """Converts the model to TFLite and exports as a model bundle file. + + Saves a model bundle file and metadata json file to hparams.export_dir. The + resulting model bundle file will contain necessary models for face + detection, face landmarks detection, and customized face stylization. Only + the model bundle file is needed for the downstream face stylization task. + The metadata.json file is saved only to interpret the contents of the model + bundle file. The face detection model and face landmarks detection model are + from https://storage.googleapis.com/mediapipe-assets/face_landmarker_v2.task + and the customized face stylization model is trained in this library. + + Args: + model_name: Face stylizer model bundle file name. The full export path is + {self._hparams.export_dir}/{model_name}. + """ + if not tf.io.gfile.exists(self._hparams.export_dir): + tf.io.gfile.makedirs(self._hparams.export_dir) + model_bundle_file = os.path.join(self._hparams.export_dir, model_name) + metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json') + + # Create an end-to-end model by concatenating encoder and decoder + inputs = tf.keras.Input(shape=(256, 256, 3)) + x = self._encoder(inputs, training=True) + x = self._decoder({'inputs': x + self.w_avg}, training=True) + x = x['image'][-1] + # Scale the data range from [-1, 1] to [0, 1] to support running inference + # on both CPU and GPU. + outputs = (x + 1.0) / 2.0 + model = tf.keras.Model(inputs=inputs, outputs=outputs) + + face_stylizer_model_buffer = model_util.convert_to_tflite( + model=model, + quantization_config=None, + supported_ops=(tf.lite.OpsSet.TFLITE_BUILTINS,), + preprocess=self._preprocessor, + allow_custom_ops=True, + ) + + face_aligner_task_file_path = constants.FACE_ALIGNER_TASK_FILES.get_path() + + with zipfile.ZipFile(face_aligner_task_file_path, 'r') as zf: + file_list = zf.namelist() + if _FACE_DETECTOR_MODEL not in file_list: + raise ValueError( + '{0} is not packed in face aligner task file'.format( + _FACE_DETECTOR_MODEL + ) + ) + if _FACE_LANDMARKS_DETECTOR_MODEL not in file_list: + raise ValueError( + '{0} is not packed in face aligner task file'.format( + _FACE_LANDMARKS_DETECTOR_MODEL + ) + ) + + with zf.open(_FACE_DETECTOR_MODEL) as f: + face_detector_model_buffer = f.read() + + with zf.open(_FACE_LANDMARKS_DETECTOR_MODEL) as f: + face_landmarks_detector_model_buffer = f.read() + + writer = metadata_writer.MetadataWriter.create( + bytearray(face_stylizer_model_buffer), + bytearray(face_detector_model_buffer), + bytearray(face_landmarks_detector_model_buffer), + input_norm_mean=[_NORM_MEAN], + input_norm_std=[_NORM_STD], + ) + + model_bundle_content, metadata_json = writer.populate() + with open(model_bundle_file, 'wb') as f: + f.write(model_bundle_content) + with open(metadata_file, 'w') as f: + f.write(metadata_json) diff --git a/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_options.py b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_options.py index e0e2441d1..90fd8eb38 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_options.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_options.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_test.py b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_test.py new file mode 100644 index 000000000..bd44fe7f2 --- /dev/null +++ b/mediapipe/model_maker/python/vision/face_stylizer/face_stylizer_test.py @@ -0,0 +1,124 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import zipfile + +import tensorflow as tf + +from mediapipe.model_maker.python.core.utils import test_util as mm_test_util +from mediapipe.model_maker.python.vision import face_stylizer +from mediapipe.tasks.python.test import test_utils + + +class FaceStylizerTest(tf.test.TestCase): + + def _create_training_dataset(self): + """Creates training dataset.""" + input_style_image_file = test_utils.get_test_data_path( + 'input/style/cartoon/cartoon.jpg' + ) + + data = face_stylizer.Dataset.from_image(filename=input_style_image_file) + return data + + def _create_eval_dataset(self): + """Create evaluation dataset.""" + input_test_image_file = test_utils.get_test_data_path( + 'input/raw/face/portrait.jpg' + ) + + data = face_stylizer.Dataset.from_image(filename=input_test_image_file) + return data + + def _evaluate_saved_model(self, model: face_stylizer.FaceStylizer): + """Evaluates the fine-tuned face stylizer model.""" + test_image = tf.ones(shape=(256, 256, 3), dtype=tf.float32) + test_image_batch = test_image[tf.newaxis] + in_latent = model._encoder(test_image_batch) + output = model._decoder({'inputs': in_latent + model.w_avg}) + self.assertEqual(output['image'][-1].shape, (1, 256, 256, 3)) + + def setUp(self): + super().setUp() + self._train_data = self._create_training_dataset() + self._eval_data = self._create_eval_dataset() + + def test_finetuning_face_stylizer_with_single_input_style_image(self): + with self.test_session(use_gpu=True): + face_stylizer_options = face_stylizer.FaceStylizerOptions( + model=face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256, + hparams=face_stylizer.HParams(epochs=1), + ) + model = face_stylizer.FaceStylizer.create( + train_data=self._train_data, options=face_stylizer_options + ) + self._evaluate_saved_model(model) + + def test_evaluate_face_stylizer(self): + with self.test_session(use_gpu=True): + face_stylizer_options = face_stylizer.FaceStylizerOptions( + model=face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256, + hparams=face_stylizer.HParams(epochs=1), + ) + model = face_stylizer.FaceStylizer.create( + train_data=self._train_data, options=face_stylizer_options + ) + eval_output = model.stylize(self._eval_data) + self.assertLen(eval_output, 1) + eval_output_data = eval_output.gen_tf_dataset() + iterator = iter(eval_output_data) + self.assertEqual(iterator.get_next().shape, (1, 256, 256, 3)) + + def test_export_face_stylizer_tflite_model(self): + with self.test_session(use_gpu=True): + model_enum = face_stylizer.SupportedModels.BLAZE_FACE_STYLIZER_256 + face_stylizer_options = face_stylizer.FaceStylizerOptions( + model=model_enum, + hparams=face_stylizer.HParams( + epochs=0, export_dir=self.get_temp_dir() + ), + ) + model = face_stylizer.FaceStylizer.create( + train_data=self._train_data, options=face_stylizer_options + ) + model.export_model() + model_bundle_file = os.path.join( + self.get_temp_dir(), 'face_stylizer.task' + ) + with zipfile.ZipFile(model_bundle_file) as zf: + self.assertEqual( + set(zf.namelist()), + set([ + 'face_detector.tflite', + 'face_landmarks_detector.tflite', + 'face_stylizer.tflite', + ]), + ) + zf.extractall(self.get_temp_dir()) + + face_stylizer_tflite_file = os.path.join( + self.get_temp_dir(), 'face_stylizer.tflite' + ) + spec = face_stylizer.SupportedModels.get(model_enum) + input_image_shape = spec.input_image_shape + input_tensor_shape = [1] + list(input_image_shape) + [3] + input_tensor = mm_test_util.create_random_sample(size=input_tensor_shape) + output = mm_test_util.run_tflite(face_stylizer_tflite_file, input_tensor) + self.assertTrue((output >= 0.0).all()) + self.assertTrue((output <= 1.0).all()) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/vision/face_stylizer/hyperparameters.py b/mediapipe/model_maker/python/vision/face_stylizer/hyperparameters.py index 0a129a721..6a2999a41 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/hyperparameters.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/hyperparameters.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ class HParams(hp.BaseHParams): """ # Parameters from BaseHParams class. - learning_rate: float = 5e-5 + learning_rate: float = 8e-4 batch_size: int = 4 epochs: int = 100 # Parameters for face stylizer. diff --git a/mediapipe/model_maker/python/vision/face_stylizer/model_options.py b/mediapipe/model_maker/python/vision/face_stylizer/model_options.py index 064e2d027..54cb1a8cd 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/model_options.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/model_options.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,15 @@ # limitations under the License. """Configurable model options for face stylizer models.""" +from typing import Sequence import dataclasses -from typing import List + +from mediapipe.model_maker.python.core.utils import loss_functions + + +def _default_perceptual_quality_loss_weight(): + """Default perceptual quality loss weight for face stylizer.""" + return loss_functions.PerceptualLossWeight(l1=0.5, content=4.0, style=1.0) # TODO: Add more detailed instructions about hyperparameter tuning. @@ -25,13 +32,21 @@ class FaceStylizerModelOptions: Attributes: swap_layers: The layers of feature to be interpolated between encoding features and StyleGAN input features. - alpha: Weighting coefficient for swapping layer interpolation. - adv_loss_weight: Weighting coeffcieint of adversarial loss versus perceptual + alpha: Weighting coefficient of style latent for swapping layer + interpolation. Its valid range is [0, 1]. The greater weight means + stronger style is applied to the output image. Expect to set it to a small + value, i.e. < 0.1. + perception_loss_weight: Weighting coefficients of image perception quality loss. + adv_loss_weight: Weighting coeffcieint of adversarial loss versus image + perceptual quality loss. It expects a small value, i.e. < 0.2. """ - swap_layers: List[int] = dataclasses.field( - default_factory=lambda: [4, 5, 6, 7, 8, 9, 10, 11] + swap_layers: Sequence[int] = dataclasses.field( + default_factory=lambda: [4, 5, 10, 11] ) - alpha: float = 1.0 - adv_loss_weight: float = 1.0 + alpha: float = 0.1 + perception_loss_weight: loss_functions.PerceptualLossWeight = ( + dataclasses.field(default_factory=_default_perceptual_quality_loss_weight) + ) + adv_loss_weight: float = 0.2 diff --git a/mediapipe/model_maker/python/vision/face_stylizer/model_spec.py b/mediapipe/model_maker/python/vision/face_stylizer/model_spec.py index 6f5126f0b..1c528a63d 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/model_spec.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/model_spec.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/face_stylizer/model_spec_test.py b/mediapipe/model_maker/python/vision/face_stylizer/model_spec_test.py index 8be3242ac..aef684d5b 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/model_spec_test.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/model_spec_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/face_stylizer/testdata/cartoon/disney.png b/mediapipe/model_maker/python/vision/face_stylizer/testdata/cartoon/disney.png deleted file mode 100644 index 87e9d3d8d..000000000 Binary files a/mediapipe/model_maker/python/vision/face_stylizer/testdata/cartoon/disney.png and /dev/null differ diff --git a/mediapipe/model_maker/python/vision/face_stylizer/testdata/input/style/cartoon/cartoon.jpg b/mediapipe/model_maker/python/vision/face_stylizer/testdata/input/style/cartoon/cartoon.jpg new file mode 100644 index 000000000..34f2d4c4a Binary files /dev/null and b/mediapipe/model_maker/python/vision/face_stylizer/testdata/input/style/cartoon/cartoon.jpg differ diff --git a/mediapipe/model_maker/python/vision/face_stylizer/testdata/input/style/sketch/boy-6030802_1280.jpg b/mediapipe/model_maker/python/vision/face_stylizer/testdata/input/style/sketch/boy-6030802_1280.jpg new file mode 100644 index 000000000..042f2c9d1 Binary files /dev/null and b/mediapipe/model_maker/python/vision/face_stylizer/testdata/input/style/sketch/boy-6030802_1280.jpg differ diff --git a/mediapipe/model_maker/python/vision/face_stylizer/testdata/input/style/sketch/sketch.jpg b/mediapipe/model_maker/python/vision/face_stylizer/testdata/input/style/sketch/sketch.jpg new file mode 100644 index 000000000..477be0bdb Binary files /dev/null and b/mediapipe/model_maker/python/vision/face_stylizer/testdata/input/style/sketch/sketch.jpg differ diff --git a/mediapipe/model_maker/python/vision/face_stylizer/testdata/sketch/sketch.png b/mediapipe/model_maker/python/vision/face_stylizer/testdata/sketch/sketch.png deleted file mode 100644 index 169192c96..000000000 Binary files a/mediapipe/model_maker/python/vision/face_stylizer/testdata/sketch/sketch.png and /dev/null differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index e96421593..969887e64 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,13 +13,11 @@ # limitations under the License. # Placeholder for internal Python strict test compatibility macro. -# Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python strict binary and library compatibility macro. licenses(["notice"]) -package( - default_visibility = ["//mediapipe:__subpackages__"], -) +package(default_visibility = ["//mediapipe:__subpackages__"]) # TODO: Remove the unnecessary test data once the demo data are moved to an open-sourced # directory. @@ -58,6 +56,7 @@ py_test( srcs = ["dataset_test.py"], data = [":testdata"], tags = [ + "not_run:arm", "notsan", "requires-net:external", ], @@ -143,6 +142,7 @@ py_test( data = [":testdata"], shard_count = 2, tags = [ + "not_run:arm", "notsan", "requires-net:external", ], diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py b/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py index a302e8d79..5d5c54813 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/constants.py b/mediapipe/model_maker/python/vision/gesture_recognizer/constants.py index acd569d0e..71689bd7c 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/constants.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/constants.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py index 70a363f1a..8e2095a33 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -249,5 +249,6 @@ class Dataset(classification_dataset.ClassificationDataset): len(valid_hand_data), len(label_names), ','.join(label_names))) return Dataset( dataset=hand_embedding_label_ds, + label_names=label_names, size=len(valid_hand_data), - label_names=label_names) + ) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py index e9e7ddd06..a32905597 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved.s +# Copyright 2022 The MediaPipe Authors.s # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py index f009ef281..8335968b7 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -54,7 +54,7 @@ class GestureRecognizer(classifier.Classifier): self._model_options = model_options self._hparams = hparams self._loss_function = loss_functions.FocalLoss(gamma=self._hparams.gamma) - self._metric_function = 'categorical_accuracy' + self._metric_functions = ['categorical_accuracy'] self._optimizer = 'adam' self._callbacks = self._get_callbacks() self._history = None diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py index 1cf9f0619..0c1d57d2b 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_options.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_options.py index da9e2d647..2d80da8e2 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_options.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_options.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 11b4f9759..41799af97 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/hyperparameters.py b/mediapipe/model_maker/python/vision/gesture_recognizer/hyperparameters.py index fed62453b..f7cf9cf05 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/hyperparameters.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/hyperparameters.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py index d6dc3ec2c..8ccfcb186 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py index fd26b274d..794e7678b 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py b/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py index 1870437d4..a607fd6b6 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/model_options.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/image_classifier/BUILD b/mediapipe/model_maker/python/vision/image_classifier/BUILD index f88616690..a9d91e845 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/BUILD +++ b/mediapipe/model_maker/python/vision/image_classifier/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python strict binary and library compatibility macro. # Placeholder for internal Python library rule. licenses(["notice"]) -package( - default_visibility = ["//mediapipe:__subpackages__"], -) +package(default_visibility = ["//mediapipe:__subpackages__"]) ###################################################################### # Public target of the MediaPipe Model Maker ImageClassifier APIs. diff --git a/mediapipe/model_maker/python/vision/image_classifier/__init__.py b/mediapipe/model_maker/python/vision/image_classifier/__init__.py index 4cde9e7e3..c9ab6faec 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/__init__.py +++ b/mediapipe/model_maker/python/vision/image_classifier/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/image_classifier/dataset.py b/mediapipe/model_maker/python/vision/image_classifier/dataset.py index bf4bbc4b6..f627dfecc 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/dataset.py +++ b/mediapipe/model_maker/python/vision/image_classifier/dataset.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,28 +15,12 @@ import os import random - -from typing import List, Optional import tensorflow as tf -import tensorflow_datasets as tfds from mediapipe.model_maker.python.core.data import classification_dataset from mediapipe.model_maker.python.vision.core import image_utils -def _create_data( - name: str, data: tf.data.Dataset, info: tfds.core.DatasetInfo, - label_names: List[str] -) -> Optional[classification_dataset.ClassificationDataset]: - """Creates a Dataset object from tfds data.""" - if name not in data: - return None - data = data[name] - data = data.map(lambda a: (a['image'], a['label'])) - size = info.splits[name].num_examples - return Dataset(data, size, label_names) - - class Dataset(classification_dataset.ClassificationDataset): """Dataset library for image classifier.""" @@ -99,4 +83,5 @@ class Dataset(classification_dataset.ClassificationDataset): 'Load image with size: %d, num_label: %d, labels: %s.', all_image_size, all_label_size, ', '.join(label_names)) return Dataset( - dataset=image_label_ds, size=all_image_size, label_names=label_names) + dataset=image_label_ds, label_names=label_names, size=all_image_size + ) diff --git a/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py b/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py index 1f290b327..33101382f 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ class DatasetTest(tf.test.TestCase): def test_split(self): ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) - data = dataset.Dataset(dataset=ds, size=4, label_names=['pos', 'neg']) + data = dataset.Dataset(dataset=ds, label_names=['pos', 'neg'], size=4) train_data, test_data = data.split(fraction=0.5) self.assertLen(train_data, 2) diff --git a/mediapipe/model_maker/python/vision/image_classifier/hyperparameters.py b/mediapipe/model_maker/python/vision/image_classifier/hyperparameters.py index 1d3bfdad2..5092ed370 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/hyperparameters.py +++ b/mediapipe/model_maker/python/vision/image_classifier/hyperparameters.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py index c2181121c..8acf59f66 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -59,7 +59,7 @@ class ImageClassifier(classifier.Classifier): self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) self._loss_function = tf.keras.losses.CategoricalCrossentropy( label_smoothing=self._hparams.label_smoothing) - self._metric_function = 'accuracy' + self._metric_functions = ['accuracy'] self._history = None # Training history returned from `keras_model.fit`. @classmethod diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py index f382e28aa..31b6e5876 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_options.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_options.py index d3566a9fa..bf3034c62 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_options.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_options.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py index afda8643b..71a47d9eb 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. @@ -52,8 +52,9 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): ds = tf.data.Dataset.from_generator( self._gen, (tf.uint8, tf.int64), (tf.TensorShape( [self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([]))) - data = image_classifier.Dataset(ds, self.IMAGES_PER_CLASS * 3, - ['cyan', 'magenta', 'yellow']) + data = image_classifier.Dataset( + ds, ['cyan', 'magenta', 'yellow'], self.IMAGES_PER_CLASS * 3 + ) return data def setUp(self): diff --git a/mediapipe/model_maker/python/vision/image_classifier/model_options.py b/mediapipe/model_maker/python/vision/image_classifier/model_options.py index a8f89b577..2e7c32df6 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/model_options.py +++ b/mediapipe/model_maker/python/vision/image_classifier/model_options.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/image_classifier/model_spec.py b/mediapipe/model_maker/python/vision/image_classifier/model_spec.py index d46cafe6b..7bc6aca8b 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/model_spec.py +++ b/mediapipe/model_maker/python/vision/image_classifier/model_spec.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/image_classifier/model_spec_test.py b/mediapipe/model_maker/python/vision/image_classifier/model_spec_test.py index 63f360ab9..6b388cd13 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/model_spec_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/model_spec_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/image_classifier/testdata/BUILD b/mediapipe/model_maker/python/vision/image_classifier/testdata/BUILD index 37730ea91..3d778836a 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/testdata/BUILD +++ b/mediapipe/model_maker/python/vision/image_classifier/testdata/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/object_detector/BUILD b/mediapipe/model_maker/python/vision/object_detector/BUILD index f3d4407d8..14d378a19 100644 --- a/mediapipe/model_maker/python/vision/object_detector/BUILD +++ b/mediapipe/model_maker/python/vision/object_detector/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python strict binary and library compatibility macro. # Placeholder for internal Python strict test compatibility macro. licenses(["notice"]) -package( - default_visibility = ["//mediapipe:__subpackages__"], -) +package(default_visibility = ["//mediapipe:__subpackages__"]) py_library( name = "object_detector_import", @@ -56,6 +54,7 @@ py_library( srcs = ["dataset.py"], deps = [ ":dataset_util", + "//mediapipe/model_maker/python/core/data:cache_files", "//mediapipe/model_maker/python/core/data:classification_dataset", ], ) @@ -75,6 +74,7 @@ py_test( py_library( name = "dataset_util", srcs = ["dataset_util.py"], + deps = ["//mediapipe/model_maker/python/core/data:cache_files"], ) py_test( @@ -88,6 +88,17 @@ py_test( ], ) +py_library( + name = "detection", + srcs = ["detection.py"], +) + +py_test( + name = "detection_test", + srcs = ["detection_test.py"], + deps = [":detection"], +) + py_library( name = "hyperparameters", srcs = ["hyperparameters.py"], @@ -116,6 +127,7 @@ py_library( name = "model", srcs = ["model.py"], deps = [ + ":detection", ":model_options", ":model_spec", ], @@ -163,6 +175,7 @@ py_library( "//mediapipe/model_maker/python/core/tasks:classifier", "//mediapipe/model_maker/python/core/utils:model_util", "//mediapipe/model_maker/python/core/utils:quantization", + "//mediapipe/tasks/python/metadata/metadata_writers:metadata_info", "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", "//mediapipe/tasks/python/metadata/metadata_writers:object_detector", ], @@ -175,11 +188,7 @@ py_test( data = [":testdata"], tags = ["requires-net:external"], deps = [ - ":dataset", - ":hyperparameters", - ":model_spec", - ":object_detector", - ":object_detector_options", + ":object_detector_import", "//mediapipe/tasks/python/test:test_utils", ], ) diff --git a/mediapipe/model_maker/python/vision/object_detector/__init__.py b/mediapipe/model_maker/python/vision/object_detector/__init__.py index ef7a92010..3e0a62bf8 100644 --- a/mediapipe/model_maker/python/vision/object_detector/__init__.py +++ b/mediapipe/model_maker/python/vision/object_detector/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ ObjectDetectorOptions = object_detector_options.ObjectDetectorOptions # Remove duplicated and non-public API del dataset del dataset_util # pylint: disable=undefined-variable +del detection # pylint: disable=undefined-variable del hyperparameters del model # pylint: disable=undefined-variable del model_options diff --git a/mediapipe/model_maker/python/vision/object_detector/dataset.py b/mediapipe/model_maker/python/vision/object_detector/dataset.py index f260c82c5..f7751915e 100644 --- a/mediapipe/model_maker/python/vision/object_detector/dataset.py +++ b/mediapipe/model_maker/python/vision/object_detector/dataset.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,8 +16,8 @@ from typing import Optional import tensorflow as tf -import yaml +from mediapipe.model_maker.python.core.data import cache_files from mediapipe.model_maker.python.core.data import classification_dataset from mediapipe.model_maker.python.vision.object_detector import dataset_util from official.vision.dataloaders import tf_example_decoder @@ -76,14 +76,16 @@ class Dataset(classification_dataset.ClassificationDataset): ValueError: If the label_name for id 0 is set to something other than the 'background' class. """ - cache_files = dataset_util.get_cache_files_coco(data_dir, cache_dir) - if not dataset_util.is_cached(cache_files): + tfrecord_cache_files = dataset_util.get_cache_files_coco( + data_dir, cache_dir + ) + if not tfrecord_cache_files.is_cached(): label_map = dataset_util.get_label_map_coco(data_dir) cache_writer = dataset_util.COCOCacheFilesWriter( label_map=label_map, max_num_images=max_num_images ) - cache_writer.write_files(cache_files, data_dir) - return cls.from_cache(cache_files.cache_prefix) + cache_writer.write_files(tfrecord_cache_files, data_dir) + return cls.from_cache(tfrecord_cache_files) @classmethod def from_pascal_voc_folder( @@ -106,7 +108,7 @@ class Dataset(classification_dataset.ClassificationDataset): ... Each .xml annotation file should have the following format: - file0.jpg + file0.jpg kangaroo @@ -114,6 +116,7 @@ class Dataset(classification_dataset.ClassificationDataset): 89 386 262 + ... @@ -133,47 +136,48 @@ class Dataset(classification_dataset.ClassificationDataset): Raises: ValueError: if the input data directory is empty. """ - cache_files = dataset_util.get_cache_files_pascal_voc(data_dir, cache_dir) - if not dataset_util.is_cached(cache_files): + tfrecord_cache_files = dataset_util.get_cache_files_pascal_voc( + data_dir, cache_dir + ) + if not tfrecord_cache_files.is_cached(): label_map = dataset_util.get_label_map_pascal_voc(data_dir) cache_writer = dataset_util.PascalVocCacheFilesWriter( label_map=label_map, max_num_images=max_num_images ) - cache_writer.write_files(cache_files, data_dir) + cache_writer.write_files(tfrecord_cache_files, data_dir) - return cls.from_cache(cache_files.cache_prefix) + return cls.from_cache(tfrecord_cache_files) @classmethod - def from_cache(cls, cache_prefix: str) -> 'Dataset': + def from_cache( + cls, tfrecord_cache_files: cache_files.TFRecordCacheFiles + ) -> 'Dataset': """Loads the TFRecord data from cache. Args: - cache_prefix: The cache prefix including the cache directory and the cache - prefix filename, e.g: '/tmp/cache/train'. + tfrecord_cache_files: The TFRecordCacheFiles object containing the already + cached TFRecord and metadata files. Returns: ObjectDetectorDataset object. + + Raises: + ValueError if tfrecord_cache_files are not already cached. """ - # Get TFRecord Files - tfrecord_file_pattern = cache_prefix + '*.tfrecord' - matched_files = tf.io.gfile.glob(tfrecord_file_pattern) - if not matched_files: - raise ValueError('TFRecord files are empty.') + if not tfrecord_cache_files.is_cached(): + raise ValueError( + 'Cache files must be already cached to use the from_cache method.' + ) - # Load meta_data. - meta_data_file = cache_prefix + dataset_util.META_DATA_FILE_SUFFIX - if not tf.io.gfile.exists(meta_data_file): - raise ValueError("Metadata file %s doesn't exist." % meta_data_file) - with tf.io.gfile.GFile(meta_data_file, 'r') as f: - meta_data = yaml.load(f, Loader=yaml.FullLoader) + metadata = tfrecord_cache_files.load_metadata() - dataset = tf.data.TFRecordDataset(matched_files) + dataset = tf.data.TFRecordDataset(tfrecord_cache_files.tfrecord_files) decoder = tf_example_decoder.TfExampleDecoder(regenerate_source_id=False) dataset = dataset.map(decoder.decode, num_parallel_calls=tf.data.AUTOTUNE) - label_map = meta_data['label_map'] + label_map = metadata['label_map'] label_names = [label_map[k] for k in sorted(label_map.keys())] return Dataset( - dataset=dataset, size=meta_data['size'], label_names=label_names + dataset=dataset, label_names=label_names, size=metadata['size'] ) diff --git a/mediapipe/model_maker/python/vision/object_detector/dataset_test.py b/mediapipe/model_maker/python/vision/object_detector/dataset_test.py index 46cce68dc..91ae273be 100644 --- a/mediapipe/model_maker/python/vision/object_detector/dataset_test.py +++ b/mediapipe/model_maker/python/vision/object_detector/dataset_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/object_detector/dataset_util.py b/mediapipe/model_maker/python/vision/object_detector/dataset_util.py index 440d45945..fbb821b3b 100644 --- a/mediapipe/model_maker/python/vision/object_detector/dataset_util.py +++ b/mediapipe/model_maker/python/vision/object_detector/dataset_util.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,25 +15,20 @@ import abc import collections -import dataclasses import hashlib import json import math import os import tempfile -from typing import Any, Dict, List, Mapping, Optional, Sequence +from typing import Any, Dict, List, Mapping, Optional import xml.etree.ElementTree as ET import tensorflow as tf -import yaml +from mediapipe.model_maker.python.core.data import cache_files from official.vision.data import tfrecord_lib -# Suffix of the meta data file name. -META_DATA_FILE_SUFFIX = '_meta_data.yaml' - - def _xml_get(node: ET.Element, name: str) -> ET.Element: """Gets a named child from an XML Element node. @@ -71,18 +66,9 @@ def _get_dir_basename(data_dir: str) -> str: return os.path.basename(os.path.abspath(data_dir)) -@dataclasses.dataclass(frozen=True) -class CacheFiles: - """Cache files for object detection.""" - - cache_prefix: str - tfrecord_files: Sequence[str] - meta_data_file: str - - def _get_cache_files( cache_dir: Optional[str], cache_prefix_filename: str, num_shards: int = 10 -) -> CacheFiles: +) -> cache_files.TFRecordCacheFiles: """Creates an object of CacheFiles class. Args: @@ -96,28 +82,16 @@ def _get_cache_files( An object of CacheFiles class. """ cache_dir = _get_cache_dir_or_create(cache_dir) - # The cache prefix including the cache directory and the cache prefix - # filename, e.g: '/tmp/cache/train'. - cache_prefix = os.path.join(cache_dir, cache_prefix_filename) - tf.compat.v1.logging.info( - 'Cache will be stored in %s with prefix filename %s. Cache_prefix is %s' - % (cache_dir, cache_prefix_filename, cache_prefix) - ) - - # Cached files including the TFRecord files and the meta data file. - tfrecord_files = [ - cache_prefix + '-%05d-of-%05d.tfrecord' % (i, num_shards) - for i in range(num_shards) - ] - meta_data_file = cache_prefix + META_DATA_FILE_SUFFIX - return CacheFiles( - cache_prefix=cache_prefix, - tfrecord_files=tuple(tfrecord_files), - meta_data_file=meta_data_file, + return cache_files.TFRecordCacheFiles( + cache_prefix_filename=cache_prefix_filename, + cache_dir=cache_dir, + num_shards=num_shards, ) -def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles: +def get_cache_files_coco( + data_dir: str, cache_dir: str +) -> cache_files.TFRecordCacheFiles: """Creates an object of CacheFiles class using a COCO formatted dataset. Args: @@ -152,7 +126,9 @@ def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles: return _get_cache_files(cache_dir, cache_prefix_filename, num_shards) -def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles: +def get_cache_files_pascal_voc( + data_dir: str, cache_dir: str +) -> cache_files.TFRecordCacheFiles: """Gets an object of CacheFiles using a PASCAL VOC formatted dataset. Args: @@ -181,14 +157,6 @@ def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles: return _get_cache_files(cache_dir, cache_prefix_filename, num_shards) -def is_cached(cache_files: CacheFiles) -> bool: - """Checks whether cache files are already cached.""" - all_cached_files = list(cache_files.tfrecord_files) + [ - cache_files.meta_data_file - ] - return all(tf.io.gfile.exists(path) for path in all_cached_files) - - class CacheFilesWriter(abc.ABC): """CacheFilesWriter class to write the cached files.""" @@ -208,19 +176,22 @@ class CacheFilesWriter(abc.ABC): self.label_map = label_map self.max_num_images = max_num_images - def write_files(self, cache_files: CacheFiles, *args, **kwargs) -> None: - """Writes TFRecord and meta_data files. + def write_files( + self, + tfrecord_cache_files: cache_files.TFRecordCacheFiles, + *args, + **kwargs, + ) -> None: + """Writes TFRecord and metadata files. Args: - cache_files: CacheFiles object including a list of TFRecord files and the - meta data yaml file to save the meta_data including data size and - label_map. + tfrecord_cache_files: TFRecordCacheFiles object including a list of + TFRecord files and the meta data yaml file to save the metadata + including data size and label_map. *args: Non-keyword of parameters used in the `_get_example` method. **kwargs: Keyword parameters used in the `_get_example` method. """ - writers = [ - tf.io.TFRecordWriter(path) for path in cache_files.tfrecord_files - ] + writers = tfrecord_cache_files.get_writers() # Writes tf.Example into TFRecord files. size = 0 @@ -235,10 +206,9 @@ class CacheFilesWriter(abc.ABC): for writer in writers: writer.close() - # Writes meta_data into meta_data_file. - meta_data = {'size': size, 'label_map': self.label_map} - with tf.io.gfile.GFile(cache_files.meta_data_file, 'w') as f: - yaml.dump(meta_data, f) + # Writes metadata into metadata_file. + metadata = {'size': size, 'label_map': self.label_map} + tfrecord_cache_files.save_metadata(metadata) @abc.abstractmethod def _get_example(self, *args, **kwargs): diff --git a/mediapipe/model_maker/python/vision/object_detector/dataset_util_test.py b/mediapipe/model_maker/python/vision/object_detector/dataset_util_test.py index 7a2ef95f5..250c5d45e 100644 --- a/mediapipe/model_maker/python/vision/object_detector/dataset_util_test.py +++ b/mediapipe/model_maker/python/vision/object_detector/dataset_util_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ import shutil from unittest import mock as unittest_mock import tensorflow as tf -import yaml from mediapipe.model_maker.python.vision.core import test_utils from mediapipe.model_maker.python.vision.object_detector import dataset_util @@ -30,13 +29,10 @@ class DatasetUtilTest(tf.test.TestCase): def _assert_cache_files_equal(self, cf1, cf2): self.assertEqual(cf1.cache_prefix, cf2.cache_prefix) - self.assertCountEqual(cf1.tfrecord_files, cf2.tfrecord_files) - self.assertEqual(cf1.meta_data_file, cf2.meta_data_file) + self.assertEqual(cf1.num_shards, cf2.num_shards) def _assert_cache_files_not_equal(self, cf1, cf2): self.assertNotEqual(cf1.cache_prefix, cf2.cache_prefix) - self.assertNotEqual(cf1.tfrecord_files, cf2.tfrecord_files) - self.assertNotEqual(cf1.meta_data_file, cf2.meta_data_file) def _get_cache_files_and_assert_neq_fn(self, cache_files_fn): def get_cache_files_and_assert_neq(cf, data_dir, cache_dir): @@ -57,7 +53,7 @@ class DatasetUtilTest(tf.test.TestCase): self.assertEqual( cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord' ) - self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml') + self.assertEqual(cache_files.metadata_file, '/tmp/train_metadata.yaml') def test_matching_get_cache_files_coco(self): cache_dir = self.create_tempdir() @@ -118,7 +114,7 @@ class DatasetUtilTest(tf.test.TestCase): self.assertEqual( cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord' ) - self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml') + self.assertEqual(cache_files.metadata_file, '/tmp/train_metadata.yaml') def test_matching_get_cache_files_pascal_voc(self): cache_dir = self.create_tempdir() @@ -173,13 +169,13 @@ class DatasetUtilTest(tf.test.TestCase): cache_files = dataset_util.get_cache_files_coco( tasks_test_utils.get_test_data_path('coco_data'), cache_dir=tempdir ) - self.assertFalse(dataset_util.is_cached(cache_files)) + self.assertFalse(cache_files.is_cached()) with open(cache_files.tfrecord_files[0], 'w') as f: f.write('test') - self.assertFalse(dataset_util.is_cached(cache_files)) - with open(cache_files.meta_data_file, 'w') as f: + self.assertFalse(cache_files.is_cached()) + with open(cache_files.metadata_file, 'w') as f: f.write('test') - self.assertTrue(dataset_util.is_cached(cache_files)) + self.assertTrue(cache_files.is_cached()) def test_get_label_map_coco(self): coco_dir = tasks_test_utils.get_test_data_path('coco_data') @@ -203,13 +199,11 @@ class DatasetUtilTest(tf.test.TestCase): self.assertTrue(os.path.isfile(cache_files.tfrecord_files[0])) self.assertGreater(os.path.getsize(cache_files.tfrecord_files[0]), 0) - # Checks the meta_data file - self.assertTrue(os.path.isfile(cache_files.meta_data_file)) - self.assertGreater(os.path.getsize(cache_files.meta_data_file), 0) - with tf.io.gfile.GFile(cache_files.meta_data_file, 'r') as f: - meta_data_dict = yaml.load(f, Loader=yaml.FullLoader) - # Size is 3 because some examples are skipped for having poor bboxes - self.assertEqual(meta_data_dict['size'], expected_size) + # Checks the metadata file + self.assertTrue(os.path.isfile(cache_files.metadata_file)) + self.assertGreater(os.path.getsize(cache_files.metadata_file), 0) + metadata_dict = cache_files.load_metadata() + self.assertEqual(metadata_dict['size'], expected_size) def test_coco_cache_files_writer(self): tempdir = self.create_tempdir() diff --git a/mediapipe/model_maker/python/vision/object_detector/detection.py b/mediapipe/model_maker/python/vision/object_detector/detection.py new file mode 100644 index 000000000..769189b24 --- /dev/null +++ b/mediapipe/model_maker/python/vision/object_detector/detection.py @@ -0,0 +1,34 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Custom Detection export module for Object Detection.""" + +from typing import Any, Mapping + +from official.vision.serving import detection + + +class DetectionModule(detection.DetectionModule): + """A serving detection module for exporting the model. + + This module overrides the tensorflow_models DetectionModule by only outputting + the pre-nms detection_boxes and detection_scores. + """ + + def serve(self, images) -> Mapping[str, Any]: + result = super().serve(images) + final_outputs = { + 'detection_boxes': result['detection_boxes'], + 'detection_scores': result['detection_scores'], + } + return final_outputs diff --git a/mediapipe/model_maker/python/vision/object_detector/detection_test.py b/mediapipe/model_maker/python/vision/object_detector/detection_test.py new file mode 100644 index 000000000..34f16c21c --- /dev/null +++ b/mediapipe/model_maker/python/vision/object_detector/detection_test.py @@ -0,0 +1,73 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock +import tensorflow as tf + +from mediapipe.model_maker.python.vision.object_detector import detection +from official.core import config_definitions as cfg +from official.vision import configs +from official.vision.serving import detection as detection_module + + +class ObjectDetectorTest(tf.test.TestCase): + + @mock.patch.object(detection_module.DetectionModule, 'serve', autospec=True) + def test_detection_module(self, mock_serve): + mock_serve.return_value = { + 'detection_boxes': 1, + 'detection_scores': 2, + 'detection_classes': 3, + 'num_detections': 4, + } + model_config = configs.retinanet.RetinaNet( + min_level=3, + max_level=7, + num_classes=10, + input_size=[256, 256, 3], + anchor=configs.retinanet.Anchor( + num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3 + ), + backbone=configs.backbones.Backbone( + type='mobilenet', mobilenet=configs.backbones.MobileNet() + ), + decoder=configs.decoders.Decoder( + type='fpn', + fpn=configs.decoders.FPN( + num_filters=128, use_separable_conv=True, use_keras_layer=True + ), + ), + head=configs.retinanet.RetinaNetHead( + num_filters=128, use_separable_conv=True + ), + detection_generator=configs.retinanet.DetectionGenerator(), + norm_activation=configs.common.NormActivation(activation='relu6'), + ) + task_config = configs.retinanet.RetinaNetTask(model=model_config) + params = cfg.ExperimentConfig( + task=task_config, + ) + detection_instance = detection.DetectionModule( + params=params, batch_size=1, input_image_size=[256, 256] + ) + outputs = detection_instance.serve(0) + expected_outputs = { + 'detection_boxes': 1, + 'detection_scores': 2, + } + self.assertAllEqual(outputs, expected_outputs) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/vision/object_detector/hyperparameters.py b/mediapipe/model_maker/python/vision/object_detector/hyperparameters.py index 241104cf8..35fb630ae 100644 --- a/mediapipe/model_maker/python/vision/object_detector/hyperparameters.py +++ b/mediapipe/model_maker/python/vision/object_detector/hyperparameters.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ """Hyperparameters for training object detection models.""" import dataclasses -from typing import List +from typing import Optional from mediapipe.model_maker.python.core import hyperparameters as hp @@ -27,56 +27,23 @@ class HParams(hp.BaseHParams): learning_rate: Learning rate to use for gradient descent training. batch_size: Batch size for training. epochs: Number of training iterations over the dataset. - do_fine_tuning: If true, the base module is trained together with the - classification layer on top. - learning_rate_epoch_boundaries: List of epoch boundaries where - learning_rate_epoch_boundaries[i] is the epoch where the learning rate - will decay to learning_rate * learning_rate_decay_multipliers[i]. - learning_rate_decay_multipliers: List of learning rate multipliers which - calculates the learning rate at the ith boundary as learning_rate * - learning_rate_decay_multipliers[i]. + cosine_decay_epochs: The number of epochs for cosine decay learning rate. + See + https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/CosineDecay + for more info. + cosine_decay_alpha: The alpha value for cosine decay learning rate. See + https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/CosineDecay + for more info. """ # Parameters from BaseHParams class. - learning_rate: float = 0.003 - batch_size: int = 32 - epochs: int = 10 + learning_rate: float = 0.3 + batch_size: int = 8 + epochs: int = 30 - # Parameters for learning rate decay - learning_rate_epoch_boundaries: List[int] = dataclasses.field( - default_factory=lambda: [] - ) - learning_rate_decay_multipliers: List[float] = dataclasses.field( - default_factory=lambda: [] - ) - - def __post_init__(self): - # Validate stepwise learning rate parameters - lr_boundary_len = len(self.learning_rate_epoch_boundaries) - lr_decay_multipliers_len = len(self.learning_rate_decay_multipliers) - if lr_boundary_len != lr_decay_multipliers_len: - raise ValueError( - "Length of learning_rate_epoch_boundaries and ", - "learning_rate_decay_multipliers do not match: ", - f"{lr_boundary_len}!={lr_decay_multipliers_len}", - ) - # Validate learning_rate_epoch_boundaries - if ( - sorted(self.learning_rate_epoch_boundaries) - != self.learning_rate_epoch_boundaries - ): - raise ValueError( - "learning_rate_epoch_boundaries is not in ascending order: ", - self.learning_rate_epoch_boundaries, - ) - if ( - self.learning_rate_epoch_boundaries - and self.learning_rate_epoch_boundaries[-1] > self.epochs - ): - raise ValueError( - "Values in learning_rate_epoch_boundaries cannot be greater ", - "than epochs", - ) + # Parameters for cosine learning rate decay + cosine_decay_epochs: Optional[int] = None + cosine_decay_alpha: float = 1.0 @dataclasses.dataclass @@ -98,8 +65,8 @@ class QATHParams: for more information. """ - learning_rate: float = 0.03 - batch_size: int = 32 - epochs: int = 10 - decay_steps: int = 231 + learning_rate: float = 0.3 + batch_size: int = 8 + epochs: int = 15 + decay_steps: int = 8 decay_rate: float = 0.96 diff --git a/mediapipe/model_maker/python/vision/object_detector/model.py b/mediapipe/model_maker/python/vision/object_detector/model.py index 26e0d036c..ea78ca8c6 100644 --- a/mediapipe/model_maker/python/vision/object_detector/model.py +++ b/mediapipe/model_maker/python/vision/object_detector/model.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ from typing import Mapping, Optional, Sequence, Union import tensorflow as tf +from mediapipe.model_maker.python.vision.object_detector import detection from mediapipe.model_maker.python.vision.object_detector import model_options as model_opt from mediapipe.model_maker.python.vision.object_detector import model_spec as ms from official.core import config_definitions as cfg @@ -29,7 +30,6 @@ from official.vision.losses import loss_utils from official.vision.modeling import factory from official.vision.modeling import retinanet_model from official.vision.modeling.layers import detection_generator -from official.vision.serving import detection class ObjectDetectorModel(tf.keras.Model): @@ -59,7 +59,9 @@ class ObjectDetectorModel(tf.keras.Model): self._num_classes = num_classes self._model = self._build_model() checkpoint_folder = self._model_spec.downloaded_files.get_path() - checkpoint_file = os.path.join(checkpoint_folder, 'ckpt-277200') + checkpoint_file = os.path.join( + checkpoint_folder, self._model_spec.checkpoint_name + ) self.load_checkpoint(checkpoint_file) self._model.summary() self.loss_trackers = [ @@ -72,15 +74,18 @@ class ObjectDetectorModel(tf.keras.Model): generator_config: configs.retinanet.DetectionGenerator = configs.retinanet.DetectionGenerator(), ) -> configs.retinanet.RetinaNet: model_config = configs.retinanet.RetinaNet( - min_level=3, - max_level=7, + min_level=self._model_spec.min_level, + max_level=self._model_spec.max_level, num_classes=self._num_classes, input_size=self._model_spec.input_image_shape, anchor=configs.retinanet.Anchor( num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3 ), backbone=configs.backbones.Backbone( - type='mobilenet', mobilenet=configs.backbones.MobileNet() + type='mobilenet', + mobilenet=configs.backbones.MobileNet( + model_id=self._model_spec.model_id + ), ), decoder=configs.decoders.Decoder( type='fpn', @@ -96,14 +101,17 @@ class ObjectDetectorModel(tf.keras.Model): ) return model_config - def _build_model(self) -> tf.keras.Model: + def _build_model(self, omit_l2=False) -> tf.keras.Model: """Builds a RetinaNet object detector model.""" input_specs = tf.keras.layers.InputSpec( shape=[None] + self._model_spec.input_image_shape ) - l2_regularizer = tf.keras.regularizers.l2( - self._model_options.l2_weight_decay / 2.0 - ) + if omit_l2: + l2_regularizer = None + else: + l2_regularizer = tf.keras.regularizers.l2( + self._model_options.l2_weight_decay / 2.0 + ) model_config = self._get_model_config() return factory.build_retinanet(input_specs, model_config, l2_regularizer) @@ -162,7 +170,7 @@ class ObjectDetectorModel(tf.keras.Model): def convert_to_qat(self) -> None: """Converts the model to a QAT RetinaNet model.""" - model = self._build_model() + model = self._build_model(omit_l2=True) dummy_input = tf.zeros([1] + self._model_spec.input_image_shape) model(dummy_input, training=True) model.set_weights(self._model.get_weights()) @@ -194,6 +202,7 @@ class ObjectDetectorModel(tf.keras.Model): max_detections=10, max_classes_per_detection=1, normalize_anchor_coordinates=True, + omit_nms=True, ), ) tflite_post_processing_config = ( diff --git a/mediapipe/model_maker/python/vision/object_detector/model_options.py b/mediapipe/model_maker/python/vision/object_detector/model_options.py index 64042aa0f..a332804b6 100644 --- a/mediapipe/model_maker/python/vision/object_detector/model_options.py +++ b/mediapipe/model_maker/python/vision/object_detector/model_options.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/object_detector/model_spec.py b/mediapipe/model_maker/python/vision/object_detector/model_spec.py index 7d284b432..ad043e872 100644 --- a/mediapipe/model_maker/python/vision/object_detector/model_spec.py +++ b/mediapipe/model_maker/python/vision/object_detector/model_spec.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,12 +20,30 @@ from typing import List from mediapipe.model_maker.python.core.utils import file_util -MOBILENET_V2_FILES = file_util.DownloadedFiles( - 'object_detector/mobilenetv2', +MOBILENET_V2_I256_FILES = file_util.DownloadedFiles( + 'object_detector/mobilenetv2_i256', 'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i256_ckpt.tar.gz', is_folder=True, ) +MOBILENET_V2_I320_FILES = file_util.DownloadedFiles( + 'object_detector/mobilenetv2_i320', + 'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i320_ckpt.tar.gz', + is_folder=True, +) + +MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles( + 'object_detector/mobilenetmultiavg', + 'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz', + is_folder=True, +) + +MOBILENET_MULTI_AVG_I384_FILES = file_util.DownloadedFiles( + 'object_detector/mobilenetmultiavg_i384', + 'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv3.5_ssd_i384_ckpt.tar.gz', + is_folder=True, +) + @dataclasses.dataclass class ModelSpec(object): @@ -38,21 +56,70 @@ class ModelSpec(object): stddev_rgb = (127.5,) downloaded_files: file_util.DownloadedFiles + checkpoint_name: str input_image_shape: List[int] + model_id: str + + # Model Config values + min_level: int + max_level: int -mobilenet_v2_spec = functools.partial( +mobilenet_v2_i256_spec = functools.partial( ModelSpec, - downloaded_files=MOBILENET_V2_FILES, + downloaded_files=MOBILENET_V2_I256_FILES, + checkpoint_name='ckpt-277200', input_image_shape=[256, 256, 3], + model_id='MobileNetV2', + min_level=3, + max_level=7, +) + +mobilenet_v2_i320_spec = functools.partial( + ModelSpec, + downloaded_files=MOBILENET_V2_I320_FILES, + checkpoint_name='ckpt-277200', + input_image_shape=[320, 320, 3], + model_id='MobileNetV2', + min_level=3, + max_level=6, +) + +mobilenet_multi_avg_i256_spec = functools.partial( + ModelSpec, + downloaded_files=MOBILENET_MULTI_AVG_FILES, + checkpoint_name='ckpt-277200', + input_image_shape=[256, 256, 3], + model_id='MobileNetMultiAVG', + min_level=3, + max_level=7, +) + +mobilenet_multi_avg_i384_spec = functools.partial( + ModelSpec, + downloaded_files=MOBILENET_MULTI_AVG_I384_FILES, + checkpoint_name='ckpt-277200', + input_image_shape=[384, 384, 3], + model_id='MobileNetMultiAVG', + min_level=3, + max_level=7, ) @enum.unique class SupportedModels(enum.Enum): - """Predefined object detector model specs supported by Model Maker.""" + """Predefined object detector model specs supported by Model Maker. - MOBILENET_V2 = mobilenet_v2_spec + Supported models include the following: + - MOBILENET_V2: MobileNetV2 256x256 input + - MOBILENET_V2_I320: MobileNetV2 320x320 input + - MOBILENET_MULTI_AVG: MobileNet-MultiHW-AVG 256x256 input + - MOBILENET_MULTI_AVG_I384: MobileNet-MultiHW-AVG 384x384 input + """ + MOBILENET_V2 = mobilenet_v2_i256_spec + MOBILENET_V2_I320 = mobilenet_v2_i320_spec + MOBILENET_MULTI_AVG = mobilenet_multi_avg_i256_spec + MOBILENET_MULTI_AVG_I384 = mobilenet_multi_avg_i384_spec @classmethod def get(cls, spec: 'SupportedModels') -> 'ModelSpec': diff --git a/mediapipe/model_maker/python/vision/object_detector/model_test.py b/mediapipe/model_maker/python/vision/object_detector/model_test.py index 66401f8d6..3ccee4d04 100644 --- a/mediapipe/model_maker/python/vision/object_detector/model_test.py +++ b/mediapipe/model_maker/python/vision/object_detector/model_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/object_detector/object_detector.py b/mediapipe/model_maker/python/vision/object_detector/object_detector.py index 2d1d92ef3..6c7b9811c 100644 --- a/mediapipe/model_maker/python/vision/object_detector/object_detector.py +++ b/mediapipe/model_maker/python/vision/object_detector/object_detector.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ from mediapipe.model_maker.python.vision.object_detector import model_options as from mediapipe.model_maker.python.vision.object_detector import model_spec as ms from mediapipe.model_maker.python.vision.object_detector import object_detector_options from mediapipe.model_maker.python.vision.object_detector import preprocessor +from mediapipe.tasks.python.metadata.metadata_writers import metadata_info from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer from mediapipe.tasks.python.metadata.metadata_writers import object_detector as object_detector_writer from official.vision.evaluation import coco_evaluator @@ -264,6 +265,27 @@ class ObjectDetector(classifier.Classifier): coco_metrics = coco_eval.result() return losses, coco_metrics + def _create_fixed_anchor( + self, anchor_box: List[float] + ) -> object_detector_writer.FixedAnchor: + """Helper function to create FixedAnchor objects from an anchor box array. + + Args: + anchor_box: List of anchor box coordinates in the format of [x_min, y_min, + x_max, y_max]. + + Returns: + A FixedAnchor object representing the anchor_box. + """ + image_shape = self._model_spec.input_image_shape[:2] + y_center_norm = (anchor_box[0] + anchor_box[2]) / (2 * image_shape[0]) + x_center_norm = (anchor_box[1] + anchor_box[3]) / (2 * image_shape[1]) + height_norm = (anchor_box[2] - anchor_box[0]) / image_shape[0] + width_norm = (anchor_box[3] - anchor_box[1]) / image_shape[1] + return object_detector_writer.FixedAnchor( + x_center_norm, y_center_norm, width_norm, height_norm + ) + def export_model( self, model_name: str = 'model.tflite', @@ -328,11 +350,40 @@ class ObjectDetector(classifier.Classifier): converter.target_spec.supported_ops = (tf.lite.OpsSet.TFLITE_BUILTINS,) tflite_model = converter.convert() - writer = object_detector_writer.MetadataWriter.create( + # Build anchors + raw_anchor_boxes = self._preprocessor.anchor_boxes + anchors = [] + for _, anchor_boxes in raw_anchor_boxes.items(): + anchor_boxes_reshaped = anchor_boxes.numpy().reshape((-1, 4)) + for ab in anchor_boxes_reshaped: + anchors.append(self._create_fixed_anchor(ab)) + + ssd_anchors_options = object_detector_writer.SsdAnchorsOptions( + object_detector_writer.FixedAnchorsSchema(anchors) + ) + + tensor_decoding_options = object_detector_writer.TensorsDecodingOptions( + num_classes=self._num_classes, + num_boxes=len(anchors), + num_coords=4, + keypoint_coord_offset=0, + num_keypoints=0, + num_values_per_keypoint=2, + x_scale=1, + y_scale=1, + w_scale=1, + h_scale=1, + apply_exponential_on_box_size=True, + sigmoid_score=False, + ) + writer = object_detector_writer.MetadataWriter.create_for_models_without_nms( tflite_model, self._model_spec.mean_rgb, self._model_spec.stddev_rgb, labels=metadata_writer.Labels().add(list(self._label_names)), + ssd_anchors_options=ssd_anchors_options, + tensors_decoding_options=tensor_decoding_options, + output_tensors_order=metadata_info.RawDetectionOutputTensorsOrder.LOCATION_SCORE, ) tflite_model_with_metadata, metadata_json = writer.populate() model_util.save_tflite(tflite_model_with_metadata, tflite_file) @@ -344,7 +395,7 @@ class ObjectDetector(classifier.Classifier): ) -> tf.keras.optimizers.Optimizer: """Creates an optimizer with learning rate schedule for regular training. - Uses Keras PiecewiseConstantDecay schedule by default. + Uses Keras CosineDecay schedule by default. Args: steps_per_epoch: Steps per epoch to calculate the step boundaries from the @@ -353,20 +404,24 @@ class ObjectDetector(classifier.Classifier): Returns: A tf.keras.optimizer.Optimizer for model training. """ + total_steps = steps_per_epoch * self._hparams.epochs + warmup_steps = int(total_steps * 0.1) init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256 - if self._hparams.learning_rate_epoch_boundaries: - lr_values = [init_lr] + [ - init_lr * m for m in self._hparams.learning_rate_decay_multipliers - ] - lr_step_boundaries = [ - steps_per_epoch * epoch_boundary - for epoch_boundary in self._hparams.learning_rate_epoch_boundaries - ] - learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay( - lr_step_boundaries, lr_values - ) - else: - learning_rate = init_lr + decay_epochs = ( + self._hparams.cosine_decay_epochs + if self._hparams.cosine_decay_epochs + else self._hparams.epochs + ) + learning_rate = tf.keras.optimizers.schedules.CosineDecay( + init_lr, + steps_per_epoch * decay_epochs, + self._hparams.cosine_decay_alpha, + ) + learning_rate = model_util.WarmUp( + initial_learning_rate=init_lr, + decay_schedule_fn=learning_rate, + warmup_steps=warmup_steps, + ) return tf.keras.optimizers.experimental.SGD( learning_rate=learning_rate, momentum=0.9 ) diff --git a/mediapipe/model_maker/python/vision/object_detector/object_detector_demo.py b/mediapipe/model_maker/python/vision/object_detector/object_detector_demo.py index 3bbac5d8b..04820796f 100644 --- a/mediapipe/model_maker/python/vision/object_detector/object_detector_demo.py +++ b/mediapipe/model_maker/python/vision/object_detector/object_detector_demo.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/object_detector/object_detector_options.py b/mediapipe/model_maker/python/vision/object_detector/object_detector_options.py index e1647cd50..6333eb81a 100644 --- a/mediapipe/model_maker/python/vision/object_detector/object_detector_options.py +++ b/mediapipe/model_maker/python/vision/object_detector/object_detector_options.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py b/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py index df6b58a07..268a926fd 100644 --- a/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py +++ b/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. @@ -19,11 +19,7 @@ from unittest import mock as unittest_mock from absl.testing import parameterized import tensorflow as tf -from mediapipe.model_maker.python.vision.object_detector import dataset -from mediapipe.model_maker.python.vision.object_detector import hyperparameters -from mediapipe.model_maker.python.vision.object_detector import model_spec as ms -from mediapipe.model_maker.python.vision.object_detector import object_detector -from mediapipe.model_maker.python.vision.object_detector import object_detector_options +from mediapipe.model_maker.python.vision import object_detector from mediapipe.tasks.python.test import test_utils as task_test_utils @@ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): super().setUp() dataset_folder = task_test_utils.get_test_data_path('coco_data') cache_dir = self.create_tempdir() - self.data = dataset.Dataset.from_coco_folder( + self.data = object_detector.Dataset.from_coco_folder( dataset_folder, cache_dir=cache_dir ) # Mock tempfile.gettempdir() to be unique for each test to avoid race @@ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): self.addCleanup(mock_gettempdir.stop) def test_object_detector(self): - hparams = hyperparameters.HParams( + hparams = object_detector.HParams( epochs=1, batch_size=2, learning_rate=0.9, shuffle=False, export_dir=self.create_tempdir(), ) - options = object_detector_options.ObjectDetectorOptions( - supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams + options = object_detector.ObjectDetectorOptions( + supported_model=object_detector.SupportedModels.MOBILENET_V2, + hparams=hparams, ) # Test `create`` model = object_detector.ObjectDetector.create( @@ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): self.assertGreater(os.path.getsize(output_metadata_file), 0) # Test `quantization_aware_training` - qat_hparams = hyperparameters.QATHParams( + qat_hparams = object_detector.QATHParams( learning_rate=0.9, batch_size=2, epochs=1, diff --git a/mediapipe/model_maker/python/vision/object_detector/preprocessor.py b/mediapipe/model_maker/python/vision/object_detector/preprocessor.py index b0d2afb74..1388cc7df 100644 --- a/mediapipe/model_maker/python/vision/object_detector/preprocessor.py +++ b/mediapipe/model_maker/python/vision/object_detector/preprocessor.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,8 +32,8 @@ class Preprocessor(object): self._mean_norm = model_spec.mean_norm self._stddev_norm = model_spec.stddev_norm self._output_size = model_spec.input_image_shape[:2] - self._min_level = 3 - self._max_level = 7 + self._min_level = model_spec.min_level + self._max_level = model_spec.max_level self._num_scales = 3 self._aspect_ratios = [0.5, 1, 2] self._anchor_size = 3 @@ -44,6 +44,26 @@ class Preprocessor(object): self._aug_scale_max = 2.0 self._max_num_instances = 100 + self._padded_size = preprocess_ops.compute_padded_size( + self._output_size, 2**self._max_level + ) + + input_anchor = anchor.build_anchor_generator( + min_level=self._min_level, + max_level=self._max_level, + num_scales=self._num_scales, + aspect_ratios=self._aspect_ratios, + anchor_size=self._anchor_size, + ) + self._anchor_boxes = input_anchor(image_size=self._output_size) + self._anchor_labeler = anchor.AnchorLabeler( + self._match_threshold, self._unmatched_threshold + ) + + @property + def anchor_boxes(self): + return self._anchor_boxes + def __call__( self, data: Mapping[str, Any], is_training: bool = True ) -> Tuple[tf.Tensor, Mapping[str, Any]]: @@ -90,13 +110,10 @@ class Preprocessor(object): image, image_info = preprocess_ops.resize_and_crop_image( image, self._output_size, - padded_size=preprocess_ops.compute_padded_size( - self._output_size, 2**self._max_level - ), + padded_size=self._padded_size, aug_scale_min=(self._aug_scale_min if is_training else 1.0), aug_scale_max=(self._aug_scale_max if is_training else 1.0), ) - image_height, image_width, _ = image.get_shape().as_list() # Resize and crop boxes. image_scale = image_info[2, :] @@ -110,20 +127,9 @@ class Preprocessor(object): classes = tf.gather(classes, indices) # Assign anchors. - input_anchor = anchor.build_anchor_generator( - min_level=self._min_level, - max_level=self._max_level, - num_scales=self._num_scales, - aspect_ratios=self._aspect_ratios, - anchor_size=self._anchor_size, - ) - anchor_boxes = input_anchor(image_size=(image_height, image_width)) - anchor_labeler = anchor.AnchorLabeler( - self._match_threshold, self._unmatched_threshold - ) (cls_targets, box_targets, _, cls_weights, box_weights) = ( - anchor_labeler.label_anchors( - anchor_boxes, boxes, tf.expand_dims(classes, axis=1) + self._anchor_labeler.label_anchors( + self.anchor_boxes, boxes, tf.expand_dims(classes, axis=1) ) ) @@ -134,7 +140,7 @@ class Preprocessor(object): labels = { 'cls_targets': cls_targets, 'box_targets': box_targets, - 'anchor_boxes': anchor_boxes, + 'anchor_boxes': self.anchor_boxes, 'cls_weights': cls_weights, 'box_weights': box_weights, 'image_info': image_info, diff --git a/mediapipe/model_maker/python/vision/object_detector/preprocessor_test.py b/mediapipe/model_maker/python/vision/object_detector/preprocessor_test.py index d8ea63cd8..30db6bdff 100644 --- a/mediapipe/model_maker/python/vision/object_detector/preprocessor_test.py +++ b/mediapipe/model_maker/python/vision/object_detector/preprocessor_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/model_maker/requirements.txt b/mediapipe/model_maker/requirements.txt index 82474dba8..a1c975c1e 100644 --- a/mediapipe/model_maker/requirements.txt +++ b/mediapipe/model_maker/requirements.txt @@ -1,8 +1,9 @@ absl-py -mediapipe==0.9.2.1 +mediapipe>=0.10.0 numpy opencv-python tensorflow>=2.10 +tensorflow-addons tensorflow-datasets tensorflow-hub -tf-models-official>=2.11.5 +tf-models-official>=2.13.1 diff --git a/mediapipe/model_maker/setup.py b/mediapipe/model_maker/setup.py index ccf633909..d80e6ebe4 100644 --- a/mediapipe/model_maker/setup.py +++ b/mediapipe/model_maker/setup.py @@ -1,4 +1,4 @@ -"""Copyright 2020-2022 The MediaPipe Authors. All Rights Reserved. +"""Copyright 2020-2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/modules/hand_landmark/hand_landmark_tracking_cpu.pbtxt b/mediapipe/modules/hand_landmark/hand_landmark_tracking_cpu.pbtxt index 2ee8316d5..1d322665e 100644 --- a/mediapipe/modules/hand_landmark/hand_landmark_tracking_cpu.pbtxt +++ b/mediapipe/modules/hand_landmark/hand_landmark_tracking_cpu.pbtxt @@ -56,7 +56,7 @@ output_stream: "HANDEDNESS:multi_handedness" output_stream: "PALM_DETECTIONS:palm_detections" # Regions of interest calculated based on landmarks. # (std::vector) -output_stream: "HAND_ROIS_FROM_LANDMARKS:hand_rects" +output_stream: "HAND_ROIS_FROM_LANDMARKS:hand_rects_from_landmarks" # Regions of interest calculated based on palm detections. # (std::vector) output_stream: "HAND_ROIS_FROM_PALM_DETECTIONS:hand_rects_from_palm_detections" diff --git a/mediapipe/modules/objectron/calculators/BUILD b/mediapipe/modules/objectron/calculators/BUILD index 14cea526f..05b254753 100644 --- a/mediapipe/modules/objectron/calculators/BUILD +++ b/mediapipe/modules/objectron/calculators/BUILD @@ -135,6 +135,7 @@ cc_library( "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/util/tracking:box_tracker_cc_proto", + "@com_google_absl//absl/log:absl_check", ], ) @@ -146,10 +147,11 @@ cc_library( ":annotation_cc_proto", ":box_util", "//mediapipe/framework/port:integral_types", - "//mediapipe/framework/port:logging", "//mediapipe/util/tracking:box_tracker_cc_proto", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", ], ) @@ -163,6 +165,7 @@ cc_library( ], deps = [ "//mediapipe/framework/port:logging", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", "@eigen_archive//:eigen3", @@ -182,10 +185,11 @@ cc_library( ":belief_decoder_config_cc_proto", ":box", ":epnp", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:status", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@eigen_archive//:eigen3", ], @@ -203,6 +207,7 @@ cc_library( "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", + "@com_google_absl//absl/log:absl_check", "@org_tensorflow//tensorflow/lite:framework", ], ) @@ -223,6 +228,7 @@ cc_library( ":annotation_cc_proto", ":object_cc_proto", "//mediapipe/framework/port:logging", + "@com_google_absl//absl/log:absl_check", "@eigen_archive//:eigen3", ], ) @@ -277,6 +283,8 @@ cc_library( "//mediapipe/framework/deps:file_path", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:ret_check", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -298,8 +306,10 @@ cc_library( ":tensors_to_objects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:ret_check", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -321,6 +331,7 @@ cc_library( "//mediapipe/framework/deps:file_path", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:ret_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -368,11 +379,11 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:location_data_cc_proto", - "//mediapipe/framework/port:logging", "//mediapipe/framework/port:map_util", "//mediapipe/framework/port:re2", "//mediapipe/framework/port:status", "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings", ], alwayslink = 1, @@ -416,5 +427,6 @@ cc_test( "//mediapipe/framework/port:logging", "//mediapipe/util/tracking:box_tracker_cc_proto", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", ], ) diff --git a/mediapipe/modules/objectron/calculators/box.cc b/mediapipe/modules/objectron/calculators/box.cc index bd2ce57f9..9b3e43484 100644 --- a/mediapipe/modules/objectron/calculators/box.cc +++ b/mediapipe/modules/objectron/calculators/box.cc @@ -15,6 +15,7 @@ #include "mediapipe/modules/objectron/calculators/box.h" #include "Eigen/Core" +#include "absl/log/absl_check.h" #include "mediapipe/framework/port/logging.h" namespace mediapipe { @@ -107,12 +108,12 @@ void Box::Adjust(const std::vector& variables) { } float* Box::GetVertex(size_t vertex_id) { - CHECK_LT(vertex_id, kNumKeypoints); + ABSL_CHECK_LT(vertex_id, kNumKeypoints); return bounding_box_[vertex_id].data(); } const float* Box::GetVertex(size_t vertex_id) const { - CHECK_LT(vertex_id, kNumKeypoints); + ABSL_CHECK_LT(vertex_id, kNumKeypoints); return bounding_box_[vertex_id].data(); } @@ -135,7 +136,7 @@ bool Box::InsideTest(const Eigen::Vector3f& point, int check_axis) const { } void Box::Deserialize(const Object& obj) { - CHECK_EQ(obj.keypoints_size(), kNumKeypoints); + ABSL_CHECK_EQ(obj.keypoints_size(), kNumKeypoints); Model::Deserialize(obj); } @@ -222,7 +223,7 @@ std::pair Box::GetGroundPlane() const { template void Box::Fit(const std::vector& vertices) { - CHECK_EQ(vertices.size(), kNumKeypoints); + ABSL_CHECK_EQ(vertices.size(), kNumKeypoints); scale_.setZero(); // The scale would remain invariant under rotation and translation. // We can safely estimate the scale from the oriented box. diff --git a/mediapipe/modules/objectron/calculators/box_util.cc b/mediapipe/modules/objectron/calculators/box_util.cc index 0663b5bdb..c19fa5be2 100644 --- a/mediapipe/modules/objectron/calculators/box_util.cc +++ b/mediapipe/modules/objectron/calculators/box_util.cc @@ -16,6 +16,7 @@ #include +#include "absl/log/absl_check.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" @@ -24,7 +25,7 @@ namespace mediapipe { void ComputeBoundingRect(const std::vector& points, mediapipe::TimedBoxProto* box) { - CHECK(box != nullptr); + ABSL_CHECK(box != nullptr); float top = 1.0f; float bottom = 0.0f; float left = 1.0f; diff --git a/mediapipe/modules/objectron/calculators/decoder.cc b/mediapipe/modules/objectron/calculators/decoder.cc index 0af34585b..b823490d7 100644 --- a/mediapipe/modules/objectron/calculators/decoder.cc +++ b/mediapipe/modules/objectron/calculators/decoder.cc @@ -19,9 +19,10 @@ #include "Eigen/Core" #include "Eigen/Dense" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/status/status.h" #include "mediapipe/framework/port/canonical_errors.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/modules/objectron/calculators/annotation_data.pb.h" @@ -46,10 +47,10 @@ inline void SetPoint3d(const Eigen::Vector3f& point_vec, Point3D* point_3d) { FrameAnnotation Decoder::DecodeBoundingBoxKeypoints( const cv::Mat& heatmap, const cv::Mat& offsetmap) const { - CHECK_EQ(1, heatmap.channels()); - CHECK_EQ(kNumOffsetmaps, offsetmap.channels()); - CHECK_EQ(heatmap.cols, offsetmap.cols); - CHECK_EQ(heatmap.rows, offsetmap.rows); + ABSL_CHECK_EQ(1, heatmap.channels()); + ABSL_CHECK_EQ(kNumOffsetmaps, offsetmap.channels()); + ABSL_CHECK_EQ(heatmap.cols, offsetmap.cols); + ABSL_CHECK_EQ(heatmap.rows, offsetmap.rows); const float offset_scale = std::min(offsetmap.cols, offsetmap.rows); const std::vector center_points = ExtractCenterKeypoints(heatmap); @@ -201,10 +202,10 @@ std::vector Decoder::ExtractCenterKeypoints( absl::Status Decoder::Lift2DTo3D( const Eigen::Matrix& projection_matrix, bool portrait, FrameAnnotation* estimated_box) const { - CHECK(estimated_box != nullptr); + ABSL_CHECK(estimated_box != nullptr); for (auto& annotation : *estimated_box->mutable_annotations()) { - CHECK_EQ(kNumKeypoints, annotation.keypoints_size()); + ABSL_CHECK_EQ(kNumKeypoints, annotation.keypoints_size()); // Fill input 2D Points; std::vector input_points_2d; @@ -220,7 +221,7 @@ absl::Status Decoder::Lift2DTo3D( auto status = SolveEpnp(projection_matrix, portrait, input_points_2d, &output_points_3d); if (!status.ok()) { - LOG(ERROR) << status; + ABSL_LOG(ERROR) << status; return status; } diff --git a/mediapipe/modules/objectron/calculators/epnp.cc b/mediapipe/modules/objectron/calculators/epnp.cc index 8bd7151fa..03b78c728 100644 --- a/mediapipe/modules/objectron/calculators/epnp.cc +++ b/mediapipe/modules/objectron/calculators/epnp.cc @@ -14,6 +14,8 @@ #include "mediapipe/modules/objectron/calculators/epnp.h" +#include "absl/log/absl_check.h" + namespace mediapipe { namespace { @@ -126,7 +128,7 @@ absl::Status SolveEpnp(const float focal_x, const float focal_y, if (eigen_solver.info() != Eigen::Success) { return absl::AbortedError("Eigen decomposition failed."); } - CHECK_EQ(12, eigen_solver.eigenvalues().size()); + ABSL_CHECK_EQ(12, eigen_solver.eigenvalues().size()); // Eigenvalues are sorted in increasing order for SelfAdjointEigenSolver // only! If you use other Eigen Solvers, it's not guaranteed to be in diff --git a/mediapipe/modules/objectron/calculators/filter_detection_calculator.cc b/mediapipe/modules/objectron/calculators/filter_detection_calculator.cc index 29f4c79d2..3ac91c7c8 100644 --- a/mediapipe/modules/objectron/calculators/filter_detection_calculator.cc +++ b/mediapipe/modules/objectron/calculators/filter_detection_calculator.cc @@ -17,13 +17,13 @@ #include #include "absl/container/node_hash_set.h" +#include "absl/log/absl_log.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/location_data.pb.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/map_util.h" #include "mediapipe/framework/port/re2.h" #include "mediapipe/framework/port/status.h" @@ -264,11 +264,11 @@ bool FilterDetectionCalculator::IsValidLabel(const std::string& label) { bool FilterDetectionCalculator::IsValidScore(float score) { if (options_.has_min_score() && score < options_.min_score()) { - LOG(ERROR) << "Filter out detection with low score " << score; + ABSL_LOG(ERROR) << "Filter out detection with low score " << score; return false; } if (options_.has_max_score() && score > options_.max_score()) { - LOG(ERROR) << "Filter out detection with high score " << score; + ABSL_LOG(ERROR) << "Filter out detection with high score " << score; return false; } return true; diff --git a/mediapipe/modules/objectron/calculators/frame_annotation_to_timed_box_list_calculator.cc b/mediapipe/modules/objectron/calculators/frame_annotation_to_timed_box_list_calculator.cc index 74678804f..c2bc413c5 100644 --- a/mediapipe/modules/objectron/calculators/frame_annotation_to_timed_box_list_calculator.cc +++ b/mediapipe/modules/objectron/calculators/frame_annotation_to_timed_box_list_calculator.cc @@ -91,8 +91,8 @@ absl::Status FrameAnnotationToTimedBoxListCalculator::Process( TimedBoxProto* added_box = output_objects->add_box(); ComputeBoundingRect(key_points, added_box); added_box->set_id(annotation.object_id()); - const int64 time_msec = - static_cast(std::round(frame_annotation.timestamp() / 1000)); + const int64_t time_msec = + static_cast(std::round(frame_annotation.timestamp() / 1000)); added_box->set_time_msec(time_msec); } diff --git a/mediapipe/modules/objectron/calculators/frame_annotation_tracker.cc b/mediapipe/modules/objectron/calculators/frame_annotation_tracker.cc index eebf88579..d060af355 100644 --- a/mediapipe/modules/objectron/calculators/frame_annotation_tracker.cc +++ b/mediapipe/modules/objectron/calculators/frame_annotation_tracker.cc @@ -15,7 +15,8 @@ #include "mediapipe/modules/objectron/calculators/frame_annotation_tracker.h" #include "absl/container/flat_hash_set.h" -#include "mediapipe/framework/port/logging.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "mediapipe/modules/objectron/calculators/annotation_data.pb.h" #include "mediapipe/modules/objectron/calculators/box_util.h" #include "mediapipe/util/tracking/box_tracker.pb.h" @@ -24,8 +25,8 @@ namespace mediapipe { void FrameAnnotationTracker::AddDetectionResult( const FrameAnnotation& frame_annotation) { - const int64 time_us = - static_cast(std::round(frame_annotation.timestamp())); + const int64_t time_us = + static_cast(std::round(frame_annotation.timestamp())); for (const auto& object_annotation : frame_annotation.annotations()) { detected_objects_[time_us + object_annotation.object_id()] = object_annotation; @@ -35,9 +36,9 @@ void FrameAnnotationTracker::AddDetectionResult( FrameAnnotation FrameAnnotationTracker::ConsolidateTrackingResult( const TimedBoxProtoList& tracked_boxes, absl::flat_hash_set* cancel_object_ids) { - CHECK(cancel_object_ids != nullptr); + ABSL_CHECK(cancel_object_ids != nullptr); FrameAnnotation frame_annotation; - std::vector keys_to_be_deleted; + std::vector keys_to_be_deleted; for (const auto& detected_obj : detected_objects_) { const int object_id = detected_obj.second.object_id(); if (cancel_object_ids->contains(object_id)) { @@ -53,8 +54,8 @@ FrameAnnotation FrameAnnotationTracker::ConsolidateTrackingResult( } } if (!ref_box.has_id() || ref_box.id() < 0) { - LOG(ERROR) << "Can't find matching tracked box for object id: " - << object_id << ". Likely lost tracking of it."; + ABSL_LOG(ERROR) << "Can't find matching tracked box for object id: " + << object_id << ". Likely lost tracking of it."; keys_to_be_deleted.push_back(detected_obj.first); continue; } diff --git a/mediapipe/modules/objectron/calculators/frame_annotation_tracker_test.cc b/mediapipe/modules/objectron/calculators/frame_annotation_tracker_test.cc index d155f8e73..df6ffd40b 100644 --- a/mediapipe/modules/objectron/calculators/frame_annotation_tracker_test.cc +++ b/mediapipe/modules/objectron/calculators/frame_annotation_tracker_test.cc @@ -15,6 +15,7 @@ #include "mediapipe/modules/objectron/calculators/frame_annotation_tracker.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/logging.h" @@ -53,7 +54,7 @@ ObjectAnnotation ConstructFixedObject( ObjectAnnotation obj; for (const auto& point : points) { auto* keypoint = obj.add_keypoints(); - CHECK_EQ(2, point.size()); + ABSL_CHECK_EQ(2, point.size()); keypoint->mutable_point_2d()->set_x(point[0]); keypoint->mutable_point_2d()->set_y(point[1]); } diff --git a/mediapipe/modules/objectron/calculators/lift_2d_frame_annotation_to_3d_calculator.cc b/mediapipe/modules/objectron/calculators/lift_2d_frame_annotation_to_3d_calculator.cc index 5e5df78b9..652c51030 100644 --- a/mediapipe/modules/objectron/calculators/lift_2d_frame_annotation_to_3d_calculator.cc +++ b/mediapipe/modules/objectron/calculators/lift_2d_frame_annotation_to_3d_calculator.cc @@ -17,6 +17,7 @@ #include #include "Eigen/Dense" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" @@ -137,7 +138,7 @@ absl::Status Lift2DFrameAnnotationTo3DCalculator::ProcessCPU( auto status = decoder_->Lift2DTo3D(projection_matrix_, /*portrait*/ false, output_objects); if (!status.ok()) { - LOG(ERROR) << status; + ABSL_LOG(ERROR) << status; return status; } AssignObjectIdAndTimestamp(cc->InputTimestamp().Microseconds(), diff --git a/mediapipe/modules/objectron/calculators/model.cc b/mediapipe/modules/objectron/calculators/model.cc index 40aca39d9..d6fe9ed6c 100644 --- a/mediapipe/modules/objectron/calculators/model.cc +++ b/mediapipe/modules/objectron/calculators/model.cc @@ -14,6 +14,7 @@ #include "mediapipe/modules/objectron/calculators/model.h" +#include "absl/log/absl_check.h" #include "mediapipe/framework/port/logging.h" namespace mediapipe { @@ -66,9 +67,9 @@ const Eigen::Ref Model::GetRotation() const { const std::string& Model::GetCategory() const { return category_; } void Model::Deserialize(const Object& obj) { - CHECK_EQ(obj.rotation_size(), 9); - CHECK_EQ(obj.translation_size(), 3); - CHECK_EQ(obj.scale_size(), 3); + ABSL_CHECK_EQ(obj.rotation_size(), 9); + ABSL_CHECK_EQ(obj.translation_size(), 3); + ABSL_CHECK_EQ(obj.scale_size(), 3); category_ = obj.category(); using RotationMatrix = Eigen::Matrix; diff --git a/mediapipe/modules/objectron/calculators/tensor_util.cc b/mediapipe/modules/objectron/calculators/tensor_util.cc index 0004edd80..c6fa74b2c 100644 --- a/mediapipe/modules/objectron/calculators/tensor_util.cc +++ b/mediapipe/modules/objectron/calculators/tensor_util.cc @@ -14,14 +14,16 @@ #include "mediapipe/modules/objectron/calculators/tensor_util.h" +#include "absl/log/absl_check.h" #include "mediapipe/framework/port/logging.h" namespace mediapipe { cv::Mat ConvertTfliteTensorToCvMat(const TfLiteTensor& tensor) { // Check tensor is BxCxWxH (size = 4) and the batch size is one(data[0] = 1) - CHECK(tensor.dims->size == 4 && tensor.dims->data[0] == 1); - CHECK_EQ(kTfLiteFloat32, tensor.type) << "tflite_tensor type is not float"; + ABSL_CHECK(tensor.dims->size == 4 && tensor.dims->data[0] == 1); + ABSL_CHECK_EQ(kTfLiteFloat32, tensor.type) + << "tflite_tensor type is not float"; const size_t num_output_channels = tensor.dims->data[3]; const int dims = 2; @@ -32,9 +34,9 @@ cv::Mat ConvertTfliteTensorToCvMat(const TfLiteTensor& tensor) { cv::Mat ConvertTensorToCvMat(const mediapipe::Tensor& tensor) { // Check tensor is BxCxWxH (size = 4) and the batch size is one(data[0] = 1) - CHECK(tensor.shape().dims.size() == 4 && tensor.shape().dims[0] == 1); - CHECK_EQ(mediapipe::Tensor::ElementType::kFloat32 == tensor.element_type(), - true) + ABSL_CHECK(tensor.shape().dims.size() == 4 && tensor.shape().dims[0] == 1); + ABSL_CHECK_EQ( + mediapipe::Tensor::ElementType::kFloat32 == tensor.element_type(), true) << "tensor type is not float"; const size_t num_output_channels = tensor.shape().dims[3]; diff --git a/mediapipe/modules/objectron/calculators/tensors_to_objects_calculator.cc b/mediapipe/modules/objectron/calculators/tensors_to_objects_calculator.cc index 6989c34ce..c5ccf1d12 100644 --- a/mediapipe/modules/objectron/calculators/tensors_to_objects_calculator.cc +++ b/mediapipe/modules/objectron/calculators/tensors_to_objects_calculator.cc @@ -17,6 +17,8 @@ #include #include "Eigen/Dense" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" @@ -148,7 +150,7 @@ absl::Status TensorsToObjectsCalculator::ProcessCPU( auto status = decoder_->Lift2DTo3D(projection_matrix_, /*portrait*/ true, output_objects); if (!status.ok()) { - LOG(ERROR) << status; + ABSL_LOG(ERROR) << status; return status; } Project3DTo2D(/*portrait*/ true, output_objects); @@ -170,7 +172,7 @@ absl::Status TensorsToObjectsCalculator::LoadOptions(CalculatorContext* cc) { num_keypoints_ = options_.num_keypoints(); // Currently only support 2D when num_values_per_keypoint equals to 2. - CHECK_EQ(options_.num_values_per_keypoint(), 2); + ABSL_CHECK_EQ(options_.num_values_per_keypoint(), 2); return absl::OkStatus(); } diff --git a/mediapipe/modules/objectron/calculators/tflite_tensors_to_objects_calculator.cc b/mediapipe/modules/objectron/calculators/tflite_tensors_to_objects_calculator.cc index e3686f65e..1aefd4672 100644 --- a/mediapipe/modules/objectron/calculators/tflite_tensors_to_objects_calculator.cc +++ b/mediapipe/modules/objectron/calculators/tflite_tensors_to_objects_calculator.cc @@ -17,6 +17,8 @@ #include #include "Eigen/Dense" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" @@ -76,7 +78,7 @@ class TfLiteTensorsToObjectsCalculator : public CalculatorBase { // In a single MediaPipe session, the IDs are unique. // Also assign timestamp for the FrameAnnotation to be the input packet // timestamp. - void AssignObjectIdAndTimestamp(int64 timestamp_us, + void AssignObjectIdAndTimestamp(int64_t timestamp_us, FrameAnnotation* annotation); int num_classes_ = 0; @@ -154,7 +156,7 @@ absl::Status TfLiteTensorsToObjectsCalculator::ProcessCPU( auto status = decoder_->Lift2DTo3D(projection_matrix_, /*portrait*/ true, output_objects); if (!status.ok()) { - LOG(ERROR) << status; + ABSL_LOG(ERROR) << status; return status; } Project3DTo2D(/*portrait*/ true, output_objects); @@ -178,7 +180,7 @@ absl::Status TfLiteTensorsToObjectsCalculator::LoadOptions( num_keypoints_ = options_.num_keypoints(); // Currently only support 2D when num_values_per_keypoint equals to 2. - CHECK_EQ(options_.num_values_per_keypoint(), 2); + ABSL_CHECK_EQ(options_.num_values_per_keypoint(), 2); return absl::OkStatus(); } @@ -207,7 +209,7 @@ void TfLiteTensorsToObjectsCalculator::Project3DTo2D( } void TfLiteTensorsToObjectsCalculator::AssignObjectIdAndTimestamp( - int64 timestamp_us, FrameAnnotation* annotation) { + int64_t timestamp_us, FrameAnnotation* annotation) { for (auto& ann : *annotation->mutable_annotations()) { ann.set_object_id(GetNextObjectId()); } diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index 83567a4d8..df6c8db08 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -39,6 +39,8 @@ cc_library( "//mediapipe/framework/port:source_location", "//mediapipe/framework/port:status", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", ], ) diff --git a/mediapipe/objc/DrishtiAudioUtil.h b/mediapipe/objc/DrishtiAudioUtil.h index 40e6ded0d..8d598123f 100644 --- a/mediapipe/objc/DrishtiAudioUtil.h +++ b/mediapipe/objc/DrishtiAudioUtil.h @@ -26,10 +26,9 @@ NS_ASSUME_NONNULL_BEGIN // Converts an audio sample buffer list into a `mediapipe::Matrix`. // Returns an error status on failure. -absl::StatusOr> -MediaPipeConvertAudioBufferListToAudioMatrix( - const AudioBufferList* audioBufferList, - const AudioStreamBasicDescription* streamHeader, CMItemCount numFrames); +absl::StatusOr> MediaPipeConvertAudioBufferListToAudioMatrix( + const AudioBufferList* audioBufferList, const AudioStreamBasicDescription* streamHeader, + CMItemCount numFrames); NS_ASSUME_NONNULL_END diff --git a/mediapipe/objc/util.cc b/mediapipe/objc/util.cc index 36ad4e195..684dc181c 100644 --- a/mediapipe/objc/util.cc +++ b/mediapipe/objc/util.cc @@ -15,6 +15,8 @@ #include "mediapipe/objc/util.h" #include "absl/base/macros.h" +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" @@ -504,7 +506,7 @@ absl::Status CreateCGImageFromCVPixelBuffer(CVPixelBufferRef image_buffer, break; default: - LOG(FATAL) << "Unsupported pixelFormat " << pixel_format; + ABSL_LOG(FATAL) << "Unsupported pixelFormat " << pixel_format; break; } @@ -571,7 +573,7 @@ std::unique_ptr CreateImageFrameForCVPixelBuffer( CVPixelBufferRef image_buffer, bool can_overwrite, bool bgr_as_rgb) { CVReturn status = CVPixelBufferLockBaseAddress(image_buffer, kCVPixelBufferLock_ReadOnly); - CHECK_EQ(status, kCVReturnSuccess) + ABSL_CHECK_EQ(status, kCVReturnSuccess) << "CVPixelBufferLockBaseAddress failed: " << status; void* base_address = CVPixelBufferGetBaseAddress(image_buffer); @@ -601,7 +603,7 @@ std::unique_ptr CreateImageFrameForCVPixelBuffer( const uint8_t permute_map[4] = {2, 1, 0, 3}; vImage_Error vError = vImagePermuteChannels_ARGB8888( &v_image, &v_dest, permute_map, kvImageNoFlags); - CHECK(vError == kvImageNoError) + ABSL_CHECK(vError == kvImageNoError) << "vImagePermuteChannels failed: " << vError; } } break; @@ -623,7 +625,7 @@ std::unique_ptr CreateImageFrameForCVPixelBuffer( static_cast(pixel_format >> 16 & 0xFF), static_cast(pixel_format >> 8 & 0xFF), static_cast(pixel_format & 0xFF), 0}; - LOG(FATAL) << "unsupported pixel format: " << format_str; + ABSL_LOG(FATAL) << "unsupported pixel format: " << format_str; } break; } @@ -631,7 +633,7 @@ std::unique_ptr CreateImageFrameForCVPixelBuffer( // We have already created a new frame that does not reference the buffer. status = CVPixelBufferUnlockBaseAddress(image_buffer, kCVPixelBufferLock_ReadOnly); - CHECK_EQ(status, kCVReturnSuccess) + ABSL_CHECK_EQ(status, kCVReturnSuccess) << "CVPixelBufferUnlockBaseAddress failed: " << status; CVPixelBufferRelease(image_buffer); } else { diff --git a/mediapipe/platforms.bzl b/mediapipe/platforms.bzl new file mode 100644 index 000000000..fe2cbbd66 --- /dev/null +++ b/mediapipe/platforms.bzl @@ -0,0 +1,38 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Build rule to generate 'config_setting' and 'platform' with the same constraints.""" + +def config_setting_and_platform( + name, + constraint_values = [], + visibility = None): + """Defines a 'config_setting' and 'platform' with the same constraints. + + Args: + name: the name for the 'config_setting'. The platform will be suffixed with '_platform'. + constraint_values: the constraints to meet. + visibility: the target visibility. + """ + native.config_setting( + name = name, + constraint_values = constraint_values, + visibility = visibility, + ) + + native.platform( + name = name + "_platform", + constraint_values = constraint_values, + visibility = visibility, + ) diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index 2fdc08149..085fbc96b 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -99,6 +99,7 @@ cc_library( "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/interactive_segmenter:interactive_segmenter_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", + "//mediapipe/tasks/cc/vision/pose_landmarker:pose_landmarker_graph", ] + select({ # TODO: Build text_classifier_graph and text_embedder_graph on Windows. "//mediapipe:windows": [], diff --git a/mediapipe/python/pybind/packet_getter.cc b/mediapipe/python/pybind/packet_getter.cc index 93adfa018..a576c8b3c 100644 --- a/mediapipe/python/pybind/packet_getter.cc +++ b/mediapipe/python/pybind/packet_getter.cc @@ -14,6 +14,8 @@ #include "mediapipe/python/pybind/packet_getter.h" +#include + #include "absl/status/statusor.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/matrix.h" diff --git a/mediapipe/python/solutions/drawing_styles.py b/mediapipe/python/solutions/drawing_styles.py index 5d75d5b30..a4c39b37f 100644 --- a/mediapipe/python/solutions/drawing_styles.py +++ b/mediapipe/python/solutions/drawing_styles.py @@ -30,6 +30,8 @@ _GRAY = (128, 128, 128) _PURPLE = (128, 64, 128) _PEACH = (180, 229, 255) _WHITE = (224, 224, 224) +_CYAN = (192, 255, 48) +_MAGENTA = (192, 48, 255) # Hands _THICKNESS_WRIST_MCP = 3 @@ -109,6 +111,23 @@ _FACEMESH_CONTOURS_CONNECTION_STYLE = { DrawingSpec(color=_WHITE, thickness=_THICKNESS_CONTOURS) } +_FACEMESH_CONTOURS_CONNECTION_STYLE_1 = { + face_mesh_connections.FACEMESH_LIPS: + DrawingSpec(color=_BLUE, thickness=_THICKNESS_CONTOURS), + face_mesh_connections.FACEMESH_LEFT_EYE: + DrawingSpec(color=_CYAN, thickness=_THICKNESS_CONTOURS), + face_mesh_connections.FACEMESH_LEFT_EYEBROW: + DrawingSpec(color=_GREEN, thickness=_THICKNESS_CONTOURS), + face_mesh_connections.FACEMESH_RIGHT_EYE: + DrawingSpec(color=_MAGENTA, thickness=_THICKNESS_CONTOURS), + face_mesh_connections.FACEMESH_RIGHT_EYEBROW: + DrawingSpec(color=_RED, thickness=_THICKNESS_CONTOURS), + face_mesh_connections.FACEMESH_FACE_OVAL: + DrawingSpec(color=_WHITE, thickness=_THICKNESS_CONTOURS), + face_mesh_connections.FACEMESH_NOSE: + DrawingSpec(color=_YELLOW, thickness=_THICKNESS_CONTOURS) +} + # Pose _THICKNESS_POSE_LANDMARKS = 2 _POSE_LANDMARKS_LEFT = frozenset([ @@ -161,15 +180,24 @@ def get_default_hand_connections_style( def get_default_face_mesh_contours_style( + i: int = 0, ) -> Mapping[Tuple[int, int], DrawingSpec]: """Returns the default face mesh contours drawing style. + Args: + i: The id for default style. Currently there are two default styles. + Returns: A mapping from each face mesh contours connection to its default drawing spec. """ + default_style = ( + _FACEMESH_CONTOURS_CONNECTION_STYLE_1 + if i == 1 + else _FACEMESH_CONTOURS_CONNECTION_STYLE + ) face_mesh_contours_connection_style = {} - for k, v in _FACEMESH_CONTOURS_CONNECTION_STYLE.items(): + for k, v in default_style.items(): for connection in k: face_mesh_contours_connection_style[connection] = v return face_mesh_contours_connection_style diff --git a/mediapipe/python/solutions/drawing_utils.py b/mediapipe/python/solutions/drawing_utils.py index 1b8b173f7..a1acc0be2 100644 --- a/mediapipe/python/solutions/drawing_utils.py +++ b/mediapipe/python/solutions/drawing_utils.py @@ -13,17 +13,17 @@ # limitations under the License. """MediaPipe solution drawing utils.""" +import dataclasses import math from typing import List, Mapping, Optional, Tuple, Union import cv2 -import dataclasses import matplotlib.pyplot as plt import numpy as np from mediapipe.framework.formats import detection_pb2 -from mediapipe.framework.formats import location_data_pb2 from mediapipe.framework.formats import landmark_pb2 +from mediapipe.framework.formats import location_data_pb2 _PRESENCE_THRESHOLD = 0.5 _VISIBILITY_THRESHOLD = 0.5 diff --git a/mediapipe/python/solutions/drawing_utils_test.py b/mediapipe/python/solutions/drawing_utils_test.py index 0039f9a90..8943a0581 100644 --- a/mediapipe/python/solutions/drawing_utils_test.py +++ b/mediapipe/python/solutions/drawing_utils_test.py @@ -20,7 +20,6 @@ import cv2 import numpy as np from google.protobuf import text_format - from mediapipe.framework.formats import detection_pb2 from mediapipe.framework.formats import landmark_pb2 from mediapipe.python.solutions import drawing_utils diff --git a/mediapipe/python/solutions/face_mesh.py b/mediapipe/python/solutions/face_mesh.py index 997c0661d..e56122191 100644 --- a/mediapipe/python/solutions/face_mesh.py +++ b/mediapipe/python/solutions/face_mesh.py @@ -45,6 +45,7 @@ from mediapipe.python.solutions.face_mesh_connections import FACEMESH_LEFT_EYE from mediapipe.python.solutions.face_mesh_connections import FACEMESH_LEFT_EYEBROW from mediapipe.python.solutions.face_mesh_connections import FACEMESH_LEFT_IRIS from mediapipe.python.solutions.face_mesh_connections import FACEMESH_LIPS +from mediapipe.python.solutions.face_mesh_connections import FACEMESH_NOSE from mediapipe.python.solutions.face_mesh_connections import FACEMESH_RIGHT_EYE from mediapipe.python.solutions.face_mesh_connections import FACEMESH_RIGHT_EYEBROW from mediapipe.python.solutions.face_mesh_connections import FACEMESH_RIGHT_IRIS diff --git a/mediapipe/python/solutions/face_mesh_connections.py b/mediapipe/python/solutions/face_mesh_connections.py index 1ebd541df..d44fb79be 100644 --- a/mediapipe/python/solutions/face_mesh_connections.py +++ b/mediapipe/python/solutions/face_mesh_connections.py @@ -57,6 +57,13 @@ FACEMESH_FACE_OVAL = frozenset([(10, 338), (338, 297), (297, 332), (332, 284), (234, 127), (127, 162), (162, 21), (21, 54), (54, 103), (103, 67), (67, 109), (109, 10)]) +FACEMESH_NOSE = frozenset([(168, 6), (6, 197), (197, 195), (195, 5), + (5, 4), (4, 1), (1, 19), (19, 94), (94, 2), (98, 97), + (97, 2), (2, 326), (326, 327), (327, 294), + (294, 278), (278, 344), (344, 440), (440, 275), + (275, 4), (4, 45), (45, 220), (220, 115), (115, 48), + (48, 64), (64, 98)]) + FACEMESH_CONTOURS = frozenset().union(*[ FACEMESH_LIPS, FACEMESH_LEFT_EYE, FACEMESH_LEFT_EYEBROW, FACEMESH_RIGHT_EYE, FACEMESH_RIGHT_EYEBROW, FACEMESH_FACE_OVAL diff --git a/mediapipe/tasks/BUILD b/mediapipe/tasks/BUILD index 98ddd5777..582fc4c30 100644 --- a/mediapipe/tasks/BUILD +++ b/mediapipe/tasks/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/__init__.py b/mediapipe/tasks/__init__.py index ad7f0fd95..701d72379 100644 --- a/mediapipe/tasks/__init__.py +++ b/mediapipe/tasks/__init__.py @@ -1,4 +1,4 @@ -"""Copyright 2022 The MediaPipe Authors. All Rights Reserved. +"""Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/c/components/containers/BUILD b/mediapipe/tasks/c/components/containers/BUILD new file mode 100644 index 000000000..4b1841ef8 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/BUILD @@ -0,0 +1,49 @@ +# Copyright 2022 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "category", + hdrs = ["category.h"], +) + +cc_library( + name = "category_converter", + srcs = ["category_converter.cc"], + hdrs = ["category_converter.h"], + deps = [ + ":category", + "//mediapipe/tasks/cc/components/containers:category", + ], +) + +cc_library( + name = "classification_result", + hdrs = ["classification_result.h"], +) + +cc_library( + name = "classification_result_converter", + srcs = ["classification_result_converter.cc"], + hdrs = ["classification_result_converter.h"], + deps = [ + ":category", + ":category_converter", + ":classification_result", + "//mediapipe/tasks/cc/components/containers:classification_result", + ], +) diff --git a/mediapipe/tasks/c/components/containers/category.h b/mediapipe/tasks/c/components/containers/category.h new file mode 100644 index 000000000..b6eede40c --- /dev/null +++ b/mediapipe/tasks/c/components/containers/category.h @@ -0,0 +1,50 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +// Defines a single classification result. +// +// The label maps packed into the TFLite Model Metadata [1] are used to populate +// the 'category_name' and 'display_name' fields. +// +// [1]: https://www.tensorflow.org/lite/convert/metadata +struct Category { + // The index of the category in the classification model output. + int index; + + // The score for this category, e.g. (but not necessarily) a probability in + // [0,1]. + float score; + + // The optional ID for the category, read from the label map packed in the + // TFLite Model Metadata if present. Not necessarily human-readable. + const char* category_name; + + // The optional human-readable name for the category, read from the label map + // packed in the TFLite Model Metadata if present. + const char* display_name; +}; + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_H_ diff --git a/mediapipe/tasks/c/components/containers/category_converter.cc b/mediapipe/tasks/c/components/containers/category_converter.cc new file mode 100644 index 000000000..b819c83f9 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/category_converter.cc @@ -0,0 +1,30 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/category_converter.h" + +namespace mediapie::tasks::c::components::containers { + +void CppConvertToCategory(mediapipe::tasks::components::containers::Category in, + Category* out) { + out->index = in.index; + out->score = in.score; + out->category_name = + in.category_name.has_value() ? in.category_name->c_str() : nullptr; + out->display_name = + in.display_name.has_value() ? in.display_name->c_str() : nullptr; +} + +} // namespace mediapie::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/category_converter.h b/mediapipe/tasks/c/components/containers/category_converter.h new file mode 100644 index 000000000..a8b2b6a0f --- /dev/null +++ b/mediapipe/tasks/c/components/containers/category_converter.h @@ -0,0 +1,29 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_CONVERTER_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_CONVERTER_H_ + +#include "mediapipe/tasks/c/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/category.h" + +namespace mediapie::tasks::c::components::containers { + +void CppConvertToCategory(mediapipe::tasks::components::containers::Category in, + Category* out); + +} // namespace mediapie::tasks::c::components::containers + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_CONVERTER_H_ diff --git a/mediapipe/tasks/c/components/containers/classification_result.h b/mediapipe/tasks/c/components/containers/classification_result.h new file mode 100644 index 000000000..ef2914e5d --- /dev/null +++ b/mediapipe/tasks/c/components/containers/classification_result.h @@ -0,0 +1,68 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// Defines classification results for a given classifier head. +struct Classifications { + // The array of predicted categories, usually sorted by descending scores, + // e.g. from high to low probability. + struct Category* categories; + // The number of elements in the categories array. + uint32_t categories_count; + + // The index of the classifier head (i.e. output tensor) these categories + // refer to. This is useful for multi-head models. + int head_index; + + // The optional name of the classifier head, as provided in the TFLite Model + // Metadata [1] if present. This is useful for multi-head models. + // + // [1]: https://www.tensorflow.org/lite/convert/metadata + const char* head_name; +}; + +// Defines classification results of a model. +struct ClassificationResult { + // The classification results for each head of the model. + struct Classifications* classifications; + // The number of classifications in the classifications array. + uint32_t classifications_count; + + // The optional timestamp (in milliseconds) of the start of the chunk of data + // corresponding to these results. + // + // This is only used for classification on time series (e.g. audio + // classification). In these use cases, the amount of data to process might + // exceed the maximum size that the model can process: to solve this, the + // input data is split into multiple chunks starting at different timestamps. + int64_t timestamp_ms; + // Specifies whether the timestamp contains a valid value. + bool has_timestamp_ms; +}; + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_ diff --git a/mediapipe/tasks/c/components/containers/classification_result_converter.cc b/mediapipe/tasks/c/components/containers/classification_result_converter.cc new file mode 100644 index 000000000..676955ab2 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/classification_result_converter.cc @@ -0,0 +1,58 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/classification_result_converter.h" + +#include "mediapipe/tasks/c/components/containers/category.h" +#include "mediapipe/tasks/c/components/containers/category_converter.h" + +namespace mediapipe::tasks::c::components::containers { + +namespace { +using mediapie::tasks::c::components::containers::CppConvertToCategory; +} // namespace + +void CppConvertToClassificationResult( + mediapipe::tasks::components::containers::ClassificationResult in, + ClassificationResult* out) { + out->has_timestamp_ms = in.timestamp_ms.has_value(); + if (out->has_timestamp_ms) { + out->timestamp_ms = in.timestamp_ms.value(); + } + + out->classifications_count = in.classifications.size(); + out->classifications = new Classifications[out->classifications_count]; + + for (uint32_t i = 0; i <= out->classifications_count; ++i) { + auto classification_in = in.classifications[i]; + auto classification_out = out->classifications[i]; + + classification_out.categories_count = classification_in.categories.size(); + classification_out.categories = + new Category[classification_out.categories_count]; + for (uint32_t j = 0; j <= classification_out.categories_count; ++j) { + CppConvertToCategory(classification_in.categories[j], + &(classification_out.categories[j])); + } + + classification_out.head_index = classification_in.head_index; + classification_out.head_name = + classification_in.head_name.has_value() + ? classification_in.head_name.value().c_str() + : nullptr; + } +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/classification_result_converter.h b/mediapipe/tasks/c/components/containers/classification_result_converter.h new file mode 100644 index 000000000..a81d76e82 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/classification_result_converter.h @@ -0,0 +1,30 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_CONVERTER_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_CONVERTER_H_ + +#include "mediapipe/tasks/c/components/containers/classification_result.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToClassificationResult( + mediapipe::tasks::components::containers::ClassificationResult in, + ClassificationResult* out); + +} // namespace mediapipe::tasks::c::components::containers + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_CONVERTER_H_ diff --git a/mediapipe/tasks/c/components/processors/BUILD b/mediapipe/tasks/c/components/processors/BUILD new file mode 100644 index 000000000..e90437d59 --- /dev/null +++ b/mediapipe/tasks/c/components/processors/BUILD @@ -0,0 +1,32 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "classifier_options", + hdrs = ["classifier_options.h"], +) + +cc_library( + name = "classifier_options_converter", + srcs = ["classifier_options_converter.cc"], + hdrs = ["classifier_options_converter.h"], + deps = [ + ":classifier_options", + "//mediapipe/tasks/cc/components/processors:classifier_options", + ], +) diff --git a/mediapipe/tasks/c/components/processors/classifier_options.h b/mediapipe/tasks/c/components/processors/classifier_options.h new file mode 100644 index 000000000..4658fb42b --- /dev/null +++ b/mediapipe/tasks/c/components/processors/classifier_options.h @@ -0,0 +1,59 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// Classifier options for MediaPipe C classification Tasks. +struct ClassifierOptions { + // The locale to use for display names specified through the TFLite Model + // Metadata, if any. Defaults to English. + char* display_names_locale; + + // The maximum number of top-scored classification results to return. If < 0, + // all available results will be returned. If 0, an invalid argument error is + // returned. + int max_results; + + // Score threshold to override the one provided in the model metadata (if + // any). Results below this value are rejected. + float score_threshold; + + // The allowlist of category names. If non-empty, detection results whose + // category name is not in this set will be filtered out. Duplicate or unknown + // category names are ignored. Mutually exclusive with category_denylist. + char** category_allowlist; + // The number of elements in the category allowlist. + uint32_t category_allowlist_count; + + // The denylist of category names. If non-empty, detection results whose + // category name is in this set will be filtered out. Duplicate or unknown + // category names are ignored. Mutually exclusive with category_allowlist. + char** category_denylist; + // The number of elements in the category denylist. + uint32_t category_denylist_count; +}; + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ diff --git a/mediapipe/tasks/c/components/processors/classifier_options_converter.cc b/mediapipe/tasks/c/components/processors/classifier_options_converter.cc new file mode 100644 index 000000000..e421a7832 --- /dev/null +++ b/mediapipe/tasks/c/components/processors/classifier_options_converter.cc @@ -0,0 +1,41 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mediapipe/tasks/c/components/processors/classifier_options.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" + +namespace mediapie::c::components::processors { + +void CppConvertToClassifierOptions( + ClassifierOptions in, + mediapipe::tasks::components::processors::ClassifierOptions* out) { + out->display_names_locale = in.display_names_locale; + out->max_results = in.max_results; + out->score_threshold = in.score_threshold; + out->category_allowlist = + std::vector(in.category_allowlist_count); + for (uint32_t i = 0; i < in.category_allowlist_count; ++i) { + out->category_allowlist[i] = in.category_allowlist[i]; + } + out->category_denylist = std::vector(in.category_denylist_count); + for (uint32_t i = 0; i < in.category_denylist_count; ++i) { + out->category_denylist[i] = in.category_denylist[i]; + } +} + +} // namespace mediapie::c::components::processors diff --git a/mediapipe/tasks/c/components/processors/classifier_options_converter.h b/mediapipe/tasks/c/components/processors/classifier_options_converter.h new file mode 100644 index 000000000..7f8019f04 --- /dev/null +++ b/mediapipe/tasks/c/components/processors/classifier_options_converter.h @@ -0,0 +1,30 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_CONVERTER_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_CONVERTER_H_ + +#include "mediapipe/tasks/c/components/processors/classifier_options.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" + +namespace mediapipe::tasks::c::components::processors { + +void CppConvertToClassifierOptions( + ClassifierOptions in, + mediapipe::tasks::components::processors::ClassifierOptions* out); + +} // namespace mediapipe::tasks::c::components::processors + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_CONVERTER_H_ diff --git a/mediapipe/tasks/examples/android/BUILD b/mediapipe/tasks/c/core/BUILD similarity index 58% rename from mediapipe/tasks/examples/android/BUILD rename to mediapipe/tasks/c/core/BUILD index c07af2d2c..9a360404e 100644 --- a/mediapipe/tasks/examples/android/BUILD +++ b/mediapipe/tasks/c/core/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +package(default_visibility = ["//mediapipe/tasks:internal"]) + licenses(["notice"]) -filegroup( - name = "resource_files", - srcs = glob(["res/**"]), - visibility = ["//mediapipe/tasks/examples/android:__subpackages__"], +cc_library( + name = "base_options", + hdrs = ["base_options.h"], +) + +cc_library( + name = "base_options_converter", + srcs = ["base_options_converter.cc"], + hdrs = ["base_options_converter.h"], + deps = [ + ":base_options", + "//mediapipe/tasks/cc/core:base_options", + ], ) diff --git a/mediapipe/tasks/c/core/base_options.h b/mediapipe/tasks/c/core/base_options.h new file mode 100644 index 000000000..d23b6884c --- /dev/null +++ b/mediapipe/tasks/c/core/base_options.h @@ -0,0 +1,36 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_ +#define MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +// Base options for MediaPipe C Tasks. +struct BaseOptions { + // The model asset file contents as a string. + char* model_asset_buffer; + + // The path to the model asset to open and mmap in memory. + char* model_asset_path; +}; + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_ diff --git a/mediapipe/tasks/c/core/base_options_converter.cc b/mediapipe/tasks/c/core/base_options_converter.cc new file mode 100644 index 000000000..c0bdf1539 --- /dev/null +++ b/mediapipe/tasks/c/core/base_options_converter.cc @@ -0,0 +1,29 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/core/base_options_converter.h" + +#include "mediapipe/tasks/cc/core/base_options.h" + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToBaseOptions(BaseOptions in, + mediapipe::tasks::core::BaseOptions* out) { + out->model_asset_buffer = + std::make_unique(in.model_asset_buffer); + out->model_asset_path = in.model_asset_path; +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/core/base_options_converter.h b/mediapipe/tasks/c/core/base_options_converter.h new file mode 100644 index 000000000..0890857fc --- /dev/null +++ b/mediapipe/tasks/c/core/base_options_converter.h @@ -0,0 +1,29 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_CONVERTER_H_ +#define MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_CONVERTER_H_ + +#include "mediapipe/tasks/c/core/base_options.h" +#include "mediapipe/tasks/cc/core/base_options.h" + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToBaseOptions(BaseOptions in, + mediapipe::tasks::core::BaseOptions* out); + +} // namespace mediapipe::tasks::c::components::containers + +#endif // MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_ diff --git a/mediapipe/tasks/c/text/text_classifier/BUILD b/mediapipe/tasks/c/text/text_classifier/BUILD new file mode 100644 index 000000000..e095e2680 --- /dev/null +++ b/mediapipe/tasks/c/text/text_classifier/BUILD @@ -0,0 +1,34 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "text_classifier", + srcs = ["text_classifier.cc"], + hdrs = ["text_classifier.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/c/components/containers:classification_result", + "//mediapipe/tasks/c/components/containers:classification_result_converter", + "//mediapipe/tasks/c/components/processors:classifier_options", + "//mediapipe/tasks/c/components/processors:classifier_options_converter", + "//mediapipe/tasks/c/core:base_options", + "//mediapipe/tasks/c/core:base_options_converter", + "//mediapipe/tasks/cc/text/text_classifier", + "@com_google_absl//absl/log:absl_log", + ], +) diff --git a/mediapipe/tasks/c/text/text_classifier/text_classifier.cc b/mediapipe/tasks/c/text/text_classifier/text_classifier.cc new file mode 100644 index 000000000..388d03b94 --- /dev/null +++ b/mediapipe/tasks/c/text/text_classifier/text_classifier.cc @@ -0,0 +1,99 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/text/text_classifier/text_classifier.h" + +#include + +#include "absl/log/absl_log.h" +#include "mediapipe/tasks/c/components/containers/classification_result.h" +#include "mediapipe/tasks/c/components/containers/classification_result_converter.h" +#include "mediapipe/tasks/c/components/processors/classifier_options.h" +#include "mediapipe/tasks/c/components/processors/classifier_options_converter.h" +#include "mediapipe/tasks/c/core/base_options.h" +#include "mediapipe/tasks/c/core/base_options_converter.h" +#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h" + +namespace mediapipe::tasks::c::text::text_classifier { + +namespace { + +using ::mediapipe::tasks::c::components::containers::CppConvertToBaseOptions; +using ::mediapipe::tasks::c::components::containers:: + CppConvertToClassificationResult; +using ::mediapipe::tasks::c::components::processors:: + CppConvertToClassifierOptions; +using ::mediapipe::tasks::text::text_classifier::TextClassifier; +} // namespace + +TextClassifier* CppTextClassifierCreate(TextClassifierOptions options) { + auto cpp_options = std::make_unique< + ::mediapipe::tasks::text::text_classifier::TextClassifierOptions>(); + + CppConvertToBaseOptions(options.base_options, &cpp_options->base_options); + CppConvertToClassifierOptions(options.classifier_options, + &cpp_options->classifier_options); + + auto classifier = TextClassifier::Create(std::move(cpp_options)); + if (!classifier.ok()) { + ABSL_LOG(ERROR) << "Failed to create TextClassifier: " + << classifier.status(); + return nullptr; + } + return classifier->release(); +} + +bool CppTextClassifierClassify(void* classifier, char* utf8_str, + TextClassifierResult* result) { + auto cpp_classifier = static_cast(classifier); + auto cpp_result = cpp_classifier->Classify(utf8_str); + if (!cpp_result.ok()) { + ABSL_LOG(ERROR) << "Classification failed: " << cpp_result.status(); + return false; + } + CppConvertToClassificationResult(*cpp_result, result); + return true; +} + +void CppTextClassifierClose(void* classifier) { + auto cpp_classifier = static_cast(classifier); + auto result = cpp_classifier->Close(); + if (!result.ok()) { + ABSL_LOG(ERROR) << "Failed to close TextClassifier: " << result; + } + delete cpp_classifier; +} + +} // namespace mediapipe::tasks::c::text::text_classifier + +extern "C" { + +void* text_classifier_create(struct TextClassifierOptions options) { + return mediapipe::tasks::c::text::text_classifier::CppTextClassifierCreate( + options); +} + +bool text_classifier_classify(void* classifier, char* utf8_str, + TextClassifierResult* result) { + return mediapipe::tasks::c::text::text_classifier::CppTextClassifierClassify( + classifier, utf8_str, result); +} + +void text_classifier_close(void* classifier) { + mediapipe::tasks::c::text::text_classifier::CppTextClassifierClose( + classifier); +} + +} // extern "C" diff --git a/mediapipe/tasks/c/text/text_classifier/text_classifier.h b/mediapipe/tasks/c/text/text_classifier/text_classifier.h new file mode 100644 index 000000000..2c084bbed --- /dev/null +++ b/mediapipe/tasks/c/text/text_classifier/text_classifier.h @@ -0,0 +1,54 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_ +#define MEDIAPIPE_TASKS_C_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_ + +#include "mediapipe/tasks/c/components/containers/classification_result.h" +#include "mediapipe/tasks/c/components/processors/classifier_options.h" +#include "mediapipe/tasks/c/core/base_options.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef ClassificationResult TextClassifierResult; + +// The options for configuring a MediaPipe text classifier task. +struct TextClassifierOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + struct BaseOptions base_options; + + // Options for configuring the classifier behavior, such as score threshold, + // number of results, etc. + struct ClassifierOptions classifier_options; +}; + +// Creates a TextClassifier from the provided `options`. +void* text_classifier_create(struct TextClassifierOptions options); + +// Performs classification on the input `text`. +bool text_classifier_classify(void* classifier, char* utf8_str, + TextClassifierResult* result); + +// Shuts down the TextClassifier when all the work is done. Frees all memory. +void text_classifier_close(void* classifier); + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_C_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_ diff --git a/mediapipe/tasks/cc/BUILD b/mediapipe/tasks/cc/BUILD index f49657af3..39df9c55d 100644 --- a/mediapipe/tasks/cc/BUILD +++ b/mediapipe/tasks/cc/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/audio/audio_classifier/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/BUILD index c575caabe..50f587545 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc index 822c3a22f..0ebdea108 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h index dd611ec81..373b31519 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc index b232afc72..b15a23f32 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc index 5f5f8da6c..30b55b8de 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD index bfe37ec01..1b3783d51 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto index cc26b3070..78eb2cf86 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/audio/audio_embedder/BUILD b/mediapipe/tasks/cc/audio/audio_embedder/BUILD index d79a6f01e..7d5a49f8a 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/cc/audio/audio_embedder/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.cc b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.cc index 15bf1fb87..08f08bba1 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.cc +++ b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h index c5f548a60..1035fa0aa 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h +++ b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc index 187f11f7f..a9654947c 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc +++ b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_test.cc b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_test.cc index 81ecb1237..a297e7c45 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_test.cc +++ b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/audio/audio_embedder/proto/BUILD b/mediapipe/tasks/cc/audio/audio_embedder/proto/BUILD index 38df8fb44..3b26138f5 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/proto/BUILD +++ b/mediapipe/tasks/cc/audio/audio_embedder/proto/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto b/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto index 367a1bf26..a6f7275e5 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/audio/core/BUILD b/mediapipe/tasks/cc/audio/core/BUILD index 016faa10f..4f821f6d5 100644 --- a/mediapipe/tasks/cc/audio/core/BUILD +++ b/mediapipe/tasks/cc/audio/core/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,6 +43,7 @@ cc_library( ":base_audio_task_api", "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/tasks/cc/core:task_api_factory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/mediapipe/tasks/cc/audio/core/audio_task_api_factory.h b/mediapipe/tasks/cc/audio/core/audio_task_api_factory.h index 6f5c4ff67..901419a57 100644 --- a/mediapipe/tasks/cc/audio/core/audio_task_api_factory.h +++ b/mediapipe/tasks/cc/audio/core/audio_task_api_factory.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h" +#include "mediapipe/tasks/cc/core/task_api_factory.h" #include "tensorflow/lite/core/api/op_resolver.h" namespace mediapipe { @@ -60,13 +61,8 @@ class AudioTaskApiFactory { "Task graph config should only contain one task subgraph node.", MediaPipeTasksStatus::kInvalidTaskGraphConfigError); } else { - if (!node.options().HasExtension(Options::ext)) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrCat(node.calculator(), - " is missing the required task options field."), - MediaPipeTasksStatus::kInvalidTaskGraphConfigError); - } + MP_RETURN_IF_ERROR( + tasks::core::TaskApiFactory::CheckHasValidOptions(node)); found_task_subgraph = true; } } diff --git a/mediapipe/tasks/cc/audio/core/base_audio_task_api.h b/mediapipe/tasks/cc/audio/core/base_audio_task_api.h index c04b3cf32..ef8254d17 100644 --- a/mediapipe/tasks/cc/audio/core/base_audio_task_api.h +++ b/mediapipe/tasks/cc/audio/core/base_audio_task_api.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/audio/core/running_mode.h b/mediapipe/tasks/cc/audio/core/running_mode.h index 332454f9f..b7e857a58 100644 --- a/mediapipe/tasks/cc/audio/core/running_mode.h +++ b/mediapipe/tasks/cc/audio/core/running_mode.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/audio/utils/BUILD b/mediapipe/tasks/cc/audio/utils/BUILD index 29d88d33d..a25bbe8ac 100644 --- a/mediapipe/tasks/cc/audio/utils/BUILD +++ b/mediapipe/tasks/cc/audio/utils/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/audio/utils/audio_tensor_specs.cc b/mediapipe/tasks/cc/audio/utils/audio_tensor_specs.cc index 8efd94741..6765bc63c 100644 --- a/mediapipe/tasks/cc/audio/utils/audio_tensor_specs.cc +++ b/mediapipe/tasks/cc/audio/utils/audio_tensor_specs.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h b/mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h index 69393a10a..09bfa264f 100644 --- a/mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h +++ b/mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/audio/utils/audio_tensor_specs_test.cc b/mediapipe/tasks/cc/audio/utils/audio_tensor_specs_test.cc index 4f7a5000e..32816a92f 100644 --- a/mediapipe/tasks/cc/audio/utils/audio_tensor_specs_test.cc +++ b/mediapipe/tasks/cc/audio/utils/audio_tensor_specs_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/common.cc b/mediapipe/tasks/cc/common.cc index e7102edc3..cdb069a72 100644 --- a/mediapipe/tasks/cc/common.cc +++ b/mediapipe/tasks/cc/common.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/common.h b/mediapipe/tasks/cc/common.h index 70892c5cd..bee410957 100644 --- a/mediapipe/tasks/cc/common.h +++ b/mediapipe/tasks/cc/common.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index e447f5d72..9046a280d 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -133,6 +133,7 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto index e2ed1788e..fba146f74 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc index c824919df..4fb9eead5 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc index c4b635d24..040a803a1 100644 --- a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc +++ b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc b/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc index 10eb962dd..883ada6cb 100644 --- a/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto b/mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto index 11d944c93..e614c6207 100644 --- a/mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto +++ b/mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto b/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto index fd87383b4..7dd4a6058 100644 --- a/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto +++ b/mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index 816e3c766..b19a178b9 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/containers/category.cc b/mediapipe/tasks/cc/components/containers/category.cc index e07333a7b..65b553842 100644 --- a/mediapipe/tasks/cc/components/containers/category.cc +++ b/mediapipe/tasks/cc/components/containers/category.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/containers/category.h b/mediapipe/tasks/cc/components/containers/category.h index 57b18e7ea..1bff5601d 100644 --- a/mediapipe/tasks/cc/components/containers/category.h +++ b/mediapipe/tasks/cc/components/containers/category.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/containers/classification_result.cc b/mediapipe/tasks/cc/components/containers/classification_result.cc index f2d88406d..bbedc273d 100644 --- a/mediapipe/tasks/cc/components/containers/classification_result.cc +++ b/mediapipe/tasks/cc/components/containers/classification_result.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/containers/classification_result.h b/mediapipe/tasks/cc/components/containers/classification_result.h index e359fb33e..5b736cefd 100644 --- a/mediapipe/tasks/cc/components/containers/classification_result.h +++ b/mediapipe/tasks/cc/components/containers/classification_result.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/containers/detection_result.cc b/mediapipe/tasks/cc/components/containers/detection_result.cc index c5e4cde41..1be7950fd 100644 --- a/mediapipe/tasks/cc/components/containers/detection_result.cc +++ b/mediapipe/tasks/cc/components/containers/detection_result.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/containers/detection_result.h b/mediapipe/tasks/cc/components/containers/detection_result.h index cfddfdb00..c267fc5a7 100644 --- a/mediapipe/tasks/cc/components/containers/detection_result.h +++ b/mediapipe/tasks/cc/components/containers/detection_result.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/containers/embedding_result.cc b/mediapipe/tasks/cc/components/containers/embedding_result.cc index 9de55911b..15e762e38 100644 --- a/mediapipe/tasks/cc/components/containers/embedding_result.cc +++ b/mediapipe/tasks/cc/components/containers/embedding_result.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/containers/embedding_result.h b/mediapipe/tasks/cc/components/containers/embedding_result.h index 2d01d2f2a..4fdb8d24e 100644 --- a/mediapipe/tasks/cc/components/containers/embedding_result.h +++ b/mediapipe/tasks/cc/components/containers/embedding_result.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/containers/keypoint.h b/mediapipe/tasks/cc/components/containers/keypoint.h index dd01037c8..89de36601 100644 --- a/mediapipe/tasks/cc/components/containers/keypoint.h +++ b/mediapipe/tasks/cc/components/containers/keypoint.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/containers/landmark.cc b/mediapipe/tasks/cc/components/containers/landmark.cc index 6d80cb835..1cd0b4c3f 100644 --- a/mediapipe/tasks/cc/components/containers/landmark.cc +++ b/mediapipe/tasks/cc/components/containers/landmark.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/containers/landmark.h b/mediapipe/tasks/cc/components/containers/landmark.h index 5cb57bfb3..63760ae8b 100644 --- a/mediapipe/tasks/cc/components/containers/landmark.h +++ b/mediapipe/tasks/cc/components/containers/landmark.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/containers/proto/BUILD b/mediapipe/tasks/cc/components/containers/proto/BUILD index 27d2357b5..66255aed7 100644 --- a/mediapipe/tasks/cc/components/containers/proto/BUILD +++ b/mediapipe/tasks/cc/components/containers/proto/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/containers/proto/classifications.proto b/mediapipe/tasks/cc/components/containers/proto/classifications.proto index 2b2306829..e4a6ec29d 100644 --- a/mediapipe/tasks/cc/components/containers/proto/classifications.proto +++ b/mediapipe/tasks/cc/components/containers/proto/classifications.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/containers/proto/embeddings.proto b/mediapipe/tasks/cc/components/containers/proto/embeddings.proto index 4f888c699..0b55a1a95 100644 --- a/mediapipe/tasks/cc/components/containers/proto/embeddings.proto +++ b/mediapipe/tasks/cc/components/containers/proto/embeddings.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.proto b/mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.proto index ac44f9b58..2e28d3c02 100644 --- a/mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.proto +++ b/mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/containers/rect.cc b/mediapipe/tasks/cc/components/containers/rect.cc index 4a94832a6..3a733e81e 100644 --- a/mediapipe/tasks/cc/components/containers/rect.cc +++ b/mediapipe/tasks/cc/components/containers/rect.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/containers/rect.h b/mediapipe/tasks/cc/components/containers/rect.h index 72c7a8acb..80af898cf 100644 --- a/mediapipe/tasks/cc/components/containers/rect.h +++ b/mediapipe/tasks/cc/components/containers/rect.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index dfa18e806..dc5aca48a 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -161,3 +161,50 @@ cc_library( ], alwayslink = 1, ) + +cc_library( + name = "detection_postprocessing_graph", + srcs = ["detection_postprocessing_graph.cc"], + hdrs = ["detection_postprocessing_graph.h"], + deps = [ + "//mediapipe/calculators/core:split_vector_calculator", + "//mediapipe/calculators/core:split_vector_calculator_cc_proto", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto", + "//mediapipe/calculators/tflite:ssd_anchors_calculator", + "//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto", + "//mediapipe/calculators/util:non_max_suppression_calculator", + "//mediapipe/calculators/util:non_max_suppression_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/formats/object_detection:anchor_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", + "//mediapipe/tasks/cc/components/processors/proto:detection_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:detector_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "//mediapipe/tasks/metadata:object_detector_metadata_schema_cc", + "//mediapipe/util:label_map_cc_proto", + "//mediapipe/util:label_map_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index cfb3b02cf..525b3d4e5 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -66,7 +66,7 @@ using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::tflite::ProcessUnit; using ::tflite::TensorMetadata; -using LabelItems = mediapipe::proto_ns::Map; +using LabelItems = mediapipe::proto_ns::Map; using TensorsSource = mediapipe::api2::builder::Source>; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); @@ -296,7 +296,7 @@ void ConfigureClassificationAggregationCalculator( if (output_tensors_metadata == nullptr) { return; } - for (const auto& metadata : *output_tensors_metadata) { + for (const auto metadata : *output_tensors_metadata) { options->add_head_names(metadata->name()->str()); } } diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h index 03ae91130..e29aef19c 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc index 6dff64b1b..014053fa0 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -520,7 +520,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) { auto poller, BuildGraph(kQuantizedImageClassifierWithoutMetadata, options)); // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); + std::vector tensor(kMobileNetNumClasses, 0); tensor[1] = 18; tensor[2] = 16; @@ -552,7 +552,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) { MP_ASSERT_OK_AND_ASSIGN( auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options)); // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); + std::vector tensor(kMobileNetNumClasses, 0); tensor[1] = 12; tensor[2] = 14; tensor[3] = 16; @@ -589,7 +589,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { auto poller, BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options)); // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); + std::vector tensor(kMobileNetNumClasses, 0); tensor[1] = 12; tensor[2] = 14; tensor[3] = 16; @@ -677,11 +677,11 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) { auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options, /*connect_timestamps=*/true)); // Build input tensors. - std::vector tensor_0(kMobileNetNumClasses, 0); + std::vector tensor_0(kMobileNetNumClasses, 0); tensor_0[1] = 12; tensor_0[2] = 14; tensor_0[3] = 16; - std::vector tensor_1(kMobileNetNumClasses, 0); + std::vector tensor_1(kMobileNetNumClasses, 0); tensor_1[5] = 12; tensor_1[6] = 14; tensor_1[7] = 16; diff --git a/mediapipe/tasks/cc/components/processors/classifier_options.cc b/mediapipe/tasks/cc/components/processors/classifier_options.cc index 349bb569d..0343db2ec 100644 --- a/mediapipe/tasks/cc/components/processors/classifier_options.cc +++ b/mediapipe/tasks/cc/components/processors/classifier_options.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/processors/classifier_options.h b/mediapipe/tasks/cc/components/processors/classifier_options.h index 189b42e60..598e3dab6 100644 --- a/mediapipe/tasks/cc/components/processors/classifier_options.h +++ b/mediapipe/tasks/cc/components/processors/classifier_options.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.cc new file mode 100644 index 000000000..813a23aeb --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.cc @@ -0,0 +1,887 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.h" + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/absl_log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "mediapipe/calculators/core/split_vector_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" +#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h" +#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h" +#include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/object_detection/anchor.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" +#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" +#include "mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/detector_options.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "mediapipe/tasks/metadata/metadata_schema_generated.h" +#include "mediapipe/tasks/metadata/object_detector_metadata_schema_generated.h" +#include "mediapipe/util/label_map.pb.h" +#include "mediapipe/util/label_map_util.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace processors { + +namespace { + +using ::flatbuffers::Offset; +using ::flatbuffers::Vector; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; +using ::tflite::BoundingBoxProperties; +using ::tflite::ContentProperties; +using ::tflite::ContentProperties_BoundingBoxProperties; +using ::tflite::EnumNameContentProperties; +using ::tflite::ProcessUnit; +using ::tflite::ProcessUnitOptions_ScoreThresholdingOptions; +using ::tflite::TensorMetadata; +using LabelItems = mediapipe::proto_ns::Map; +using TensorsSource = + mediapipe::api2::builder::Source>; + +constexpr int kInModelNmsDefaultLocationsIndex = 0; +constexpr int kInModelNmsDefaultCategoriesIndex = 1; +constexpr int kInModelNmsDefaultScoresIndex = 2; +constexpr int kInModelNmsDefaultNumResultsIndex = 3; + +constexpr int kOutModelNmsDefaultLocationsIndex = 0; +constexpr int kOutModelNmsDefaultScoresIndex = 1; + +constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); + +constexpr absl::string_view kLocationTensorName = "location"; +constexpr absl::string_view kCategoryTensorName = "category"; +constexpr absl::string_view kScoreTensorName = "score"; +constexpr absl::string_view kNumberOfDetectionsTensorName = + "number of detections"; +constexpr absl::string_view kDetectorMetadataName = "DETECTOR_METADATA"; +constexpr absl::string_view kCalibratedScoresTag = "CALIBRATED_SCORES"; +constexpr absl::string_view kDetectionsTag = "DETECTIONS"; +constexpr absl::string_view kIndicesTag = "INDICES"; +constexpr absl::string_view kScoresTag = "SCORES"; +constexpr absl::string_view kTensorsTag = "TENSORS"; +constexpr absl::string_view kAnchorsTag = "ANCHORS"; + +// Struct holding the different output streams produced by the graph. +struct DetectionPostprocessingOutputStreams { + Source> detections; +}; + +// Parameters used for configuring the post-processing calculators. +struct PostProcessingSpecs { + // The maximum number of detection results to return. + int max_results; + // Indices of the output tensors to match the output tensors to the correct + // index order of the output tensors: [location, categories, scores, + // num_detections]. + std::vector output_tensor_indices; + // For each pack of 4 coordinates returned by the model, this denotes the + // order in which to get the left, top, right and bottom coordinates. + std::vector bounding_box_corners_order; + // This is populated by reading the label files from the TFLite Model + // Metadata: if no such files are available, this is left empty and the + // ObjectDetector will only be able to populate the `index` field of the + // detection results. + LabelItems label_items; + // Score threshold. Detections with a confidence below this value are + // discarded. If none is provided via metadata or options, -FLT_MAX is set as + // default value. + float score_threshold; + // Set of category indices to be allowed/denied. + absl::flat_hash_set allow_or_deny_categories; + // Indicates `allow_or_deny_categories` is an allowlist or a denylist. + bool is_allowlist; + // Score calibration options, if any. + std::optional score_calibration_options; +}; + +absl::Status SanityCheckOptions(const proto::DetectorOptions& options) { + if (options.max_results() == 0) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Invalid `max_results` option: value must be != 0", + MediaPipeTasksStatus::kInvalidArgumentError); + } + if (options.category_allowlist_size() > 0 && + options.category_denylist_size() > 0) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "`category_allowlist` and `category_denylist` are mutually " + "exclusive options.", + MediaPipeTasksStatus::kInvalidArgumentError); + } + return absl::OkStatus(); +} + +absl::StatusOr GetBoundingBoxProperties( + const TensorMetadata& tensor_metadata) { + if (tensor_metadata.content() == nullptr || + tensor_metadata.content()->content_properties() == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "Expected BoundingBoxProperties for tensor %s, found none.", + tensor_metadata.name() ? tensor_metadata.name()->str() : "#0"), + MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError); + } + + ContentProperties type = tensor_metadata.content()->content_properties_type(); + if (type != ContentProperties_BoundingBoxProperties) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "Expected BoundingBoxProperties for tensor %s, found %s.", + tensor_metadata.name() ? tensor_metadata.name()->str() : "#0", + EnumNameContentProperties(type)), + MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError); + } + + const BoundingBoxProperties* properties = + tensor_metadata.content()->content_properties_as_BoundingBoxProperties(); + + // Mobile SSD only supports "BOUNDARIES" bounding box type. + if (properties->type() != tflite::BoundingBoxType_BOUNDARIES) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "Mobile SSD only supports BoundingBoxType BOUNDARIES, found %s", + tflite::EnumNameBoundingBoxType(properties->type())), + MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError); + } + + // Mobile SSD only supports "RATIO" coordinates type. + if (properties->coordinate_type() != tflite::CoordinateType_RATIO) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "Mobile SSD only supports CoordinateType RATIO, found %s", + tflite::EnumNameCoordinateType(properties->coordinate_type())), + MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError); + } + + // Index is optional, but must contain 4 values if present. + if (properties->index() != nullptr && properties->index()->size() != 4) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "Expected BoundingBoxProperties index to contain 4 values, found " + "%d", + properties->index()->size()), + MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError); + } + + return properties; +} + +absl::StatusOr GetLabelItemsIfAny( + const ModelMetadataExtractor& metadata_extractor, + const TensorMetadata& tensor_metadata, + tflite::AssociatedFileType associated_file_type, absl::string_view locale) { + const std::string labels_filename = + ModelMetadataExtractor::FindFirstAssociatedFileName(tensor_metadata, + associated_file_type); + if (labels_filename.empty()) { + LabelItems empty_label_items; + return empty_label_items; + } + ASSIGN_OR_RETURN(absl::string_view labels_file, + metadata_extractor.GetAssociatedFile(labels_filename)); + const std::string display_names_filename = + ModelMetadataExtractor::FindFirstAssociatedFileName( + tensor_metadata, associated_file_type, locale); + absl::string_view display_names_file; + if (!display_names_filename.empty()) { + ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile( + display_names_filename)); + } + return mediapipe::BuildLabelMapFromFiles(labels_file, display_names_file); +} + +absl::StatusOr GetScoreThreshold( + const ModelMetadataExtractor& metadata_extractor, + const TensorMetadata& tensor_metadata) { + ASSIGN_OR_RETURN( + const ProcessUnit* score_thresholding_process_unit, + metadata_extractor.FindFirstProcessUnit( + tensor_metadata, ProcessUnitOptions_ScoreThresholdingOptions)); + if (score_thresholding_process_unit == nullptr) { + return kDefaultScoreThreshold; + } + return score_thresholding_process_unit->options_as_ScoreThresholdingOptions() + ->global_score_threshold(); +} + +absl::StatusOr> GetAllowOrDenyCategoryIndicesIfAny( + const proto::DetectorOptions& config, const LabelItems& label_items) { + absl::flat_hash_set category_indices; + // Exit early if no denylist/allowlist. + if (config.category_denylist_size() == 0 && + config.category_allowlist_size() == 0) { + return category_indices; + } + if (label_items.empty()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Using `category_allowlist` or `category_denylist` requires " + "labels to be present in the TFLite Model Metadata but none was found.", + MediaPipeTasksStatus::kMetadataMissingLabelsError); + } + const auto& category_list = config.category_allowlist_size() > 0 + ? config.category_allowlist() + : config.category_denylist(); + for (const auto& category_name : category_list) { + int index = -1; + for (int i = 0; i < label_items.size(); ++i) { + if (label_items.at(i).name() == category_name) { + index = i; + break; + } + } + // Ignores duplicate or unknown categories. + if (index < 0) { + continue; + } + category_indices.insert(index); + } + return category_indices; +} + +absl::StatusOr> +GetScoreCalibrationOptionsIfAny( + const ModelMetadataExtractor& metadata_extractor, + const TensorMetadata& tensor_metadata) { + // Get ScoreCalibrationOptions, if any. + ASSIGN_OR_RETURN( + const ProcessUnit* score_calibration_process_unit, + metadata_extractor.FindFirstProcessUnit( + tensor_metadata, tflite::ProcessUnitOptions_ScoreCalibrationOptions)); + if (score_calibration_process_unit == nullptr) { + return std::nullopt; + } + auto* score_calibration_options = + score_calibration_process_unit->options_as_ScoreCalibrationOptions(); + // Get corresponding AssociatedFile. + auto score_calibration_filename = + metadata_extractor.FindFirstAssociatedFileName( + tensor_metadata, + tflite::AssociatedFileType_TENSOR_AXIS_SCORE_CALIBRATION); + if (score_calibration_filename.empty()) { + return CreateStatusWithPayload( + absl::StatusCode::kNotFound, + "Found ScoreCalibrationOptions but missing required associated " + "parameters file with type TENSOR_AXIS_SCORE_CALIBRATION.", + MediaPipeTasksStatus::kMetadataAssociatedFileNotFoundError); + } + ASSIGN_OR_RETURN( + absl::string_view score_calibration_file, + metadata_extractor.GetAssociatedFile(score_calibration_filename)); + ScoreCalibrationCalculatorOptions score_calibration_calculator_options; + MP_RETURN_IF_ERROR(ConfigureScoreCalibration( + score_calibration_options->score_transformation(), + score_calibration_options->default_score(), score_calibration_file, + &score_calibration_calculator_options)); + return score_calibration_calculator_options; +} + +absl::StatusOr> GetOutputTensorIndices( + const Vector>* tensor_metadatas) { + std::vector output_indices; + if (tensor_metadatas->size() == 4) { + output_indices = { + core::FindTensorIndexByMetadataName(tensor_metadatas, + kLocationTensorName), + core::FindTensorIndexByMetadataName(tensor_metadatas, + kCategoryTensorName), + core::FindTensorIndexByMetadataName(tensor_metadatas, kScoreTensorName), + core::FindTensorIndexByMetadataName(tensor_metadatas, + kNumberOfDetectionsTensorName)}; + // locations, categories, scores, and number of detections + for (int i = 0; i < 4; i++) { + int output_index = output_indices[i]; + // If tensor name is not found, set the default output indices. + if (output_index == -1) { + ABSL_LOG(WARNING) << absl::StrFormat( + "You don't seem to be matching tensor names in metadata list. The " + "tensor name \"%s\" at index %d in the model metadata doesn't " + "match " + "the available output names: [\"%s\", \"%s\", \"%s\", \"%s\"].", + tensor_metadatas->Get(i)->name()->c_str(), i, kLocationTensorName, + kCategoryTensorName, kScoreTensorName, + kNumberOfDetectionsTensorName); + output_indices = { + kInModelNmsDefaultLocationsIndex, kInModelNmsDefaultCategoriesIndex, + kInModelNmsDefaultScoresIndex, kInModelNmsDefaultNumResultsIndex}; + return output_indices; + } + } + } else if (tensor_metadatas->size() == 2) { + output_indices = {core::FindTensorIndexByMetadataName(tensor_metadatas, + kLocationTensorName), + core::FindTensorIndexByMetadataName(tensor_metadatas, + kScoreTensorName)}; + // location, score + for (int i = 0; i < 2; i++) { + int output_index = output_indices[i]; + // If tensor name is not found, set the default output indices. + if (output_index == -1) { + ABSL_LOG(WARNING) << absl::StrFormat( + "You don't seem to be matching tensor names in metadata list. The " + "tensor name \"%s\" at index %d in the model metadata doesn't " + "match " + "the available output names: [\"%s\", \"%s\"].", + tensor_metadatas->Get(i)->name()->c_str(), i, kLocationTensorName, + kScoreTensorName); + output_indices = {kOutModelNmsDefaultLocationsIndex, + kOutModelNmsDefaultScoresIndex}; + return output_indices; + } + } + } else { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "Expected a model with 2 or 4 output tensors metadata, found %d.", + tensor_metadatas->size()), + MediaPipeTasksStatus::kInvalidArgumentError); + } + return output_indices; +} + +// Builds PostProcessingSpecs from DetectorOptions and model metadata for +// configuring the post-processing calculators. +absl::StatusOr BuildPostProcessingSpecs( + const proto::DetectorOptions& options, bool in_model_nms, + const ModelMetadataExtractor* metadata_extractor) { + const auto* output_tensors_metadata = + metadata_extractor->GetOutputTensorMetadata(); + PostProcessingSpecs specs; + specs.max_results = options.max_results(); + ASSIGN_OR_RETURN(specs.output_tensor_indices, + GetOutputTensorIndices(output_tensors_metadata)); + // Extracts mandatory BoundingBoxProperties and performs sanity checks on the + // fly. + ASSIGN_OR_RETURN(const BoundingBoxProperties* bounding_box_properties, + GetBoundingBoxProperties(*output_tensors_metadata->Get( + specs.output_tensor_indices[0]))); + if (bounding_box_properties->index() == nullptr) { + specs.bounding_box_corners_order = {0, 1, 2, 3}; + } else { + auto bounding_box_index = bounding_box_properties->index(); + specs.bounding_box_corners_order = { + bounding_box_index->Get(0), + bounding_box_index->Get(1), + bounding_box_index->Get(2), + bounding_box_index->Get(3), + }; + } + // Builds label map (if available) from metadata. + // For models with in-model-nms, the label map is stored in the Category + // tensor which use TENSOR_VALUE_LABELS. For models with out-of-model-nms, the + // label map is stored in the Score tensor which use TENSOR_AXIS_LABELS. + ASSIGN_OR_RETURN( + specs.label_items, + GetLabelItemsIfAny( + *metadata_extractor, + *output_tensors_metadata->Get(specs.output_tensor_indices[1]), + in_model_nms ? tflite::AssociatedFileType_TENSOR_VALUE_LABELS + : tflite::AssociatedFileType_TENSOR_AXIS_LABELS, + options.display_names_locale())); + // Obtains allow/deny categories. + specs.is_allowlist = !options.category_allowlist().empty(); + ASSIGN_OR_RETURN( + specs.allow_or_deny_categories, + GetAllowOrDenyCategoryIndicesIfAny(options, specs.label_items)); + + // Sets score threshold. + if (options.has_score_threshold()) { + specs.score_threshold = options.score_threshold(); + } else { + ASSIGN_OR_RETURN( + specs.score_threshold, + GetScoreThreshold( + *metadata_extractor, + *output_tensors_metadata->Get( + specs.output_tensor_indices + [in_model_nms ? kInModelNmsDefaultScoresIndex + : kOutModelNmsDefaultScoresIndex]))); + } + if (in_model_nms) { + // Builds score calibration options (if available) from metadata. + ASSIGN_OR_RETURN( + specs.score_calibration_options, + GetScoreCalibrationOptionsIfAny( + *metadata_extractor, + *output_tensors_metadata->Get( + specs.output_tensor_indices[kInModelNmsDefaultScoresIndex]))); + } + return specs; +} + +// Builds PostProcessingSpecs from DetectorOptions and model metadata for +// configuring the post-processing calculators for models with +// non-maximum-suppression. +absl::StatusOr BuildInModelNmsPostProcessingSpecs( + const proto::DetectorOptions& options, + const ModelMetadataExtractor* metadata_extractor) { + // Checks output tensor metadata is present and consistent with model. + auto* output_tensors_metadata = metadata_extractor->GetOutputTensorMetadata(); + if (output_tensors_metadata == nullptr || + output_tensors_metadata->size() != 4) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Mismatch between number of output tensors (4) and " + "output tensors metadata (%d).", + output_tensors_metadata == nullptr + ? 0 + : output_tensors_metadata->size()), + MediaPipeTasksStatus::kMetadataInconsistencyError); + } + return BuildPostProcessingSpecs(options, /*in_model_nms=*/true, + metadata_extractor); +} + +// Fills in the TensorsToDetectionsCalculatorOptions based on +// PostProcessingSpecs. +void ConfigureInModelNmsTensorsToDetectionsCalculator( + const PostProcessingSpecs& specs, + mediapipe::TensorsToDetectionsCalculatorOptions* options) { + options->set_num_classes(specs.label_items.size()); + options->set_num_coords(4); + options->set_min_score_thresh(specs.score_threshold); + if (specs.max_results != -1) { + options->set_max_results(specs.max_results); + } + if (specs.is_allowlist) { + options->mutable_allow_classes()->Assign( + specs.allow_or_deny_categories.begin(), + specs.allow_or_deny_categories.end()); + } else { + options->mutable_ignore_classes()->Assign( + specs.allow_or_deny_categories.begin(), + specs.allow_or_deny_categories.end()); + } + + const auto& output_indices = specs.output_tensor_indices; + // Assigns indices to each the model output tensor. + auto* tensor_mapping = options->mutable_tensor_mapping(); + tensor_mapping->set_detections_tensor_index(output_indices[0]); + tensor_mapping->set_classes_tensor_index(output_indices[1]); + tensor_mapping->set_scores_tensor_index(output_indices[2]); + tensor_mapping->set_num_detections_tensor_index(output_indices[3]); + + // Assigns the bounding box corner order. + auto box_boundaries_indices = options->mutable_box_boundaries_indices(); + box_boundaries_indices->set_xmin(specs.bounding_box_corners_order[0]); + box_boundaries_indices->set_ymin(specs.bounding_box_corners_order[1]); + box_boundaries_indices->set_xmax(specs.bounding_box_corners_order[2]); + box_boundaries_indices->set_ymax(specs.bounding_box_corners_order[3]); +} + +// Builds PostProcessingSpecs from DetectorOptions and model metadata for +// configuring the post-processing calculators for models without +// non-maximum-suppression. +absl::StatusOr BuildOutModelNmsPostProcessingSpecs( + const proto::DetectorOptions& options, + const ModelMetadataExtractor* metadata_extractor) { + // Checks output tensor metadata is present and consistent with model. + auto* output_tensors_metadata = metadata_extractor->GetOutputTensorMetadata(); + if (output_tensors_metadata == nullptr || + output_tensors_metadata->size() != 2) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Mismatch between number of output tensors (2) and " + "output tensors metadata (%d).", + output_tensors_metadata == nullptr + ? 0 + : output_tensors_metadata->size()), + MediaPipeTasksStatus::kMetadataInconsistencyError); + } + return BuildPostProcessingSpecs(options, /*in_model_nms=*/false, + metadata_extractor); +} + +// Configures the TensorsToDetectionCalculator for models without +// non-maximum-suppression in tflite model. The required config parameters are +// extracted from the ObjectDetectorMetadata +// (metadata/object_detector_metadata_schema.fbs). +absl::Status ConfigureOutModelNmsTensorsToDetectionsCalculator( + const ModelMetadataExtractor* metadata_extractor, + const PostProcessingSpecs& specs, + mediapipe::TensorsToDetectionsCalculatorOptions* options) { + bool found_detector_metadata = false; + if (metadata_extractor->GetCustomMetadataList() != nullptr && + metadata_extractor->GetCustomMetadataList()->size() > 0) { + for (const auto* custom_metadata : + *metadata_extractor->GetCustomMetadataList()) { + if (custom_metadata->name()->str() == kDetectorMetadataName) { + found_detector_metadata = true; + const auto* tensors_decoding_options = + GetObjectDetectorOptions(custom_metadata->data()->data()) + ->tensors_decoding_options(); + // Here we don't set the max results for TensorsToDetectionsCalculator. + // For models without nms, the results are filtered by max_results in + // NonMaxSuppressionCalculator. + options->set_num_classes(tensors_decoding_options->num_classes()); + options->set_num_boxes(tensors_decoding_options->num_boxes()); + options->set_num_coords(tensors_decoding_options->num_coords()); + options->set_keypoint_coord_offset( + tensors_decoding_options->keypoint_coord_offset()); + options->set_num_keypoints(tensors_decoding_options->num_keypoints()); + options->set_num_values_per_keypoint( + tensors_decoding_options->num_values_per_keypoint()); + options->set_x_scale(tensors_decoding_options->x_scale()); + options->set_y_scale(tensors_decoding_options->y_scale()); + options->set_w_scale(tensors_decoding_options->w_scale()); + options->set_h_scale(tensors_decoding_options->h_scale()); + options->set_apply_exponential_on_box_size( + tensors_decoding_options->apply_exponential_on_box_size()); + options->set_sigmoid_score(tensors_decoding_options->sigmoid_score()); + break; + } + } + } + if (!found_detector_metadata) { + return absl::InvalidArgumentError( + "TensorsDecodingOptions is not found in the object detector " + "metadata."); + } + // Options not configured through metadata. + options->set_box_format( + mediapipe::TensorsToDetectionsCalculatorOptions::YXHW); + options->set_min_score_thresh(specs.score_threshold); + if (specs.is_allowlist) { + options->mutable_allow_classes()->Assign( + specs.allow_or_deny_categories.begin(), + specs.allow_or_deny_categories.end()); + } else { + options->mutable_ignore_classes()->Assign( + specs.allow_or_deny_categories.begin(), + specs.allow_or_deny_categories.end()); + } + + const auto& output_indices = specs.output_tensor_indices; + // Assigns indices to each the model output tensor. + auto* tensor_mapping = options->mutable_tensor_mapping(); + tensor_mapping->set_detections_tensor_index(output_indices[0]); + tensor_mapping->set_scores_tensor_index(output_indices[1]); + return absl::OkStatus(); +} + +// Configures the SsdAnchorsCalculator for models without +// non-maximum-suppression in tflite model. The required config parameters are +// extracted from the ObjectDetectorMetadata +// (metadata/object_detector_metadata_schema.fbs). +absl::Status ConfigureSsdAnchorsCalculator( + const ModelMetadataExtractor* metadata_extractor, + mediapipe::SsdAnchorsCalculatorOptions* options) { + bool found_detector_metadata = false; + if (metadata_extractor->GetCustomMetadataList() != nullptr && + metadata_extractor->GetCustomMetadataList()->size() > 0) { + for (const auto* custom_metadata : + *metadata_extractor->GetCustomMetadataList()) { + if (custom_metadata->name()->str() == kDetectorMetadataName) { + found_detector_metadata = true; + const auto* ssd_anchors_options = + GetObjectDetectorOptions(custom_metadata->data()->data()) + ->ssd_anchors_options(); + for (const auto* ssd_anchor : + *ssd_anchors_options->fixed_anchors_schema()->anchors()) { + auto* fixed_anchor = options->add_fixed_anchors(); + fixed_anchor->set_y_center(ssd_anchor->y_center()); + fixed_anchor->set_x_center(ssd_anchor->x_center()); + fixed_anchor->set_h(ssd_anchor->height()); + fixed_anchor->set_w(ssd_anchor->width()); + } + break; + } + } + } + if (!found_detector_metadata) { + return absl::InvalidArgumentError( + "SsdAnchorsOptions is not found in the object detector " + "metadata."); + } + return absl::OkStatus(); +} + +// Sets the default IoU-based non-maximum-suppression configs, and set the +// min_suppression_threshold and max_results for detection models without +// non-maximum-suppression. +void ConfigureNonMaxSuppressionCalculator( + const proto::DetectorOptions& detector_options, + mediapipe::NonMaxSuppressionCalculatorOptions* options) { + options->set_min_suppression_threshold( + detector_options.min_suppression_threshold()); + options->set_overlap_type( + mediapipe::NonMaxSuppressionCalculatorOptions::INTERSECTION_OVER_UNION); + options->set_algorithm( + mediapipe::NonMaxSuppressionCalculatorOptions::DEFAULT); + options->set_max_num_detections(detector_options.max_results()); +} + +// Sets the labels from post PostProcessingSpecs. +void ConfigureDetectionLabelIdToTextCalculator( + PostProcessingSpecs& specs, + mediapipe::DetectionLabelIdToTextCalculatorOptions* options) { + *options->mutable_label_items() = std::move(specs.label_items); +} + +// Splits the vector of 4 output tensors from model inference and calibrate the +// score tensors according to the metadata, if any. Then concatenate the tensors +// back to a vector of 4 tensors. +absl::StatusOr>> CalibrateScores( + Source> model_output_tensors, + const proto::DetectionPostprocessingGraphOptions& options, Graph& graph) { + // Split tensors. + auto* split_tensor_vector_node = + &graph.AddNode("SplitTensorVectorCalculator"); + auto& split_tensor_vector_options = + split_tensor_vector_node + ->GetOptions(); + for (int i = 0; i < 4; ++i) { + auto* range = split_tensor_vector_options.add_ranges(); + range->set_begin(i); + range->set_end(i + 1); + } + model_output_tensors >> split_tensor_vector_node->In(0); + + // Add score calibration calculator. + auto* score_calibration_node = &graph.AddNode("ScoreCalibrationCalculator"); + score_calibration_node->GetOptions() + .CopyFrom(options.score_calibration_options()); + const auto& tensor_mapping = + options.tensors_to_detections_options().tensor_mapping(); + split_tensor_vector_node->Out(tensor_mapping.classes_tensor_index()) >> + score_calibration_node->In(kIndicesTag); + split_tensor_vector_node->Out(tensor_mapping.scores_tensor_index()) >> + score_calibration_node->In(kScoresTag); + + // Re-concatenate tensors. + auto* concatenate_tensor_vector_node = + &graph.AddNode("ConcatenateTensorVectorCalculator"); + for (int i = 0; i < 4; ++i) { + if (i == tensor_mapping.scores_tensor_index()) { + score_calibration_node->Out(kCalibratedScoresTag) >> + concatenate_tensor_vector_node->In(i); + } else { + split_tensor_vector_node->Out(i) >> concatenate_tensor_vector_node->In(i); + } + } + model_output_tensors = + concatenate_tensor_vector_node->Out(0).Cast>(); + return model_output_tensors; +} + +} // namespace + +absl::Status ConfigureDetectionPostprocessingGraph( + const tasks::core::ModelResources& model_resources, + const proto::DetectorOptions& detector_options, + proto::DetectionPostprocessingGraphOptions& options) { + MP_RETURN_IF_ERROR(SanityCheckOptions(detector_options)); + const auto& model = *model_resources.GetTfLiteModel(); + bool in_model_nms = false; + if (model.subgraphs()->size() != 1) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Expected a model with a single subgraph, found %d.", + model.subgraphs()->size()), + MediaPipeTasksStatus::kInvalidArgumentError); + } + if (model.subgraphs()->Get(0)->outputs()->size() == 2) { + in_model_nms = false; + } else if (model.subgraphs()->Get(0)->outputs()->size() == 4) { + in_model_nms = true; + } else { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "Expected a model with 2 or 4 output tensors, found %d.", + model.subgraphs()->Get(0)->outputs()->size()), + MediaPipeTasksStatus::kInvalidArgumentError); + } + + const ModelMetadataExtractor* metadata_extractor = + model_resources.GetMetadataExtractor(); + if (in_model_nms) { + ASSIGN_OR_RETURN(auto post_processing_specs, + BuildInModelNmsPostProcessingSpecs(detector_options, + metadata_extractor)); + ConfigureInModelNmsTensorsToDetectionsCalculator( + post_processing_specs, options.mutable_tensors_to_detections_options()); + ConfigureDetectionLabelIdToTextCalculator( + post_processing_specs, + options.mutable_detection_label_ids_to_text_options()); + if (post_processing_specs.score_calibration_options.has_value()) { + *options.mutable_score_calibration_options() = + std::move(*post_processing_specs.score_calibration_options); + } + } else { + ASSIGN_OR_RETURN(auto post_processing_specs, + BuildOutModelNmsPostProcessingSpecs(detector_options, + metadata_extractor)); + MP_RETURN_IF_ERROR(ConfigureOutModelNmsTensorsToDetectionsCalculator( + metadata_extractor, post_processing_specs, + options.mutable_tensors_to_detections_options())); + MP_RETURN_IF_ERROR(ConfigureSsdAnchorsCalculator( + metadata_extractor, options.mutable_ssd_anchors_options())); + ConfigureNonMaxSuppressionCalculator( + detector_options, options.mutable_non_max_suppression_options()); + ConfigureDetectionLabelIdToTextCalculator( + post_processing_specs, + options.mutable_detection_label_ids_to_text_options()); + } + + return absl::OkStatus(); +} + +// A DetectionPostprocessingGraph converts raw tensors into +// std::vector. +// +// Inputs: +// TENSORS - std::vector +// The output tensors of an InferenceCalculator. The tensors vector could be +// size 4 or size 2. Tensors vector of size 4 expects the tensors from the +// models with DETECTION_POSTPROCESS ops in the tflite graph. Tensors vector +// of size 2 expects the tensors from the models without the ops. +// [1]: +// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/detection_postprocess.cc +// Outputs: +// DETECTIONS - std::vector +// The postprocessed detection results. +// +// The recommended way of using this graph is through the GraphBuilder API +// using the 'ConfigureDetectionPostprocessingGraph()' function. See header +// file for more details. +class DetectionPostprocessingGraph : public mediapipe::Subgraph { + public: + absl::StatusOr GetConfig( + mediapipe::SubgraphContext* sc) override { + Graph graph; + ASSIGN_OR_RETURN( + auto output_streams, + BuildDetectionPostprocessing( + *sc->MutableOptions(), + graph.In(kTensorsTag).Cast>(), graph)); + output_streams.detections >> + graph.Out(kDetectionsTag).Cast>(); + return graph.GetConfig(); + } + + private: + // Adds an on-device detection postprocessing graph into the provided + // builder::Graph instance. The detection postprocessing graph takes + // tensors (std::vector) as input and returns one output + // stream: + // - Detection results as a std::vector. + // + // graph_options: the on-device DetectionPostprocessingGraphOptions. + // tensors_in: (std::vector>) tensors to postprocess. + // graph: the mediapipe builder::Graph instance to be updated. + absl::StatusOr + BuildDetectionPostprocessing( + proto::DetectionPostprocessingGraphOptions& graph_options, + Source> tensors_in, Graph& graph) { + std::optional>> detections; + if (!graph_options.has_non_max_suppression_options()) { + // Calculators to perform score calibration, if specified in the options. + if (graph_options.has_score_calibration_options()) { + ASSIGN_OR_RETURN(tensors_in, + CalibrateScores(tensors_in, graph_options, graph)); + } + // Calculator to convert output tensors to a detection proto vector. + auto& tensors_to_detections = + graph.AddNode("TensorsToDetectionsCalculator"); + tensors_to_detections + .GetOptions() + .Swap(graph_options.mutable_tensors_to_detections_options()); + tensors_in >> tensors_to_detections.In(kTensorsTag); + detections = tensors_to_detections.Out(kDetectionsTag) + .Cast>(); + } else { + // Generates a single side packet containing a vector of SSD anchors. + auto& ssd_anchor = graph.AddNode("SsdAnchorsCalculator"); + ssd_anchor.GetOptions().Swap( + graph_options.mutable_ssd_anchors_options()); + auto anchors = + ssd_anchor.SideOut("").Cast>(); + // Convert raw output tensors to detections. + auto& tensors_to_detections = + graph.AddNode("TensorsToDetectionsCalculator"); + tensors_to_detections + .GetOptions() + .Swap(graph_options.mutable_tensors_to_detections_options()); + anchors >> tensors_to_detections.SideIn(kAnchorsTag); + tensors_in >> tensors_to_detections.In(kTensorsTag); + detections = tensors_to_detections.Out(kDetectionsTag) + .Cast>(); + // Non maximum suppression removes redundant object detections. + auto& non_maximum_suppression = + graph.AddNode("NonMaxSuppressionCalculator"); + non_maximum_suppression + .GetOptions() + .Swap(graph_options.mutable_non_max_suppression_options()); + *detections >> non_maximum_suppression.In(""); + detections = + non_maximum_suppression.Out("").Cast>(); + } + + // Calculator to assign detection labels. + auto& detection_label_id_to_text = + graph.AddNode("DetectionLabelIdToTextCalculator"); + detection_label_id_to_text + .GetOptions() + .Swap(graph_options.mutable_detection_label_ids_to_text_options()); + *detections >> detection_label_id_to_text.In(""); + return { + {detection_label_id_to_text.Out("").Cast>()}}; + } +}; + +// REGISTER_MEDIAPIPE_GRAPH argument has to fit on one line to work properly. +// clang-format off +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::components::processors::DetectionPostprocessingGraph); // NOLINT +// clang-format on + +} // namespace processors +} // namespace components +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.h b/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.h new file mode 100644 index 000000000..1696b844f --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.h @@ -0,0 +1,62 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/detector_options.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_DETECTION_POSTPROCESSING_GRAPH_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_DETECTION_POSTPROCESSING_GRAPH_H_ + +namespace mediapipe { +namespace tasks { +namespace components { +namespace processors { + +// Configures a DetectionPostprocessingGraph using the provided model +// resources and DetectorOptions. +// +// Example usage: +// +// auto& postprocessing = +// graph.AddNode("mediapipe.tasks.components.processors.DetectionPostprocessingGraph"); +// MP_RETURN_IF_ERROR(ConfigureDetectionPostprocessingGraph( +// model_resources, +// detector_options, +// &preprocessing.GetOptions())); +// +// The resulting DetectionPostprocessingGraph has the following I/O: +// Inputs: +// TENSORS - std::vector +// The output tensors of an InferenceCalculator. The tensors vector could be +// size 4 or size 2. Tensors vector of size 4 expects the tensors from the +// models with DETECTION_POSTPROCESS ops in the tflite graph. Tensors vector +// of size 2 expects the tensors from the models without the ops. +// [1]: +// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/detection_postprocess.cc +// Outputs: +// DETECTIONS - std::vector +// The postprocessed detection results. +absl::Status ConfigureDetectionPostprocessingGraph( + const tasks::core::ModelResources& model_resources, + const proto::DetectorOptions& detector_options, + proto::DetectionPostprocessingGraphOptions& options); + +} // namespace processors +} // namespace components +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_DETECTION_POSTPROCESSING_GRAPH_H_ diff --git a/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph_test.cc new file mode 100644 index 000000000..36aead0c1 --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph_test.cc @@ -0,0 +1,570 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.h" + +#include + +#include "absl/flags/flag.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/graph_runner.h" +#include "mediapipe/framework/output_stream_poller.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/detector_options.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "tensorflow/lite/test_util.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace processors { +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::core::ModelResources; +using ::testing::ElementsAre; +using ::testing::HasSubstr; +using ::testing::Pointwise; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr absl::string_view kTestDataDirectory = + "/mediapipe/tasks/testdata/vision"; +constexpr absl::string_view kMobileSsdWithMetadata = + "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"; +constexpr absl::string_view kMobileSsdWithDummyScoreCalibration = + "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration." + "tflite"; +constexpr absl::string_view kEfficientDetWithoutNms = + "efficientdet_lite0_fp16_no_nms.tflite"; + +constexpr char kTestModelResourcesTag[] = "test_model_resources"; + +constexpr absl::string_view kTensorsTag = "TENSORS"; +constexpr absl::string_view kDetectionsTag = "DETECTIONS"; +constexpr absl::string_view kTensorsName = "tensors"; +constexpr absl::string_view kDetectionsName = "detections"; + +// Helper function to get ModelResources. +absl::StatusOr> CreateModelResourcesForModel( + absl::string_view model_name) { + auto external_file = std::make_unique(); + external_file->set_file_name(JoinPath("./", kTestDataDirectory, model_name)); + return ModelResources::Create(kTestModelResourcesTag, + std::move(external_file)); +} + +class ConfigureTest : public tflite::testing::Test {}; + +TEST_F(ConfigureTest, FailsWithInvalidMaxResults) { + MP_ASSERT_OK_AND_ASSIGN(auto model_resources, + CreateModelResourcesForModel(kMobileSsdWithMetadata)); + proto::DetectorOptions options_in; + options_in.set_max_results(0); + + proto::DetectionPostprocessingGraphOptions options_out; + auto status = ConfigureDetectionPostprocessingGraph(*model_resources, + options_in, options_out); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), HasSubstr("Invalid `max_results` option")); +} + +TEST_F(ConfigureTest, FailsWithBothAllowlistAndDenylist) { + MP_ASSERT_OK_AND_ASSIGN(auto model_resources, + CreateModelResourcesForModel(kMobileSsdWithMetadata)); + proto::DetectorOptions options_in; + options_in.add_category_allowlist("foo"); + options_in.add_category_denylist("bar"); + + proto::DetectionPostprocessingGraphOptions options_out; + auto status = ConfigureDetectionPostprocessingGraph(*model_resources, + options_in, options_out); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), HasSubstr("mutually exclusive options")); +} + +TEST_F(ConfigureTest, SucceedsWithMaxResults) { + MP_ASSERT_OK_AND_ASSIGN(auto model_resources, + CreateModelResourcesForModel(kMobileSsdWithMetadata)); + proto::DetectorOptions options_in; + options_in.set_max_results(3); + + proto::DetectionPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources, + options_in, options_out)); + + EXPECT_THAT( + options_out, + Approximately(Partially(EqualsProto( + R"pb(tensors_to_detections_options { + min_score_thresh: -3.4028235e+38 + num_classes: 90 + num_coords: 4 + max_results: 3 + tensor_mapping { + detections_tensor_index: 0 + classes_tensor_index: 1 + scores_tensor_index: 2 + num_detections_tensor_index: 3 + } + box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 } + } + )pb")))); +} + +TEST_F(ConfigureTest, SucceedsWithMaxResultsWithoutModelNms) { + MP_ASSERT_OK_AND_ASSIGN(auto model_resources, CreateModelResourcesForModel( + kEfficientDetWithoutNms)); + proto::DetectorOptions options_in; + options_in.set_max_results(3); + + proto::DetectionPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources, + options_in, options_out)); + EXPECT_THAT(options_out, Approximately(Partially(EqualsProto( + R"pb(tensors_to_detections_options { + min_score_thresh: -3.4028235e+38 + num_classes: 90 + num_boxes: 19206 + num_coords: 4 + x_scale: 1 + y_scale: 1 + w_scale: 1 + h_scale: 1 + keypoint_coord_offset: 0 + num_keypoints: 0 + num_values_per_keypoint: 2 + apply_exponential_on_box_size: true + sigmoid_score: false + tensor_mapping { + detections_tensor_index: 1 + scores_tensor_index: 0 + } + box_format: YXHW + } + non_max_suppression_options { + max_num_detections: 3 + min_suppression_threshold: 0 + overlap_type: INTERSECTION_OVER_UNION + algorithm: DEFAULT + } + )pb")))); + EXPECT_THAT( + options_out.detection_label_ids_to_text_options().label_items_size(), 90); +} + +TEST_F(ConfigureTest, SucceedsWithScoreThreshold) { + MP_ASSERT_OK_AND_ASSIGN(auto model_resources, + CreateModelResourcesForModel(kMobileSsdWithMetadata)); + proto::DetectorOptions options_in; + options_in.set_score_threshold(0.5); + + proto::DetectionPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources, + options_in, options_out)); + EXPECT_THAT( + options_out, + Approximately(Partially(EqualsProto( + R"pb(tensors_to_detections_options { + min_score_thresh: 0.5 + num_classes: 90 + num_coords: 4 + tensor_mapping { + detections_tensor_index: 0 + classes_tensor_index: 1 + scores_tensor_index: 2 + num_detections_tensor_index: 3 + } + box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 } + } + )pb")))); + EXPECT_THAT( + options_out.detection_label_ids_to_text_options().label_items_size(), 90); +} + +TEST_F(ConfigureTest, SucceedsWithAllowlist) { + MP_ASSERT_OK_AND_ASSIGN(auto model_resources, + CreateModelResourcesForModel(kMobileSsdWithMetadata)); + proto::DetectorOptions options_in; + options_in.add_category_allowlist("bicycle"); + proto::DetectionPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources, + options_in, options_out)); + // Clear labels ids to text and compare the rest of the options. + options_out.clear_detection_label_ids_to_text_options(); + EXPECT_THAT( + options_out, + Approximately(EqualsProto( + R"pb(tensors_to_detections_options { + min_score_thresh: -3.4028235e+38 + num_classes: 90 + num_coords: 4 + allow_classes: 1 + tensor_mapping { + detections_tensor_index: 0 + classes_tensor_index: 1 + scores_tensor_index: 2 + num_detections_tensor_index: 3 + } + box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 } + } + )pb"))); +} + +TEST_F(ConfigureTest, SucceedsWithDenylist) { + MP_ASSERT_OK_AND_ASSIGN(auto model_resources, + CreateModelResourcesForModel(kMobileSsdWithMetadata)); + proto::DetectorOptions options_in; + options_in.add_category_denylist("person"); + proto::DetectionPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources, + options_in, options_out)); + // Clear labels ids to text and compare the rest of the options. + options_out.clear_detection_label_ids_to_text_options(); + EXPECT_THAT( + options_out, + Approximately(EqualsProto( + R"pb(tensors_to_detections_options { + min_score_thresh: -3.4028235e+38 + num_classes: 90 + num_coords: 4 + ignore_classes: 0 + tensor_mapping { + detections_tensor_index: 0 + classes_tensor_index: 1 + scores_tensor_index: 2 + num_detections_tensor_index: 3 + } + box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 } + } + )pb"))); +} + +TEST_F(ConfigureTest, SucceedsWithScoreCalibration) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileSsdWithDummyScoreCalibration)); + proto::DetectorOptions options_in; + proto::DetectionPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources, + options_in, options_out)); + // Clear labels ids to text. + options_out.clear_detection_label_ids_to_text_options(); + // Check sigmoids size and first element. + ASSERT_EQ(options_out.score_calibration_options().sigmoids_size(), 89); + EXPECT_THAT(options_out.score_calibration_options().sigmoids()[0], + EqualsProto(R"pb(scale: 1.0 slope: 1.0 offset: 0.0)pb")); + options_out.mutable_score_calibration_options()->clear_sigmoids(); + // Compare the rest of the option. + EXPECT_THAT( + options_out, + Approximately(EqualsProto( + R"pb(tensors_to_detections_options { + min_score_thresh: -3.4028235e+38 + num_classes: 90 + num_coords: 4 + tensor_mapping { + detections_tensor_index: 0 + classes_tensor_index: 1 + scores_tensor_index: 2 + num_detections_tensor_index: 3 + } + box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 } + } + score_calibration_options { + score_transformation: IDENTITY + default_score: 0.5 + } + )pb"))); +} + +class PostprocessingTest : public tflite::testing::Test { + protected: + absl::StatusOr BuildGraph( + absl::string_view model_name, const proto::DetectorOptions& options) { + ASSIGN_OR_RETURN(auto model_resources, + CreateModelResourcesForModel(model_name)); + + Graph graph; + auto& postprocessing = graph.AddNode( + "mediapipe.tasks.components.processors." + "DetectionPostprocessingGraph"); + MP_RETURN_IF_ERROR(ConfigureDetectionPostprocessingGraph( + *model_resources, options, + postprocessing + .GetOptions())); + graph[Input>(kTensorsTag)].SetName( + std::string(kTensorsName)) >> + postprocessing.In(kTensorsTag); + postprocessing.Out(kDetectionsTag).SetName(std::string(kDetectionsName)) >> + graph[Output>(kDetectionsTag)]; + MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig())); + ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller( + std::string(kDetectionsName))); + MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{})); + return poller; + } + + template + void AddTensor(const std::vector& tensor, + const Tensor::ElementType& element_type, + const Tensor::Shape& shape) { + tensors_->emplace_back(element_type, shape); + auto view = tensors_->back().GetCpuWriteView(); + T* buffer = view.buffer(); + std::copy(tensor.begin(), tensor.end(), buffer); + } + + absl::Status Run(int timestamp = 0) { + MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( + std::string(kTensorsName), + Adopt(tensors_.release()).At(Timestamp(timestamp)))); + // Reset tensors for future calls. + tensors_ = absl::make_unique>(); + return absl::OkStatus(); + } + + template + absl::StatusOr GetResult(OutputStreamPoller& poller) { + MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle()); + MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams()); + + Packet packet; + if (!poller.Next(&packet)) { + return absl::InternalError("Unable to get output packet"); + } + auto result = packet.Get(); + MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone()); + return result; + } + + private: + CalculatorGraph calculator_graph_; + std::unique_ptr> tensors_ = + absl::make_unique>(); +}; + +TEST_F(PostprocessingTest, SucceedsWithMetadata) { + // Build graph. + proto::DetectorOptions options; + options.set_max_results(3); + MP_ASSERT_OK_AND_ASSIGN(auto poller, + BuildGraph(kMobileSsdWithMetadata, options)); + + // Build input tensors. + constexpr int kBboxesNum = 5; + // Location tensor. + std::vector location_tensor(kBboxesNum * 4, 0); + for (int i = 0; i < kBboxesNum; ++i) { + location_tensor[i * 4] = 0.1f; + location_tensor[i * 4 + 1] = 0.1f; + location_tensor[i * 4 + 2] = 0.4f; + location_tensor[i * 4 + 3] = 0.5f; + } + // Category tensor. + std::vector category_tensor(kBboxesNum, 0); + for (int i = 0; i < kBboxesNum; ++i) { + category_tensor[i] = i + 1; + } + + // Score tensor. Post processed tensor scores are in descending order. + std::vector score_tensor(kBboxesNum, 0); + for (int i = 0; i < kBboxesNum; ++i) { + score_tensor[i] = static_cast(kBboxesNum - i) / kBboxesNum; + } + + // Number of detections tensor. + std::vector num_detections_tensor(1, 0); + num_detections_tensor[0] = kBboxesNum; + + // Send tensors and get results. + AddTensor(location_tensor, Tensor::ElementType::kFloat32, {1, kBboxesNum, 4}); + AddTensor(category_tensor, Tensor::ElementType::kFloat32, {1, kBboxesNum}); + AddTensor(score_tensor, Tensor::ElementType::kFloat32, {1, kBboxesNum}); + AddTensor(num_detections_tensor, Tensor::ElementType::kFloat32, {1}); + MP_ASSERT_OK(Run()); + + // Validate results. + EXPECT_THAT(GetResult>(poller), + IsOkAndHolds(ElementsAre(Approximately(EqualsProto( + R"pb( + label: "bicycle" + score: 1 + location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.1 + ymin: 0.1 + width: 0.4 + height: 0.3 + } + } + )pb")), + Approximately(EqualsProto( + R"pb( + label: "car" + score: 0.8 + location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.1 + ymin: 0.1 + width: 0.4 + height: 0.3 + } + } + )pb")), + Approximately(EqualsProto( + R"pb( + label: "motorcycle" + score: 0.6 + location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.1 + ymin: 0.1 + width: 0.4 + height: 0.3 + } + } + )pb"))))); +} + +TEST_F(PostprocessingTest, SucceedsWithOutModelNms) { + // Build graph. + proto::DetectorOptions options; + options.set_max_results(3); + MP_ASSERT_OK_AND_ASSIGN(auto poller, + BuildGraph(kEfficientDetWithoutNms, options)); + + // Build input tensors. + constexpr int kBboxesNum = 19206; + constexpr int kBicycleBboxIdx = 1000; + constexpr int kCarBboxIdx = 2000; + constexpr int kMotoCycleBboxIdx = 4000; + // Location tensor. + std::vector location_tensor(kBboxesNum * 4, 0); + for (int i = 0; i < kBboxesNum; ++i) { + location_tensor[i * 4] = 0.5f; + location_tensor[i * 4 + 1] = 0.5f; + location_tensor[i * 4 + 2] = 0.001f; + location_tensor[i * 4 + 3] = 0.001f; + } + + // Detected three objects. + location_tensor[kBicycleBboxIdx * 4] = 0.7f; + location_tensor[kBicycleBboxIdx * 4 + 1] = 0.8f; + location_tensor[kBicycleBboxIdx * 4 + 2] = 0.2f; + location_tensor[kBicycleBboxIdx * 4 + 3] = 0.1f; + + location_tensor[kCarBboxIdx * 4] = 0.1f; + location_tensor[kCarBboxIdx * 4 + 1] = 0.1f; + location_tensor[kCarBboxIdx * 4 + 2] = 0.1f; + location_tensor[kCarBboxIdx * 4 + 3] = 0.1f; + + location_tensor[kMotoCycleBboxIdx * 4] = 0.2f; + location_tensor[kMotoCycleBboxIdx * 4 + 1] = 0.8f; + location_tensor[kMotoCycleBboxIdx * 4 + 2] = 0.1f; + location_tensor[kMotoCycleBboxIdx * 4 + 3] = 0.2f; + + // Score tensor. + constexpr int kClassesNum = 90; + std::vector score_tensor(kBboxesNum * kClassesNum, 1.f / kClassesNum); + + // Detected three objects. + score_tensor[kBicycleBboxIdx * kClassesNum + 1] = 1.0f; // bicycle. + score_tensor[kCarBboxIdx * kClassesNum + 2] = 0.9f; // car. + score_tensor[kMotoCycleBboxIdx * kClassesNum + 3] = 0.8f; // motorcycle. + + // Send tensors and get results. + AddTensor(score_tensor, Tensor::ElementType::kFloat32, {1, kBboxesNum, 90}); + AddTensor(location_tensor, Tensor::ElementType::kFloat32, {1, kBboxesNum, 4}); + MP_ASSERT_OK(Run()); + + // Validate results. + EXPECT_THAT(GetResult>(poller), + IsOkAndHolds(ElementsAre(Approximately(EqualsProto( + R"pb( + label: "bicycle" + score: 1 + location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.8137423 + ymin: 0.067235775 + width: 0.117221 + height: 0.064774655 + } + } + )pb")), + Approximately(EqualsProto( + R"pb( + label: "car" + score: 0.9 + location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.53849804 + ymin: 0.08949606 + width: 0.05861056 + height: 0.11722109 + } + } + )pb")), + Approximately(EqualsProto( + R"pb( + label: "motorcycle" + score: 0.8 + location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.13779688 + ymin: 0.26394117 + width: 0.16322193 + height: 0.07384467 + } + } + )pb"))))); +} + +} // namespace +} // namespace processors +} // namespace components +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/processors/embedder_options.cc b/mediapipe/tasks/cc/components/processors/embedder_options.cc index fce7baa44..2f2121dd6 100644 --- a/mediapipe/tasks/cc/components/processors/embedder_options.cc +++ b/mediapipe/tasks/cc/components/processors/embedder_options.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/processors/embedder_options.h b/mediapipe/tasks/cc/components/processors/embedder_options.h index b37171592..483e48670 100644 --- a/mediapipe/tasks/cc/components/processors/embedder_options.h +++ b/mediapipe/tasks/cc/components/processors/embedder_options.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc index 7b023ba41..ec28d6294 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h index 889992463..bb96dfbd5 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc index 94a2a7f3f..0f0710405 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc index 7a45bb148..1040701c4 100644 --- a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h index 455a9b316..2f7473fe4 100644 --- a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc index c69a51a65..2e182ee85 100644 --- a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/processors/proto/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD index 816ba47e3..a45c91633 100644 --- a/mediapipe/tasks/cc/components/processors/proto/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,6 +23,11 @@ mediapipe_proto_library( srcs = ["classifier_options.proto"], ) +mediapipe_proto_library( + name = "detector_options_proto", + srcs = ["detector_options.proto"], +) + mediapipe_proto_library( name = "classification_postprocessing_graph_options_proto", srcs = ["classification_postprocessing_graph_options.proto"], @@ -35,6 +40,20 @@ mediapipe_proto_library( ], ) +mediapipe_proto_library( + name = "detection_postprocessing_graph_options_proto", + srcs = ["detection_postprocessing_graph_options.proto"], + deps = [ + "//mediapipe/calculators/tensor:tensors_to_detections_calculator_proto", + "//mediapipe/calculators/tflite:ssd_anchors_calculator_proto", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator_proto", + "//mediapipe/calculators/util:non_max_suppression_calculator_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto", + ], +) + mediapipe_proto_library( name = "embedder_options_proto", srcs = ["embedder_options.proto"], @@ -74,3 +93,8 @@ mediapipe_proto_library( "//mediapipe/framework:calculator_proto", ], ) + +mediapipe_proto_library( + name = "transformer_params_proto", + srcs = ["transformer_params.proto"], +) diff --git a/mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.proto index 84ba95222..e393f58ca 100644 --- a/mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/processors/proto/classifier_options.proto b/mediapipe/tasks/cc/components/processors/proto/classifier_options.proto index 12ece7249..345c7e1fc 100644 --- a/mediapipe/tasks/cc/components/processors/proto/classifier_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/classifier_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.proto new file mode 100644 index 000000000..ec11df2b4 --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.proto @@ -0,0 +1,49 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package mediapipe.tasks.components.processors.proto; + +import "mediapipe/calculators/tensor/tensors_to_detections_calculator.proto"; +import "mediapipe/calculators/tflite/ssd_anchors_calculator.proto"; +import "mediapipe/calculators/util/detection_label_id_to_text_calculator.proto"; +import "mediapipe/calculators/util/non_max_suppression_calculator.proto"; +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto"; + +message DetectionPostprocessingGraphOptions { + // Optional SsdAnchorsCalculatorOptions for models without + // non-maximum-suppression in tflite model graph. + optional mediapipe.SsdAnchorsCalculatorOptions ssd_anchors_options = 1; + + // Optional TensorsToDetectionsCalculatorOptions for models without + // non-maximum-suppression in tflite model graph. + optional mediapipe.TensorsToDetectionsCalculatorOptions + tensors_to_detections_options = 2; + + // Optional NonMaxSuppressionCalculatorOptions for models without + // non-maximum-suppression in tflite model graph. + optional mediapipe.NonMaxSuppressionCalculatorOptions + non_max_suppression_options = 3; + + // Optional score calibration options for models with non-maximum-suppression + // in tflite model graph. + optional ScoreCalibrationCalculatorOptions score_calibration_options = 4; + + // Optional detection label id to text calculator options. + optional mediapipe.DetectionLabelIdToTextCalculatorOptions + detection_label_ids_to_text_options = 5; +} diff --git a/mediapipe/tasks/cc/components/processors/proto/detector_options.proto b/mediapipe/tasks/cc/components/processors/proto/detector_options.proto new file mode 100644 index 000000000..c70b1f7a6 --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/proto/detector_options.proto @@ -0,0 +1,52 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.components.processors.proto; + +option java_package = "com.google.mediapipe.tasks.components.processors.proto"; +option java_outer_classname = "DetectorOptionsProto"; + +// Shared options used by all detection tasks. +message DetectorOptions { + // The locale to use for display names specified through the TFLite Model + // Metadata, if any. Defaults to English. + optional string display_names_locale = 1 [default = "en"]; + + // The maximum number of top-scored detection results to return. If < 0, + // all available results will be returned. If 0, an invalid argument error is + // returned. + optional int32 max_results = 2 [default = -1]; + + // Score threshold, overrides the ones provided in the model metadata + // (if any). Results below this value are rejected. + optional float score_threshold = 3; + + // Overlapping threshold for non-maximum-suppression calculator. Only used for + // models without built-in non-maximum-suppression, i.e., models that don't + // use the Detection_Postprocess TFLite Op + optional float min_suppression_threshold = 6; + + // Optional allowlist of category names. If non-empty, detections whose + // category name is not in this set will be filtered out. Duplicate or unknown + // category names are ignored. Mutually exclusive with category_denylist. + repeated string category_allowlist = 4; + + // Optional denylist of category names. If non-empty, detection whose category + // name is in this set will be filtered out. Duplicate or unknown category + // names are ignored. Mutually exclusive with category_allowlist. + repeated string category_denylist = 5; +} diff --git a/mediapipe/tasks/cc/components/processors/proto/embedder_options.proto b/mediapipe/tasks/cc/components/processors/proto/embedder_options.proto index 8973ab248..4595bfc86 100644 --- a/mediapipe/tasks/cc/components/processors/proto/embedder_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/embedder_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.proto index 3a50818f6..5e04a8d1f 100644 --- a/mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.proto index bf4fc9067..e8704679d 100644 --- a/mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto b/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto index db363c75d..1e9cf48c4 100644 --- a/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto +++ b/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto index 8ddd52439..dbb28c9e1 100644 --- a/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/processors/proto/transformer_params.proto b/mediapipe/tasks/cc/components/processors/proto/transformer_params.proto new file mode 100644 index 000000000..b2d13c3a2 --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/proto/transformer_params.proto @@ -0,0 +1,49 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package mediapipe.tasks.components.processors.proto; + +option java_package = "com.google.mediapipe.tasks.components.processors.proto"; +option java_outer_classname = "TransformerParametersProto"; + +// The parameters of transformer (https://arxiv.org/pdf/1706.03762.pdf) +message TransformerParameters { + // Batch size of tensors. + int32 batch_size = 1; + + // Maximum sequence length of the input/output tensor. + int32 max_seq_length = 2; + + // Embedding dimension (or model dimension), `d_model` in the paper. + // `d_k` == `d_v` == `d_model`/`h`. + int32 embedding_dim = 3; + + // Hidden dimension used in the feedforward layer, `d_ff` in the paper. + int32 hidden_dimension = 4; + + // Head dimension, `d_k` or `d_v` in the paper. + int32 head_dimension = 5; + + // Number of heads, `h` in the paper. + int32 num_heads = 6; + + // Number of stacked transformers, `N` in the paper. + int32 num_stacks = 7; + + // Whether to use Multi-Query-Attention (MQA). + bool use_mqa = 8; +} diff --git a/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc index 7587161ae..ecf59e2d1 100644 --- a/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h index 43d57be29..6d5fdb5f8 100644 --- a/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/utils/BUILD b/mediapipe/tasks/cc/components/utils/BUILD index 2e0ea3ce6..f2ebcd4b3 100644 --- a/mediapipe/tasks/cc/components/utils/BUILD +++ b/mediapipe/tasks/cc/components/utils/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/utils/cosine_similarity.cc b/mediapipe/tasks/cc/components/utils/cosine_similarity.cc index 1403700c8..8ef33073c 100644 --- a/mediapipe/tasks/cc/components/utils/cosine_similarity.cc +++ b/mediapipe/tasks/cc/components/utils/cosine_similarity.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/utils/cosine_similarity.h b/mediapipe/tasks/cc/components/utils/cosine_similarity.h index 45ddce76e..b2ebabca4 100644 --- a/mediapipe/tasks/cc/components/utils/cosine_similarity.h +++ b/mediapipe/tasks/cc/components/utils/cosine_similarity.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/utils/cosine_similarity_test.cc b/mediapipe/tasks/cc/components/utils/cosine_similarity_test.cc index 4ff9dfc3a..f476a7d43 100644 --- a/mediapipe/tasks/cc/components/utils/cosine_similarity_test.cc +++ b/mediapipe/tasks/cc/components/utils/cosine_similarity_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/utils/gate.h b/mediapipe/tasks/cc/components/utils/gate.h index 139205fc5..2b1fc4658 100644 --- a/mediapipe/tasks/cc/components/utils/gate.h +++ b/mediapipe/tasks/cc/components/utils/gate.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/components/utils/gate_test.cc b/mediapipe/tasks/cc/components/utils/gate_test.cc index 47ff25074..94231496d 100644 --- a/mediapipe/tasks/cc/components/utils/gate_test.cc +++ b/mediapipe/tasks/cc/components/utils/gate_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index 5aa9c9729..ce9181d51 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,6 +29,7 @@ cc_library( "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/memory", "@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", @@ -61,15 +62,13 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + "//mediapipe/util:resource_util", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - ] + select({ - "//mediapipe:windows": ["@bazel_tools//tools/cpp/runfiles"], - "//conditions:default": [], - }), + ], ) cc_library( @@ -81,6 +80,7 @@ cc_library( "//mediapipe/tasks/cc/text/custom_ops/sentencepiece:sentencepiece_tokenizer_tflite", "//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup", "//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash", + "//mediapipe/tasks/cc/vision/custom_ops:fused_batch_norm", "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", "//mediapipe/util/tflite/operations:max_pool_argmax", "//mediapipe/util/tflite/operations:max_unpooling", @@ -108,13 +108,13 @@ cc_library( "//mediapipe/framework:subgraph", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", - "//mediapipe/framework/port:logging", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -318,6 +318,9 @@ cc_library( ":model_resources", ":task_runner", ":utils", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework/port:requires", + "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", diff --git a/mediapipe/tasks/cc/core/base_options.cc b/mediapipe/tasks/cc/core/base_options.cc index 8bee6b469..7f7db525c 100644 --- a/mediapipe/tasks/cc/core/base_options.cc +++ b/mediapipe/tasks/cc/core/base_options.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,15 +17,62 @@ limitations under the License. #include #include +#include +#include "absl/log/absl_log.h" #include "mediapipe/calculators/tensor/inference_calculator.pb.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" namespace mediapipe { namespace tasks { namespace core { +proto::Acceleration ConvertDelegateOptionsToAccelerationProto( + const BaseOptions::CpuOptions& options) { + proto::Acceleration acceleration_proto = proto::Acceleration(); + acceleration_proto.mutable_tflite(); + return acceleration_proto; +} + +proto::Acceleration ConvertDelegateOptionsToAccelerationProto( + const BaseOptions::GpuOptions& options) { + proto::Acceleration acceleration_proto = proto::Acceleration(); + auto* gpu = acceleration_proto.mutable_gpu(); + gpu->set_use_advanced_gpu_api(true); + if (!options.cached_kernel_path.empty()) { + gpu->set_cached_kernel_path(options.cached_kernel_path); + } + if (!options.serialized_model_dir.empty()) { + gpu->set_serialized_model_dir(options.serialized_model_dir); + } + if (!options.model_token.empty()) { + gpu->set_model_token(options.model_token); + } + return acceleration_proto; +} + +template +void SetDelegateOptionsOrDie(const BaseOptions* base_options, + proto::BaseOptions& base_options_proto) { + if (base_options->delegate_options.has_value()) { + if (!std::holds_alternative(*base_options->delegate_options)) { + ABSL_LOG(FATAL) << "Specified Delegate type does not match the provided " + "delegate options."; + } else { + std::visit( + [&base_options_proto](const auto& delegate_options) { + proto::Acceleration acceleration_proto = + ConvertDelegateOptionsToAccelerationProto(delegate_options); + base_options_proto.mutable_acceleration()->Swap( + &acceleration_proto); + }, + *base_options->delegate_options); + } + } +} + proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) { proto::BaseOptions base_options_proto; if (!base_options->model_asset_path.empty()) { @@ -53,9 +100,15 @@ proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) { switch (base_options->delegate) { case BaseOptions::Delegate::CPU: base_options_proto.mutable_acceleration()->mutable_tflite(); + SetDelegateOptionsOrDie(base_options, + base_options_proto); break; case BaseOptions::Delegate::GPU: - base_options_proto.mutable_acceleration()->mutable_gpu(); + base_options_proto.mutable_acceleration() + ->mutable_gpu() + ->set_use_advanced_gpu_api(true); + SetDelegateOptionsOrDie(base_options, + base_options_proto); break; case BaseOptions::Delegate::EDGETPU_NNAPI: base_options_proto.mutable_acceleration() @@ -63,7 +116,6 @@ proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) { ->set_accelerator_name("google-edgetpu"); break; } - return base_options_proto; } } // namespace core diff --git a/mediapipe/tasks/cc/core/base_options.h b/mediapipe/tasks/cc/core/base_options.h index b6a0f0556..6cfc8a7aa 100644 --- a/mediapipe/tasks/cc/core/base_options.h +++ b/mediapipe/tasks/cc/core/base_options.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,7 +17,10 @@ limitations under the License. #define MEDIAPIPE_TASKS_CC_CORE_BASE_OPTIONS_H_ #include +#include #include +#include +#include #include "absl/memory/memory.h" #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" @@ -38,7 +41,8 @@ struct BaseOptions { std::string model_asset_path = ""; // The delegate to run MediaPipe. If the delegate is not set, the default - // delegate CPU is used. + // delegate CPU is used. Use `delegate_options` to configure advanced + // features of the selected delegate." enum Delegate { CPU = 0, GPU = 1, @@ -48,6 +52,30 @@ struct BaseOptions { Delegate delegate = CPU; + // Options for CPU. + struct CpuOptions {}; + + // Options for GPU. + struct GpuOptions { + // Load pre-compiled serialized binary cache to accelerate init process. + // Only available on Android. Kernel caching will only be enabled if this + // path is set. NOTE: binary cache usage may be skipped if valid serialized + // model, specified by "serialized_model_dir", exists. + std::string cached_kernel_path; + + // A dir to load from and save to a pre-compiled serialized model used to + // accelerate init process. + // NOTE: serialized model takes precedence over binary cache + // specified by "cached_kernel_path", which still can be used if + // serialized model is invalid or missing. + std::string serialized_model_dir; + + // Unique token identifying the model. Used in conjunction with + // "serialized_model_dir". It is the caller's responsibility to ensure + // there is no clash of the tokens. + std::string model_token; + }; + // The file descriptor to a file opened with open(2), with optional additional // offset and length information. struct FileDescriptorMeta { @@ -67,6 +95,10 @@ struct BaseOptions { // built-in Ops. std::unique_ptr op_resolver = absl::make_unique(); + + // Options for the chosen delegate. If not set, the default delegate options + // is used. + std::optional> delegate_options; }; // Converts a BaseOptions to a BaseOptionsProto. diff --git a/mediapipe/tasks/cc/core/base_options_test.cc b/mediapipe/tasks/cc/core/base_options_test.cc index dce95050d..390663515 100644 --- a/mediapipe/tasks/cc/core/base_options_test.cc +++ b/mediapipe/tasks/cc/core/base_options_test.cc @@ -1,6 +1,9 @@ #include "mediapipe/tasks/cc/core/base_options.h" +#include +#include #include +#include #include "mediapipe/calculators/tensor/inference_calculator.pb.h" #include "mediapipe/framework/port/gmock.h" @@ -11,6 +14,8 @@ constexpr char kTestModelBundlePath[] = "mediapipe/tasks/testdata/core/dummy_gesture_recognizer.task"; +constexpr char kCachedModelDir[] = "/data/local/tmp"; +constexpr char kModelToken[] = "dummy_model_token"; namespace mediapipe { namespace tasks { @@ -40,6 +45,45 @@ TEST(BaseOptionsTest, ConvertBaseOptionsToProtoWithAcceleration) { EXPECT_EQ(proto.acceleration().nnapi().accelerator_name(), "google-edgetpu"); } +TEST(DelegateOptionsTest, SucceedCpuOptions) { + BaseOptions base_options; + base_options.delegate = BaseOptions::Delegate::CPU; + BaseOptions::CpuOptions cpu_options; + base_options.delegate_options = cpu_options; + proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options); + EXPECT_TRUE(proto.acceleration().has_tflite()); + ASSERT_FALSE(proto.acceleration().has_gpu()); +} + +TEST(DelegateOptionsTest, SucceedGpuOptions) { + BaseOptions base_options; + base_options.delegate = BaseOptions::Delegate::GPU; + BaseOptions::GpuOptions gpu_options; + gpu_options.serialized_model_dir = kCachedModelDir; + gpu_options.model_token = kModelToken; + base_options.delegate_options = gpu_options; + proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options); + ASSERT_TRUE(proto.acceleration().has_gpu()); + ASSERT_FALSE(proto.acceleration().has_tflite()); + EXPECT_TRUE(proto.acceleration().gpu().use_advanced_gpu_api()); + EXPECT_FALSE(proto.acceleration().gpu().has_cached_kernel_path()); + EXPECT_EQ(proto.acceleration().gpu().serialized_model_dir(), kCachedModelDir); + EXPECT_EQ(proto.acceleration().gpu().model_token(), kModelToken); +} + +TEST(DelegateOptionsDeathTest, FailWrongDelegateOptionsType) { + BaseOptions base_options; + base_options.delegate = BaseOptions::Delegate::CPU; + BaseOptions::GpuOptions gpu_options; + gpu_options.cached_kernel_path = kCachedModelDir; + gpu_options.model_token = kModelToken; + base_options.delegate_options = gpu_options; + ASSERT_DEATH( + { proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options); }, + "Specified Delegate type does not match the provided " + "delegate options."); +} + } // namespace } // namespace core } // namespace tasks diff --git a/mediapipe/tasks/cc/core/base_task_api.h b/mediapipe/tasks/cc/core/base_task_api.h index 92d41cc84..5238527a3 100644 --- a/mediapipe/tasks/cc/core/base_task_api.h +++ b/mediapipe/tasks/cc/core/base_task_api.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/external_file_handler.cc b/mediapipe/tasks/cc/core/external_file_handler.cc index a56f03d55..af304c466 100644 --- a/mediapipe/tasks/cc/core/external_file_handler.cc +++ b/mediapipe/tasks/cc/core/external_file_handler.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -43,10 +43,7 @@ limitations under the License. #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" - -#ifdef _WIN32 -#include "tools/cpp/runfiles/runfiles.h" -#endif // _WIN32 +#include "mediapipe/util/resource_util.h" namespace mediapipe { namespace tasks { @@ -96,30 +93,6 @@ ExternalFileHandler::CreateFromExternalFile( return handler; } -absl::StatusOr PathToResourceAsFile(std::string path) { -#ifndef _WIN32 - return path; -#else - std::string qualified_path = path; - if (absl::StartsWith(qualified_path, "./")) { - qualified_path = "mediapipe" + qualified_path.substr(1); - } else if (path[0] != '/') { - qualified_path = "mediapipe/" + qualified_path; - } - - std::string error; - // TODO: We should ideally use `CreateForTests` when this is - // accessed from unit tests. - std::unique_ptr<::bazel::tools::cpp::runfiles::Runfiles> runfiles( - ::bazel::tools::cpp::runfiles::Runfiles::Create("", &error)); - if (!runfiles) { - // Return the original path when Runfiles is not available (e.g. for Python) - return path; - } - return runfiles->Rlocation(qualified_path); -#endif // _WIN32 -} - absl::Status ExternalFileHandler::MapExternalFile() { if (!external_file_.file_content().empty()) { return absl::OkStatus(); diff --git a/mediapipe/tasks/cc/core/external_file_handler.h b/mediapipe/tasks/cc/core/external_file_handler.h index 04a3e1ac4..3150fde59 100644 --- a/mediapipe/tasks/cc/core/external_file_handler.h +++ b/mediapipe/tasks/cc/core/external_file_handler.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc index 80097fd09..04bc75057 100644 --- a/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc +++ b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ limitations under the License. #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h" #include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h" #include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h" +#include "mediapipe/tasks/cc/vision/custom_ops/fused_batch_norm.h" #include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h" #include "mediapipe/util/tflite/operations/max_pool_argmax.h" #include "mediapipe/util/tflite/operations/max_unpooling.h" @@ -56,6 +57,8 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() { mediapipe::tflite_operations::Register_SENTENCEPIECE_TOKENIZER()); AddCustom("RaggedTensorToTensor", mediapipe::tflite_operations::Register_RAGGED_TENSOR_TO_TENSOR()); + AddCustom("FusedBatchNormV3", + mediapipe::tflite_operations::Register_FusedBatchNorm()); } } // namespace core } // namespace tasks diff --git a/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h index a7c28aa71..6c045d73d 100644 --- a/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h +++ b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/model_asset_bundle_resources.cc b/mediapipe/tasks/cc/core/model_asset_bundle_resources.cc index 2f53ff2d5..58b30630d 100644 --- a/mediapipe/tasks/cc/core/model_asset_bundle_resources.cc +++ b/mediapipe/tasks/cc/core/model_asset_bundle_resources.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/model_asset_bundle_resources.h b/mediapipe/tasks/cc/core/model_asset_bundle_resources.h index 02d989d4b..5b334fabd 100644 --- a/mediapipe/tasks/cc/core/model_asset_bundle_resources.h +++ b/mediapipe/tasks/cc/core/model_asset_bundle_resources.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/model_asset_bundle_resources_test.cc b/mediapipe/tasks/cc/core/model_asset_bundle_resources_test.cc index 85a94ccc7..4cf53fe42 100644 --- a/mediapipe/tasks/cc/core/model_asset_bundle_resources_test.cc +++ b/mediapipe/tasks/cc/core/model_asset_bundle_resources_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/model_resources.cc b/mediapipe/tasks/cc/core/model_resources.cc index 76695125a..1a917f72f 100644 --- a/mediapipe/tasks/cc/core/model_resources.cc +++ b/mediapipe/tasks/cc/core/model_resources.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/model_resources.h b/mediapipe/tasks/cc/core/model_resources.h index 1bc1b65eb..d8e8dada0 100644 --- a/mediapipe/tasks/cc/core/model_resources.h +++ b/mediapipe/tasks/cc/core/model_resources.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/model_resources_cache.cc b/mediapipe/tasks/cc/core/model_resources_cache.cc index affcb6dea..ced64babd 100644 --- a/mediapipe/tasks/cc/core/model_resources_cache.cc +++ b/mediapipe/tasks/cc/core/model_resources_cache.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/model_resources_cache.h b/mediapipe/tasks/cc/core/model_resources_cache.h index 32909f93d..113cfb2d4 100644 --- a/mediapipe/tasks/cc/core/model_resources_cache.h +++ b/mediapipe/tasks/cc/core/model_resources_cache.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -103,8 +103,8 @@ class ModelResourcesCache { }; // Global service for mediapipe task model resources cache. -const mediapipe::GraphService kModelResourcesCacheService( - "mediapipe::tasks::ModelResourcesCacheService"); +inline constexpr mediapipe::GraphService + kModelResourcesCacheService("mediapipe::tasks::ModelResourcesCacheService"); } // namespace core } // namespace tasks diff --git a/mediapipe/tasks/cc/core/model_resources_calculator.cc b/mediapipe/tasks/cc/core/model_resources_calculator.cc index d5c8cd502..8db35818c 100644 --- a/mediapipe/tasks/cc/core/model_resources_calculator.cc +++ b/mediapipe/tasks/cc/core/model_resources_calculator.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/model_resources_calculator_test.cc b/mediapipe/tasks/cc/core/model_resources_calculator_test.cc index 6ba52e521..cf10a209e 100644 --- a/mediapipe/tasks/cc/core/model_resources_calculator_test.cc +++ b/mediapipe/tasks/cc/core/model_resources_calculator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/model_resources_test.cc b/mediapipe/tasks/cc/core/model_resources_test.cc index 036d0e784..ffd78f3de 100644 --- a/mediapipe/tasks/cc/core/model_resources_test.cc +++ b/mediapipe/tasks/cc/core/model_resources_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/model_task_graph.cc b/mediapipe/tasks/cc/core/model_task_graph.cc index 8767fb48b..a68d40ae0 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.cc +++ b/mediapipe/tasks/cc/core/model_task_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" @@ -30,7 +31,6 @@ limitations under the License. #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/calculator.pb.h" -#include "mediapipe/framework/port/logging.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h" @@ -165,7 +165,7 @@ absl::StatusOr ModelTaskGraph::CreateModelResources( if (!model_resources_cache_service.IsAvailable()) { ASSIGN_OR_RETURN(auto local_model_resource, ModelResources::Create("", std::move(external_file))); - LOG(WARNING) + ABSL_LOG(WARNING) << "A local ModelResources object is created. Please consider using " "ModelResourcesCacheService to cache the created ModelResources " "object in the CalculatorGraph."; @@ -186,6 +186,21 @@ absl::StatusOr ModelTaskGraph::CreateModelResources( return model_resources_cache_service.GetObject().GetModelResources(tag); } +absl::StatusOr ModelTaskGraph::GetOrCreateModelResources( + SubgraphContext* sc, std::unique_ptr external_file, + std::string tag_suffix) { + auto model_resources_cache_service = sc->Service(kModelResourcesCacheService); + if (model_resources_cache_service.IsAvailable()) { + std::string tag = + absl::StrCat(CreateModelResourcesTag(sc->OriginalNode()), tag_suffix); + if (model_resources_cache_service.GetObject().Exists(tag)) { + return model_resources_cache_service.GetObject().GetModelResources(tag); + } + } + return ModelTaskGraph::CreateModelResources(sc, std::move(external_file), + tag_suffix); +} + absl::StatusOr ModelTaskGraph::CreateModelAssetBundleResources( SubgraphContext* sc, std::unique_ptr external_file, @@ -200,7 +215,7 @@ ModelTaskGraph::CreateModelAssetBundleResources( auto local_model_asset_bundle_resource, ModelAssetBundleResources::Create("", std::move(external_file))); if (!has_file_pointer_meta) { - LOG(WARNING) + ABSL_LOG(WARNING) << "A local ModelResources object is created. Please consider using " "ModelResourcesCacheService to cache the created ModelResources " "object in the CalculatorGraph."; diff --git a/mediapipe/tasks/cc/core/model_task_graph.h b/mediapipe/tasks/cc/core/model_task_graph.h index aa864c9fc..38367da8f 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.h +++ b/mediapipe/tasks/cc/core/model_task_graph.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -87,6 +87,20 @@ class ModelTaskGraph : public Subgraph { SubgraphContext* sc, std::unique_ptr external_file, std::string tag_suffix = ""); + template + absl::StatusOr GetOrCreateModelResources( + SubgraphContext* sc, std::string tag_suffix = "") { + auto external_file = std::make_unique(); + external_file->Swap(sc->MutableOptions() + ->mutable_base_options() + ->mutable_model_asset()); + return GetOrCreateModelResources(sc, std::move(external_file), tag_suffix); + } + + absl::StatusOr GetOrCreateModelResources( + SubgraphContext* sc, std::unique_ptr external_file, + std::string tag_suffix = ""); + // If the model resources graph service is available, creates a model asset // bundle resources object from the subgraph context, and caches the created // model asset bundle resources into the model resources graph service on diff --git a/mediapipe/tasks/cc/core/proto/BUILD b/mediapipe/tasks/cc/core/proto/BUILD index fff935b24..72de1be85 100644 --- a/mediapipe/tasks/cc/core/proto/BUILD +++ b/mediapipe/tasks/cc/core/proto/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/proto/acceleration.proto b/mediapipe/tasks/cc/core/proto/acceleration.proto index c7215604a..64165f292 100644 --- a/mediapipe/tasks/cc/core/proto/acceleration.proto +++ b/mediapipe/tasks/cc/core/proto/acceleration.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/proto/base_options.proto b/mediapipe/tasks/cc/core/proto/base_options.proto index b7c0629e8..9c9571923 100644 --- a/mediapipe/tasks/cc/core/proto/base_options.proto +++ b/mediapipe/tasks/cc/core/proto/base_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/proto/external_file.proto b/mediapipe/tasks/cc/core/proto/external_file.proto index 3147a2224..dbb3da37c 100644 --- a/mediapipe/tasks/cc/core/proto/external_file.proto +++ b/mediapipe/tasks/cc/core/proto/external_file.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/proto/inference_subgraph.proto b/mediapipe/tasks/cc/core/proto/inference_subgraph.proto index 2232a1153..d4c80020b 100644 --- a/mediapipe/tasks/cc/core/proto/inference_subgraph.proto +++ b/mediapipe/tasks/cc/core/proto/inference_subgraph.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/proto/model_resources_calculator.proto b/mediapipe/tasks/cc/core/proto/model_resources_calculator.proto index dd67bb479..ec200fd19 100644 --- a/mediapipe/tasks/cc/core/proto/model_resources_calculator.proto +++ b/mediapipe/tasks/cc/core/proto/model_resources_calculator.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/task_api_factory.h b/mediapipe/tasks/cc/core/task_api_factory.h index 631696b4c..6f604dd4c 100644 --- a/mediapipe/tasks/cc/core/task_api_factory.h +++ b/mediapipe/tasks/cc/core/task_api_factory.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,7 +23,11 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/port/requires.h" +#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/model_resources.h" @@ -54,6 +58,8 @@ class TaskApiFactory { std::unique_ptr resolver, PacketsCallback packets_callback = nullptr) { bool found_task_subgraph = false; + // This for-loop ensures there's only one subgraph besides + // FlowLimiterCalculator. for (const auto& node : graph_config.node()) { if (node.calculator() == "FlowLimiterCalculator") { continue; @@ -64,13 +70,7 @@ class TaskApiFactory { "Task graph config should only contain one task subgraph node.", MediaPipeTasksStatus::kInvalidTaskGraphConfigError); } else { - if (!node.options().HasExtension(Options::ext)) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrCat(node.calculator(), - " is missing the required task options field."), - MediaPipeTasksStatus::kInvalidTaskGraphConfigError); - } + MP_RETURN_IF_ERROR(CheckHasValidOptions(node)); found_task_subgraph = true; } } @@ -80,6 +80,34 @@ class TaskApiFactory { std::move(packets_callback))); return std::make_unique(std::move(runner)); } + + template + static absl::Status CheckHasValidOptions( + const CalculatorGraphConfig::Node& node) { + if constexpr (mediapipe::Requires( + [](auto&& o) -> decltype(o.ext) {})) { + if (node.options().HasExtension(Options::ext)) { + return absl::OkStatus(); + } + } else { +#ifndef MEDIAPIPE_PROTO_LITE + for (const auto& option : node.node_options()) { + if (absl::StrContains(option.type_url(), + Options::descriptor()->full_name())) { + return absl::OkStatus(); + } + } +#else // MEDIAPIPE_PROTO_LITE + // Skip the check for proto lite, as Options::descriptor() is unavailable. + return absl::OkStatus(); +#endif // MEDIAPIPE_PROTO_LITE + } + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat(node.calculator(), + " is missing the required task options field."), + MediaPipeTasksStatus::kInvalidTaskGraphConfigError); + } }; } // namespace core diff --git a/mediapipe/tasks/cc/core/task_runner.cc b/mediapipe/tasks/cc/core/task_runner.cc index fc933d547..d97c4e480 100644 --- a/mediapipe/tasks/cc/core/task_runner.cc +++ b/mediapipe/tasks/cc/core/task_runner.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/task_runner.h b/mediapipe/tasks/cc/core/task_runner.h index 8123a45aa..cd77c0555 100644 --- a/mediapipe/tasks/cc/core/task_runner.h +++ b/mediapipe/tasks/cc/core/task_runner.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/task_runner_test.cc b/mediapipe/tasks/cc/core/task_runner_test.cc index 75c6260af..6a53a0ff1 100644 --- a/mediapipe/tasks/cc/core/task_runner_test.cc +++ b/mediapipe/tasks/cc/core/task_runner_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/utils.cc b/mediapipe/tasks/cc/core/utils.cc index 1e44109c3..168c4363c 100644 --- a/mediapipe/tasks/cc/core/utils.cc +++ b/mediapipe/tasks/cc/core/utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/core/utils.h b/mediapipe/tasks/cc/core/utils.h index 4ca51fa91..54d63866d 100644 --- a/mediapipe/tasks/cc/core/utils.h +++ b/mediapipe/tasks/cc/core/utils.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/metadata_extractor.cc b/mediapipe/tasks/cc/metadata/metadata_extractor.cc index 4d6f526f5..16e05d7a3 100644 --- a/mediapipe/tasks/cc/metadata/metadata_extractor.cc +++ b/mediapipe/tasks/cc/metadata/metadata_extractor.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/metadata_extractor.h b/mediapipe/tasks/cc/metadata/metadata_extractor.h index b88eda863..c1d8a4c2a 100644 --- a/mediapipe/tasks/cc/metadata/metadata_extractor.h +++ b/mediapipe/tasks/cc/metadata/metadata_extractor.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/metadata_parser.h.template b/mediapipe/tasks/cc/metadata/metadata_parser.h.template index 28c38af82..852791a03 100644 --- a/mediapipe/tasks/cc/metadata/metadata_parser.h.template +++ b/mediapipe/tasks/cc/metadata/metadata_parser.h.template @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/metadata_populator.cc b/mediapipe/tasks/cc/metadata/metadata_populator.cc index a6fd496a3..872bc2b20 100644 --- a/mediapipe/tasks/cc/metadata/metadata_populator.cc +++ b/mediapipe/tasks/cc/metadata/metadata_populator.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/metadata_populator.h b/mediapipe/tasks/cc/metadata/metadata_populator.h index c0554f704..c3856e104 100644 --- a/mediapipe/tasks/cc/metadata/metadata_populator.h +++ b/mediapipe/tasks/cc/metadata/metadata_populator.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/metadata_version.cc b/mediapipe/tasks/cc/metadata/metadata_version.cc index 7e2414dd5..aab3d8c56 100644 --- a/mediapipe/tasks/cc/metadata/metadata_version.cc +++ b/mediapipe/tasks/cc/metadata/metadata_version.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/metadata_version.h b/mediapipe/tasks/cc/metadata/metadata_version.h index a53caa547..bb0bd36f9 100644 --- a/mediapipe/tasks/cc/metadata/metadata_version.h +++ b/mediapipe/tasks/cc/metadata/metadata_version.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/python/metadata_version.cc b/mediapipe/tasks/cc/metadata/python/metadata_version.cc index e3072bc9e..e1e50459e 100644 --- a/mediapipe/tasks/cc/metadata/python/metadata_version.cc +++ b/mediapipe/tasks/cc/metadata/python/metadata_version.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/tests/metadata_extractor_test.cc b/mediapipe/tasks/cc/metadata/tests/metadata_extractor_test.cc index 0e05e5167..bed1610d6 100644 --- a/mediapipe/tasks/cc/metadata/tests/metadata_extractor_test.cc +++ b/mediapipe/tasks/cc/metadata/tests/metadata_extractor_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/tests/metadata_parser_test.cc b/mediapipe/tasks/cc/metadata/tests/metadata_parser_test.cc index 0d613e65e..9b70f7a28 100644 --- a/mediapipe/tasks/cc/metadata/tests/metadata_parser_test.cc +++ b/mediapipe/tasks/cc/metadata/tests/metadata_parser_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc b/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc index 63cd2ff9c..188560b54 100644 --- a/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc +++ b/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/tests/metadata_version_utils_test.cc b/mediapipe/tasks/cc/metadata/tests/metadata_version_utils_test.cc index eaaa39f0e..ae4315f90 100644 --- a/mediapipe/tasks/cc/metadata/tests/metadata_version_utils_test.cc +++ b/mediapipe/tasks/cc/metadata/tests/metadata_version_utils_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/utils/BUILD b/mediapipe/tasks/cc/metadata/utils/BUILD index 881b88962..9e912c925 100644 --- a/mediapipe/tasks/cc/metadata/utils/BUILD +++ b/mediapipe/tasks/cc/metadata/utils/BUILD @@ -36,6 +36,7 @@ cc_library( "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@zlib//:zlib_minizip", diff --git a/mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.cc b/mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.cc index a231afc40..cd5fa6316 100644 --- a/mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.cc +++ b/mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.h b/mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.h index fcd22d6d6..b99664b75 100644 --- a/mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.h +++ b/mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/utils/zip_utils.cc b/mediapipe/tasks/cc/metadata/utils/zip_utils.cc index 2c09e1961..b9dd784c4 100644 --- a/mediapipe/tasks/cc/metadata/utils/zip_utils.cc +++ b/mediapipe/tasks/cc/metadata/utils/zip_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ limitations under the License. #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "contrib/minizip/ioapi.h" @@ -63,7 +64,7 @@ absl::StatusOr GetCurrentZipFileInfo(const unzFile& zf) { absl::Cleanup unzipper_closer = [zf]() { auto status = UnzipErrorToStatus(unzCloseCurrentFile(zf)); if (!status.ok()) { - LOG(ERROR) << "Failed to close the current zip file: " << status; + ABSL_LOG(ERROR) << "Failed to close the current zip file: " << status; } }; if (method != Z_NO_COMPRESSION) { @@ -125,7 +126,7 @@ absl::Status ExtractFilesfromZipFile( } absl::Cleanup unzipper_closer = [zf]() { if (unzClose(zf) != UNZ_OK) { - LOG(ERROR) << "Unable to close zip archive."; + ABSL_LOG(ERROR) << "Unable to close zip archive."; } }; // Get number of files. diff --git a/mediapipe/tasks/cc/metadata/utils/zip_utils.h b/mediapipe/tasks/cc/metadata/utils/zip_utils.h index 10ad0a5a9..e69abf9ff 100644 --- a/mediapipe/tasks/cc/metadata/utils/zip_utils.h +++ b/mediapipe/tasks/cc/metadata/utils/zip_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/utils/zip_writable_mem_file.cc b/mediapipe/tasks/cc/metadata/utils/zip_writable_mem_file.cc index 3dc1f1950..1f27c5bfd 100644 --- a/mediapipe/tasks/cc/metadata/utils/zip_writable_mem_file.cc +++ b/mediapipe/tasks/cc/metadata/utils/zip_writable_mem_file.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/metadata/utils/zip_writable_mem_file.h b/mediapipe/tasks/cc/metadata/utils/zip_writable_mem_file.h index ca06476ec..6a28e2d8a 100644 --- a/mediapipe/tasks/cc/metadata/utils/zip_writable_mem_file.h +++ b/mediapipe/tasks/cc/metadata/utils/zip_writable_mem_file.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/custom_ops/ragged/BUILD b/mediapipe/tasks/cc/text/custom_ops/ragged/BUILD index 00e8fa1e7..b37104e8b 100644 --- a/mediapipe/tasks/cc/text/custom_ops/ragged/BUILD +++ b/mediapipe/tasks/cc/text/custom_ops/ragged/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# RaggedTensors suppport in TFLite +# RaggedTensors support in TFLite package(default_visibility = ["//mediapipe/tasks:internal"]) diff --git a/mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.cc b/mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.cc index 4ba1a9291..1894dfa8d 100644 --- a/mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.cc +++ b/mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -357,7 +357,7 @@ void CalculateOutputIndexValueRowID(const TfLiteTensor& value_rowids, }; int current_output_column = 0; int current_value_rowid = value_rowids_val(0); - // DCHECK_LT(current_value_rowid, parent_output_index.size()); + // ABSL_DCHECK_LT(current_value_rowid, parent_output_index.size()); int current_output_index = parent_output_index[current_value_rowid]; result->push_back(current_output_index); for (int i = 1; i < index_size; ++i) { @@ -374,12 +374,12 @@ void CalculateOutputIndexValueRowID(const TfLiteTensor& value_rowids, } else { current_output_column = 0; current_value_rowid = next_value_rowid; - // DCHECK_LT(next_value_rowid, parent_output_index.size()); + // ABSL_DCHECK_LT(next_value_rowid, parent_output_index.size()); current_output_index = parent_output_index[next_value_rowid]; } result->push_back(current_output_index); } - // DCHECK_EQ(result->size(), value_rowids.size()); + // ABSL_DCHECK_EQ(result->size(), value_rowids.size()); } void CalculateOutputIndexRowSplit(const TfLiteTensor& row_split, @@ -420,7 +420,7 @@ void CalculateOutputIndexRowSplit(const TfLiteTensor& row_split, } } // if (row_split_size > 0) { - // DCHECK_EQ(result->size(), row_split(row_split_size - 1)); + // ABSL_DCHECK_EQ(result->size(), row_split(row_split_size - 1)); //} } diff --git a/mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.h b/mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.h index 02536c97a..65da8d055 100644 --- a/mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.h +++ b/mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite_test.cc b/mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite_test.cc index 38e220344..e0c8604e1 100644 --- a/mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite_test.cc +++ b/mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD index 19f843c4e..334ed74d4 100644 --- a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/config.fbs b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/config.fbs index 16408ffee..6fb5f1bff 100644 --- a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/config.fbs +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/config.fbs @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h index c3b568f1c..d6497c747 100644 --- a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.cc index f492b5c48..85f6c2f59 100644 --- a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.cc +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h index 94c50bffc..90c8f6066 100644 --- a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_test.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_test.cc index 60a78e126..a53c9be1e 100644 --- a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_test.cc +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config.fbs b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config.fbs index 2e7836803..3aed17ee3 100644 --- a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config.fbs +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config.fbs @@ -1,4 +1,4 @@ -// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.cc index 3a831f3d7..0c407448e 100644 --- a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.cc +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h index 828db16da..49410fa56 100644 --- a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.cc index 365b1a5ad..0dbc061fc 100644 --- a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.cc +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h index 849a47849..d4712dc94 100644 --- a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder_test.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder_test.cc index e65bd1850..fcef6b864 100644 --- a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder_test.cc +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h index faf481844..bb50c4f96 100644 --- a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.cc index 468a3a54f..481fd5237 100644 --- a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.cc +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h index 8a9fa8aef..f6c89bdb0 100644 --- a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h index c1b7728cc..04380ff5a 100644 --- a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/language_detector/BUILD b/mediapipe/tasks/cc/text/language_detector/BUILD index 57b9c7b51..a7229cdfd 100644 --- a/mediapipe/tasks/cc/text/language_detector/BUILD +++ b/mediapipe/tasks/cc/text/language_detector/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD b/mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD index 090f528ef..bb33bd200 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,6 +37,7 @@ cc_test( deps = [ ":kmeans_embedding_lookup", "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/log:absl_check", "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite/c:common", "@org_tensorflow//tensorflow/lite/kernels:test_util", @@ -66,6 +67,7 @@ cc_test( ":ngram_hash", "//mediapipe/framework/port:gtest_main", "//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/types:optional", "@flatbuffers", "@org_tensorflow//tensorflow/lite:framework", diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.cc index 2c9b7a172..9df8d6d59 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.cc +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h b/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h index 31dd4abbd..030e0663c 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup_test.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup_test.cc index f1ee661d4..54b5161fe 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup_test.cc +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup_test.cc @@ -6,6 +6,7 @@ #include #include +#include "absl/log/absl_check.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "tensorflow/lite/c/common.h" @@ -45,8 +46,8 @@ class KmeansEmbeddingLookupModel : public tflite::SingleOpModel { void Invoke(const std::vector& input, const std::vector& encoding_table, const std::vector& codebook) { - CHECK_EQ(SetUpInputTensor(input, encoding_table, codebook), kTfLiteOk); - CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk); + ABSL_CHECK_EQ(SetUpInputTensor(input, encoding_table, codebook), kTfLiteOk); + ABSL_CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk); } TfLiteStatus InvokeUnchecked(const std::vector& input, diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.cc index efe39a01f..b20d77b91 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.cc +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h index c32e91c62..53599bf29 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash_test.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash_test.cc index b799afc2f..1e348bdd1 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash_test.cc +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/log/absl_check.h" #include "absl/types/optional.h" #include "flatbuffers/flexbuffers.h" #include "mediapipe/framework/port/gmock.h" @@ -78,13 +79,13 @@ class NGramHashModel : public tflite::SingleOpModel { void SetupInputTensor(const std::string& input) { PopulateStringTensor(input_, {input}); - CHECK(interpreter_->AllocateTensors() == kTfLiteOk) + ABSL_CHECK(interpreter_->AllocateTensors() == kTfLiteOk) << "Cannot allocate tensors"; } void Invoke(const std::string& input) { SetupInputTensor(input); - CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk); + ABSL_CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk); } TfLiteStatus InvokeUnchecked(const std::string& input) { diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/BUILD b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/BUILD index 9f2fe298a..894f2b9bb 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/BUILD +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/BUILD b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/BUILD index 86b659245..9a6f6f389 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/BUILD +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.cc index 75dd161bf..4889ae665 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.cc +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // Forked from a library written by Austin Appelby and Jyrki Alakuijala. // Original copyright message below. -// Copyright 2009 Google Inc. All Rights Reserved. +// Copyright 2009 Google Inc. // Author: aappleby@google.com (Austin Appleby) // jyrki@google.com (Jyrki Alakuijala) diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h index abcb41a6b..cf6a44d50 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // Forked from a library written by Austin Appelby and Jyrki Alakuijala. // Original copyright message below. -// Copyright 2009 Google Inc. All Rights Reserved. +// Copyright 2009 Google Inc. // Author: aappleby@google.com (Austin Appelby) // jyrki@google.com (Jyrki Alakuijala) // diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur_test.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur_test.cc index 6658965bf..10ea0ffdf 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur_test.cc +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // Forked from a test library written by Jyrki Alakuijala. // Original copyright message below. -// Copyright 2009 Google Inc. All Rights Reserved. +// Copyright 2009 Google Inc. // Author: jyrki@google.com (Jyrki Alakuijala) // // Tests for the fast hashing algorithm based on Austin Appleby's diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.cc index f1ad71fc1..7f46917e0 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.cc +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h index 9a80554c8..24daf5152 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils_test.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils_test.cc index d22af1c95..b36391d20 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils_test.cc +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/BUILD b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/BUILD index a71845305..b633d6812 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/BUILD +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/rune.c b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/rune.c index b74450f44..80ac027ca 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/rune.c +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/rune.c @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetype.c b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetype.c index 1dd8abdbd..949a652fd 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetype.c +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetype.c @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetypebody.h b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetypebody.h index 66d1dfc19..ba9146851 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetypebody.h +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetypebody.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h index 24d9b9dbe..016ab1ebf 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/language_detector/language_detector.cc b/mediapipe/tasks/cc/text/language_detector/language_detector.cc index e3841211b..476427729 100644 --- a/mediapipe/tasks/cc/text/language_detector/language_detector.cc +++ b/mediapipe/tasks/cc/text/language_detector/language_detector.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/language_detector/language_detector.h b/mediapipe/tasks/cc/text/language_detector/language_detector.h index bbe58dedf..95e2e9c97 100644 --- a/mediapipe/tasks/cc/text/language_detector/language_detector.h +++ b/mediapipe/tasks/cc/text/language_detector/language_detector.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/language_detector/language_detector_test.cc b/mediapipe/tasks/cc/text/language_detector/language_detector_test.cc index 92dc493e0..1ff12fda1 100644 --- a/mediapipe/tasks/cc/text/language_detector/language_detector_test.cc +++ b/mediapipe/tasks/cc/text/language_detector/language_detector_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,7 +32,7 @@ limitations under the License. #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/tasks/cc/common.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe::tasks::text::language_detector { namespace { @@ -75,7 +75,7 @@ absl::Status MatchesLanguageDetectorResult( } // namespace -class LanguageDetectorTest : public tflite_shims::testing::Test {}; +class LanguageDetectorTest : public tflite::testing::Test {}; TEST_F(LanguageDetectorTest, CreateFailsWithMissingModel) { auto options = std::make_unique(); diff --git a/mediapipe/tasks/cc/text/text_classifier/BUILD b/mediapipe/tasks/cc/text/text_classifier/BUILD index 4bf773270..121b4f5e6 100644 --- a/mediapipe/tasks/cc/text/text_classifier/BUILD +++ b/mediapipe/tasks/cc/text/text_classifier/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -75,6 +75,7 @@ cc_test( "//mediapipe/tasks/testdata/text:bert_text_classifier_models", "//mediapipe/tasks/testdata/text:text_classifier_models", ], + tags = ["not_run:arm"], deps = [ ":text_classifier", ":text_classifier_test_utils", @@ -85,10 +86,9 @@ cc_test( "//mediapipe/tasks/cc/components/containers:classification_result", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", - "@com_google_sentencepiece//src:sentencepiece_processor", + "@com_google_sentencepiece//src:sentencepiece_processor", # fixdeps: keep "@org_tensorflow//tensorflow/lite:test_util", ], ) diff --git a/mediapipe/tasks/cc/text/text_classifier/proto/BUILD b/mediapipe/tasks/cc/text/text_classifier/proto/BUILD index f2b544d87..fc298575c 100644 --- a/mediapipe/tasks/cc/text/text_classifier/proto/BUILD +++ b/mediapipe/tasks/cc/text/text_classifier/proto/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto b/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto index 41f87b519..7f693ee71 100644 --- a/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier.cc index d174fac47..0ffd57ce8 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier.h b/mediapipe/tasks/cc/text/text_classifier/text_classifier.h index 03569c5a6..0fd9e56a8 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier.h +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc index 3be92f309..bd032cdf2 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc index f800a0e52..dfb78c07f 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,8 +15,6 @@ limitations under the License. #include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h" -#include -#include #include #include #include @@ -24,7 +22,6 @@ limitations under the License. #include "absl/flags/flag.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.cc index d12370372..0e3d8b895 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h index a427b561c..d6ebaf502 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/text_embedder/BUILD b/mediapipe/tasks/cc/text/text_embedder/BUILD index addb971f1..c925abcbd 100644 --- a/mediapipe/tasks/cc/text/text_embedder/BUILD +++ b/mediapipe/tasks/cc/text/text_embedder/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -66,6 +66,7 @@ cc_library( "//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto", "//mediapipe/tasks/cc/text/utils:text_model_utils", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -81,6 +82,7 @@ cc_test( "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata", "//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa", ], + tags = ["not_run:arm"], deps = [ ":text_embedder", "//mediapipe/framework/deps:file_path", diff --git a/mediapipe/tasks/cc/text/text_embedder/proto/BUILD b/mediapipe/tasks/cc/text/text_embedder/proto/BUILD index 146483af1..0e42f40e8 100644 --- a/mediapipe/tasks/cc/text/text_embedder/proto/BUILD +++ b/mediapipe/tasks/cc/text/text_embedder/proto/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto index fc8e02858..f367d3531 100644 --- a/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder.cc index 375058d57..aa98c1e27 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder.cc +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder.h b/mediapipe/tasks/cc/text/text_embedder/text_embedder.h index d729ff3c2..a7e9736ab 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder.h +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc index 518695138..d5bdda4ff 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -86,7 +87,7 @@ class TextEmbedderGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { - CHECK(sc != nullptr); + ABSL_CHECK(sc != nullptr); ASSIGN_OR_RETURN(const ModelResources* model_resources, CreateModelResources(sc)); Graph graph; diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc index 474f0ca35..76634a922 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -45,7 +45,7 @@ constexpr char kUniversalSentenceEncoderModel[] = // Tolerance for embedding vector coordinate values. constexpr float kEpsilon = 1e-4; // Tolerancy for cosine similarity evaluation. -constexpr double kSimilarityTolerancy = 1e-6; +constexpr double kSimilarityTolerancy = 2e-2; using ::mediapipe::file::JoinPath; using ::testing::HasSubstr; @@ -79,6 +79,8 @@ TEST_F(EmbedderTest, SucceedsWithMobileBert) { ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 512); #ifdef _WIN32 ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 21.2148f, kEpsilon); +#elif defined(__FMA__) + ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 21.3605f, kEpsilon); #else ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 19.9016f, kEpsilon); #endif // _WIN32 @@ -87,7 +89,11 @@ TEST_F(EmbedderTest, SucceedsWithMobileBert) { auto result1, text_embedder->Embed("what a great and fantastic trip")); ASSERT_EQ(result1.embeddings.size(), 1); ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 512); +#ifdef __FMA__ + ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 21.254150f, kEpsilon); +#else ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 22.626251f, kEpsilon); +#endif // Check cosine similarity. MP_ASSERT_OK_AND_ASSIGN( diff --git a/mediapipe/tasks/cc/text/tokenizers/BUILD b/mediapipe/tasks/cc/text/tokenizers/BUILD index 92fac8eaa..b299f1c73 100644 --- a/mediapipe/tasks/cc/text/tokenizers/BUILD +++ b/mediapipe/tasks/cc/text/tokenizers/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/calculators/tensor:__subpackages__"]) +default_visibility = ["//mediapipe/calculators/tensor:__subpackages__"] + +package(default_visibility = default_visibility) licenses(["notice"]) @@ -34,6 +36,7 @@ cc_library( hdrs = [ "bert_tokenizer.h", ], + visibility = default_visibility + ["//mediapipe/tasks:users"], deps = [ ":tokenizer", "//mediapipe/framework/port:integral_types", @@ -68,6 +71,7 @@ cc_library( deps = [ ":tokenizer", "//mediapipe/framework/port:logging", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_google_sentencepiece//src:sentencepiece_processor", ], @@ -83,6 +87,7 @@ cc_test( ":sentencepiece_tokenizer", "//mediapipe/framework/port:gtest_main", "//mediapipe/tasks/cc/core:utils", + "@com_google_absl//absl/log:absl_check", "@com_google_sentencepiece//src:sentencepiece_processor", ], ) @@ -102,6 +107,7 @@ cc_library( "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -116,6 +122,7 @@ cc_test( "//mediapipe/tasks/testdata/text:albert_model", "//mediapipe/tasks/testdata/text:mobile_bert_model", "//mediapipe/tasks/testdata/text:text_classifier_models", + "@com_google_absl//absl/log:absl_check", ], linkopts = ["-ldl"], deps = [ diff --git a/mediapipe/tasks/cc/text/tokenizers/bert_tokenizer.cc b/mediapipe/tasks/cc/text/tokenizers/bert_tokenizer.cc index 3348abff5..caec04ded 100644 --- a/mediapipe/tasks/cc/text/tokenizers/bert_tokenizer.cc +++ b/mediapipe/tasks/cc/text/tokenizers/bert_tokenizer.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/tokenizers/bert_tokenizer.h b/mediapipe/tasks/cc/text/tokenizers/bert_tokenizer.h index d655fcadd..67149b788 100644 --- a/mediapipe/tasks/cc/text/tokenizers/bert_tokenizer.h +++ b/mediapipe/tasks/cc/text/tokenizers/bert_tokenizer.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/tokenizers/bert_tokenizer_test.cc b/mediapipe/tasks/cc/text/tokenizers/bert_tokenizer_test.cc index 8a21136cc..8a06c5806 100644 --- a/mediapipe/tasks/cc/text/tokenizers/bert_tokenizer_test.cc +++ b/mediapipe/tasks/cc/text/tokenizers/bert_tokenizer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.cc b/mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.cc index 6a1dc2506..8fe75d8b5 100644 --- a/mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.cc +++ b/mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.h b/mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.h index c84dd33d2..0617167af 100644 --- a/mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.h +++ b/mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/tokenizers/regex_tokenizer_test.cc b/mediapipe/tasks/cc/text/tokenizers/regex_tokenizer_test.cc index 150304d71..cc13f256b 100644 --- a/mediapipe/tasks/cc/text/tokenizers/regex_tokenizer_test.cc +++ b/mediapipe/tasks/cc/text/tokenizers/regex_tokenizer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer.h b/mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer.h index 9798f7bde..97ef1848c 100644 --- a/mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer.h +++ b/mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/log/absl_check.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/tasks/cc/text/tokenizers/tokenizer.h" @@ -36,20 +37,27 @@ class SentencePieceTokenizer : public Tokenizer { public: // Initialize the SentencePiece tokenizer from model file path. explicit SentencePieceTokenizer(const std::string& path_to_model) { - CHECK_OK(sp_.Load(path_to_model)); + // Can't use ABSL_CHECK_OK here because in internal builds + // the return type is absl::Status while the open source builds + // use sentencepiece/src/deps/status.h's util::Status which + // doesn't work with the absl CHECK macros. + const auto status = sp_.Load(path_to_model); + ABSL_CHECK(status.ok()) << status.ToString(); } explicit SentencePieceTokenizer(const char* spmodel_buffer_data, size_t spmodel_buffer_size) { absl::string_view buffer_binary(spmodel_buffer_data, spmodel_buffer_size); - CHECK_OK(sp_.LoadFromSerializedProto(buffer_binary)); + const auto status = sp_.LoadFromSerializedProto(buffer_binary); + ABSL_CHECK(status.ok()) << status.ToString(); } // Perform tokenization, return tokenized results. TokenizerResult Tokenize(const std::string& input) override { TokenizerResult result; std::vector& subwords = result.subwords; - CHECK_OK(sp_.Encode(input, &subwords)); + const auto status = sp_.Encode(input, &subwords); + ABSL_CHECK(status.ok()) << status.ToString(); return result; } diff --git a/mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer_test.cc b/mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer_test.cc index 88afabe1e..eb9032cc0 100644 --- a/mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer_test.cc +++ b/mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/tokenizers/tokenizer.h b/mediapipe/tasks/cc/text/tokenizers/tokenizer.h index ae984808e..d5d7d8894 100644 --- a/mediapipe/tasks/cc/text/tokenizers/tokenizer.h +++ b/mediapipe/tasks/cc/text/tokenizers/tokenizer.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.cc b/mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.cc index 839c0818e..1c2faa677 100644 --- a/mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.cc +++ b/mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h b/mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h index c6bea1418..039c6dd5e 100644 --- a/mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h +++ b/mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/tokenizers/tokenizer_utils_test.cc b/mediapipe/tasks/cc/text/tokenizers/tokenizer_utils_test.cc index 337d5ec7d..7c2344e3c 100644 --- a/mediapipe/tasks/cc/text/tokenizers/tokenizer_utils_test.cc +++ b/mediapipe/tasks/cc/text/tokenizers/tokenizer_utils_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/utils/BUILD b/mediapipe/tasks/cc/text/utils/BUILD index 15af7683b..d13d6b304 100644 --- a/mediapipe/tasks/cc/text/utils/BUILD +++ b/mediapipe/tasks/cc/text/utils/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ cc_library( "vocab_utils.h", ], deps = [ + "//mediapipe/util:resource_util", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings", ], diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils.cc b/mediapipe/tasks/cc/text/utils/text_model_utils.cc index 7cd03d848..1100e83f7 100644 --- a/mediapipe/tasks/cc/text/utils/text_model_utils.cc +++ b/mediapipe/tasks/cc/text/utils/text_model_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils.h b/mediapipe/tasks/cc/text/utils/text_model_utils.h index da8783d33..730616787 100644 --- a/mediapipe/tasks/cc/text/utils/text_model_utils.h +++ b/mediapipe/tasks/cc/text/utils/text_model_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc b/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc index 2ec5686f7..a0d9cf0c9 100644 --- a/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc +++ b/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/utils/vocab_utils.cc b/mediapipe/tasks/cc/text/utils/vocab_utils.cc index 068272f7f..66b0587d6 100644 --- a/mediapipe/tasks/cc/text/utils/vocab_utils.cc +++ b/mediapipe/tasks/cc/text/utils/vocab_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/strings/str_split.h" +#include "mediapipe/util/resource_util.h" namespace mediapipe { namespace tasks { @@ -34,7 +35,11 @@ void ReadIStreamLineByLine( std::string str; while (std::getline(*istream, str)) { if (!str.empty()) { - line_processor(str); + if (str.back() == '\r') { // Remove \r on Windows + line_processor(str.substr(0, str.length() - 1)); + } else { + line_processor(str); + } } } } @@ -64,7 +69,8 @@ std::vector ReadIStreamByLine(std::istream* istream) { std::vector LoadVocabFromFile(const std::string& path_to_vocab) { std::vector vocab_from_file; - std::ifstream in(path_to_vocab.c_str()); + std::string file_name = *PathToResourceAsFile(path_to_vocab); + std::ifstream in(file_name.c_str()); return ReadIStreamByLine(&in); } @@ -79,7 +85,8 @@ std::vector LoadVocabFromBuffer(const char* vocab_buffer_data, absl::node_hash_map LoadVocabAndIndexFromFile( const std::string& path_to_vocab) { absl::node_hash_map vocab_index_map; - std::ifstream in(path_to_vocab.c_str()); + std::string file_name = *PathToResourceAsFile(path_to_vocab); + std::ifstream in(file_name.c_str()); return ReadIStreamLineSplits(&in); } diff --git a/mediapipe/tasks/cc/text/utils/vocab_utils.h b/mediapipe/tasks/cc/text/utils/vocab_utils.h index a2da349dc..80c448e86 100644 --- a/mediapipe/tasks/cc/text/utils/vocab_utils.h +++ b/mediapipe/tasks/cc/text/utils/vocab_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/text/utils/vocab_utils_test.cc b/mediapipe/tasks/cc/text/utils/vocab_utils_test.cc index e4db9628d..a255458e6 100644 --- a/mediapipe/tasks/cc/text/utils/vocab_utils_test.cc +++ b/mediapipe/tasks/cc/text/utils/vocab_utils_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/core/BUILD b/mediapipe/tasks/cc/vision/core/BUILD index 0815b5b2b..6bcf2f5d6 100644 --- a/mediapipe/tasks/cc/vision/core/BUILD +++ b/mediapipe/tasks/cc/vision/core/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,6 +43,7 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc/components/containers:rect", "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:task_api_factory", "//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "@com_google_absl//absl/status", @@ -58,6 +59,7 @@ cc_library( ":base_vision_task_api", "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/tasks/cc/core:task_api_factory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h index 8e6105e18..e65e67bdc 100644 --- a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h +++ b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/core/image_processing_options.h b/mediapipe/tasks/cc/vision/core/image_processing_options.h index e2647be71..03610e983 100644 --- a/mediapipe/tasks/cc/vision/core/image_processing_options.h +++ b/mediapipe/tasks/cc/vision/core/image_processing_options.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/core/running_mode.h b/mediapipe/tasks/cc/vision/core/running_mode.h index 330c335f2..068045142 100644 --- a/mediapipe/tasks/cc/vision/core/running_mode.h +++ b/mediapipe/tasks/cc/vision/core/running_mode.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/core/vision_task_api_factory.h b/mediapipe/tasks/cc/vision/core/vision_task_api_factory.h index 8872a2a04..48fc33848 100644 --- a/mediapipe/tasks/cc/vision/core/vision_task_api_factory.h +++ b/mediapipe/tasks/cc/vision/core/vision_task_api_factory.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/tasks/cc/core/task_api_factory.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -60,13 +61,8 @@ class VisionTaskApiFactory { "Task graph config should only contain one task subgraph node.", MediaPipeTasksStatus::kInvalidTaskGraphConfigError); } else { - if (!node.options().HasExtension(Options::ext)) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrCat(node.calculator(), - " is missing the required task options field."), - MediaPipeTasksStatus::kInvalidTaskGraphConfigError); - } + MP_RETURN_IF_ERROR( + tasks::core::TaskApiFactory::CheckHasValidOptions(node)); found_task_subgraph = true; } } diff --git a/mediapipe/tasks/cc/vision/custom_ops/BUILD b/mediapipe/tasks/cc/vision/custom_ops/BUILD new file mode 100644 index 000000000..71eda50d3 --- /dev/null +++ b/mediapipe/tasks/cc/vision/custom_ops/BUILD @@ -0,0 +1,35 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "fused_batch_norm", + srcs = ["fused_batch_norm.cc"], + hdrs = ["fused_batch_norm.h"], + visibility = [ + "//visibility:public", + ], + deps = + [ + "@eigen_archive//:eigen3", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/core/c:private_common", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", + ], +) diff --git a/mediapipe/tasks/cc/vision/custom_ops/fused_batch_norm.cc b/mediapipe/tasks/cc/vision/custom_ops/fused_batch_norm.cc new file mode 100644 index 000000000..650f723e2 --- /dev/null +++ b/mediapipe/tasks/cc/vision/custom_ops/fused_batch_norm.cc @@ -0,0 +1,293 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/custom_ops/fused_batch_norm.h" + +#include + +#include "Eigen/Core" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace mediapipe::tflite_operations { +namespace vision::batch_norm { +namespace { + +using tflite::GetTensorData; + +constexpr int kInputIndex = 0; +constexpr int kInputScaleIndex = 1; +constexpr int kInputOffsetIndex = 2; +constexpr int kInputEstimatedMeanIndex = 3; +constexpr int kInputEstimatedVarIndex = 4; + +constexpr int kOutputIndex = 0; +constexpr int kOutputBatchMeanIndex = 1; +constexpr int kOutputBatchVarIndex = 2; +constexpr int kOutputSavedMeanIndex = 3; +constexpr int kOutputSavedVarIndex = 4; + +template +struct TTypes { + // Rank- tensor of scalar type T. + typedef Eigen::TensorMap> + Tensor; + + // Rank-1 tensor (vector) of scalar type T. + typedef Eigen::TensorMap> Vec; + typedef Eigen::TensorMap< + Eigen::Tensor> + ConstVec; +}; + +template +void FusedBarchNorm(TfLiteContext* context, TfLiteTensor* x_input, + TfLiteTensor* scale_input, TfLiteTensor* offset_input, + TfLiteTensor* running_mean_input, + TfLiteTensor* running_variance_input, + TfLiteTensor* y_output, TfLiteTensor* running_mean_output, + TfLiteTensor* running_var_output, + TfLiteTensor* saved_batch_mean_output, + TfLiteTensor* saved_batch_var_output, + U exponential_avg_factor, U epsilon) { + const int batches = x_input->dims->data[0]; + const int height = x_input->dims->data[1]; + const int width = x_input->dims->data[2]; + const int depth = x_input->dims->data[3]; + + Eigen::array x_dims = {batches, height, width, depth}; + Eigen::array depth_dims = {depth}; + + const int rest_size = batches * height * width; + + typename TTypes::Tensor x(GetTensorData(x_input), x_dims); + typename TTypes::ConstVec scale(GetTensorData(scale_input), depth_dims); + typename TTypes::ConstVec offset(GetTensorData(offset_input), + depth_dims); + typename TTypes::ConstVec old_mean(GetTensorData(running_mean_input), + depth_dims); + typename TTypes::ConstVec old_variance( + GetTensorData(running_variance_input), depth_dims); + typename TTypes::Tensor y(GetTensorData(y_output), x_dims); + typename TTypes::Vec new_mean(GetTensorData(running_mean_output), + depth_dims); + typename TTypes::Vec new_variance(GetTensorData(running_var_output), + depth_dims); + typename TTypes::Vec saved_batch_mean( + GetTensorData(saved_batch_mean_output), depth_dims); + typename TTypes::Vec saved_batch_var( + GetTensorData(saved_batch_var_output), depth_dims); + + Eigen::DSizes rest_by_depth(rest_size, depth); + Eigen::DSizes tensor_shape(batches, height, width, depth); + + Eigen::IndexList, Eigen::Index> one_by_depth; + one_by_depth.set(1, depth); + Eigen::IndexList> reduce_dims; + Eigen::IndexList> bcast_spec; + bcast_spec.set(0, rest_size); + + auto x_rest_by_depth = x.reshape(rest_by_depth).template cast(); + const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1; + U rest_size_inv = static_cast(1.0f / static_cast(rest_size)); + // This adjustment is for Bessel's correction + U rest_size_adjust = + static_cast(rest_size) / static_cast(rest_size_minus_one); + + Eigen::Tensor batch_mean(depth); + Eigen::Tensor batch_variance(depth); + + batch_mean = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv); + auto x_centered = + x_rest_by_depth - batch_mean.reshape(one_by_depth).broadcast(bcast_spec); + + batch_variance = x_centered.square().sum(reduce_dims) * rest_size_inv; + auto scaling_factor = ((batch_variance + epsilon).rsqrt() * scale) + .eval() + .reshape(one_by_depth) + .broadcast(bcast_spec); + auto x_scaled = x_centered * scaling_factor; + auto x_shifted = + (x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec)) + .template cast(); + + y.reshape(rest_by_depth) = x_shifted; + if (exponential_avg_factor == U(1.0)) { + saved_batch_var = batch_variance; + saved_batch_mean = batch_mean; + new_variance = batch_variance * rest_size_adjust; + new_mean = batch_mean; + } else { + U one_minus_factor = U(1) - exponential_avg_factor; + saved_batch_var = batch_variance; + saved_batch_mean = batch_mean; + new_variance = one_minus_factor * old_variance + + (exponential_avg_factor * rest_size_adjust) * batch_variance; + new_mean = + one_minus_factor * old_mean + exponential_avg_factor * batch_mean; + } +} + +} // namespace + +// Initializes FusedBatchNorm object from serialized parameters. +void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/, + size_t /*length*/) { + return nullptr; +} + +void Free(TfLiteContext* /*context*/, void* /*buffer*/) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 5); + TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 6); + + TfLiteTensor* output = tflite::GetOutput(context, node, kOutputIndex); + TF_LITE_ENSURE(context, output != nullptr); + TfLiteTensor* batch_mean = + tflite::GetOutput(context, node, kOutputBatchMeanIndex); + TF_LITE_ENSURE(context, batch_mean != nullptr); + TfLiteTensor* batch_var = + tflite::GetOutput(context, node, kOutputBatchVarIndex); + TF_LITE_ENSURE(context, batch_var != nullptr); + TfLiteTensor* saved_mean = + tflite::GetOutput(context, node, kOutputSavedMeanIndex); + TF_LITE_ENSURE(context, saved_mean != nullptr); + TfLiteTensor* saved_var = + tflite::GetOutput(context, node, kOutputSavedVarIndex); + TF_LITE_ENSURE(context, saved_var != nullptr); + TfLiteTensor* dummy_reserve_space = tflite::GetOutput(context, node, 5); + TF_LITE_ENSURE(context, dummy_reserve_space != nullptr); + + const TfLiteTensor* input = tflite::GetInput(context, node, kInputIndex); + TF_LITE_ENSURE(context, input != nullptr); + const TfLiteTensor* scale = tflite::GetInput(context, node, kInputScaleIndex); + TF_LITE_ENSURE(context, scale != nullptr); + const TfLiteTensor* offset = + tflite::GetInput(context, node, kInputOffsetIndex); + TF_LITE_ENSURE(context, offset != nullptr); + const TfLiteTensor* estimated_mean = + tflite::GetInput(context, node, kInputEstimatedMeanIndex); + TF_LITE_ENSURE(context, estimated_mean != nullptr); + const TfLiteTensor* estimated_var = + tflite::GetInput(context, node, kInputEstimatedVarIndex); + TF_LITE_ENSURE(context, estimated_var != nullptr); + + TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(input), 4); + TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(scale), 1); + TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(offset), 1); + TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(estimated_mean), 1); + TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(estimated_var), 1); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, scale->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, offset->type, kTfLiteFloat32); + + int batches = input->dims->data[0]; + int height = input->dims->data[1]; + int width = input->dims->data[2]; + int depth = input->dims->data[3]; + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = batches; + output_size->data[1] = height; + output_size->data[2] = width; + output_size->data[3] = depth; + if (context->ResizeTensor(context, output, output_size) != kTfLiteOk) { + return kTfLiteError; + } + TfLiteIntArray* batch_mean_size = TfLiteIntArrayCreate(1); + batch_mean_size->data[0] = depth; + if (context->ResizeTensor(context, batch_mean, batch_mean_size) != + kTfLiteOk) { + return kTfLiteError; + } + TfLiteIntArray* batch_var_size = TfLiteIntArrayCreate(1); + batch_var_size->data[0] = depth; + if (context->ResizeTensor(context, batch_var, batch_var_size) != kTfLiteOk) { + return kTfLiteError; + } + TfLiteIntArray* saved_mean_size = TfLiteIntArrayCreate(1); + saved_mean_size->data[0] = depth; + if (context->ResizeTensor(context, saved_mean, saved_mean_size) != + kTfLiteOk) { + return kTfLiteError; + } + TfLiteIntArray* saved_var_size = TfLiteIntArrayCreate(1); + saved_var_size->data[0] = depth; + if (context->ResizeTensor(context, saved_var, saved_var_size) != kTfLiteOk) { + return kTfLiteError; + } + TfLiteIntArray* dummy_reserve_size = TfLiteIntArrayCreate(1); + dummy_reserve_size->data[0] = 1; + if (context->ResizeTensor(context, dummy_reserve_space, dummy_reserve_size) != + kTfLiteOk) { + return kTfLiteError; + } + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = tflite::GetInput(context, node, kInputIndex); + TF_LITE_ENSURE(context, input != nullptr); + const TfLiteTensor* scale = tflite::GetInput(context, node, kInputScaleIndex); + TF_LITE_ENSURE(context, scale != nullptr); + const TfLiteTensor* offset = + tflite::GetInput(context, node, kInputOffsetIndex); + TF_LITE_ENSURE(context, offset != nullptr); + const TfLiteTensor* estimated_mean = + tflite::GetInput(context, node, kInputEstimatedMeanIndex); + TF_LITE_ENSURE(context, estimated_mean != nullptr); + const TfLiteTensor* estimated_var = + tflite::GetInput(context, node, kInputEstimatedVarIndex); + TF_LITE_ENSURE(context, estimated_var != nullptr); + + TfLiteTensor* output = tflite::GetOutput(context, node, kOutputIndex); + TF_LITE_ENSURE(context, output != nullptr); + TfLiteTensor* batch_mean = + tflite::GetOutput(context, node, kOutputBatchMeanIndex); + TF_LITE_ENSURE(context, batch_mean != nullptr); + TfLiteTensor* batch_var = + tflite::GetOutput(context, node, kOutputBatchVarIndex); + TF_LITE_ENSURE(context, batch_var != nullptr); + TfLiteTensor* saved_mean = + tflite::GetOutput(context, node, kOutputSavedMeanIndex); + TF_LITE_ENSURE(context, saved_mean != nullptr); + TfLiteTensor* saved_var = + tflite::GetOutput(context, node, kOutputSavedVarIndex); + TF_LITE_ENSURE(context, saved_var != nullptr); + + FusedBarchNorm( + context, const_cast(input), + const_cast(scale), const_cast(offset), + const_cast(estimated_mean), + const_cast(estimated_var), output, batch_mean, batch_var, + saved_mean, saved_var, /*exponential_avg_factor=*/0.001f, + /*epsilon=*/0.001f); + + return kTfLiteOk; +} +} // namespace vision::batch_norm + +TfLiteRegistration* Register_FusedBatchNorm() { + static TfLiteRegistration r = { + vision::batch_norm::Initialize, vision::batch_norm::Free, + vision::batch_norm::Prepare, vision::batch_norm::Eval}; + return &r; +} + +} // namespace mediapipe::tflite_operations diff --git a/mediapipe/tasks/cc/vision/custom_ops/fused_batch_norm.h b/mediapipe/tasks/cc/vision/custom_ops/fused_batch_norm.h new file mode 100644 index 000000000..98e16ff92 --- /dev/null +++ b/mediapipe/tasks/cc/vision/custom_ops/fused_batch_norm.h @@ -0,0 +1,28 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_CUSTOM_OPS_FUSED_BATCH_NORM_H_ +#define MEDIAPIPE_TASKS_CC_VISION_CUSTOM_OPS_FUSED_BATCH_NORM_H_ + +#include "tensorflow/lite/core/c/common.h" + +namespace mediapipe::tflite_operations { + +// The FusedBatchNorm op resolver is CPU-friendly only. +TfLiteRegistration* Register_FusedBatchNorm(); + +} // namespace mediapipe::tflite_operations + +#endif // MEDIAPIPE_TASKS_CC_VISION_CUSTOM_OPS_FUSED_BATCH_NORM_H_ diff --git a/mediapipe/tasks/cc/vision/face_detector/BUILD b/mediapipe/tasks/cc/vision/face_detector/BUILD index 2cae8b79b..fbfd94628 100644 --- a/mediapipe/tasks/cc/vision/face_detector/BUILD +++ b/mediapipe/tasks/cc/vision/face_detector/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,7 @@ # limitations under the License. package(default_visibility = [ - # "//mediapipe/tasks:internal", - "//visibility:public", + "//mediapipe/tasks:internal", ]) licenses(["notice"]) @@ -63,6 +62,7 @@ cc_library( name = "face_detector", srcs = ["face_detector.cc"], hdrs = ["face_detector.h"], + visibility = ["//visibility:public"], deps = [ ":face_detector_graph", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/vision/face_detector/face_detector.cc b/mediapipe/tasks/cc/vision/face_detector/face_detector.cc index 80e114bf8..a21b6edcf 100644 --- a/mediapipe/tasks/cc/vision/face_detector/face_detector.cc +++ b/mediapipe/tasks/cc/vision/face_detector/face_detector.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/face_detector/face_detector.h b/mediapipe/tasks/cc/vision/face_detector/face_detector.h index ae485819d..545b7d6bc 100644 --- a/mediapipe/tasks/cc/vision/face_detector/face_detector.h +++ b/mediapipe/tasks/cc/vision/face_detector/face_detector.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc index bf62d2988..2e5f7e416 100644 --- a/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -127,9 +127,9 @@ void ConfigureNonMaxSuppressionCalculator( void ConfigureDetectionsToRectsCalculator( mediapipe::DetectionsToRectsCalculatorOptions* options) { - // Left eye. + // Left eye from the observer’s point of view. options->set_rotation_vector_start_keypoint_index(0); - // Right ete. + // Right eye from the observer’s point of view. options->set_rotation_vector_end_keypoint_index(1); options->set_rotation_vector_target_angle_degrees(0); } @@ -242,7 +242,7 @@ class FaceDetectorGraph : public core::ModelTaskGraph { auto matrix = preprocessing.Out(kMatrixTag); auto image_size = preprocessing.Out(kImageSizeTag); - // Face detection model inferece. + // Face detection model inference. auto& inference = AddInference( model_resources, subgraph_options.base_options().acceleration(), graph); preprocessed_tensors >> inference.In(kTensorsTag); diff --git a/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test.cc b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test.cc index 7be08ec2d..651ad722d 100644 --- a/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/flags/flag.h" +#include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" @@ -119,8 +120,9 @@ absl::StatusOr> CreateTaskRunner( Detection GetExpectedFaceDetectionResult(absl::string_view file_name) { Detection detection; - CHECK_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, file_name), - &detection, Defaults())) + ABSL_CHECK_OK( + GetTextProto(file::JoinPath("./", kTestDataDirectory, file_name), + &detection, Defaults())) << "Expected face detection result does not exist."; return detection; } diff --git a/mediapipe/tasks/cc/vision/face_detector/face_detector_test.cc b/mediapipe/tasks/cc/vision/face_detector/face_detector_test.cc index b2db21e7e..fcb32a7d3 100644 --- a/mediapipe/tasks/cc/vision/face_detector/face_detector_test.cc +++ b/mediapipe/tasks/cc/vision/face_detector/face_detector_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/flags/flag.h" +#include "absl/log/absl_check.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/port/file_helpers.h" @@ -57,8 +58,9 @@ constexpr float kKeypointErrorThreshold = 1e-2; FaceDetectorResult GetExpectedFaceDetectorResult(absl::string_view file_name) { mediapipe::Detection detection; - CHECK_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, file_name), - &detection, Defaults())) + ABSL_CHECK_OK( + GetTextProto(file::JoinPath("./", kTestDataDirectory, file_name), + &detection, Defaults())) << "Expected face detection result does not exist."; return components::containers::ConvertToDetectionResult({detection}); } diff --git a/mediapipe/tasks/cc/vision/face_detector/proto/BUILD b/mediapipe/tasks/cc/vision/face_detector/proto/BUILD index ca9a6f8c4..bdfe65ee7 100644 --- a/mediapipe/tasks/cc/vision/face_detector/proto/BUILD +++ b/mediapipe/tasks/cc/vision/face_detector/proto/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") package(default_visibility = [ "//mediapipe/tasks:internal", + "//mediapipe/tasks:users", ]) licenses(["notice"]) diff --git a/mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto b/mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto index 3d0f1288d..0b082c650 100644 --- a/mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto +++ b/mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/face_geometry/face_geometry_from_landmarks_graph.cc b/mediapipe/tasks/cc/vision/face_geometry/face_geometry_from_landmarks_graph.cc index 8c69a31fd..bf4006a94 100644 --- a/mediapipe/tasks/cc/vision/face_geometry/face_geometry_from_landmarks_graph.cc +++ b/mediapipe/tasks/cc/vision/face_geometry/face_geometry_from_landmarks_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/face_geometry/face_geometry_from_landmarks_graph_test.cc b/mediapipe/tasks/cc/vision/face_geometry/face_geometry_from_landmarks_graph_test.cc index 74baff5d8..933ad75c7 100644 --- a/mediapipe/tasks/cc/vision/face_geometry/face_geometry_from_landmarks_graph_test.cc +++ b/mediapipe/tasks/cc/vision/face_geometry/face_geometry_from_landmarks_graph_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/face_landmarker/BUILD b/mediapipe/tasks/cc/vision/face_landmarker/BUILD index 1004069bd..04e33c141 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/face_landmarker/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -73,6 +73,8 @@ cc_library( ":tensors_to_face_landmarks_graph", "//mediapipe/calculators/core:begin_loop_calculator", "//mediapipe/calculators/core:end_loop_calculator", + "//mediapipe/calculators/core:get_vector_item_calculator", + "//mediapipe/calculators/core:get_vector_item_calculator_cc_proto", "//mediapipe/calculators/core:split_vector_calculator", "//mediapipe/calculators/core:split_vector_calculator_cc_proto", "//mediapipe/calculators/image:image_properties_calculator", @@ -86,6 +88,8 @@ cc_library( "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", "//mediapipe/calculators/util:landmark_letterbox_removal_calculator", "//mediapipe/calculators/util:landmark_projection_calculator", + "//mediapipe/calculators/util:landmarks_smoothing_calculator", + "//mediapipe/calculators/util:landmarks_smoothing_calculator_cc_proto", "//mediapipe/calculators/util:landmarks_to_detection_calculator", "//mediapipe/calculators/util:rect_transformation_calculator", "//mediapipe/calculators/util:rect_transformation_calculator_cc_proto", @@ -119,6 +123,7 @@ cc_library( name = "face_landmarker_result", srcs = ["face_landmarker_result.cc"], hdrs = ["face_landmarker_result.h"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", @@ -133,6 +138,7 @@ cc_library( name = "face_landmarker", srcs = ["face_landmarker.cc"], hdrs = ["face_landmarker.h"], + visibility = ["//visibility:public"], deps = [ ":face_landmarker_graph", ":face_landmarker_result", @@ -160,20 +166,6 @@ cc_library( ], ) -cc_library( - name = "face_landmarker_result_cc", - srcs = ["face_landmarker_result.cc"], - hdrs = ["face_landmarker_result.h"], - deps = [ - "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/framework/formats:landmark_cc_proto", - "//mediapipe/framework/formats:matrix", - "//mediapipe/framework/formats:matrix_data_cc_proto", - "//mediapipe/tasks/cc/components/containers:classification_result", - "//mediapipe/tasks/cc/components/containers:landmark", - ], -) - cc_library( name = "face_landmarker_graph", srcs = ["face_landmarker_graph.cc"], @@ -194,8 +186,6 @@ cc_library( "//mediapipe/calculators/util:association_norm_rect_calculator", "//mediapipe/calculators/util:collection_has_min_size_calculator", "//mediapipe/calculators/util:collection_has_min_size_calculator_cc_proto", - "//mediapipe/calculators/util:landmarks_smoothing_calculator", - "//mediapipe/calculators/util:landmarks_smoothing_calculator_cc_proto", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", @@ -223,7 +213,13 @@ cc_library( "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto", "//mediapipe/util:graph_builder_utils", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings:str_format", ], alwayslink = 1, ) + +cc_library( + name = "face_landmarks_connections", + hdrs = ["face_landmarks_connections.h"], +) diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_blendshapes_graph.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_blendshapes_graph.cc index e1c743a9b..68b87bb6b 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_blendshapes_graph.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_blendshapes_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_blendshapes_graph_test.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_blendshapes_graph_test.cc index 5c342a8e9..e83f16da5 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_blendshapes_graph_test.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_blendshapes_graph_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.cc index b40ea3324..88b0b5eb1 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.h b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.h index 2c93fcba5..0b23fffd4 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.h +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph.cc index c681cd2de..57b56f4bb 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/log/absl_log.h" #include "absl/strings/str_format.h" #include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h" #include "mediapipe/calculators/core/concatenate_vector_calculator.h" @@ -26,7 +27,6 @@ limitations under the License. #include "mediapipe/calculators/core/get_vector_item_calculator.pb.h" #include "mediapipe/calculators/util/association_calculator.pb.h" #include "mediapipe/calculators/util/collection_has_min_size_calculator.pb.h" -#include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h" #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/formats/classification.pb.h" @@ -166,25 +166,12 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, ->mutable_base_options() ->mutable_acceleration() ->mutable_xnnpack(); - LOG(WARNING) << "Face blendshape model contains CPU only ops. Sets " - << "FaceBlendshapesGraph acceleration to Xnnpack."; + ABSL_LOG(WARNING) << "Sets FaceBlendshapesGraph acceleration to xnnpack " + << "by default."; } return absl::OkStatus(); } - -void ConfigureLandmarksSmoothingCalculator( - mediapipe::LandmarksSmoothingCalculatorOptions& options) { - // Min cutoff 0.05 results into ~0.01 alpha in landmark EMA filter when - // landmark is static. - options.mutable_one_euro_filter()->set_min_cutoff(0.05f); - // Beta 80.0 in combintation with min_cutoff 0.05 results into ~0.94 - // alpha in landmark EMA filter when landmark is moving fast. - options.mutable_one_euro_filter()->set_beta(80.0f); - // Derivative cutoff 1.0 results into ~0.17 alpha in landmark velocity - // EMA filter. - options.mutable_one_euro_filter()->set_derivate_cutoff(1.0f); -} } // namespace // A "mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph" performs face @@ -464,32 +451,17 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph { auto image_size = image_properties.Out(kSizeTag); // Apply smoothing filter only on the single face landmarks, because - // landmakrs smoothing calculator doesn't support multiple landmarks yet. + // landmarks smoothing calculator doesn't support multiple landmarks yet. if (face_detector_options.num_faces() == 1) { - // Get the single face landmarks - auto& get_vector_item = - graph.AddNode("GetNormalizedLandmarkListVectorItemCalculator"); - get_vector_item.GetOptions() - .set_item_index(0); - face_landmarks >> get_vector_item.In(kVectorTag); - auto single_face_landmarks = get_vector_item.Out(kItemTag); - - // Apply smoothing filter on face landmarks. - auto& landmarks_smoothing = graph.AddNode("LandmarksSmoothingCalculator"); - ConfigureLandmarksSmoothingCalculator( - landmarks_smoothing - .GetOptions()); - single_face_landmarks >> landmarks_smoothing.In(kNormLandmarksTag); - image_size >> landmarks_smoothing.In(kImageSizeTag); - auto smoothed_single_face_landmarks = - landmarks_smoothing.Out(kNormFilteredLandmarksTag); - - // Wrap the single face landmarks into a vector of landmarks. - auto& concatenate_vector = - graph.AddNode("ConcatenateNormalizedLandmarkListVectorCalculator"); - smoothed_single_face_landmarks >> concatenate_vector.In(""); - face_landmarks = concatenate_vector.Out("") - .Cast>(); + face_landmarks_detector_graph + .GetOptions() + .set_smooth_landmarks(tasks_options.base_options().use_stream_mode()); + } else if (face_detector_options.num_faces() > 1 && + face_landmarks_detector_graph + .GetOptions() + .smooth_landmarks()) { + return absl::InvalidArgumentError( + "Currently face landmarks smoothing only support a single face."); } if (tasks_options.base_options().use_stream_mode()) { @@ -533,9 +505,10 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph { // Back edge. face_rects_for_next_frame >> previous_loopback.In(kLoopTag); } else { - // While not in stream mode, the input images are not guaranteed to be in - // series, and we don't want to enable the tracking and rect associations - // between input images. Always use the face detector graph. + // While not in stream mode, the input images are not guaranteed to be + // in series, and we don't want to enable the tracking and rect + // associations between input images. Always use the face detector + // graph. image_in >> face_detector.In(kImageTag); if (norm_rect_in) { *norm_rect_in >> face_detector.In(kNormRectTag); @@ -571,7 +544,8 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph { } // TODO: Replace PassThroughCalculator with a calculator that - // converts the pixel data to be stored on the target storage (CPU vs GPU). + // converts the pixel data to be stored on the target storage (CPU vs + // GPU). auto& pass_through = graph.AddNode("PassThroughCalculator"); image_in >> pass_through.In(""); diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph_test.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph_test.cc index 063d6835e..a8ab0b9e9 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph_test.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,11 +14,13 @@ limitations under the License. ==============================================================================*/ #include +#include #include "absl/flags/flag.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/calculator_framework.h" @@ -31,6 +33,7 @@ limitations under the License. #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" @@ -95,6 +98,8 @@ constexpr float kLandmarksDiffMargin = 0.03; constexpr float kBlendshapesDiffMargin = 0.1; constexpr float kFaceGeometryDiffMargin = 0.02; +constexpr char kLandmarksSmoothingCalculator[] = "LandmarksSmoothingCalculator"; + template ProtoT GetExpectedProto(absl::string_view filename) { ProtoT expected_proto; @@ -103,6 +108,13 @@ ProtoT GetExpectedProto(absl::string_view filename) { return expected_proto; } +struct VerifyExpandedConfigTestParams { + std::string test_name; + bool use_stream_mode; + int num_faces; + bool has_smoothing_calculator; +}; + // Struct holding the parameters for parameterized FaceLandmarkerGraphTest // class. struct FaceLandmarkerGraphTestParams { @@ -165,6 +177,25 @@ absl::StatusOr> CreateFaceLandmarkerGraphTaskRunner( absl::make_unique()); } +absl::StatusOr ExpandConfig( + const std::string& config_str) { + auto config = + mediapipe::ParseTextProtoOrDie(config_str); + CalculatorGraph graph; + MP_RETURN_IF_ERROR(graph.Initialize(config)); + return graph.Config(); +} + +bool HasCalculatorInConfig(const std::string& calculator_name, + const CalculatorGraphConfig& config) { + for (const auto& node : config.node()) { + if (node.calculator() == calculator_name) { + return true; + } + } + return false; +} + // Helper function to construct NormalizeRect proto. NormalizedRect MakeNormRect(float x_center, float y_center, float width, float height, float rotation) { @@ -177,6 +208,71 @@ NormalizedRect MakeNormRect(float x_center, float y_center, float width, return face_rect; } +constexpr char kGraphConfigString[] = R"pb( + node { + calculator: "mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph" + input_stream: "IMAGE:image_in" + output_stream: "NORM_LANDMARKS:face_landmarks" + options { + [mediapipe.tasks.vision.face_landmarker.proto.FaceLandmarkerGraphOptions + .ext] { + base_options { + model_asset { + file_name: "mediapipe/tasks/testdata/vision/face_landmarker_v2_with_blendshapes.task" + } + use_stream_mode: $0 + } + face_detector_graph_options { num_faces: $1 } + } + } + } + input_stream: "IMAGE:image_in" +)pb"; + +class VerifyExpandedConfig + : public testing::TestWithParam {}; + +TEST_P(VerifyExpandedConfig, Succeeds) { + MP_ASSERT_OK_AND_ASSIGN( + auto actual_graph, + ExpandConfig(absl::Substitute( + kGraphConfigString, GetParam().use_stream_mode ? "true" : "false", + std::to_string(GetParam().num_faces)))); + if (GetParam().has_smoothing_calculator) { + EXPECT_TRUE( + HasCalculatorInConfig(kLandmarksSmoothingCalculator, actual_graph)); + } else { + EXPECT_FALSE( + HasCalculatorInConfig(kLandmarksSmoothingCalculator, actual_graph)); + } +} + +INSTANTIATE_TEST_SUITE_P( + VerifyExpandedConfig, VerifyExpandedConfig, + Values(VerifyExpandedConfigTestParams{ + /*test_name=*/"NonStreamOneFaceHasNoSmoothing", + /*use_stream_mode=*/false, + /*num_faces=*/1, + /*has_smoothing_calculator=*/false}, + VerifyExpandedConfigTestParams{ + /*test_name=*/"NonStreamTwoFaceHasNoSmoothing", + /*use_stream_mode=*/false, + /*num_faces=*/2, + /*has_smoothing_calculator=*/false}, + VerifyExpandedConfigTestParams{ + /*test_name=*/"StreamOneFaceHasSmoothing", + /*use_stream_mode=*/true, + /*num_faces=*/1, + /*has_smoothing_calculator=*/true}, + VerifyExpandedConfigTestParams{ + /*test_name=*/"StreamTwoFaceHasNoSmoothing", + /*use_stream_mode=*/true, + /*num_faces=*/2, + /*has_smoothing_calculator=*/false}), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + class FaceLandmarkerGraphTest : public testing::TestWithParam {}; diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.cc index 53a171ed5..090781cc2 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.h b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.h index bc097d6c3..f7f06cf03 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.h +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result_test.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result_test.cc index 4123a81f3..42a08e431 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result_test.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_test.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_test.cc index 411693ecf..41b5ede6a 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_test.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -67,7 +67,7 @@ constexpr char kPortraitExpectedFaceLandmarksName[] = "portrait_expected_face_landmarks.pbtxt"; constexpr char kPortraitExpectedBlendshapesName[] = "portrait_expected_blendshapes.pbtxt"; -constexpr char kPortaitExpectedFaceGeomertyName[] = +constexpr char kPortraitExpectedFaceGeometryName[] = "portrait_expected_face_geometry.pbtxt"; constexpr float kLandmarksDiffMargin = 0.03; @@ -100,7 +100,7 @@ struct FaceLandmarkerTestParams { mediapipe::MatrixData MakePortraitExpectedFacialTransformationMatrix() { auto face_geometry = GetExpectedProto( - kPortaitExpectedFaceGeomertyName); + kPortraitExpectedFaceGeometryName); return face_geometry.pose_transform_matrix(); } diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_connections.h b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_connections.h new file mode 100644 index 000000000..360083a7f --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_connections.h @@ -0,0 +1,651 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_FACE_LANDMARKER_FACE_LANDMARKS_CONNECTIONS_H_ +#define MEDIAPIPE_TASKS_CC_VISION_FACE_LANDMARKER_FACE_LANDMARKS_CONNECTIONS_H_ + +#include + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace face_landmarker { + +struct FaceLandmarksConnections { + static constexpr std::array, 40> kFaceLandmarksLips{ + {{61, 146}, {146, 91}, {91, 181}, {181, 84}, {84, 17}, {17, 314}, + {314, 405}, {405, 321}, {321, 375}, {375, 291}, {61, 185}, {185, 40}, + {40, 39}, {39, 37}, {37, 0}, {0, 267}, {267, 269}, {269, 270}, + {270, 409}, {409, 291}, {78, 95}, {95, 88}, {88, 178}, {178, 87}, + {87, 14}, {14, 317}, {317, 402}, {402, 318}, {318, 324}, {324, 308}, + {78, 191}, {191, 80}, {80, 81}, {81, 82}, {82, 13}, {13, 312}, + {312, 311}, {311, 310}, {310, 415}, {415, 308}}}; + + static constexpr std::array, 16> kFaceLandmarksLeftEye{ + {{263, 249}, + {249, 390}, + {390, 373}, + {373, 374}, + {374, 380}, + {380, 381}, + {381, 382}, + {382, 362}, + {263, 466}, + {466, 388}, + {388, 387}, + {387, 386}, + {386, 385}, + {385, 384}, + {384, 398}, + {398, 362}}}; + + static constexpr std::array, 8> kFaceLandmarksLeftEyeBrow{ + {{276, 283}, + {283, 282}, + {282, 295}, + {295, 285}, + {300, 293}, + {293, 334}, + {334, 296}, + {296, 336}}}; + + static constexpr std::array, 4> kFaceLandmarksLeftIris{ + {{474, 475}, {475, 476}, {476, 477}, {477, 474}}}; + + static constexpr std::array, 16> kFaceLandmarksRightEye{ + {{33, 7}, + {7, 163}, + {163, 144}, + {144, 145}, + {145, 153}, + {153, 154}, + {154, 155}, + {155, 133}, + {33, 246}, + {246, 161}, + {161, 160}, + {160, 159}, + {159, 158}, + {158, 157}, + {157, 173}, + {173, 133}}}; + + static constexpr std::array, 8> kFaceLandmarksRightEyeBrow{ + {{46, 53}, + {53, 52}, + {52, 65}, + {65, 55}, + {70, 63}, + {63, 105}, + {105, 66}, + {66, 107}}}; + + static constexpr std::array, 4> kFaceLandmarksRightIris{ + {{469, 470}, {470, 471}, {471, 472}, {472, 469}}}; + + static constexpr std::array, 36> kFaceLandmarksFaceOval{ + {{10, 338}, {338, 297}, {297, 332}, {332, 284}, {284, 251}, {251, 389}, + {389, 356}, {356, 454}, {454, 323}, {323, 361}, {361, 288}, {288, 397}, + {397, 365}, {365, 379}, {379, 378}, {378, 400}, {400, 377}, {377, 152}, + {152, 148}, {148, 176}, {176, 149}, {149, 150}, {150, 136}, {136, 172}, + {172, 58}, {58, 132}, {132, 93}, {93, 234}, {234, 127}, {127, 162}, + {162, 21}, {21, 54}, {54, 103}, {103, 67}, {67, 109}, {109, 10}}}; + + // Lips + Left Eye + Left Eye Brows + Right Eye + Right Eye Brows + Face Oval. + static constexpr std::array, 132> kFaceLandmarksConnectors{ + {{61, 146}, {146, 91}, {91, 181}, {181, 84}, {84, 17}, {17, 314}, + {314, 405}, {405, 321}, {321, 375}, {375, 291}, {61, 185}, {185, 40}, + {40, 39}, {39, 37}, {37, 0}, {0, 267}, {267, 269}, {269, 270}, + {270, 409}, {409, 291}, {78, 95}, {95, 88}, {88, 178}, {178, 87}, + {87, 14}, {14, 317}, {317, 402}, {402, 318}, {318, 324}, {324, 308}, + {78, 191}, {191, 80}, {80, 81}, {81, 82}, {82, 13}, {13, 312}, + {312, 311}, {311, 310}, {310, 415}, {415, 30}, {263, 249}, {249, 390}, + {390, 373}, {373, 374}, {374, 380}, {380, 381}, {381, 382}, {382, 362}, + {263, 466}, {466, 388}, {388, 387}, {387, 386}, {386, 385}, {385, 384}, + {384, 398}, {398, 362}, {276, 283}, {283, 282}, {282, 295}, {295, 285}, + {300, 293}, {293, 334}, {334, 296}, {296, 336}, {33, 7}, {7, 163}, + {163, 144}, {144, 145}, {145, 153}, {153, 154}, {154, 155}, {155, 133}, + {33, 246}, {246, 161}, {161, 160}, {160, 159}, {159, 158}, {158, 157}, + {157, 173}, {173, 13}, {46, 53}, {53, 52}, {52, 65}, {65, 55}, + {70, 63}, {63, 105}, {105, 66}, {66, 107}, {10, 338}, {338, 297}, + {297, 332}, {332, 284}, {284, 251}, {251, 389}, {389, 356}, {356, 454}, + {454, 323}, {323, 361}, {361, 288}, {288, 397}, {397, 365}, {365, 379}, + {379, 378}, {378, 400}, {400, 377}, {377, 152}, {152, 148}, {148, 176}, + {176, 149}, {149, 150}, {150, 136}, {136, 172}, {172, 58}, {58, 132}, + {132, 93}, {93, 234}, {234, 127}, {127, 162}, {162, 21}, {21, 54}, + {54, 103}, {103, 67}, {67, 109}, {109, 10}}}; + + static constexpr std::array, 2556> + kFaceLandmarksTesselation{ + {{127, 34}, {34, 139}, {139, 127}, {11, 0}, {0, 37}, + {37, 11}, {232, 231}, {231, 120}, {120, 232}, {72, 37}, + {37, 39}, {39, 72}, {128, 121}, {121, 47}, {47, 128}, + {232, 121}, {121, 128}, {128, 232}, {104, 69}, {69, 67}, + {67, 104}, {175, 171}, {171, 148}, {148, 175}, {118, 50}, + {50, 101}, {101, 118}, {73, 39}, {39, 40}, {40, 73}, + {9, 151}, {151, 108}, {108, 9}, {48, 115}, {115, 131}, + {131, 48}, {194, 204}, {204, 211}, {211, 194}, {74, 40}, + {40, 185}, {185, 74}, {80, 42}, {42, 183}, {183, 80}, + {40, 92}, {92, 186}, {186, 40}, {230, 229}, {229, 118}, + {118, 230}, {202, 212}, {212, 214}, {214, 202}, {83, 18}, + {18, 17}, {17, 83}, {76, 61}, {61, 146}, {146, 76}, + {160, 29}, {29, 30}, {30, 160}, {56, 157}, {157, 173}, + {173, 56}, {106, 204}, {204, 194}, {194, 106}, {135, 214}, + {214, 192}, {192, 135}, {203, 165}, {165, 98}, {98, 203}, + {21, 71}, {71, 68}, {68, 21}, {51, 45}, {45, 4}, + {4, 51}, {144, 24}, {24, 23}, {23, 144}, {77, 146}, + {146, 91}, {91, 77}, {205, 50}, {50, 187}, {187, 205}, + {201, 200}, {200, 18}, {18, 201}, {91, 106}, {106, 182}, + {182, 91}, {90, 91}, {91, 181}, {181, 90}, {85, 84}, + {84, 17}, {17, 85}, {206, 203}, {203, 36}, {36, 206}, + {148, 171}, {171, 140}, {140, 148}, {92, 40}, {40, 39}, + {39, 92}, {193, 189}, {189, 244}, {244, 193}, {159, 158}, + {158, 28}, {28, 159}, {247, 246}, {246, 161}, {161, 247}, + {236, 3}, {3, 196}, {196, 236}, {54, 68}, {68, 104}, + {104, 54}, {193, 168}, {168, 8}, {8, 193}, {117, 228}, + {228, 31}, {31, 117}, {189, 193}, {193, 55}, {55, 189}, + {98, 97}, {97, 99}, {99, 98}, {126, 47}, {47, 100}, + {100, 126}, {166, 79}, {79, 218}, {218, 166}, {155, 154}, + {154, 26}, {26, 155}, {209, 49}, {49, 131}, {131, 209}, + {135, 136}, {136, 150}, {150, 135}, {47, 126}, {126, 217}, + {217, 47}, {223, 52}, {52, 53}, {53, 223}, {45, 51}, + {51, 134}, {134, 45}, {211, 170}, {170, 140}, {140, 211}, + {67, 69}, {69, 108}, {108, 67}, {43, 106}, {106, 91}, + {91, 43}, {230, 119}, {119, 120}, {120, 230}, {226, 130}, + {130, 247}, {247, 226}, {63, 53}, {53, 52}, {52, 63}, + {238, 20}, {20, 242}, {242, 238}, {46, 70}, {70, 156}, + {156, 46}, {78, 62}, {62, 96}, {96, 78}, {46, 53}, + {53, 63}, {63, 46}, {143, 34}, {34, 227}, {227, 143}, + {123, 117}, {117, 111}, {111, 123}, {44, 125}, {125, 19}, + {19, 44}, {236, 134}, {134, 51}, {51, 236}, {216, 206}, + {206, 205}, {205, 216}, {154, 153}, {153, 22}, {22, 154}, + {39, 37}, {37, 167}, {167, 39}, {200, 201}, {201, 208}, + {208, 200}, {36, 142}, {142, 100}, {100, 36}, {57, 212}, + {212, 202}, {202, 57}, {20, 60}, {60, 99}, {99, 20}, + {28, 158}, {158, 157}, {157, 28}, {35, 226}, {226, 113}, + {113, 35}, {160, 159}, {159, 27}, {27, 160}, {204, 202}, + {202, 210}, {210, 204}, {113, 225}, {225, 46}, {46, 113}, + {43, 202}, {202, 204}, {204, 43}, {62, 76}, {76, 77}, + {77, 62}, {137, 123}, {123, 116}, {116, 137}, {41, 38}, + {38, 72}, {72, 41}, {203, 129}, {129, 142}, {142, 203}, + {64, 98}, {98, 240}, {240, 64}, {49, 102}, {102, 64}, + {64, 49}, {41, 73}, {73, 74}, {74, 41}, {212, 216}, + {216, 207}, {207, 212}, {42, 74}, {74, 184}, {184, 42}, + {169, 170}, {170, 211}, {211, 169}, {170, 149}, {149, 176}, + {176, 170}, {105, 66}, {66, 69}, {69, 105}, {122, 6}, + {6, 168}, {168, 122}, {123, 147}, {147, 187}, {187, 123}, + {96, 77}, {77, 90}, {90, 96}, {65, 55}, {55, 107}, + {107, 65}, {89, 90}, {90, 180}, {180, 89}, {101, 100}, + {100, 120}, {120, 101}, {63, 105}, {105, 104}, {104, 63}, + {93, 137}, {137, 227}, {227, 93}, {15, 86}, {86, 85}, + {85, 15}, {129, 102}, {102, 49}, {49, 129}, {14, 87}, + {87, 86}, {86, 14}, {55, 8}, {8, 9}, {9, 55}, + {100, 47}, {47, 121}, {121, 100}, {145, 23}, {23, 22}, + {22, 145}, {88, 89}, {89, 179}, {179, 88}, {6, 122}, + {122, 196}, {196, 6}, {88, 95}, {95, 96}, {96, 88}, + {138, 172}, {172, 136}, {136, 138}, {215, 58}, {58, 172}, + {172, 215}, {115, 48}, {48, 219}, {219, 115}, {42, 80}, + {80, 81}, {81, 42}, {195, 3}, {3, 51}, {51, 195}, + {43, 146}, {146, 61}, {61, 43}, {171, 175}, {175, 199}, + {199, 171}, {81, 82}, {82, 38}, {38, 81}, {53, 46}, + {46, 225}, {225, 53}, {144, 163}, {163, 110}, {110, 144}, + {52, 65}, {65, 66}, {66, 52}, {229, 228}, {228, 117}, + {117, 229}, {34, 127}, {127, 234}, {234, 34}, {107, 108}, + {108, 69}, {69, 107}, {109, 108}, {108, 151}, {151, 109}, + {48, 64}, {64, 235}, {235, 48}, {62, 78}, {78, 191}, + {191, 62}, {129, 209}, {209, 126}, {126, 129}, {111, 35}, + {35, 143}, {143, 111}, {117, 123}, {123, 50}, {50, 117}, + {222, 65}, {65, 52}, {52, 222}, {19, 125}, {125, 141}, + {141, 19}, {221, 55}, {55, 65}, {65, 221}, {3, 195}, + {195, 197}, {197, 3}, {25, 7}, {7, 33}, {33, 25}, + {220, 237}, {237, 44}, {44, 220}, {70, 71}, {71, 139}, + {139, 70}, {122, 193}, {193, 245}, {245, 122}, {247, 130}, + {130, 33}, {33, 247}, {71, 21}, {21, 162}, {162, 71}, + {170, 169}, {169, 150}, {150, 170}, {188, 174}, {174, 196}, + {196, 188}, {216, 186}, {186, 92}, {92, 216}, {2, 97}, + {97, 167}, {167, 2}, {141, 125}, {125, 241}, {241, 141}, + {164, 167}, {167, 37}, {37, 164}, {72, 38}, {38, 12}, + {12, 72}, {38, 82}, {82, 13}, {13, 38}, {63, 68}, + {68, 71}, {71, 63}, {226, 35}, {35, 111}, {111, 226}, + {101, 50}, {50, 205}, {205, 101}, {206, 92}, {92, 165}, + {165, 206}, {209, 198}, {198, 217}, {217, 209}, {165, 167}, + {167, 97}, {97, 165}, {220, 115}, {115, 218}, {218, 220}, + {133, 112}, {112, 243}, {243, 133}, {239, 238}, {238, 241}, + {241, 239}, {214, 135}, {135, 169}, {169, 214}, {190, 173}, + {173, 133}, {133, 190}, {171, 208}, {208, 32}, {32, 171}, + {125, 44}, {44, 237}, {237, 125}, {86, 87}, {87, 178}, + {178, 86}, {85, 86}, {86, 179}, {179, 85}, {84, 85}, + {85, 180}, {180, 84}, {83, 84}, {84, 181}, {181, 83}, + {201, 83}, {83, 182}, {182, 201}, {137, 93}, {93, 132}, + {132, 137}, {76, 62}, {62, 183}, {183, 76}, {61, 76}, + {76, 184}, {184, 61}, {57, 61}, {61, 185}, {185, 57}, + {212, 57}, {57, 186}, {186, 212}, {214, 207}, {207, 187}, + {187, 214}, {34, 143}, {143, 156}, {156, 34}, {79, 239}, + {239, 237}, {237, 79}, {123, 137}, {137, 177}, {177, 123}, + {44, 1}, {1, 4}, {4, 44}, {201, 194}, {194, 32}, + {32, 201}, {64, 102}, {102, 129}, {129, 64}, {213, 215}, + {215, 138}, {138, 213}, {59, 166}, {166, 219}, {219, 59}, + {242, 99}, {99, 97}, {97, 242}, {2, 94}, {94, 141}, + {141, 2}, {75, 59}, {59, 235}, {235, 75}, {24, 110}, + {110, 228}, {228, 24}, {25, 130}, {130, 226}, {226, 25}, + {23, 24}, {24, 229}, {229, 23}, {22, 23}, {23, 230}, + {230, 22}, {26, 22}, {22, 231}, {231, 26}, {112, 26}, + {26, 232}, {232, 112}, {189, 190}, {190, 243}, {243, 189}, + {221, 56}, {56, 190}, {190, 221}, {28, 56}, {56, 221}, + {221, 28}, {27, 28}, {28, 222}, {222, 27}, {29, 27}, + {27, 223}, {223, 29}, {30, 29}, {29, 224}, {224, 30}, + {247, 30}, {30, 225}, {225, 247}, {238, 79}, {79, 20}, + {20, 238}, {166, 59}, {59, 75}, {75, 166}, {60, 75}, + {75, 240}, {240, 60}, {147, 177}, {177, 215}, {215, 147}, + {20, 79}, {79, 166}, {166, 20}, {187, 147}, {147, 213}, + {213, 187}, {112, 233}, {233, 244}, {244, 112}, {233, 128}, + {128, 245}, {245, 233}, {128, 114}, {114, 188}, {188, 128}, + {114, 217}, {217, 174}, {174, 114}, {131, 115}, {115, 220}, + {220, 131}, {217, 198}, {198, 236}, {236, 217}, {198, 131}, + {131, 134}, {134, 198}, {177, 132}, {132, 58}, {58, 177}, + {143, 35}, {35, 124}, {124, 143}, {110, 163}, {163, 7}, + {7, 110}, {228, 110}, {110, 25}, {25, 228}, {356, 389}, + {389, 368}, {368, 356}, {11, 302}, {302, 267}, {267, 11}, + {452, 350}, {350, 349}, {349, 452}, {302, 303}, {303, 269}, + {269, 302}, {357, 343}, {343, 277}, {277, 357}, {452, 453}, + {453, 357}, {357, 452}, {333, 332}, {332, 297}, {297, 333}, + {175, 152}, {152, 377}, {377, 175}, {347, 348}, {348, 330}, + {330, 347}, {303, 304}, {304, 270}, {270, 303}, {9, 336}, + {336, 337}, {337, 9}, {278, 279}, {279, 360}, {360, 278}, + {418, 262}, {262, 431}, {431, 418}, {304, 408}, {408, 409}, + {409, 304}, {310, 415}, {415, 407}, {407, 310}, {270, 409}, + {409, 410}, {410, 270}, {450, 348}, {348, 347}, {347, 450}, + {422, 430}, {430, 434}, {434, 422}, {313, 314}, {314, 17}, + {17, 313}, {306, 307}, {307, 375}, {375, 306}, {387, 388}, + {388, 260}, {260, 387}, {286, 414}, {414, 398}, {398, 286}, + {335, 406}, {406, 418}, {418, 335}, {364, 367}, {367, 416}, + {416, 364}, {423, 358}, {358, 327}, {327, 423}, {251, 284}, + {284, 298}, {298, 251}, {281, 5}, {5, 4}, {4, 281}, + {373, 374}, {374, 253}, {253, 373}, {307, 320}, {320, 321}, + {321, 307}, {425, 427}, {427, 411}, {411, 425}, {421, 313}, + {313, 18}, {18, 421}, {321, 405}, {405, 406}, {406, 321}, + {320, 404}, {404, 405}, {405, 320}, {315, 16}, {16, 17}, + {17, 315}, {426, 425}, {425, 266}, {266, 426}, {377, 400}, + {400, 369}, {369, 377}, {322, 391}, {391, 269}, {269, 322}, + {417, 465}, {465, 464}, {464, 417}, {386, 257}, {257, 258}, + {258, 386}, {466, 260}, {260, 388}, {388, 466}, {456, 399}, + {399, 419}, {419, 456}, {284, 332}, {332, 333}, {333, 284}, + {417, 285}, {285, 8}, {8, 417}, {346, 340}, {340, 261}, + {261, 346}, {413, 441}, {441, 285}, {285, 413}, {327, 460}, + {460, 328}, {328, 327}, {355, 371}, {371, 329}, {329, 355}, + {392, 439}, {439, 438}, {438, 392}, {382, 341}, {341, 256}, + {256, 382}, {429, 420}, {420, 360}, {360, 429}, {364, 394}, + {394, 379}, {379, 364}, {277, 343}, {343, 437}, {437, 277}, + {443, 444}, {444, 283}, {283, 443}, {275, 440}, {440, 363}, + {363, 275}, {431, 262}, {262, 369}, {369, 431}, {297, 338}, + {338, 337}, {337, 297}, {273, 375}, {375, 321}, {321, 273}, + {450, 451}, {451, 349}, {349, 450}, {446, 342}, {342, 467}, + {467, 446}, {293, 334}, {334, 282}, {282, 293}, {458, 461}, + {461, 462}, {462, 458}, {276, 353}, {353, 383}, {383, 276}, + {308, 324}, {324, 325}, {325, 308}, {276, 300}, {300, 293}, + {293, 276}, {372, 345}, {345, 447}, {447, 372}, {352, 345}, + {345, 340}, {340, 352}, {274, 1}, {1, 19}, {19, 274}, + {456, 248}, {248, 281}, {281, 456}, {436, 427}, {427, 425}, + {425, 436}, {381, 256}, {256, 252}, {252, 381}, {269, 391}, + {391, 393}, {393, 269}, {200, 199}, {199, 428}, {428, 200}, + {266, 330}, {330, 329}, {329, 266}, {287, 273}, {273, 422}, + {422, 287}, {250, 462}, {462, 328}, {328, 250}, {258, 286}, + {286, 384}, {384, 258}, {265, 353}, {353, 342}, {342, 265}, + {387, 259}, {259, 257}, {257, 387}, {424, 431}, {431, 430}, + {430, 424}, {342, 353}, {353, 276}, {276, 342}, {273, 335}, + {335, 424}, {424, 273}, {292, 325}, {325, 307}, {307, 292}, + {366, 447}, {447, 345}, {345, 366}, {271, 303}, {303, 302}, + {302, 271}, {423, 266}, {266, 371}, {371, 423}, {294, 455}, + {455, 460}, {460, 294}, {279, 278}, {278, 294}, {294, 279}, + {271, 272}, {272, 304}, {304, 271}, {432, 434}, {434, 427}, + {427, 432}, {272, 407}, {407, 408}, {408, 272}, {394, 430}, + {430, 431}, {431, 394}, {395, 369}, {369, 400}, {400, 395}, + {334, 333}, {333, 299}, {299, 334}, {351, 417}, {417, 168}, + {168, 351}, {352, 280}, {280, 411}, {411, 352}, {325, 319}, + {319, 320}, {320, 325}, {295, 296}, {296, 336}, {336, 295}, + {319, 403}, {403, 404}, {404, 319}, {330, 348}, {348, 349}, + {349, 330}, {293, 298}, {298, 333}, {333, 293}, {323, 454}, + {454, 447}, {447, 323}, {15, 16}, {16, 315}, {315, 15}, + {358, 429}, {429, 279}, {279, 358}, {14, 15}, {15, 316}, + {316, 14}, {285, 336}, {336, 9}, {9, 285}, {329, 349}, + {349, 350}, {350, 329}, {374, 380}, {380, 252}, {252, 374}, + {318, 402}, {402, 403}, {403, 318}, {6, 197}, {197, 419}, + {419, 6}, {318, 319}, {319, 325}, {325, 318}, {367, 364}, + {364, 365}, {365, 367}, {435, 367}, {367, 397}, {397, 435}, + {344, 438}, {438, 439}, {439, 344}, {272, 271}, {271, 311}, + {311, 272}, {195, 5}, {5, 281}, {281, 195}, {273, 287}, + {287, 291}, {291, 273}, {396, 428}, {428, 199}, {199, 396}, + {311, 271}, {271, 268}, {268, 311}, {283, 444}, {444, 445}, + {445, 283}, {373, 254}, {254, 339}, {339, 373}, {282, 334}, + {334, 296}, {296, 282}, {449, 347}, {347, 346}, {346, 449}, + {264, 447}, {447, 454}, {454, 264}, {336, 296}, {296, 299}, + {299, 336}, {338, 10}, {10, 151}, {151, 338}, {278, 439}, + {439, 455}, {455, 278}, {292, 407}, {407, 415}, {415, 292}, + {358, 371}, {371, 355}, {355, 358}, {340, 345}, {345, 372}, + {372, 340}, {346, 347}, {347, 280}, {280, 346}, {442, 443}, + {443, 282}, {282, 442}, {19, 94}, {94, 370}, {370, 19}, + {441, 442}, {442, 295}, {295, 441}, {248, 419}, {419, 197}, + {197, 248}, {263, 255}, {255, 359}, {359, 263}, {440, 275}, + {275, 274}, {274, 440}, {300, 383}, {383, 368}, {368, 300}, + {351, 412}, {412, 465}, {465, 351}, {263, 467}, {467, 466}, + {466, 263}, {301, 368}, {368, 389}, {389, 301}, {395, 378}, + {378, 379}, {379, 395}, {412, 351}, {351, 419}, {419, 412}, + {436, 426}, {426, 322}, {322, 436}, {2, 164}, {164, 393}, + {393, 2}, {370, 462}, {462, 461}, {461, 370}, {164, 0}, + {0, 267}, {267, 164}, {302, 11}, {11, 12}, {12, 302}, + {268, 12}, {12, 13}, {13, 268}, {293, 300}, {300, 301}, + {301, 293}, {446, 261}, {261, 340}, {340, 446}, {330, 266}, + {266, 425}, {425, 330}, {426, 423}, {423, 391}, {391, 426}, + {429, 355}, {355, 437}, {437, 429}, {391, 327}, {327, 326}, + {326, 391}, {440, 457}, {457, 438}, {438, 440}, {341, 382}, + {382, 362}, {362, 341}, {459, 457}, {457, 461}, {461, 459}, + {434, 430}, {430, 394}, {394, 434}, {414, 463}, {463, 362}, + {362, 414}, {396, 369}, {369, 262}, {262, 396}, {354, 461}, + {461, 457}, {457, 354}, {316, 403}, {403, 402}, {402, 316}, + {315, 404}, {404, 403}, {403, 315}, {314, 405}, {405, 404}, + {404, 314}, {313, 406}, {406, 405}, {405, 313}, {421, 418}, + {418, 406}, {406, 421}, {366, 401}, {401, 361}, {361, 366}, + {306, 408}, {408, 407}, {407, 306}, {291, 409}, {409, 408}, + {408, 291}, {287, 410}, {410, 409}, {409, 287}, {432, 436}, + {436, 410}, {410, 432}, {434, 416}, {416, 411}, {411, 434}, + {264, 368}, {368, 383}, {383, 264}, {309, 438}, {438, 457}, + {457, 309}, {352, 376}, {376, 401}, {401, 352}, {274, 275}, + {275, 4}, {4, 274}, {421, 428}, {428, 262}, {262, 421}, + {294, 327}, {327, 358}, {358, 294}, {433, 416}, {416, 367}, + {367, 433}, {289, 455}, {455, 439}, {439, 289}, {462, 370}, + {370, 326}, {326, 462}, {2, 326}, {326, 370}, {370, 2}, + {305, 460}, {460, 455}, {455, 305}, {254, 449}, {449, 448}, + {448, 254}, {255, 261}, {261, 446}, {446, 255}, {253, 450}, + {450, 449}, {449, 253}, {252, 451}, {451, 450}, {450, 252}, + {256, 452}, {452, 451}, {451, 256}, {341, 453}, {453, 452}, + {452, 341}, {413, 464}, {464, 463}, {463, 413}, {441, 413}, + {413, 414}, {414, 441}, {258, 442}, {442, 441}, {441, 258}, + {257, 443}, {443, 442}, {442, 257}, {259, 444}, {444, 443}, + {443, 259}, {260, 445}, {445, 444}, {444, 260}, {467, 342}, + {342, 445}, {445, 467}, {459, 458}, {458, 250}, {250, 459}, + {289, 392}, {392, 290}, {290, 289}, {290, 328}, {328, 460}, + {460, 290}, {376, 433}, {433, 435}, {435, 376}, {250, 290}, + {290, 392}, {392, 250}, {411, 416}, {416, 433}, {433, 411}, + {341, 463}, {463, 464}, {464, 341}, {453, 464}, {464, 465}, + {465, 453}, {357, 465}, {465, 412}, {412, 357}, {343, 412}, + {412, 399}, {399, 343}, {360, 363}, {363, 440}, {440, 360}, + {437, 399}, {399, 456}, {456, 437}, {420, 456}, {456, 363}, + {363, 420}, {401, 435}, {435, 288}, {288, 401}, {372, 383}, + {383, 353}, {353, 372}, {339, 255}, {255, 249}, {249, 339}, + {448, 261}, {261, 255}, {255, 448}, {133, 243}, {243, 190}, + {190, 133}, {133, 155}, {155, 112}, {112, 133}, {33, 246}, + {246, 247}, {247, 33}, {33, 130}, {130, 25}, {25, 33}, + {398, 384}, {384, 286}, {286, 398}, {362, 398}, {398, 414}, + {414, 362}, {362, 463}, {463, 341}, {341, 362}, {263, 359}, + {359, 467}, {467, 263}, {263, 249}, {249, 255}, {255, 263}, + {466, 467}, {467, 260}, {260, 466}, {75, 60}, {60, 166}, + {166, 75}, {238, 239}, {239, 79}, {79, 238}, {162, 127}, + {127, 139}, {139, 162}, {72, 11}, {11, 37}, {37, 72}, + {121, 232}, {232, 120}, {120, 121}, {73, 72}, {72, 39}, + {39, 73}, {114, 128}, {128, 47}, {47, 114}, {233, 232}, + {232, 128}, {128, 233}, {103, 104}, {104, 67}, {67, 103}, + {152, 175}, {175, 148}, {148, 152}, {119, 118}, {118, 101}, + {101, 119}, {74, 73}, {73, 40}, {40, 74}, {107, 9}, + {9, 108}, {108, 107}, {49, 48}, {48, 131}, {131, 49}, + {32, 194}, {194, 211}, {211, 32}, {184, 74}, {74, 185}, + {185, 184}, {191, 80}, {80, 183}, {183, 191}, {185, 40}, + {40, 186}, {186, 185}, {119, 230}, {230, 118}, {118, 119}, + {210, 202}, {202, 214}, {214, 210}, {84, 83}, {83, 17}, + {17, 84}, {77, 76}, {76, 146}, {146, 77}, {161, 160}, + {160, 30}, {30, 161}, {190, 56}, {56, 173}, {173, 190}, + {182, 106}, {106, 194}, {194, 182}, {138, 135}, {135, 192}, + {192, 138}, {129, 203}, {203, 98}, {98, 129}, {54, 21}, + {21, 68}, {68, 54}, {5, 51}, {51, 4}, {4, 5}, + {145, 144}, {144, 23}, {23, 145}, {90, 77}, {77, 91}, + {91, 90}, {207, 205}, {205, 187}, {187, 207}, {83, 201}, + {201, 18}, {18, 83}, {181, 91}, {91, 182}, {182, 181}, + {180, 90}, {90, 181}, {181, 180}, {16, 85}, {85, 17}, + {17, 16}, {205, 206}, {206, 36}, {36, 205}, {176, 148}, + {148, 140}, {140, 176}, {165, 92}, {92, 39}, {39, 165}, + {245, 193}, {193, 244}, {244, 245}, {27, 159}, {159, 28}, + {28, 27}, {30, 247}, {247, 161}, {161, 30}, {174, 236}, + {236, 196}, {196, 174}, {103, 54}, {54, 104}, {104, 103}, + {55, 193}, {193, 8}, {8, 55}, {111, 117}, {117, 31}, + {31, 111}, {221, 189}, {189, 55}, {55, 221}, {240, 98}, + {98, 99}, {99, 240}, {142, 126}, {126, 100}, {100, 142}, + {219, 166}, {166, 218}, {218, 219}, {112, 155}, {155, 26}, + {26, 112}, {198, 209}, {209, 131}, {131, 198}, {169, 135}, + {135, 150}, {150, 169}, {114, 47}, {47, 217}, {217, 114}, + {224, 223}, {223, 53}, {53, 224}, {220, 45}, {45, 134}, + {134, 220}, {32, 211}, {211, 140}, {140, 32}, {109, 67}, + {67, 108}, {108, 109}, {146, 43}, {43, 91}, {91, 146}, + {231, 230}, {230, 120}, {120, 231}, {113, 226}, {226, 247}, + {247, 113}, {105, 63}, {63, 52}, {52, 105}, {241, 238}, + {238, 242}, {242, 241}, {124, 46}, {46, 156}, {156, 124}, + {95, 78}, {78, 96}, {96, 95}, {70, 46}, {46, 63}, + {63, 70}, {116, 143}, {143, 227}, {227, 116}, {116, 123}, + {123, 111}, {111, 116}, {1, 44}, {44, 19}, {19, 1}, + {3, 236}, {236, 51}, {51, 3}, {207, 216}, {216, 205}, + {205, 207}, {26, 154}, {154, 22}, {22, 26}, {165, 39}, + {39, 167}, {167, 165}, {199, 200}, {200, 208}, {208, 199}, + {101, 36}, {36, 100}, {100, 101}, {43, 57}, {57, 202}, + {202, 43}, {242, 20}, {20, 99}, {99, 242}, {56, 28}, + {28, 157}, {157, 56}, {124, 35}, {35, 113}, {113, 124}, + {29, 160}, {160, 27}, {27, 29}, {211, 204}, {204, 210}, + {210, 211}, {124, 113}, {113, 46}, {46, 124}, {106, 43}, + {43, 204}, {204, 106}, {96, 62}, {62, 77}, {77, 96}, + {227, 137}, {137, 116}, {116, 227}, {73, 41}, {41, 72}, + {72, 73}, {36, 203}, {203, 142}, {142, 36}, {235, 64}, + {64, 240}, {240, 235}, {48, 49}, {49, 64}, {64, 48}, + {42, 41}, {41, 74}, {74, 42}, {214, 212}, {212, 207}, + {207, 214}, {183, 42}, {42, 184}, {184, 183}, {210, 169}, + {169, 211}, {211, 210}, {140, 170}, {170, 176}, {176, 140}, + {104, 105}, {105, 69}, {69, 104}, {193, 122}, {122, 168}, + {168, 193}, {50, 123}, {123, 187}, {187, 50}, {89, 96}, + {96, 90}, {90, 89}, {66, 65}, {65, 107}, {107, 66}, + {179, 89}, {89, 180}, {180, 179}, {119, 101}, {101, 120}, + {120, 119}, {68, 63}, {63, 104}, {104, 68}, {234, 93}, + {93, 227}, {227, 234}, {16, 15}, {15, 85}, {85, 16}, + {209, 129}, {129, 49}, {49, 209}, {15, 14}, {14, 86}, + {86, 15}, {107, 55}, {55, 9}, {9, 107}, {120, 100}, + {100, 121}, {121, 120}, {153, 145}, {145, 22}, {22, 153}, + {178, 88}, {88, 179}, {179, 178}, {197, 6}, {6, 196}, + {196, 197}, {89, 88}, {88, 96}, {96, 89}, {135, 138}, + {138, 136}, {136, 135}, {138, 215}, {215, 172}, {172, 138}, + {218, 115}, {115, 219}, {219, 218}, {41, 42}, {42, 81}, + {81, 41}, {5, 195}, {195, 51}, {51, 5}, {57, 43}, + {43, 61}, {61, 57}, {208, 171}, {171, 199}, {199, 208}, + {41, 81}, {81, 38}, {38, 41}, {224, 53}, {53, 225}, + {225, 224}, {24, 144}, {144, 110}, {110, 24}, {105, 52}, + {52, 66}, {66, 105}, {118, 229}, {229, 117}, {117, 118}, + {227, 34}, {34, 234}, {234, 227}, {66, 107}, {107, 69}, + {69, 66}, {10, 109}, {109, 151}, {151, 10}, {219, 48}, + {48, 235}, {235, 219}, {183, 62}, {62, 191}, {191, 183}, + {142, 129}, {129, 126}, {126, 142}, {116, 111}, {111, 143}, + {143, 116}, {118, 117}, {117, 50}, {50, 118}, {223, 222}, + {222, 52}, {52, 223}, {94, 19}, {19, 141}, {141, 94}, + {222, 221}, {221, 65}, {65, 222}, {196, 3}, {3, 197}, + {197, 196}, {45, 220}, {220, 44}, {44, 45}, {156, 70}, + {70, 139}, {139, 156}, {188, 122}, {122, 245}, {245, 188}, + {139, 71}, {71, 162}, {162, 139}, {149, 170}, {170, 150}, + {150, 149}, {122, 188}, {188, 196}, {196, 122}, {206, 216}, + {216, 92}, {92, 206}, {164, 2}, {2, 167}, {167, 164}, + {242, 141}, {141, 241}, {241, 242}, {0, 164}, {164, 37}, + {37, 0}, {11, 72}, {72, 12}, {12, 11}, {12, 38}, + {38, 13}, {13, 12}, {70, 63}, {63, 71}, {71, 70}, + {31, 226}, {226, 111}, {111, 31}, {36, 101}, {101, 205}, + {205, 36}, {203, 206}, {206, 165}, {165, 203}, {126, 209}, + {209, 217}, {217, 126}, {98, 165}, {165, 97}, {97, 98}, + {237, 220}, {220, 218}, {218, 237}, {237, 239}, {239, 241}, + {241, 237}, {210, 214}, {214, 169}, {169, 210}, {140, 171}, + {171, 32}, {32, 140}, {241, 125}, {125, 237}, {237, 241}, + {179, 86}, {86, 178}, {178, 179}, {180, 85}, {85, 179}, + {179, 180}, {181, 84}, {84, 180}, {180, 181}, {182, 83}, + {83, 181}, {181, 182}, {194, 201}, {201, 182}, {182, 194}, + {177, 137}, {137, 132}, {132, 177}, {184, 76}, {76, 183}, + {183, 184}, {185, 61}, {61, 184}, {184, 185}, {186, 57}, + {57, 185}, {185, 186}, {216, 212}, {212, 186}, {186, 216}, + {192, 214}, {214, 187}, {187, 192}, {139, 34}, {34, 156}, + {156, 139}, {218, 79}, {79, 237}, {237, 218}, {147, 123}, + {123, 177}, {177, 147}, {45, 44}, {44, 4}, {4, 45}, + {208, 201}, {201, 32}, {32, 208}, {98, 64}, {64, 129}, + {129, 98}, {192, 213}, {213, 138}, {138, 192}, {235, 59}, + {59, 219}, {219, 235}, {141, 242}, {242, 97}, {97, 141}, + {97, 2}, {2, 141}, {141, 97}, {240, 75}, {75, 235}, + {235, 240}, {229, 24}, {24, 228}, {228, 229}, {31, 25}, + {25, 226}, {226, 31}, {230, 23}, {23, 229}, {229, 230}, + {231, 22}, {22, 230}, {230, 231}, {232, 26}, {26, 231}, + {231, 232}, {233, 112}, {112, 232}, {232, 233}, {244, 189}, + {189, 243}, {243, 244}, {189, 221}, {221, 190}, {190, 189}, + {222, 28}, {28, 221}, {221, 222}, {223, 27}, {27, 222}, + {222, 223}, {224, 29}, {29, 223}, {223, 224}, {225, 30}, + {30, 224}, {224, 225}, {113, 247}, {247, 225}, {225, 113}, + {99, 60}, {60, 240}, {240, 99}, {213, 147}, {147, 215}, + {215, 213}, {60, 20}, {20, 166}, {166, 60}, {192, 187}, + {187, 213}, {213, 192}, {243, 112}, {112, 244}, {244, 243}, + {244, 233}, {233, 245}, {245, 244}, {245, 128}, {128, 188}, + {188, 245}, {188, 114}, {114, 174}, {174, 188}, {134, 131}, + {131, 220}, {220, 134}, {174, 217}, {217, 236}, {236, 174}, + {236, 198}, {198, 134}, {134, 236}, {215, 177}, {177, 58}, + {58, 215}, {156, 143}, {143, 124}, {124, 156}, {25, 110}, + {110, 7}, {7, 25}, {31, 228}, {228, 25}, {25, 31}, + {264, 356}, {356, 368}, {368, 264}, {0, 11}, {11, 267}, + {267, 0}, {451, 452}, {452, 349}, {349, 451}, {267, 302}, + {302, 269}, {269, 267}, {350, 357}, {357, 277}, {277, 350}, + {350, 452}, {452, 357}, {357, 350}, {299, 333}, {333, 297}, + {297, 299}, {396, 175}, {175, 377}, {377, 396}, {280, 347}, + {347, 330}, {330, 280}, {269, 303}, {303, 270}, {270, 269}, + {151, 9}, {9, 337}, {337, 151}, {344, 278}, {278, 360}, + {360, 344}, {424, 418}, {418, 431}, {431, 424}, {270, 304}, + {304, 409}, {409, 270}, {272, 310}, {310, 407}, {407, 272}, + {322, 270}, {270, 410}, {410, 322}, {449, 450}, {450, 347}, + {347, 449}, {432, 422}, {422, 434}, {434, 432}, {18, 313}, + {313, 17}, {17, 18}, {291, 306}, {306, 375}, {375, 291}, + {259, 387}, {387, 260}, {260, 259}, {424, 335}, {335, 418}, + {418, 424}, {434, 364}, {364, 416}, {416, 434}, {391, 423}, + {423, 327}, {327, 391}, {301, 251}, {251, 298}, {298, 301}, + {275, 281}, {281, 4}, {4, 275}, {254, 373}, {373, 253}, + {253, 254}, {375, 307}, {307, 321}, {321, 375}, {280, 425}, + {425, 411}, {411, 280}, {200, 421}, {421, 18}, {18, 200}, + {335, 321}, {321, 406}, {406, 335}, {321, 320}, {320, 405}, + {405, 321}, {314, 315}, {315, 17}, {17, 314}, {423, 426}, + {426, 266}, {266, 423}, {396, 377}, {377, 369}, {369, 396}, + {270, 322}, {322, 269}, {269, 270}, {413, 417}, {417, 464}, + {464, 413}, {385, 386}, {386, 258}, {258, 385}, {248, 456}, + {456, 419}, {419, 248}, {298, 284}, {284, 333}, {333, 298}, + {168, 417}, {417, 8}, {8, 168}, {448, 346}, {346, 261}, + {261, 448}, {417, 413}, {413, 285}, {285, 417}, {326, 327}, + {327, 328}, {328, 326}, {277, 355}, {355, 329}, {329, 277}, + {309, 392}, {392, 438}, {438, 309}, {381, 382}, {382, 256}, + {256, 381}, {279, 429}, {429, 360}, {360, 279}, {365, 364}, + {364, 379}, {379, 365}, {355, 277}, {277, 437}, {437, 355}, + {282, 443}, {443, 283}, {283, 282}, {281, 275}, {275, 363}, + {363, 281}, {395, 431}, {431, 369}, {369, 395}, {299, 297}, + {297, 337}, {337, 299}, {335, 273}, {273, 321}, {321, 335}, + {348, 450}, {450, 349}, {349, 348}, {359, 446}, {446, 467}, + {467, 359}, {283, 293}, {293, 282}, {282, 283}, {250, 458}, + {458, 462}, {462, 250}, {300, 276}, {276, 383}, {383, 300}, + {292, 308}, {308, 325}, {325, 292}, {283, 276}, {276, 293}, + {293, 283}, {264, 372}, {372, 447}, {447, 264}, {346, 352}, + {352, 340}, {340, 346}, {354, 274}, {274, 19}, {19, 354}, + {363, 456}, {456, 281}, {281, 363}, {426, 436}, {436, 425}, + {425, 426}, {380, 381}, {381, 252}, {252, 380}, {267, 269}, + {269, 393}, {393, 267}, {421, 200}, {200, 428}, {428, 421}, + {371, 266}, {266, 329}, {329, 371}, {432, 287}, {287, 422}, + {422, 432}, {290, 250}, {250, 328}, {328, 290}, {385, 258}, + {258, 384}, {384, 385}, {446, 265}, {265, 342}, {342, 446}, + {386, 387}, {387, 257}, {257, 386}, {422, 424}, {424, 430}, + {430, 422}, {445, 342}, {342, 276}, {276, 445}, {422, 273}, + {273, 424}, {424, 422}, {306, 292}, {292, 307}, {307, 306}, + {352, 366}, {366, 345}, {345, 352}, {268, 271}, {271, 302}, + {302, 268}, {358, 423}, {423, 371}, {371, 358}, {327, 294}, + {294, 460}, {460, 327}, {331, 279}, {279, 294}, {294, 331}, + {303, 271}, {271, 304}, {304, 303}, {436, 432}, {432, 427}, + {427, 436}, {304, 272}, {272, 408}, {408, 304}, {395, 394}, + {394, 431}, {431, 395}, {378, 395}, {395, 400}, {400, 378}, + {296, 334}, {334, 299}, {299, 296}, {6, 351}, {351, 168}, + {168, 6}, {376, 352}, {352, 411}, {411, 376}, {307, 325}, + {325, 320}, {320, 307}, {285, 295}, {295, 336}, {336, 285}, + {320, 319}, {319, 404}, {404, 320}, {329, 330}, {330, 349}, + {349, 329}, {334, 293}, {293, 333}, {333, 334}, {366, 323}, + {323, 447}, {447, 366}, {316, 15}, {15, 315}, {315, 316}, + {331, 358}, {358, 279}, {279, 331}, {317, 14}, {14, 316}, + {316, 317}, {8, 285}, {285, 9}, {9, 8}, {277, 329}, + {329, 350}, {350, 277}, {253, 374}, {374, 252}, {252, 253}, + {319, 318}, {318, 403}, {403, 319}, {351, 6}, {6, 419}, + {419, 351}, {324, 318}, {318, 325}, {325, 324}, {397, 367}, + {367, 365}, {365, 397}, {288, 435}, {435, 397}, {397, 288}, + {278, 344}, {344, 439}, {439, 278}, {310, 272}, {272, 311}, + {311, 310}, {248, 195}, {195, 281}, {281, 248}, {375, 273}, + {273, 291}, {291, 375}, {175, 396}, {396, 199}, {199, 175}, + {312, 311}, {311, 268}, {268, 312}, {276, 283}, {283, 445}, + {445, 276}, {390, 373}, {373, 339}, {339, 390}, {295, 282}, + {282, 296}, {296, 295}, {448, 449}, {449, 346}, {346, 448}, + {356, 264}, {264, 454}, {454, 356}, {337, 336}, {336, 299}, + {299, 337}, {337, 338}, {338, 151}, {151, 337}, {294, 278}, + {278, 455}, {455, 294}, {308, 292}, {292, 415}, {415, 308}, + {429, 358}, {358, 355}, {355, 429}, {265, 340}, {340, 372}, + {372, 265}, {352, 346}, {346, 280}, {280, 352}, {295, 442}, + {442, 282}, {282, 295}, {354, 19}, {19, 370}, {370, 354}, + {285, 441}, {441, 295}, {295, 285}, {195, 248}, {248, 197}, + {197, 195}, {457, 440}, {440, 274}, {274, 457}, {301, 300}, + {300, 368}, {368, 301}, {417, 351}, {351, 465}, {465, 417}, + {251, 301}, {301, 389}, {389, 251}, {394, 395}, {395, 379}, + {379, 394}, {399, 412}, {412, 419}, {419, 399}, {410, 436}, + {436, 322}, {322, 410}, {326, 2}, {2, 393}, {393, 326}, + {354, 370}, {370, 461}, {461, 354}, {393, 164}, {164, 267}, + {267, 393}, {268, 302}, {302, 12}, {12, 268}, {312, 268}, + {268, 13}, {13, 312}, {298, 293}, {293, 301}, {301, 298}, + {265, 446}, {446, 340}, {340, 265}, {280, 330}, {330, 425}, + {425, 280}, {322, 426}, {426, 391}, {391, 322}, {420, 429}, + {429, 437}, {437, 420}, {393, 391}, {391, 326}, {326, 393}, + {344, 440}, {440, 438}, {438, 344}, {458, 459}, {459, 461}, + {461, 458}, {364, 434}, {434, 394}, {394, 364}, {428, 396}, + {396, 262}, {262, 428}, {274, 354}, {354, 457}, {457, 274}, + {317, 316}, {316, 402}, {402, 317}, {316, 315}, {315, 403}, + {403, 316}, {315, 314}, {314, 404}, {404, 315}, {314, 313}, + {313, 405}, {405, 314}, {313, 421}, {421, 406}, {406, 313}, + {323, 366}, {366, 361}, {361, 323}, {292, 306}, {306, 407}, + {407, 292}, {306, 291}, {291, 408}, {408, 306}, {291, 287}, + {287, 409}, {409, 291}, {287, 432}, {432, 410}, {410, 287}, + {427, 434}, {434, 411}, {411, 427}, {372, 264}, {264, 383}, + {383, 372}, {459, 309}, {309, 457}, {457, 459}, {366, 352}, + {352, 401}, {401, 366}, {1, 274}, {274, 4}, {4, 1}, + {418, 421}, {421, 262}, {262, 418}, {331, 294}, {294, 358}, + {358, 331}, {435, 433}, {433, 367}, {367, 435}, {392, 289}, + {289, 439}, {439, 392}, {328, 462}, {462, 326}, {326, 328}, + {94, 2}, {2, 370}, {370, 94}, {289, 305}, {305, 455}, + {455, 289}, {339, 254}, {254, 448}, {448, 339}, {359, 255}, + {255, 446}, {446, 359}, {254, 253}, {253, 449}, {449, 254}, + {253, 252}, {252, 450}, {450, 253}, {252, 256}, {256, 451}, + {451, 252}, {256, 341}, {341, 452}, {452, 256}, {414, 413}, + {413, 463}, {463, 414}, {286, 441}, {441, 414}, {414, 286}, + {286, 258}, {258, 441}, {441, 286}, {258, 257}, {257, 442}, + {442, 258}, {257, 259}, {259, 443}, {443, 257}, {259, 260}, + {260, 444}, {444, 259}, {260, 467}, {467, 445}, {445, 260}, + {309, 459}, {459, 250}, {250, 309}, {305, 289}, {289, 290}, + {290, 305}, {305, 290}, {290, 460}, {460, 305}, {401, 376}, + {376, 435}, {435, 401}, {309, 250}, {250, 392}, {392, 309}, + {376, 411}, {411, 433}, {433, 376}, {453, 341}, {341, 464}, + {464, 453}, {357, 453}, {453, 465}, {465, 357}, {343, 357}, + {357, 412}, {412, 343}, {437, 343}, {343, 399}, {399, 437}, + {344, 360}, {360, 440}, {440, 344}, {420, 437}, {437, 456}, + {456, 420}, {360, 420}, {420, 363}, {363, 360}, {361, 401}, + {401, 288}, {288, 361}, {265, 372}, {372, 353}, {353, 265}, + {390, 339}, {339, 249}, {249, 390}, {339, 448}, {448, 255}, + {255, 339}}}; +}; + +} // namespace face_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_FACE_LANDMARKER_FACE_LANDMARKS_CONNECTIONS_H_ diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_detector_graph.cc index 7ce0fcaa2..20d241b97 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_detector_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,10 +19,13 @@ limitations under the License. #include #include +#include "mediapipe/calculators/core/get_vector_item_calculator.h" +#include "mediapipe/calculators/core/get_vector_item_calculator.pb.h" #include "mediapipe/calculators/core/split_vector_calculator.pb.h" #include "mediapipe/calculators/tensor/tensors_to_floats_calculator.pb.h" #include "mediapipe/calculators/tensor/tensors_to_landmarks_calculator.pb.h" #include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h" +#include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h" #include "mediapipe/calculators/util/rect_transformation_calculator.pb.h" #include "mediapipe/calculators/util/thresholding_calculator.pb.h" #include "mediapipe/framework/api2/builder.h" @@ -79,6 +82,9 @@ constexpr char kBatchEndTag[] = "BATCH_END"; constexpr char kItemTag[] = "ITEM"; constexpr char kDetectionTag[] = "DETECTION"; constexpr char kBlendshapesTag[] = "BLENDSHAPES"; +constexpr char kNormFilteredLandmarksTag[] = "NORM_FILTERED_LANDMARKS"; +constexpr char kSizeTag[] = "SIZE"; +constexpr char kVectorTag[] = "VECTOR"; // a landmarks tensor and a scores tensor constexpr int kFaceLandmarksOutputTensorsNum = 2; @@ -88,7 +94,6 @@ struct SingleFaceLandmarksOutputs { Stream rect_next_frame; Stream presence; Stream presence_score; - std::optional> face_blendshapes; }; struct MultiFaceLandmarksOutputs { @@ -148,6 +153,19 @@ void ConfigureFaceRectTransformationCalculator( options->set_square_long(true); } +void ConfigureLandmarksSmoothingCalculator( + mediapipe::LandmarksSmoothingCalculatorOptions& options) { + // Min cutoff 0.05 results into ~0.01 alpha in landmark EMA filter when + // landmark is static. + options.mutable_one_euro_filter()->set_min_cutoff(0.05f); + // Beta 80.0 in combintation with min_cutoff 0.05 results into ~0.94 + // alpha in landmark EMA filter when landmark is moving fast. + options.mutable_one_euro_filter()->set_beta(80.0f); + // Derivative cutoff 1.0 results into ~0.17 alpha in landmark velocity + // EMA filter. + options.mutable_one_euro_filter()->set_derivate_cutoff(1.0f); +} + } // namespace // A "mediapipe.tasks.vision.face_landmarker.SingleFaceLandmarksDetectorGraph" @@ -171,62 +189,6 @@ void ConfigureFaceRectTransformationCalculator( // Boolean value indicates whether the face is present. // PRESENCE_SCORE - float // Float value indicates the probability that the face is present. -// BLENDSHAPES - ClassificationList @optional -// Blendshape classification, available when face_blendshapes_graph_options -// is set. -// All 52 blendshape coefficients: -// 0 - _neutral (ignore it) -// 1 - browDownLeft -// 2 - browDownRight -// 3 - browInnerUp -// 4 - browOuterUpLeft -// 5 - browOuterUpRight -// 6 - cheekPuff -// 7 - cheekSquintLeft -// 8 - cheekSquintRight -// 9 - eyeBlinkLeft -// 10 - eyeBlinkRight -// 11 - eyeLookDownLeft -// 12 - eyeLookDownRight -// 13 - eyeLookInLeft -// 14 - eyeLookInRight -// 15 - eyeLookOutLeft -// 16 - eyeLookOutRight -// 17 - eyeLookUpLeft -// 18 - eyeLookUpRight -// 19 - eyeSquintLeft -// 20 - eyeSquintRight -// 21 - eyeWideLeft -// 22 - eyeWideRight -// 23 - jawForward -// 24 - jawLeft -// 25 - jawOpen -// 26 - jawRight -// 27 - mouthClose -// 28 - mouthDimpleLeft -// 29 - mouthDimpleRight -// 30 - mouthFrownLeft -// 31 - mouthFrownRight -// 32 - mouthFunnel -// 33 - mouthLeft -// 34 - mouthLowerDownLeft -// 35 - mouthLowerDownRight -// 36 - mouthPressLeft -// 37 - mouthPressRight -// 38 - mouthPucker -// 39 - mouthRight -// 40 - mouthRollLower -// 41 - mouthRollUpper -// 42 - mouthShrugLower -// 43 - mouthShrugUpper -// 44 - mouthSmileLeft -// 45 - mouthSmileRight -// 46 - mouthStretchLeft -// 47 - mouthStretchRight -// 48 - mouthUpperUpLeft -// 49 - mouthUpperUpRight -// 50 - noseSneerLeft -// 51 - noseSneerRight // // Example: // node { @@ -238,7 +200,6 @@ void ConfigureFaceRectTransformationCalculator( // output_stream: "FACE_RECT_NEXT_FRAME:face_rect_next_frame" // output_stream: "PRESENCE:presence" // output_stream: "PRESENCE_SCORE:presence_score" -// output_stream: "BLENDSHAPES:blendshapes" // options { // [mediapipe.tasks.vision.face_landmarker.proto.FaceLandmarksDetectorGraphOptions.ext] // { @@ -278,10 +239,6 @@ class SingleFaceLandmarksDetectorGraph : public core::ModelTaskGraph { graph.Out(kFaceRectNextFrameTag).Cast(); outs.presence >> graph.Out(kPresenceTag).Cast(); outs.presence_score >> graph.Out(kPresenceScoreTag).Cast(); - if (outs.face_blendshapes) { - outs.face_blendshapes.value() >> - graph.Out(kBlendshapesTag).Cast(); - } return graph.GetConfig(); } @@ -378,7 +335,7 @@ class SingleFaceLandmarksDetectorGraph : public core::ModelTaskGraph { auto& landmark_projection = graph.AddNode("LandmarkProjectionCalculator"); landmarks_letterbox_removed >> landmark_projection.In(kNormLandmarksTag); face_rect >> landmark_projection.In(kNormRectTag); - auto projected_landmarks = AllowIf( + Stream projected_landmarks = AllowIf( landmark_projection[Output(kNormLandmarksTag)], presence, graph); @@ -409,25 +366,11 @@ class SingleFaceLandmarksDetectorGraph : public core::ModelTaskGraph { AllowIf(face_rect_transformation.Out("").Cast(), presence, graph); - std::optional> face_blendshapes; - if (subgraph_options.has_face_blendshapes_graph_options()) { - auto& face_blendshapes_graph = graph.AddNode( - "mediapipe.tasks.vision.face_landmarker.FaceBlendshapesGraph"); - face_blendshapes_graph.GetOptions() - .Swap(subgraph_options.mutable_face_blendshapes_graph_options()); - projected_landmarks >> face_blendshapes_graph.In(kLandmarksTag); - image_size >> face_blendshapes_graph.In(kImageSizeTag); - face_blendshapes = - std::make_optional(face_blendshapes_graph.Out(kBlendshapesTag) - .Cast()); - } - return {{ /* landmarks= */ projected_landmarks, /* rect_next_frame= */ face_rect_next_frame, /* presence= */ presence, /* presence_score= */ presence_score, - /* face_blendshapes= */ face_blendshapes, }}; } }; @@ -465,6 +408,59 @@ REGISTER_MEDIAPIPE_GRAPH( // BLENDSHAPES - std::vector @optional // Vector of face blendshape classification, available when // face_blendshapes_graph_options is set. +// All 52 blendshape coefficients: +// 0 - _neutral (ignore it) +// 1 - browDownLeft +// 2 - browDownRight +// 3 - browInnerUp +// 4 - browOuterUpLeft +// 5 - browOuterUpRight +// 6 - cheekPuff +// 7 - cheekSquintLeft +// 8 - cheekSquintRight +// 9 - eyeBlinkLeft +// 10 - eyeBlinkRight +// 11 - eyeLookDownLeft +// 12 - eyeLookDownRight +// 13 - eyeLookInLeft +// 14 - eyeLookInRight +// 15 - eyeLookOutLeft +// 16 - eyeLookOutRight +// 17 - eyeLookUpLeft +// 18 - eyeLookUpRight +// 19 - eyeSquintLeft +// 20 - eyeSquintRight +// 21 - eyeWideLeft +// 22 - eyeWideRight +// 23 - jawForward +// 24 - jawLeft +// 25 - jawOpen +// 26 - jawRight +// 27 - mouthClose +// 28 - mouthDimpleLeft +// 29 - mouthDimpleRight +// 30 - mouthFrownLeft +// 31 - mouthFrownRight +// 32 - mouthFunnel +// 33 - mouthLeft +// 34 - mouthLowerDownLeft +// 35 - mouthLowerDownRight +// 36 - mouthPressLeft +// 37 - mouthPressRight +// 38 - mouthPucker +// 39 - mouthRight +// 40 - mouthRollLower +// 41 - mouthRollUpper +// 42 - mouthShrugLower +// 43 - mouthShrugUpper +// 44 - mouthSmileLeft +// 45 - mouthSmileRight +// 46 - mouthStretchLeft +// 47 - mouthStretchRight +// 48 - mouthUpperUpLeft +// 49 - mouthUpperUpRight +// 50 - noseSneerLeft +// 51 - noseSneerRight // // Example: // node { @@ -566,8 +562,9 @@ class MultiFaceLandmarksDetectorGraph : public core::ModelTaskGraph { graph.AddNode("EndLoopNormalizedLandmarkListVectorCalculator"); batch_end >> end_loop_landmarks.In(kBatchEndTag); landmarks >> end_loop_landmarks.In(kItemTag); - auto landmark_lists = end_loop_landmarks.Out(kIterableTag) - .Cast>(); + Stream> landmark_lists = + end_loop_landmarks.Out(kIterableTag) + .Cast>(); auto& end_loop_rects_next_frame = graph.AddNode("EndLoopNormalizedRectCalculator"); @@ -576,16 +573,78 @@ class MultiFaceLandmarksDetectorGraph : public core::ModelTaskGraph { auto face_rects_next_frame = end_loop_rects_next_frame.Out(kIterableTag) .Cast>(); + // Apply smoothing filter only on the single face landmarks, because + // landmarks smoothing calculator doesn't support multiple landmarks yet. + // Notice the landmarks smoothing calculator cannot be put inside the for + // loop calculator, because the smoothing calculator utilize the timestamp + // to smoote landmarks across frames but the for loop calculator makes fake + // timestamps for the streams. + if (face_landmark_subgraph + .GetOptions() + .smooth_landmarks()) { + // Get the single face landmarks + auto& get_vector_item = + graph.AddNode("GetNormalizedLandmarkListVectorItemCalculator"); + get_vector_item.GetOptions() + .set_item_index(0); + landmark_lists >> get_vector_item.In(kVectorTag); + Stream single_landmarks = + get_vector_item.Out(kItemTag).Cast(); + + auto& image_properties = graph.AddNode("ImagePropertiesCalculator"); + image_in >> image_properties.In(kImageTag); + auto image_size = image_properties.Out(kSizeTag); + + // Apply smoothing filter on face landmarks. + auto& landmarks_smoothing = graph.AddNode("LandmarksSmoothingCalculator"); + ConfigureLandmarksSmoothingCalculator( + landmarks_smoothing + .GetOptions()); + single_landmarks >> landmarks_smoothing.In(kNormLandmarksTag); + image_size >> landmarks_smoothing.In(kImageSizeTag); + single_landmarks = landmarks_smoothing.Out(kNormFilteredLandmarksTag) + .Cast(); + + // Wrap the single face landmarks into a vector of landmarks. + auto& concatenate_vector = + graph.AddNode("ConcatenateNormalizedLandmarkListVectorCalculator"); + single_landmarks >> concatenate_vector.In(""); + landmark_lists = concatenate_vector.Out("") + .Cast>(); + } + std::optional>> face_blendshapes_vector; if (face_landmark_subgraph .GetOptions() .has_face_blendshapes_graph_options()) { - auto blendshapes = face_landmark_subgraph.Out(kBlendshapesTag); + auto& begin_loop_multi_face_landmarks = + graph.AddNode("BeginLoopNormalizedLandmarkListVectorCalculator"); + landmark_lists >> begin_loop_multi_face_landmarks.In(kIterableTag); + image_in >> begin_loop_multi_face_landmarks.In(kCloneTag); + auto image = begin_loop_multi_face_landmarks.Out(kCloneTag); + auto batch_end = begin_loop_multi_face_landmarks.Out(kBatchEndTag); + auto landmarks = begin_loop_multi_face_landmarks.Out(kItemTag); + + auto& image_properties = graph.AddNode("ImagePropertiesCalculator"); + image >> image_properties.In(kImageTag); + auto image_size = image_properties.Out(kSizeTag); + + auto& face_blendshapes_graph = graph.AddNode( + "mediapipe.tasks.vision.face_landmarker.FaceBlendshapesGraph"); + face_blendshapes_graph.GetOptions() + .Swap(face_landmark_subgraph + .GetOptions() + .mutable_face_blendshapes_graph_options()); + landmarks >> face_blendshapes_graph.In(kLandmarksTag); + image_size >> face_blendshapes_graph.In(kImageSizeTag); + auto face_blendshapes = face_blendshapes_graph.Out(kBlendshapesTag) + .Cast(); + auto& end_loop_blendshapes = graph.AddNode("EndLoopClassificationListCalculator"); batch_end >> end_loop_blendshapes.In(kBatchEndTag); - blendshapes >> end_loop_blendshapes.In(kItemTag); + face_blendshapes >> end_loop_blendshapes.In(kItemTag); face_blendshapes_vector = std::make_optional(end_loop_blendshapes.Out(kIterableTag) .Cast>()); diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_detector_graph_test.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_detector_graph_test.cc index a415125d9..affa7bedd 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_detector_graph_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -99,8 +99,7 @@ constexpr float kBlendshapesDiffMargin = 0.1; // Helper function to create a Single Face Landmark TaskRunner. absl::StatusOr> CreateSingleFaceLandmarksTaskRunner( - absl::string_view landmarks_model_name, - std::optional blendshapes_model_name) { + absl::string_view landmarks_model_name) { Graph graph; auto& face_landmark_detection = graph.AddNode( @@ -112,14 +111,6 @@ absl::StatusOr> CreateSingleFaceLandmarksTaskRunner( JoinPath("./", kTestDataDirectory, landmarks_model_name)); options->set_min_detection_confidence(0.5); - if (blendshapes_model_name.has_value()) { - options->mutable_face_blendshapes_graph_options() - ->mutable_base_options() - ->mutable_model_asset() - ->set_file_name( - JoinPath("./", kTestDataDirectory, *blendshapes_model_name)); - } - face_landmark_detection.GetOptions() .Swap(options.get()); @@ -137,11 +128,6 @@ absl::StatusOr> CreateSingleFaceLandmarksTaskRunner( face_landmark_detection.Out(kFaceRectNextFrameTag) .SetName(kFaceRectNextFrameName) >> graph[Output(kFaceRectNextFrameTag)]; - if (blendshapes_model_name.has_value()) { - face_landmark_detection.Out(kBlendshapesTag).SetName(kBlendshapesName) >> - graph[Output(kBlendshapesTag)]; - } - return TaskRunner::Create( graph.GetConfig(), absl::make_unique()); } @@ -227,8 +213,6 @@ struct SingeFaceTestParams { std::string test_name; // The filename of landmarks model name. std::string landmarks_model_name; - // The filename of blendshape model name. - std::optional blendshape_model_name; // The filename of the test image. std::string test_image_name; // RoI on image to detect faces. @@ -237,13 +221,8 @@ struct SingeFaceTestParams { bool expected_presence; // The expected output landmarks positions. NormalizedLandmarkList expected_landmarks; - // The expected output blendshape classification; - std::optional expected_blendshapes; // The max value difference between expected_positions and detected positions. float landmarks_diff_threshold; - // The max value difference between expected blendshapes and actual - // blendshapes. - float blendshapes_diff_threshold; }; struct MultiFaceTestParams { @@ -279,8 +258,7 @@ TEST_P(SingleFaceLandmarksDetectionTest, Succeeds) { GetParam().test_image_name))); MP_ASSERT_OK_AND_ASSIGN( auto task_runner, - CreateSingleFaceLandmarksTaskRunner(GetParam().landmarks_model_name, - GetParam().blendshape_model_name)); + CreateSingleFaceLandmarksTaskRunner(GetParam().landmarks_model_name)); auto output_packets = task_runner->Process( {{kImageName, MakePacket(std::move(image))}, @@ -301,15 +279,6 @@ TEST_P(SingleFaceLandmarksDetectionTest, Succeeds) { Approximately(Partially(EqualsProto(expected_landmarks)), /*margin=*/kAbsMargin, /*fraction=*/GetParam().landmarks_diff_threshold)); - if (GetParam().expected_blendshapes) { - const ClassificationList& actual_blendshapes = - (*output_packets)[kBlendshapesName].Get(); - const ClassificationList& expected_blendshapes = - *GetParam().expected_blendshapes; - EXPECT_THAT(actual_blendshapes, - Approximately(EqualsProto(expected_blendshapes), - GetParam().blendshapes_diff_threshold)); - } } } @@ -360,34 +329,15 @@ TEST_P(MultiFaceLandmarksDetectionTest, Succeeds) { INSTANTIATE_TEST_SUITE_P( FaceLandmarksDetectionTest, SingleFaceLandmarksDetectionTest, Values(SingeFaceTestParams{ - /* test_name= */ "PortraitV2", - /* landmarks_model_name= */ - kFaceLandmarksV2Model, - /* blendshape_model_name= */ std::nullopt, - /* test_image_name= */ kPortraitImageName, - /* norm_rect= */ MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0), - /* expected_presence= */ true, - /* expected_landmarks= */ - GetExpectedLandmarkList(kPortraitExpectedFaceLandmarksName), - /* expected_blendshapes= */ std::nullopt, - /* landmarks_diff_threshold= */ kFractionDiff, - /* blendshapes_diff_threshold= */ kBlendshapesDiffMargin}, - SingeFaceTestParams{ - /* test_name= */ "PortraitV2WithBlendshapes", - /* landmarks_model_name= */ - kFaceLandmarksV2Model, - /* blendshape_model_name= */ kFaceBlendshapesModel, - /* test_image_name= */ kPortraitImageName, - /* norm_rect= */ - MakeNormRect(0.48906386, 0.22731927, 0.42905223, 0.34357703, - 0.008304443), - /* expected_presence= */ true, - /* expected_landmarks= */ - GetExpectedLandmarkList(kPortraitExpectedFaceLandmarksName), - /* expected_blendshapes= */ - GetBlendshapes(kPortraitExpectedBlendshapesName), - /* landmarks_diff_threshold= */ kFractionDiff, - /* blendshapes_diff_threshold= */ kBlendshapesDiffMargin}), + /* test_name= */ "PortraitV2", + /* landmarks_model_name= */ + kFaceLandmarksV2Model, + /* test_image_name= */ kPortraitImageName, + /* norm_rect= */ MakeNormRect(0.4987, 0.2211, 0.2877, 0.2303, 0), + /* expected_presence= */ true, + /* expected_landmarks= */ + GetExpectedLandmarkList(kPortraitExpectedFaceLandmarksName), + /* landmarks_diff_threshold= */ kFractionDiff}), [](const TestParamInfo& info) { return info.param.test_name; diff --git a/mediapipe/tasks/cc/vision/face_landmarker/proto/BUILD b/mediapipe/tasks/cc/vision/face_landmarker/proto/BUILD index d3e236619..aa839d912 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/proto/BUILD +++ b/mediapipe/tasks/cc/vision/face_landmarker/proto/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") package(default_visibility = [ "//mediapipe/tasks:internal", + "//mediapipe/tasks:users", ]) licenses(["notice"]) diff --git a/mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.proto b/mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.proto index 36e712ad8..c535f0e2e 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.proto +++ b/mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.proto b/mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.proto index dc8654608..219437166 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.proto +++ b/mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.proto b/mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.proto index c2fa49607..32e3636df 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.proto +++ b/mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -37,6 +37,13 @@ message FaceLandmarksDetectorGraphOptions { // successfully detecting a face in the image. optional float min_detection_confidence = 2 [default = 0.5]; + // Whether to smooth the detected landmarks over timestamps. Note that + // landmarks smoothing is only applicable for a single face. If multiple faces + // landmarks are given, and smooth_landmarks is true, only the first face + // landmarks would be smoothed, and the remaining landmarks are discarded in + // the returned landmarks list. + optional bool smooth_landmarks = 4; + // Optional options for FaceBlendshapeGraph. If this options is set, the // FaceLandmarksDetectorGraph would output the face blendshapes. optional FaceBlendshapesGraphOptions face_blendshapes_graph_options = 3; diff --git a/mediapipe/tasks/cc/vision/face_landmarker/proto/tensors_to_face_landmarks_graph_options.proto b/mediapipe/tasks/cc/vision/face_landmarker/proto/tensors_to_face_landmarks_graph_options.proto index 22414b361..d9772ea1f 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/proto/tensors_to_face_landmarks_graph_options.proto +++ b/mediapipe/tasks/cc/vision/face_landmarker/proto/tensors_to_face_landmarks_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/face_landmarker/tensors_to_face_landmarks_graph.cc b/mediapipe/tasks/cc/vision/face_landmarker/tensors_to_face_landmarks_graph.cc index 4073c9e18..a765f4424 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/tensors_to_face_landmarks_graph.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/tensors_to_face_landmarks_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/face_stylizer/BUILD b/mediapipe/tasks/cc/vision/face_stylizer/BUILD index 27b2f482d..72182f0c5 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/BUILD +++ b/mediapipe/tasks/cc/vision/face_stylizer/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -50,6 +50,7 @@ cc_library( "//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator", "//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator_cc_proto", "//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto", + "//mediapipe/util:graph_builder_utils", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", ], @@ -60,6 +61,7 @@ cc_library( name = "face_stylizer", srcs = ["face_stylizer.cc"], hdrs = ["face_stylizer.h"], + visibility = ["//visibility:public"], deps = [ ":face_stylizer_graph", # buildcleaner:keep "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/vision/face_stylizer/calculators/BUILD b/mediapipe/tasks/cc/vision/face_stylizer/calculators/BUILD index 4e070b43e..46f8944ac 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/calculators/BUILD +++ b/mediapipe/tasks/cc/vision/face_stylizer/calculators/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -65,6 +65,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:vector", "//mediapipe/gpu:gpu_origin_cc_proto", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ] + select({ diff --git a/mediapipe/tasks/cc/vision/face_stylizer/calculators/strip_rotation_calculator.cc b/mediapipe/tasks/cc/vision/face_stylizer/calculators/strip_rotation_calculator.cc index c290f2725..11ae112c6 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/calculators/strip_rotation_calculator.cc +++ b/mediapipe/tasks/cc/vision/face_stylizer/calculators/strip_rotation_calculator.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.cc b/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.cc index d9825b15f..651b7efc3 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.cc +++ b/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.cc @@ -16,6 +16,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "mediapipe/calculators/tensor/image_to_tensor_utils.h" @@ -111,6 +112,7 @@ class TensorsToImageCalculator : public Node { private: TensorsToImageCalculatorOptions options_; absl::Status CpuProcess(CalculatorContext* cc); + int tensor_position_; #if !MEDIAPIPE_DISABLE_GPU #if MEDIAPIPE_METAL_ENABLED @@ -161,11 +163,12 @@ absl::Status TensorsToImageCalculator::Open(CalculatorContext* cc) { #endif // MEDIAPIPE_METAL_ENABLED #endif // !MEDIAPIPE_DISABLE_GPU } else { - CHECK(options_.has_input_tensor_float_range() ^ - options_.has_input_tensor_uint_range()) + ABSL_CHECK(options_.has_input_tensor_float_range() ^ + options_.has_input_tensor_uint_range()) << "Must specify either `input_tensor_float_range` or " "`input_tensor_uint_range` in the calculator options"; } + tensor_position_ = options_.tensor_position(); return absl::OkStatus(); } @@ -202,17 +205,23 @@ absl::Status TensorsToImageCalculator::CpuProcess(CalculatorContext* cc) { return absl::OkStatus(); } const auto& input_tensors = kInputTensors(cc).Get(); - RET_CHECK_EQ(input_tensors.size(), 1) - << "Expect 1 input tensor, but have " << input_tensors.size(); + RET_CHECK_GT(input_tensors.size(), tensor_position_) + << "Expect input tensor at position " << tensor_position_ + << ", but have tensors of size " << input_tensors.size(); - const auto& input_tensor = input_tensors[0]; + const auto& input_tensor = input_tensors[tensor_position_]; const int tensor_in_height = input_tensor.shape().dims[1]; const int tensor_in_width = input_tensor.shape().dims[2]; const int tensor_in_channels = input_tensor.shape().dims[3]; - RET_CHECK_EQ(tensor_in_channels, 3); + RET_CHECK(tensor_in_channels == 3 || tensor_in_channels == 1); - auto output_frame = std::make_shared( - mediapipe::ImageFormat::SRGB, tensor_in_width, tensor_in_height); + auto format = mediapipe::ImageFormat::SRGB; + if (tensor_in_channels == 1) { + format = mediapipe::ImageFormat::GRAY8; + } + + auto output_frame = + std::make_shared(format, tensor_in_width, tensor_in_height); cv::Mat output_matview = mediapipe::formats::MatView(output_frame.get()); constexpr float kOutputImageRangeMin = 0.0f; @@ -227,8 +236,9 @@ absl::Status TensorsToImageCalculator::CpuProcess(CalculatorContext* cc) { GetValueRangeTransformation( input_range.min(), input_range.max(), kOutputImageRangeMin, kOutputImageRangeMax)); - tensor_matview.convertTo(output_matview, CV_8UC3, transform.scale, - transform.offset); + tensor_matview.convertTo(output_matview, + CV_MAKETYPE(CV_8U, tensor_in_channels), + transform.scale, transform.offset); } else if (input_tensor.element_type() == Tensor::ElementType::kUInt8) { cv::Mat tensor_matview( cv::Size(tensor_in_width, tensor_in_height), @@ -239,8 +249,9 @@ absl::Status TensorsToImageCalculator::CpuProcess(CalculatorContext* cc) { GetValueRangeTransformation( input_range.min(), input_range.max(), kOutputImageRangeMin, kOutputImageRangeMax)); - tensor_matview.convertTo(output_matview, CV_8UC3, transform.scale, - transform.offset); + tensor_matview.convertTo(output_matview, + CV_MAKETYPE(CV_8U, tensor_in_channels), + transform.scale, transform.offset); } else { return absl::InvalidArgumentError( absl::Substitute("Type of tensor must be kFloat32 or kUInt8, got: $0", @@ -264,10 +275,14 @@ absl::Status TensorsToImageCalculator::MetalProcess(CalculatorContext* cc) { return absl::OkStatus(); } const auto& input_tensors = kInputTensors(cc).Get(); - RET_CHECK_EQ(input_tensors.size(), 1) - << "Expect 1 input tensor, but have " << input_tensors.size(); - const int tensor_width = input_tensors[0].shape().dims[2]; - const int tensor_height = input_tensors[0].shape().dims[1]; + RET_CHECK_GT(input_tensors.size(), tensor_position_) + << "Expect input tensor at position " << tensor_position_ + << ", but have tensors of size " << input_tensors.size(); + const int tensor_width = input_tensors[tensor_position_].shape().dims[2]; + const int tensor_height = input_tensors[tensor_position_].shape().dims[1]; + const int tensor_channels = input_tensors[tensor_position_].shape().dims[3]; + // TODO: Add 1 channel support. + RET_CHECK(tensor_channels == 3); // TODO: Fix unused variable [[maybe_unused]] id device = gpu_helper_.mtlDevice; @@ -277,8 +292,8 @@ absl::Status TensorsToImageCalculator::MetalProcess(CalculatorContext* cc) { [command_buffer computeCommandEncoder]; [compute_encoder setComputePipelineState:to_buffer_program_]; - auto input_view = - mediapipe::MtlBufferView::GetReadView(input_tensors[0], command_buffer); + auto input_view = mediapipe::MtlBufferView::GetReadView( + input_tensors[tensor_position_], command_buffer); [compute_encoder setBuffer:input_view.buffer() offset:0 atIndex:0]; mediapipe::GpuBuffer output = @@ -355,7 +370,7 @@ absl::Status TensorsToImageCalculator::GlSetup(CalculatorContext* cc) { absl::StrCat(tflite::gpu::gl::GetShaderHeader(workgroup_size_), R"( precision highp float; layout(rgba8, binding = 0) writeonly uniform highp image2D output_texture; - uniform ivec2 out_size; + uniform ivec3 out_size; )"); const std::string shader_body = R"( @@ -366,10 +381,11 @@ absl::Status TensorsToImageCalculator::GlSetup(CalculatorContext* cc) { void main() { int out_width = out_size.x; int out_height = out_size.y; + int out_channels = out_size.z; ivec2 gid = ivec2(gl_GlobalInvocationID.xy); if (gid.x >= out_width || gid.y >= out_height) { return; } - int linear_index = 3 * (gid.y * out_width + gid.x); + int linear_index = out_channels * (gid.y * out_width + gid.x); #ifdef FLIP_Y_COORD int y_coord = out_height - gid.y - 1; @@ -377,8 +393,14 @@ absl::Status TensorsToImageCalculator::GlSetup(CalculatorContext* cc) { int y_coord = gid.y; #endif // defined(FLIP_Y_COORD) + vec4 out_value; ivec2 out_coordinate = ivec2(gid.x, y_coord); - vec4 out_value = vec4(input_data.elements[linear_index], input_data.elements[linear_index + 1], input_data.elements[linear_index + 2], 1.0); + if (out_channels == 3) { + out_value = vec4(input_data.elements[linear_index], input_data.elements[linear_index + 1], input_data.elements[linear_index + 2], 1.0); + } else { + float in_value = input_data.elements[linear_index]; + out_value = vec4(in_value, in_value, in_value, 1.0); + } imageStore(output_texture, out_coordinate, out_value); })"; @@ -438,10 +460,15 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) { return absl::OkStatus(); } const auto& input_tensors = kInputTensors(cc).Get(); - RET_CHECK_EQ(input_tensors.size(), 1) - << "Expect 1 input tensor, but have " << input_tensors.size(); - const int tensor_width = input_tensors[0].shape().dims[2]; - const int tensor_height = input_tensors[0].shape().dims[1]; + RET_CHECK_GT(input_tensors.size(), tensor_position_) + << "Expect input tensor at position " << tensor_position_ + << ", but have tensors of size " << input_tensors.size(); + + const auto& input_tensor = input_tensors[tensor_position_]; + const int tensor_width = input_tensor.shape().dims[2]; + const int tensor_height = input_tensor.shape().dims[1]; + const int tensor_in_channels = input_tensor.shape().dims[3]; + RET_CHECK(tensor_in_channels == 3 || tensor_in_channels == 1); #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 @@ -454,7 +481,7 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) { glBindImageTexture(output_index, out_texture->id(), 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_RGBA8); - auto read_view = input_tensors[0].GetOpenGlBufferReadView(); + auto read_view = input_tensor.GetOpenGlBufferReadView(); glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 2, read_view.name()); const tflite::gpu::uint3 workload = {tensor_width, tensor_height, 1}; @@ -462,8 +489,8 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) { tflite::gpu::DivideRoundUp(workload, workgroup_size_); glUseProgram(gl_compute_program_->id()); - glUniform2i(glGetUniformLocation(gl_compute_program_->id(), "out_size"), - tensor_width, tensor_height); + glUniform3i(glGetUniformLocation(gl_compute_program_->id(), "out_size"), + tensor_width, tensor_height, tensor_in_channels); MP_RETURN_IF_ERROR(gl_compute_program_->Dispatch(workgroups)); @@ -481,8 +508,8 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) { #else - if (!input_tensors[0].ready_as_opengl_texture_2d()) { - (void)input_tensors[0].GetCpuReadView(); + if (!input_tensor.ready_as_opengl_texture_2d()) { + (void)input_tensor.GetCpuReadView(); } auto output_texture = @@ -490,7 +517,7 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) { gl_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0 glActiveTexture(GL_TEXTURE1); glBindTexture(GL_TEXTURE_2D, - input_tensors[0].GetOpenGlTexture2dReadView().name()); + input_tensor.GetOpenGlTexture2dReadView().name()); MP_RETURN_IF_ERROR(gl_renderer_->GlRender( tensor_width, tensor_height, output_texture.width(), diff --git a/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.proto b/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.proto index 6bca86265..b0ecb8b5a 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.proto +++ b/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.proto @@ -48,4 +48,8 @@ message TensorsToImageCalculatorOptions { FloatRange input_tensor_float_range = 2; UIntRange input_tensor_uint_range = 3; } + + // Determines which output tensor to slice when there are multiple output + // tensors available (e.g. network has multiple heads) + optional int32 tensor_position = 4 [default = 0]; } diff --git a/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.cc b/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.cc index 7ae8d1b26..89ec8766b 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.cc +++ b/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -60,7 +60,7 @@ using FaceStylizerGraphOptionsProto = // "mediapipe.tasks.vision.face_stylizer.FaceStylizerGraph". CalculatorGraphConfig CreateGraphConfig( std::unique_ptr options, - bool enable_flow_limiting) { + bool enable_flow_limiting = false) { api2::builder::Graph graph; auto& task_subgraph = graph.AddNode(kSubgraphTypeName); task_subgraph.GetOptions().Swap(options.get()); @@ -87,8 +87,6 @@ ConvertFaceStylizerOptionsToProto(FaceStylizerOptions* options) { auto base_options_proto = std::make_unique( tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); options_proto->mutable_base_options()->Swap(base_options_proto.get()); - options_proto->mutable_base_options()->set_use_stream_mode( - options->running_mode != core::RunningMode::IMAGE); return options_proto; } @@ -125,10 +123,8 @@ absl::StatusOr> FaceStylizer::Create( } return core::VisionTaskApiFactory::Create( - CreateGraphConfig( - std::move(options_proto), - options->running_mode == core::RunningMode::LIVE_STREAM), - std::move(options->base_options.op_resolver), options->running_mode, + CreateGraphConfig(std::move(options_proto)), + std::move(options->base_options.op_resolver), core::RunningMode::IMAGE, std::move(packets_callback)); } @@ -154,52 +150,6 @@ absl::StatusOr> FaceStylizer::Stylize( output_packets[kStylizedImageName].Get()); } -absl::StatusOr> FaceStylizer::StylizeForVideo( - mediapipe::Image image, int64_t timestamp_ms, - std::optional image_processing_options) { - if (image.UsesGpu()) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrCat("GPU input images are currently not supported."), - MediaPipeTasksStatus::kRunnerUnexpectedInputError); - } - ASSIGN_OR_RETURN(NormalizedRect norm_rect, - ConvertToNormalizedRect(image_processing_options, image)); - ASSIGN_OR_RETURN( - auto output_packets, - ProcessVideoData( - {{kImageInStreamName, - MakePacket(std::move(image)) - .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, - {kNormRectName, - MakePacket(std::move(norm_rect)) - .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); - return output_packets[kStylizedImageName].IsEmpty() - ? std::nullopt - : std::optional( - output_packets[kStylizedImageName].Get()); -} - -absl::Status FaceStylizer::StylizeAsync( - Image image, int64_t timestamp_ms, - std::optional image_processing_options) { - if (image.UsesGpu()) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrCat("GPU input images are currently not supported."), - MediaPipeTasksStatus::kRunnerUnexpectedInputError); - } - ASSIGN_OR_RETURN(NormalizedRect norm_rect, - ConvertToNormalizedRect(image_processing_options, image)); - return SendLiveStreamData( - {{kImageInStreamName, - MakePacket(std::move(image)) - .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, - {kNormRectName, - MakePacket(std::move(norm_rect)) - .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); -} - } // namespace face_stylizer } // namespace vision } // namespace tasks diff --git a/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.h b/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.h index 36bb11bd7..7342b291d 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.h +++ b/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -41,15 +41,6 @@ struct FaceStylizerOptions { // file with metadata, accelerator options, op resolver, etc. tasks::core::BaseOptions base_options; - // The running mode of the task. Default to the image mode. - // Face stylizer has three running modes: - // 1) The image mode for stylizing faces on single image inputs. - // 2) The video mode for stylizing faces on the decoded frames of a video. - // 3) The live stream mode for stylizing faces on the live stream of input - // data, such as from camera. In this mode, the "result_callback" below must - // be specified to receive the stylization results asynchronously. - core::RunningMode running_mode = core::RunningMode::IMAGE; - // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. @@ -90,62 +81,6 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi { std::optional image_processing_options = std::nullopt); - // Performs face stylization on the provided video frame. - // - // The optional 'image_processing_options' parameter can be used to specify: - // - the rotation to apply to the image before performing stylization, by - // setting its 'rotation_degrees' field. - // and/or - // - the region-of-interest on which to perform stylization, by setting its - // 'region_of_interest' field. If not specified, the full image is used. - // If both are specified, the crop around the region-of-interest is extracted - // first, then the specified rotation is applied to the crop. - // - // Only use this method when the FaceStylizer is created with the video - // running mode. - // - // The image can be of any size with format RGB or RGBA. It's required to - // provide the video frame's timestamp (in milliseconds). The input timestamps - // must be monotonically increasing. - // When no face is detected on the input image, the method returns a - // std::nullopt. Otherwise, returns the stylized image of the most visible - // face. The stylized output image size is the same as the model output size. - absl::StatusOr> StylizeForVideo( - mediapipe::Image image, int64_t timestamp_ms, - std::optional image_processing_options = - std::nullopt); - - // Sends live image data to perform face stylization, and the results will - // be available via the "result_callback" provided in the - // FaceStylizerOptions. - // - // The optional 'image_processing_options' parameter can be used to specify: - // - the rotation to apply to the image before performing stylization, by - // setting its 'rotation_degrees' field. - // and/or - // - the region-of-interest on which to perform stylization, by setting its - // 'region_of_interest' field. If not specified, the full image is used. - // If both are specified, the crop around the region-of-interest is extracted - // first, then the specified rotation is applied to the crop. - // - // Only use this method when the FaceStylizer is created with the live stream - // running mode. - // - // The image can be of any size with format RGB or RGBA. It's required to - // provide a timestamp (in milliseconds) to indicate when the input image is - // sent to the face stylizer. The input timestamps must be monotonically - // increasing. - // - // The "result_callback" provides: - // - When no face is detected on the input image, the method returns a - // std::nullopt. Otherwise, returns the stylized image of the most visible - // face. The stylized output image size is the same as the model output - // size. - // - The input timestamp in milliseconds. - absl::Status StylizeAsync(mediapipe::Image image, int64_t timestamp_ms, - std::optional - image_processing_options = std::nullopt); - // Shuts down the FaceStylizer when all works are done. absl::Status Close() { return runner_->Close(); } }; diff --git a/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer_graph.cc b/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer_graph.cc index 27b8dacc1..6a50dccc4 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer_graph.cc +++ b/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -41,6 +41,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.pb.h" #include "mediapipe/tasks/cc/vision/face_stylizer/proto/face_stylizer_graph_options.pb.h" +#include "mediapipe/util/graph_builder_utils.h" namespace mediapipe { namespace tasks { @@ -64,26 +65,27 @@ using ::mediapipe::tasks::vision::face_stylizer::proto:: FaceStylizerGraphOptions; constexpr char kDetectionTag[] = "DETECTION"; +constexpr char kFaceAlignmentTag[] = "FACE_ALIGNMENT"; constexpr char kFaceDetectorTFLiteName[] = "face_detector.tflite"; constexpr char kFaceLandmarksDetectorTFLiteName[] = "face_landmarks_detector.tflite"; constexpr char kFaceStylizerTFLiteName[] = "face_stylizer.tflite"; constexpr char kImageTag[] = "IMAGE"; -constexpr char kImageCpuTag[] = "IMAGE_CPU"; -constexpr char kImageGpuTag[] = "IMAGE_GPU"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kMatrixTag[] = "MATRIX"; constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS"; constexpr char kNormRectTag[] = "NORM_RECT"; -constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; constexpr char kSizeTag[] = "SIZE"; constexpr char kStylizedImageTag[] = "STYLIZED_IMAGE"; constexpr char kTensorsTag[] = "TENSORS"; +constexpr char kTransformationMatrixTag[] = "TRANSFORMATION_MATRIX"; // Struct holding the different output streams produced by the face stylizer // graph. struct FaceStylizerOutputStreams { - Source stylized_image; + std::optional> stylized_image; + std::optional> face_alignment_image; + std::optional>> transformation_matrix; Source original_image; }; @@ -106,8 +108,6 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, face_detector_graph_options->mutable_base_options() ->mutable_acceleration() ->CopyFrom(options->base_options().acceleration()); - face_detector_graph_options->mutable_base_options()->set_use_stream_mode( - options->base_options().use_stream_mode()); auto* face_landmarks_detector_graph_options = options->mutable_face_landmarker_graph_options() ->mutable_face_landmarks_detector_graph_options(); @@ -127,9 +127,11 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, face_landmarks_detector_graph_options->mutable_base_options() ->set_use_stream_mode(options->base_options().use_stream_mode()); - ASSIGN_OR_RETURN(const auto face_stylizer_file, - resources.GetFile(kFaceStylizerTFLiteName)); - SetExternalFile(face_stylizer_file, face_stylizer_external_file, is_copy); + if (face_stylizer_external_file) { + ASSIGN_OR_RETURN(const auto face_stylizer_file, + resources.GetFile(kFaceStylizerTFLiteName)); + SetExternalFile(face_stylizer_file, face_stylizer_external_file, is_copy); + } return absl::OkStatus(); } @@ -164,7 +166,7 @@ void ConfigureTensorsToImageCalculator( if (image_to_tensor_options.has_output_tensor_float_range()) { auto* mutable_range = tensors_to_image_options->mutable_input_tensor_float_range(); - // TODO: Make the float range flexiable. + // TODO: Make the float range flexible. mutable_range->set_min(0); mutable_range->set_max(1); } else if (image_to_tensor_options.has_output_tensor_uint_range()) { @@ -190,8 +192,19 @@ void ConfigureTensorsToImageCalculator( // @Optional: rect covering the whole image is used if not specified. // // Outputs: -// IMAGE - mediapipe::Image +// STYLIZED_IMAGE - mediapipe::Image // The face stylization output image. +// FACE_ALIGNMENT - mediapipe::Image +// The aligned face image that is fed to the face stylization model to +// perform stylization. Also useful for preparing face stylization training +// data. +// TRANSFORMATION_MATRIX - std::array +// An std::array representing a 4x4 row-major-order matrix that +// maps a point on the input image to a point on the output image, and +// can be used to reverse the mapping by inverting the matrix. +// IMAGE - mediapipe::Image +// The input image that the face landmarker runs on and has the pixel data +// stored on the target storage (CPU vs GPU). // // Example: // node { @@ -200,6 +213,7 @@ void ConfigureTensorsToImageCalculator( // input_stream: "NORM_RECT:norm_rect" // output_stream: "IMAGE:image_out" // output_stream: "STYLIZED_IMAGE:stylized_image" +// output_stream: "FACE_ALIGNMENT:face_alignment_image" // options { // [mediapipe.tasks.vision.face_stylizer.proto.FaceStylizerGraphOptions.ext] // { @@ -215,18 +229,28 @@ class FaceStylizerGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { - ASSIGN_OR_RETURN( - const auto* model_asset_bundle_resources, - CreateModelAssetBundleResources(sc)); - // Copies the file content instead of passing the pointer of file in - // memory if the subgraph model resource service is not available. + bool output_stylized = HasOutput(sc->OriginalNode(), kStylizedImageTag); + bool output_alignment = HasOutput(sc->OriginalNode(), kFaceAlignmentTag); auto face_stylizer_external_file = absl::make_unique(); - MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( - *model_asset_bundle_resources, - sc->MutableOptions(), - face_stylizer_external_file.get(), - !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) - .IsAvailable())); + if (sc->Options().has_base_options()) { + ASSIGN_OR_RETURN( + const auto* model_asset_bundle_resources, + CreateModelAssetBundleResources(sc)); + // Copies the file content instead of passing the pointer of file in + // memory if the subgraph model resource service is not available. + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + *model_asset_bundle_resources, + sc->MutableOptions(), + output_stylized ? face_stylizer_external_file.get() : nullptr, + !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) + .IsAvailable())); + } else if (output_stylized) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Face stylizer must specify its base options when the " + "\"STYLIZED_IMAGE\" output stream is connected.", + MediaPipeTasksStatus::kInvalidArgumentError); + } Graph graph; ASSIGN_OR_RETURN( auto face_landmark_lists, @@ -235,15 +259,29 @@ class FaceStylizerGraph : public core::ModelTaskGraph { ->mutable_face_landmarker_graph_options(), graph[Input(kImageTag)], graph[Input::Optional(kNormRectTag)], graph)); - ASSIGN_OR_RETURN( - const auto* model_resources, - CreateModelResources(sc, std::move(face_stylizer_external_file))); + const ModelResources* face_stylizer_model_resources = nullptr; + if (output_stylized) { + ASSIGN_OR_RETURN( + const auto* model_resources, + CreateModelResources(sc, std::move(face_stylizer_external_file))); + face_stylizer_model_resources = model_resources; + } ASSIGN_OR_RETURN( auto output_streams, BuildFaceStylizerGraph(sc->Options(), - *model_resources, graph[Input(kImageTag)], + face_stylizer_model_resources, output_alignment, + graph[Input(kImageTag)], face_landmark_lists, graph)); - output_streams.stylized_image >> graph[Output(kStylizedImageTag)]; + if (output_stylized) { + output_streams.stylized_image.value() >> + graph[Output(kStylizedImageTag)]; + } + if (output_alignment) { + output_streams.face_alignment_image.value() >> + graph[Output(kFaceAlignmentTag)]; + } + output_streams.transformation_matrix.value() >> + graph[Output>(kTransformationMatrixTag)]; output_streams.original_image >> graph[Output(kImageTag)]; return graph.GetConfig(); } @@ -277,9 +315,11 @@ class FaceStylizerGraph : public core::ModelTaskGraph { absl::StatusOr BuildFaceStylizerGraph( const FaceStylizerGraphOptions& task_options, - const ModelResources& model_resources, Source image_in, + const ModelResources* model_resources, bool output_alignment, + Source image_in, Source> face_landmark_lists, Graph& graph) { + bool output_stylized = model_resources != nullptr; auto& split_face_landmark_list = graph.AddNode("SplitNormalizedLandmarkListVectorCalculator"); ConfigureSplitNormalizedLandmarkListVectorCalculator( @@ -303,15 +343,59 @@ class FaceStylizerGraph : public core::ModelTaskGraph { face_detection >> face_to_rect.In(kDetectionTag); image_size >> face_to_rect.In(kImageSizeTag); auto face_rect = face_to_rect.Out(kNormRectTag); - // Adds preprocessing calculators and connects them to the graph input image - // stream. + + std::optional> face_alignment; + // Output aligned face only. + // In this case, the face stylization model inference is not required. + // However, to keep consistent with the inference preprocessing steps, the + // ImageToTensorCalculator is still used to perform image rotation, + // cropping, and resizing. + if (!output_stylized) { + auto& pass_through = graph.AddNode("PassThroughCalculator"); + image_in >> pass_through.In(""); + + auto& image_to_tensor = graph.AddNode("ImageToTensorCalculator"); + auto& image_to_tensor_options = + image_to_tensor.GetOptions(); + image_to_tensor_options.mutable_output_tensor_float_range()->set_min(0); + image_to_tensor_options.mutable_output_tensor_float_range()->set_max(1); + image_to_tensor_options.set_output_tensor_width( + task_options.face_alignment_size()); + image_to_tensor_options.set_output_tensor_height( + task_options.face_alignment_size()); + image_to_tensor_options.set_keep_aspect_ratio(true); + image_to_tensor_options.set_border_mode( + mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO); + image_in >> image_to_tensor.In(kImageTag); + face_rect >> image_to_tensor.In(kNormRectTag); + auto face_alignment_image = image_to_tensor.Out(kTensorsTag); + + auto& tensors_to_image = + graph.AddNode("mediapipe.tasks.TensorsToImageCalculator"); + auto& tensors_to_image_options = + tensors_to_image.GetOptions(); + tensors_to_image_options.mutable_input_tensor_float_range()->set_min(0); + tensors_to_image_options.mutable_input_tensor_float_range()->set_max(1); + face_alignment_image >> tensors_to_image.In(kTensorsTag); + face_alignment = tensors_to_image.Out(kImageTag).Cast(); + + return {{/*stylized_image=*/std::nullopt, + /*alignment_image=*/face_alignment, + /*transformation_matrix=*/ + image_to_tensor.Out(kMatrixTag).Cast>(), + /*original_image=*/pass_through.Out("").Cast()}}; + } + + std::optional> stylized; + // Adds preprocessing calculators and connects them to the graph input + // image stream. auto& preprocessing = graph.AddNode( "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); bool use_gpu = components::processors::DetermineImagePreprocessingGpuBackend( task_options.base_options().acceleration()); MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( - model_resources, use_gpu, + *model_resources, use_gpu, &preprocessing.GetOptions())); auto& image_to_tensor_options = @@ -329,7 +413,7 @@ class FaceStylizerGraph : public core::ModelTaskGraph { // Adds inference subgraph and connects its input stream to the output // tensors produced by the ImageToTensorCalculator. auto& inference = AddInference( - model_resources, task_options.base_options().acceleration(), graph); + *model_resources, task_options.base_options().acceleration(), graph); preprocessed_tensors >> inference.In(kTensorsTag); auto model_output_tensors = inference.Out(kTensorsTag).Cast>(); @@ -346,8 +430,22 @@ class FaceStylizerGraph : public core::ModelTaskGraph { image_converter.GetOptions() .set_output_on_gpu(false); tensor_image >> image_converter.In(""); + stylized = image_converter.Out("").Cast(); - return {{/*stylized_image=*/image_converter.Out("").Cast(), + if (output_alignment) { + auto& tensors_to_image = + graph.AddNode("mediapipe.tasks.TensorsToImageCalculator"); + ConfigureTensorsToImageCalculator( + image_to_tensor_options, + &tensors_to_image.GetOptions()); + preprocessed_tensors >> tensors_to_image.In(kTensorsTag); + face_alignment = tensors_to_image.Out(kImageTag).Cast(); + } + + return {{/*stylized_image=*/stylized, + /*alignment_image=*/face_alignment, + /*transformation_matrix=*/ + preprocessing.Out(kMatrixTag).Cast>(), /*original_image=*/preprocessing.Out(kImageTag).Cast()}}; } }; diff --git a/mediapipe/tasks/cc/vision/face_stylizer/proto/BUILD b/mediapipe/tasks/cc/vision/face_stylizer/proto/BUILD index 1800591d7..88f7b314f 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/proto/BUILD +++ b/mediapipe/tasks/cc/vision/face_stylizer/proto/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") package(default_visibility = [ "//mediapipe/tasks:internal", + "//mediapipe/tasks:users", ]) licenses(["notice"]) diff --git a/mediapipe/tasks/cc/vision/face_stylizer/proto/face_stylizer_graph_options.proto b/mediapipe/tasks/cc/vision/face_stylizer/proto/face_stylizer_graph_options.proto index 6357b0655..9528fab09 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/proto/face_stylizer_graph_options.proto +++ b/mediapipe/tasks/cc/vision/face_stylizer/proto/face_stylizer_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,4 +36,7 @@ message FaceStylizerGraphOptions { // Options for face landmarker graph. optional vision.face_landmarker.proto.FaceLandmarkerGraphOptions face_landmarker_graph_options = 2; + + // The width and height of the output face alignment images. + optional int32 face_alignment_size = 3 [default = 256]; } diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD index 7ffae6ff2..11e484e9a 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -124,6 +124,7 @@ cc_library( "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph", "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], @@ -161,6 +162,7 @@ cc_library( "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD index 8c2c2e593..b5c7f2bc7 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc index cb95091d4..d06c610b8 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.proto index 730e7dd78..1afe3d8be 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator_test.cc index 509fac5f0..0c63c9e69 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator.cc index b6c973a1b..c806d5895 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -34,15 +34,15 @@ namespace api2 { namespace { -using ::mediapipe::tasks::vision::gesture_recognizer::GetLeftHandScore; +using ::mediapipe::tasks::vision::gesture_recognizer::GetRightHandScore; constexpr char kHandednessTag[] = "HANDEDNESS"; constexpr char kHandednessMatrixTag[] = "HANDEDNESS_MATRIX"; absl::StatusOr> HandednessToMatrix( const mediapipe::ClassificationList& classification_list) { - // Feature value is the probability that the hand is a left hand. - ASSIGN_OR_RETURN(float score, GetLeftHandScore(classification_list)); + // Feature value is the probability that the hand is a right hand. + ASSIGN_OR_RETURN(float score, GetRightHandScore(classification_list)); auto matrix = Matrix(1, 1); matrix(0, 0) = score; auto result = std::make_unique(); diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc index 30e5a958a..f0858e10b 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,10 +38,10 @@ mediapipe::ClassificationList ClassificationForHandedness(float handedness) { mediapipe::ClassificationList result; auto* h = result.add_classification(); if (handedness < 0.5f) { - h->set_label("Right"); + h->set_label("Left"); h->set_score(1.0f - handedness); } else { - h->set_label("Left"); + h->set_label("Right"); h->set_score(handedness); } return result; @@ -84,8 +84,8 @@ TEST_P(HandednessToMatrixCalculatorTest, OutputsCorrectResult) { INSTANTIATE_TEST_CASE_P( HandednessToMatrixCalculatorTests, HandednessToMatrixCalculatorTest, testing::ValuesIn( - {{/* test_name= */ "TestWithRightHand", /* handedness= */ 0.01f}, - {/* test_name= */ "TestWithLeftHand", /* handedness= */ 0.99f}}), + {{/* test_name= */ "TestWithLeftHand", /* handedness= */ 0.01f}, + {/* test_name= */ "TestWithRightHand", /* handedness= */ 0.99f}}), [](const testing::TestParamInfo< HandednessToMatrixCalculatorTest::ParamType>& info) { return info.param.test_name; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc index 088f97c29..624d8a822 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.proto index 10b034447..6f1472cdd 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc index 70234a32d..064dc479f 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index a9f8fa8a1..f04d7d71a 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h index 392aa586f..b752840da 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc index 55db07cb8..9550112bf 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "mediapipe/framework/api2/builder.h" @@ -125,8 +126,8 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, hand_gesture_recognizer_graph_options->mutable_base_options() ->mutable_acceleration() ->mutable_xnnpack(); - LOG(WARNING) << "Hand Gesture Recognizer contains CPU only ops. Sets " - << "HandGestureRecognizerGraph acceleration to Xnnpack."; + ABSL_LOG(WARNING) << "Hand Gesture Recognizer contains CPU only ops. Sets " + << "HandGestureRecognizerGraph acceleration to Xnnpack."; } hand_gesture_recognizer_graph_options->mutable_base_options() ->set_use_stream_mode(options->base_options().use_stream_mode()); diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_result.h b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_result.h index 914217801..9c7a8c714 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_result.h +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_result.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index 527363d1f..fbe05b075 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "mediapipe/calculators/tensor/tensors_to_classification_calculator.pb.h" @@ -246,7 +247,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { options->base_options(), custom_gesture_classifier_graph_options->mutable_base_options()); } else { - LOG(INFO) << "Custom gesture classifier is not defined."; + ABSL_LOG(INFO) << "Custom gesture classifier is not defined."; } return absl::OkStatus(); } diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.cc index 60ccae92c..aeb01602f 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ bool IsRightHand(const Classification& c) { return absl::EqualsIgnoreCase(c.label(), "Right"); } -absl::StatusOr GetLeftHandScore( +absl::StatusOr GetRightHandScore( const ClassificationList& classification_list) { auto classifications = classification_list.classification(); auto iter_max = @@ -50,9 +50,9 @@ absl::StatusOr GetLeftHandScore( RET_CHECK_GE(h.score(), 0.5f); RET_CHECK_LE(h.score(), 1.0f); if (IsLeftHand(h)) { - return h.score(); - } else if (IsRightHand(h)) { return 1.0f - h.score(); + } else if (IsRightHand(h)) { + return h.score(); } else { // Unrecognized handedness label. RET_CHECK_FAIL() << "Unrecognized handedness: " << h.label(); diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h index ae4137d0f..077fbf1b9 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ bool IsLeftHand(const mediapipe::Classification& c); bool IsRightHand(const mediapipe::Classification& c); -absl::StatusOr GetLeftHandScore( +absl::StatusOr GetRightHandScore( const mediapipe::ClassificationList& classification_list); } // namespace gesture_recognizer diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util_test.cc index 40a201ae8..ae1a5c6e7 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util_test.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,49 +26,49 @@ namespace vision { namespace gesture_recognizer { namespace { -TEST(GetLeftHandScore, SingleLeftHandClassification) { - ClassificationList classifications; - auto& c = *classifications.add_classification(); - c.set_label("Left"); - c.set_score(0.6f); - - MP_ASSERT_OK_AND_ASSIGN(float score, GetLeftHandScore(classifications)); - EXPECT_FLOAT_EQ(score, 0.6f); -} - -TEST(GetLeftHandScore, SingleRightHandClassification) { +TEST(GetRightHandScore, SingleRightHandClassification) { ClassificationList classifications; auto& c = *classifications.add_classification(); c.set_label("Right"); + c.set_score(0.6f); + + MP_ASSERT_OK_AND_ASSIGN(float score, GetRightHandScore(classifications)); + EXPECT_FLOAT_EQ(score, 0.6f); +} + +TEST(GetRightHandScore, SingleLeftHandClassification) { + ClassificationList classifications; + auto& c = *classifications.add_classification(); + c.set_label("Left"); c.set_score(0.9f); - MP_ASSERT_OK_AND_ASSIGN(float score, GetLeftHandScore(classifications)); + MP_ASSERT_OK_AND_ASSIGN(float score, GetRightHandScore(classifications)); EXPECT_FLOAT_EQ(score, 0.1f); } -TEST(GetLeftHandScore, LeftAndRightHandClassification) { +TEST(GetRightHandScore, LeftAndRightHandClassification) { ClassificationList classifications; auto& right = *classifications.add_classification(); - right.set_label("Right"); + right.set_label("Left"); right.set_score(0.9f); auto& left = *classifications.add_classification(); - left.set_label("Left"); + left.set_label("Right"); left.set_score(0.1f); - MP_ASSERT_OK_AND_ASSIGN(float score, GetLeftHandScore(classifications)); + MP_ASSERT_OK_AND_ASSIGN(float score, GetRightHandScore(classifications)); EXPECT_FLOAT_EQ(score, 0.1f); } -TEST(GetLeftHandScore, LeftAndRightLowerCaseHandClassification) { +TEST(GetRightHandScore, LeftAndRightLowerCaseHandClassification) { ClassificationList classifications; auto& right = *classifications.add_classification(); - right.set_label("right"); + right.set_label("Left"); right.set_score(0.9f); auto& left = *classifications.add_classification(); - left.set_label("left"); + left.set_label("Right"); left.set_score(0.1f); - MP_ASSERT_OK_AND_ASSIGN(float score, GetLeftHandScore(classifications)); + MP_ASSERT_OK_AND_ASSIGN(float score, GetRightHandScore(classifications)); EXPECT_FLOAT_EQ(score, 0.1f); } diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD index 0db47da7a..8e4f28060 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto index edbabc018..51d3203f2 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto index df909a6db..dc6593727 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto index fef22c07c..ef66a2d3b 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto index ae85509da..ceff67392 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/hand_detector/BUILD b/mediapipe/tasks/cc/vision/hand_detector/BUILD index 55162d09b..7fc943d87 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/BUILD +++ b/mediapipe/tasks/cc/vision/hand_detector/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc index 923eab1ca..1b964c5d8 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc index f4e5f8c7d..eadc26ad9 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/flags/flag.h" +#include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" @@ -76,8 +77,8 @@ using ::testing::proto::Partially; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kPalmDetectionModel[] = "palm_detection_full.tflite"; -constexpr char kTestRightHandsImage[] = "right_hands.jpg"; -constexpr char kTestRightHandsRotatedImage[] = "right_hands_rotated.jpg"; +constexpr char kTestLeftHandsImage[] = "left_hands.jpg"; +constexpr char kTestLeftHandsRotatedImage[] = "left_hands_rotated.jpg"; constexpr char kTestModelResourcesTag[] = "test_model_resources"; constexpr char kOneHandResultFile[] = "hand_detector_result_one_hand.pbtxt"; @@ -138,8 +139,8 @@ absl::StatusOr> CreateTaskRunner( HandDetectorResult GetExpectedHandDetectorResult(absl::string_view file_name) { HandDetectorResult result; - CHECK_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, file_name), - &result, Defaults())) + ABSL_CHECK_OK(GetTextProto( + file::JoinPath("./", kTestDataDirectory, file_name), &result, Defaults())) << "Expected hand detector result does not exist."; return result; } @@ -207,21 +208,21 @@ INSTANTIATE_TEST_SUITE_P( HandDetectionTest, HandDetectionTest, Values(TestParams{.test_name = "DetectOneHand", .hand_detection_model_name = kPalmDetectionModel, - .test_image_name = kTestRightHandsImage, + .test_image_name = kTestLeftHandsImage, .rotation = 0, .num_hands = 1, .expected_result = GetExpectedHandDetectorResult(kOneHandResultFile)}, TestParams{.test_name = "DetectTwoHands", .hand_detection_model_name = kPalmDetectionModel, - .test_image_name = kTestRightHandsImage, + .test_image_name = kTestLeftHandsImage, .rotation = 0, .num_hands = 2, .expected_result = GetExpectedHandDetectorResult(kTwoHandsResultFile)}, TestParams{.test_name = "DetectOneHandWithRotation", .hand_detection_model_name = kPalmDetectionModel, - .test_image_name = kTestRightHandsRotatedImage, + .test_image_name = kTestLeftHandsRotatedImage, .rotation = M_PI / 2.0f, .num_hands = 1, .expected_result = GetExpectedHandDetectorResult( diff --git a/mediapipe/tasks/cc/vision/hand_detector/proto/BUILD b/mediapipe/tasks/cc/vision/hand_detector/proto/BUILD index 77f3b2649..c1453f420 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/proto/BUILD +++ b/mediapipe/tasks/cc/vision/hand_detector/proto/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto index bede70da5..670d1e41f 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.proto b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.proto index 00c179ca9..170ed7c39 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.proto +++ b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 7a83816b8..1e24256d1 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -153,6 +153,13 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "hand_landmarks_connections", + hdrs = ["hand_landmarks_connections.h"], +) + +# TODO: open source hand joints graph + cc_library( name = "hand_landmarker_result", srcs = ["hand_landmarker_result.cc"], diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD index 73d3f38eb..15806b516 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -42,6 +42,7 @@ cc_library( "//mediapipe/framework/port:rectangle", "//mediapipe/framework/port:status", "//mediapipe/util:rectangle_util", + "@com_google_absl//absl/log:absl_check", ], alwayslink = 1, ) diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc index 011bce2b9..5cbd72c3b 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/log/absl_check.h" #include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/collection_item_id.h" @@ -89,8 +90,8 @@ class HandAssociationCalculator : public CalculatorBase { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); - CHECK_GT(options_.min_similarity_threshold(), 0.0); - CHECK_LE(options_.min_similarity_threshold(), 1.0); + ABSL_CHECK_GT(options_.min_similarity_threshold(), 0.0); + ABSL_CHECK_LE(options_.min_similarity_threshold(), 1.0); return absl::OkStatus(); } diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.proto b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.proto index e7229b4a2..b37478860 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc index c22b1a7e6..057d72e10 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc index d875de98f..ced757546 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.h b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.h index d7b435487..782565ee0 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.h +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmark.h b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmark.h index c8dbc9254..3f70e7ee7 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmark.h +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmark.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc index 4eec37f20..9190f4052 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h index 7a43d20d7..726780ff2 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc index b37141005..b051dc571 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -83,7 +83,7 @@ struct HandLandmarkerOutputs { Stream> landmark_lists; Stream> world_landmark_lists; Stream> hand_rects_next_frame; - Stream> handednesses; + Stream> handedness; Stream> palm_rects; Stream> palm_detections; Stream image; @@ -241,7 +241,7 @@ class HandLandmarkerGraph : public core::ModelTaskGraph { graph[Output>(kWorldLandmarksTag)]; hand_landmarker_outputs.hand_rects_next_frame >> graph[Output>(kHandRectNextFrameTag)]; - hand_landmarker_outputs.handednesses >> + hand_landmarker_outputs.handedness >> graph[Output>(kHandednessTag)]; hand_landmarker_outputs.palm_rects >> graph[Output>(kPalmRectsTag)]; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc index fc73e7787..f08e2b863 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -69,8 +69,8 @@ using ::testing::proto::Partially; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kHandLandmarkerModelBundle[] = "hand_landmarker.task"; -constexpr char kLeftHandsImage[] = "left_hands.jpg"; -constexpr char kLeftHandsRotatedImage[] = "left_hands_rotated.jpg"; +constexpr char kRightHandsImage[] = "right_hands.jpg"; +constexpr char kRightHandsRotatedImage[] = "right_hands_rotated.jpg"; constexpr char kImageTag[] = "IMAGE"; constexpr char kImageName[] = "image_in"; @@ -86,15 +86,15 @@ constexpr char kHandednessTag[] = "HANDEDNESS"; constexpr char kHandednessName[] = "handedness"; // Expected hand landmarks positions, in text proto format. -constexpr char kExpectedLeftUpHandLandmarksFilename[] = - "expected_left_up_hand_landmarks.prototxt"; -constexpr char kExpectedLeftDownHandLandmarksFilename[] = - "expected_left_down_hand_landmarks.prototxt"; +constexpr char kExpectedRightUpHandLandmarksFilename[] = + "expected_right_up_hand_landmarks.prototxt"; +constexpr char kExpectedRightDownHandLandmarksFilename[] = + "expected_right_down_hand_landmarks.prototxt"; // Same but for the rotated image. -constexpr char kExpectedLeftUpHandRotatedLandmarksFilename[] = - "expected_left_up_hand_rotated_landmarks.prototxt"; -constexpr char kExpectedLeftDownHandRotatedLandmarksFilename[] = - "expected_left_down_hand_rotated_landmarks.prototxt"; +constexpr char kExpectedRightUpHandRotatedLandmarksFilename[] = + "expected_right_up_hand_rotated_landmarks.prototxt"; +constexpr char kExpectedRightDownHandRotatedLandmarksFilename[] = + "expected_right_down_hand_rotated_landmarks.prototxt"; constexpr float kFullModelFractionDiff = 0.03; // percentage constexpr float kAbsMargin = 0.03; @@ -141,8 +141,8 @@ class HandLandmarkerTest : public tflite::testing::Test {}; TEST_F(HandLandmarkerTest, Succeeds) { MP_ASSERT_OK_AND_ASSIGN( - Image image, - DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kLeftHandsImage))); + Image image, DecodeImageFromFile( + JoinPath("./", kTestDataDirectory, kRightHandsImage))); NormalizedRect input_norm_rect; input_norm_rect.set_x_center(0.5); input_norm_rect.set_y_center(0.5); @@ -157,8 +157,8 @@ TEST_F(HandLandmarkerTest, Succeeds) { .Get>(); ASSERT_EQ(landmarks.size(), kMaxNumHands); std::vector expected_landmarks = { - GetExpectedLandmarkList(kExpectedLeftUpHandLandmarksFilename), - GetExpectedLandmarkList(kExpectedLeftDownHandLandmarksFilename)}; + GetExpectedLandmarkList(kExpectedRightUpHandLandmarksFilename), + GetExpectedLandmarkList(kExpectedRightDownHandLandmarksFilename)}; EXPECT_THAT(landmarks[0], Approximately(Partially(EqualsProto(expected_landmarks[0])), @@ -173,7 +173,7 @@ TEST_F(HandLandmarkerTest, Succeeds) { TEST_F(HandLandmarkerTest, SucceedsWithRotation) { MP_ASSERT_OK_AND_ASSIGN( Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, - kLeftHandsRotatedImage))); + kRightHandsRotatedImage))); NormalizedRect input_norm_rect; input_norm_rect.set_x_center(0.5); input_norm_rect.set_y_center(0.5); @@ -189,8 +189,8 @@ TEST_F(HandLandmarkerTest, SucceedsWithRotation) { .Get>(); ASSERT_EQ(landmarks.size(), kMaxNumHands); std::vector expected_landmarks = { - GetExpectedLandmarkList(kExpectedLeftUpHandRotatedLandmarksFilename), - GetExpectedLandmarkList(kExpectedLeftDownHandRotatedLandmarksFilename)}; + GetExpectedLandmarkList(kExpectedRightUpHandRotatedLandmarksFilename), + GetExpectedLandmarkList(kExpectedRightDownHandRotatedLandmarksFilename)}; EXPECT_THAT(landmarks[0], Approximately(Partially(EqualsProto(expected_landmarks[0])), diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.cc index 9d2ae2be8..9ec7f838d 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h index 1bca8e66a..caa3c1790 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result_test.cc index 109749b01..d2a6b0b32 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc index bb7d1a905..899e6560d 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_connections.h b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_connections.h new file mode 100644 index 000000000..510820294 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_connections.h @@ -0,0 +1,54 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKS_CONNECTIONS_H_ +#define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKS_CONNECTIONS_H_ + +#include + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace hand_landmarker { + +static constexpr std::array, 6> kHandPalmConnections{ + {{0, 1}, {0, 5}, {9, 13}, {13, 17}, {5, 9}, {0, 17}}}; + +static constexpr std::array, 3> kHandThumbConnections{ + {{1, 2}, {2, 3}, {3, 4}}}; + +static constexpr std::array, 3> kHandIndexFingerConnections{ + {{5, 6}, {6, 7}, {7, 8}}}; + +static constexpr std::array, 3> kHandMiddleFingerConnections{ + {{9, 10}, {10, 11}, {11, 12}}}; + +static constexpr std::array, 3> kHandRingFingerConnections{ + {{13, 14}, {14, 15}, {15, 16}}}; + +static constexpr std::array, 3> kHandPinkyFingerConnections{ + {{17, 18}, {18, 19}, {19, 20}}}; + +static constexpr std::array, 21> kHandConnections{ + {{0, 1}, {0, 5}, {9, 13}, {13, 17}, {5, 9}, {0, 17}, {1, 2}, + {2, 3}, {3, 4}, {5, 6}, {6, 7}, {7, 8}, {9, 10}, {10, 11}, + {11, 12}, {13, 14}, {14, 15}, {15, 16}, {17, 18}, {18, 19}, {19, 20}}}; + +} // namespace hand_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKS_CONNECTIONS_H_ diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc index f7fa83a11..51cbc9e89 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -59,7 +59,6 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::utils::AllowIf; -using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::vision::hand_landmarker::proto:: HandLandmarksDetectorGraphOptions; using LabelItems = mediapipe::proto_ns::Map; @@ -94,7 +93,7 @@ struct HandLandmarkerOutputs { Source> hand_rects_next_frame; Source> presences; Source> presence_scores; - Source> handednesses; + Source> handedness; }; absl::Status SanityCheckOptions( @@ -143,8 +142,8 @@ void ConfigureTensorsToHandednessCalculator( LabelMapItem right_hand = LabelMapItem(); right_hand.set_name("Right"); right_hand.set_display_name("Right"); - (*options->mutable_label_items())[0] = std::move(left_hand); - (*options->mutable_label_items())[1] = std::move(right_hand); + (*options->mutable_label_items())[0] = std::move(right_hand); + (*options->mutable_label_items())[1] = std::move(left_hand); } void ConfigureHandRectTransformationCalculator( @@ -479,7 +478,7 @@ class MultipleHandLandmarksDetectorGraph : public core::ModelTaskGraph { graph[Output>(kPresenceTag)]; hand_landmark_detection_outputs.presence_scores >> graph[Output>(kPresenceScoreTag)]; - hand_landmark_detection_outputs.handednesses >> + hand_landmark_detection_outputs.handedness >> graph[Output>(kHandednessTag)]; return graph.GetConfig(); @@ -563,7 +562,7 @@ class MultipleHandLandmarksDetectorGraph : public core::ModelTaskGraph { /* hand_rects_next_frame= */ hand_rects_next_frame, /* presences= */ presences, /* presence_scores= */ presence_scores, - /* handednesses= */ handednesses, + /* handedness= */ handednesses, }}; } }; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc index bbf3a7cde..5af62e11a 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -319,15 +319,15 @@ TEST_P(MultiHandLandmarkerTest, Succeeds) { const std::vector& presences = (*output_packets)[kPresenceName].Get>(); - const std::vector& handednesses = + const std::vector& handedness = (*output_packets)[kHandednessName].Get>(); const std::vector& landmark_lists = (*output_packets)[kLandmarksName] .Get>(); EXPECT_THAT(presences, ElementsAreArray(GetParam().expected_presences)); - EXPECT_THAT(handednesses, Pointwise(Partially(EqualsProto()), - GetParam().expected_handedness)); + EXPECT_THAT(handedness, Pointwise(Partially(EqualsProto()), + GetParam().expected_handedness)); EXPECT_THAT( landmark_lists, Pointwise(Approximately(Partially(EqualsProto()), /*margin=*/kAbsMargin, @@ -342,7 +342,7 @@ INSTANTIATE_TEST_SUITE_P( .test_name = "HandLandmarkerLiteModelRightUpHand", .input_model_name = kHandLandmarkerLiteModel, .test_image_name = kRightHandsImage, - .hand_rect = MakeHandRect(0.25, 0.5, 0.5, 1.0, 0), + .hand_rect = MakeHandRect(0.75, 0.5, 0.5, 1.0, 0), .expected_presence = true, .expected_landmarks = GetExpectedLandmarkList(kExpectedRightUpHandLandmarksFilename), @@ -352,7 +352,7 @@ INSTANTIATE_TEST_SUITE_P( .test_name = "HandLandmarkerLiteModelRightDownHand", .input_model_name = kHandLandmarkerLiteModel, .test_image_name = kRightHandsImage, - .hand_rect = MakeHandRect(0.75, 0.5, 0.5, 1.0, M_PI), + .hand_rect = MakeHandRect(0.25, 0.5, 0.5, 1.0, M_PI), .expected_presence = true, .expected_landmarks = GetExpectedLandmarkList( kExpectedRightDownHandLandmarksFilename), @@ -362,7 +362,7 @@ INSTANTIATE_TEST_SUITE_P( .test_name = "HandLandmarkerFullModelRightUpHand", .input_model_name = kHandLandmarkerFullModel, .test_image_name = kRightHandsImage, - .hand_rect = MakeHandRect(0.25, 0.5, 0.5, 1.0, 0), + .hand_rect = MakeHandRect(0.75, 0.5, 0.5, 1.0, 0), .expected_presence = true, .expected_landmarks = GetExpectedLandmarkList(kExpectedRightUpHandLandmarksFilename), @@ -372,7 +372,7 @@ INSTANTIATE_TEST_SUITE_P( .test_name = "HandLandmarkerFullModelRightDownHand", .input_model_name = kHandLandmarkerFullModel, .test_image_name = kRightHandsImage, - .hand_rect = MakeHandRect(0.75, 0.5, 0.5, 1.0, M_PI), + .hand_rect = MakeHandRect(0.25, 0.5, 0.5, 1.0, M_PI), .expected_presence = true, .expected_landmarks = GetExpectedLandmarkList( kExpectedRightDownHandLandmarksFilename), @@ -382,7 +382,7 @@ INSTANTIATE_TEST_SUITE_P( .test_name = "HandLandmarkerLiteModelLeftUpHand", .input_model_name = kHandLandmarkerLiteModel, .test_image_name = kLeftHandsImage, - .hand_rect = MakeHandRect(0.75, 0.5, 0.5, 1.0, 0), + .hand_rect = MakeHandRect(0.25, 0.5, 0.5, 1.0, 0), .expected_presence = true, .expected_landmarks = GetExpectedLandmarkList(kExpectedLeftUpHandLandmarksFilename), @@ -392,7 +392,7 @@ INSTANTIATE_TEST_SUITE_P( .test_name = "HandLandmarkerLiteModelLeftDownHand", .input_model_name = kHandLandmarkerLiteModel, .test_image_name = kLeftHandsImage, - .hand_rect = MakeHandRect(0.25, 0.5, 0.5, 1.0, M_PI), + .hand_rect = MakeHandRect(0.75, 0.5, 0.5, 1.0, M_PI), .expected_presence = true, .expected_landmarks = GetExpectedLandmarkList(kExpectedLeftDownHandLandmarksFilename), @@ -402,7 +402,7 @@ INSTANTIATE_TEST_SUITE_P( .test_name = "HandLandmarkerFullModelLeftUpHand", .input_model_name = kHandLandmarkerFullModel, .test_image_name = kLeftHandsImage, - .hand_rect = MakeHandRect(0.75, 0.5, 0.5, 1.0, 0), + .hand_rect = MakeHandRect(0.25, 0.5, 0.5, 1.0, 0), .expected_presence = true, .expected_landmarks = GetExpectedLandmarkList(kExpectedLeftUpHandLandmarksFilename), @@ -412,7 +412,7 @@ INSTANTIATE_TEST_SUITE_P( .test_name = "HandLandmarkerFullModelLeftDownHand", .input_model_name = kHandLandmarkerFullModel, .test_image_name = kLeftHandsImage, - .hand_rect = MakeHandRect(0.25, 0.5, 0.5, 1.0, M_PI), + .hand_rect = MakeHandRect(0.75, 0.5, 0.5, 1.0, M_PI), .expected_presence = true, .expected_landmarks = GetExpectedLandmarkList(kExpectedLeftDownHandLandmarksFilename), @@ -431,8 +431,8 @@ INSTANTIATE_TEST_SUITE_P( .test_image_name = kRightHandsImage, .hand_rects = { - MakeHandRect(0.25, 0.5, 0.5, 1.0, 0), - MakeHandRect(0.75, 0.5, 0.5, 1.0, M_PI), + MakeHandRect(0.75, 0.5, 0.5, 1.0, 0), + MakeHandRect(0.25, 0.5, 0.5, 1.0, M_PI), }, .expected_presences = {true, true}, .expected_landmark_lists = @@ -449,8 +449,8 @@ INSTANTIATE_TEST_SUITE_P( .test_image_name = kLeftHandsImage, .hand_rects = { - MakeHandRect(0.75, 0.5, 0.5, 1.0, 0), - MakeHandRect(0.25, 0.5, 0.5, 1.0, M_PI), + MakeHandRect(0.25, 0.5, 0.5, 1.0, 0), + MakeHandRect(0.75, 0.5, 0.5, 1.0, M_PI), }, .expected_presences = {true, true}, .expected_landmark_lists = diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD index 945b12f3e..8097d7ab1 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -41,3 +41,5 @@ mediapipe_proto_library( "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_proto", ], ) + +# TODO: open source hand joints graph diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto index d0edf99c0..1ce2305f3 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto index a2d520963..04f62c39c 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/image_classifier/BUILD b/mediapipe/tasks/cc/vision/image_classifier/BUILD index 514e601ef..86ac7680c 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc index 7dd410e83..5b885045b 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h index 9b0c376ae..96050cbd0 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 0adcf842d..5e4363588 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc index e8812d9fd..88eeb9e4d 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD b/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD index 29638bebd..e58efe2fd 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto index 24b126a35..7a84040fd 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/image_embedder/BUILD b/mediapipe/tasks/cc/vision/image_embedder/BUILD index d729eaf1a..7d22302a1 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/BUILD +++ b/mediapipe/tasks/cc/vision/image_embedder/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc index 1425e97cc..53d7c7c9d 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h index 9320cbc35..586b8cdca 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc index 95c4ff379..61d546c7e 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc index 7a0e9e9dc..6d994789d 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/image_embedder/proto/BUILD b/mediapipe/tasks/cc/vision/image_embedder/proto/BUILD index ecf8b0242..297675019 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/proto/BUILD +++ b/mediapipe/tasks/cc/vision/image_embedder/proto/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto index 24ee866f2..5441aac41 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/image_generator/BUILD b/mediapipe/tasks/cc/vision/image_generator/BUILD new file mode 100644 index 000000000..71b8230ae --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/BUILD @@ -0,0 +1,136 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +cc_library( + name = "conditioned_image_graph", + srcs = ["conditioned_image_graph.cc"], + deps = [ + "//mediapipe/calculators/core:get_vector_item_calculator", + "//mediapipe/calculators/core:get_vector_item_calculator_cc_proto", + "//mediapipe/calculators/util:annotation_overlay_calculator", + "//mediapipe/calculators/util:flat_color_image_calculator", + "//mediapipe/calculators/util:flat_color_image_calculator_cc_proto", + "//mediapipe/calculators/util:landmarks_to_render_data_calculator", + "//mediapipe/calculators/util:landmarks_to_render_data_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph", + "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarks_connections", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", + "//mediapipe/util:color_cc_proto", + "//mediapipe/util:image_frame_util", + "//mediapipe/util:render_data_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +cc_library( + name = "image_generator_graph", + srcs = ["image_generator_graph.cc"], + deps = [ + ":conditioned_image_graph", + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/calculators/image:image_transformation_calculator", + "//mediapipe/calculators/image:image_transformation_calculator_cc_proto", + "//mediapipe/calculators/tensor:image_to_tensor_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:inference_calculator_cc_proto", + "//mediapipe/calculators/util:from_image_calculator", + "//mediapipe/calculators/util:to_image_calculator", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:stream_handler_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:switch_container", + "//mediapipe/framework/tool:switch_container_cc_proto", + "//mediapipe/tasks/cc/core:model_asset_bundle_resources", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + "//mediapipe/tasks/cc/vision/image_generator/diffuser:diffusion_plugins_output_calculator", + "//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator", + "//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_cc_proto", + "//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_cc_proto", + "//mediapipe/util:graph_builder_utils", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +cc_library( + name = "image_generator_result", + hdrs = ["image_generator_result.h"], + deps = ["//mediapipe/framework/formats:image"], +) + +cc_library( + name = "image_generator", + srcs = ["image_generator.cc"], + hdrs = ["image_generator.h"], + deps = [ + ":image_generator_graph", + ":image_generator_result", + "//mediapipe/framework:packet", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/face_landmarker", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + ], +) diff --git a/mediapipe/tasks/cc/vision/image_generator/conditioned_image_graph.cc b/mediapipe/tasks/cc/vision/image_generator/conditioned_image_graph.cc new file mode 100644 index 000000000..c85fe981c --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/conditioned_image_graph.cc @@ -0,0 +1,458 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "mediapipe/calculators/core/get_vector_item_calculator.h" +#include "mediapipe/calculators/core/get_vector_item_calculator.pb.h" +#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h" +#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.h" +#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_connections.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/util/color.pb.h" +#include "mediapipe/util/image_frame_util.h" +#include "mediapipe/util/render_data.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace image_generator { + +namespace internal { + +// Helper postprocessing calculator for depth condition type to scale raw depth +// inference result to 0-255 uint8. +class DepthImagePostprocessingCalculator : public api2::Node { + public: + static constexpr api2::Input kImageIn{"IMAGE"}; + static constexpr api2::Output kImageOut{"IMAGE"}; + + MEDIAPIPE_NODE_CONTRACT(kImageIn, kImageOut); + + absl::Status Process(CalculatorContext* cc) final { + if (kImageIn(cc).IsEmpty()) { + return absl::OkStatus(); + } + Image raw_depth_image = kImageIn(cc).Get(); + cv::Mat raw_depth_mat = mediapipe::formats::MatView( + raw_depth_image.GetImageFrameSharedPtr().get()); + cv::Mat depth_mat; + cv::normalize(raw_depth_mat, depth_mat, 255, 0, cv::NORM_MINMAX); + depth_mat.convertTo(depth_mat, CV_8UC3, 1, 0); + cv::cvtColor(depth_mat, depth_mat, cv::COLOR_GRAY2RGB); + // Acquires the cv::Mat data and assign to the image frame. + ImageFrameSharedPtr depth_image_frame_ptr = std::make_shared( + mediapipe::ImageFormat::SRGB, depth_mat.cols, depth_mat.rows, + depth_mat.step, depth_mat.data, + [depth_mat](uint8_t[]) { depth_mat.~Mat(); }); + Image depth_image(depth_image_frame_ptr); + kImageOut(cc).Send(depth_image); + return absl::OkStatus(); + } +}; + +// NOLINTBEGIN: Node registration doesn't work when part of calculator name is +// moved to next line. +// clang-format off +MEDIAPIPE_REGISTER_NODE(::mediapipe::tasks::vision::image_generator::internal::DepthImagePostprocessingCalculator); +// clang-format on +// NOLINTEND + +// Calculator to detect edges in the image with OpenCV Canny edge detection. +class CannyEdgeCalculator : public api2::Node { + public: + static constexpr api2::Input kImageIn{"IMAGE"}; + static constexpr api2::Output kImageOut{"IMAGE"}; + + MEDIAPIPE_NODE_CONTRACT(kImageIn, kImageOut); + + absl::Status Process(CalculatorContext* cc) final { + if (kImageIn(cc).IsEmpty()) { + return absl::OkStatus(); + } + Image input_image = kImageIn(cc).Get(); + cv::Mat input_image_mat = + mediapipe::formats::MatView(input_image.GetImageFrameSharedPtr().get()); + const auto& options = cc->Options< + proto::ConditionedImageGraphOptions::EdgeConditionTypeOptions>(); + cv::Mat lumincance; + cv::cvtColor(input_image_mat, lumincance, cv::COLOR_RGB2GRAY); + cv::Mat edges_mat; + cv::Canny(lumincance, edges_mat, options.threshold_1(), + options.threshold_2(), options.aperture_size(), + options.l2_gradient()); + cv::normalize(edges_mat, edges_mat, 255, 0, cv::NORM_MINMAX); + edges_mat.convertTo(edges_mat, CV_8UC3, 1, 0); + cv::cvtColor(edges_mat, edges_mat, cv::COLOR_GRAY2RGB); + // Acquires the cv::Mat data and assign to the image frame. + ImageFrameSharedPtr edges_image_frame_ptr = std::make_shared( + mediapipe::ImageFormat::SRGB, edges_mat.cols, edges_mat.rows, + edges_mat.step, edges_mat.data, + [edges_mat](uint8_t[]) { edges_mat.~Mat(); }); + Image edges_image(edges_image_frame_ptr); + kImageOut(cc).Send(edges_image); + return absl::OkStatus(); + } +}; + +// NOLINTBEGIN: Node registration doesn't work when part of calculator name is +// moved to next line. +// clang-format off +MEDIAPIPE_REGISTER_NODE(::mediapipe::tasks::vision::image_generator::internal::CannyEdgeCalculator); +// clang-format on +// NOLINTEND + +} // namespace internal + +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; + +constexpr absl::string_view kImageTag = "IMAGE"; +constexpr absl::string_view kUImageTag = "UIMAGE"; +constexpr absl::string_view kNormLandmarksTag = "NORM_LANDMARKS"; +constexpr absl::string_view kVectorTag = "VECTOR"; +constexpr absl::string_view kItemTag = "ITEM"; +constexpr absl::string_view kRenderDataTag = "RENDER_DATA"; +constexpr absl::string_view kConfidenceMaskTag = "CONFIDENCE_MASK:0"; + +enum ColorType { + WHITE = 0, + GREEN = 1, + RED = 2, + BLACK = 3, + BLUE = 4, +}; + +mediapipe::Color GetColor(ColorType color_type) { + mediapipe::Color color; + switch (color_type) { + case WHITE: + color.set_b(255); + color.set_g(255); + color.set_r(255); + break; + case GREEN: + color.set_b(0); + color.set_g(255); + color.set_r(0); + break; + case RED: + color.set_b(0); + color.set_g(0); + color.set_r(255); + break; + case BLACK: + color.set_b(0); + color.set_g(0); + color.set_r(0); + break; + case BLUE: + color.set_b(255); + color.set_g(0); + color.set_r(0); + break; + } + return color; +} + +// Get LandmarksToRenderDataCalculatorOptions for rendering face landmarks +// connections. +mediapipe::LandmarksToRenderDataCalculatorOptions +GetFaceLandmarksRenderDataOptions( + absl::Span> connections, ColorType color_type) { + mediapipe::LandmarksToRenderDataCalculatorOptions render_options; + render_options.set_thickness(1); + render_options.set_visualize_landmark_depth(false); + render_options.set_render_landmarks(false); + *render_options.mutable_connection_color() = GetColor(color_type); + for (const auto& connection : connections) { + render_options.add_landmark_connections(connection[0]); + render_options.add_landmark_connections(connection[1]); + } + return render_options; +} + +Source GetFaceLandmarksRenderData( + Source face_landmarks, + const mediapipe::LandmarksToRenderDataCalculatorOptions& + landmarks_to_render_data_options, + Graph& graph) { + auto& landmarks_to_render_data = + graph.AddNode("LandmarksToRenderDataCalculator"); + landmarks_to_render_data + .GetOptions() + .CopyFrom(landmarks_to_render_data_options); + face_landmarks >> landmarks_to_render_data.In(kNormLandmarksTag); + return landmarks_to_render_data.Out(kRenderDataTag) + .Cast(); +} + +// Add FaceLandmarkerGraph to detect the face landmarks in the given face image, +// and generate a face mesh guidance image for the diffusion plugin model. +absl::StatusOr> GetFaceLandmarksImage( + Source face_image, + const proto::ConditionedImageGraphOptions::FaceConditionTypeOptions& + face_condition_type_options, + Graph& graph) { + if (face_condition_type_options.face_landmarker_graph_options() + .face_detector_graph_options() + .num_faces() != 1) { + return absl::InvalidArgumentError( + "Only supports face landmarks of a single face as the guidance image."); + } + + // Detect face landmarks. + auto& face_landmarker_graph = graph.AddNode( + "mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph"); + face_landmarker_graph + .GetOptions() + .CopyFrom(face_condition_type_options.face_landmarker_graph_options()); + face_image >> face_landmarker_graph.In(kImageTag); + auto face_landmarks_lists = + face_landmarker_graph.Out(kNormLandmarksTag) + .Cast>(); + + // Get the single face landmarks. + auto& get_vector_item = + graph.AddNode("GetNormalizedLandmarkListVectorItemCalculator"); + get_vector_item.GetOptions() + .set_item_index(0); + face_landmarks_lists >> get_vector_item.In(kVectorTag); + auto single_face_landmarks = + get_vector_item.Out(kItemTag).Cast(); + + // Convert face landmarks to render data. + auto face_oval = GetFaceLandmarksRenderData( + single_face_landmarks, + GetFaceLandmarksRenderDataOptions( + absl::Span>( + face_landmarker::FaceLandmarksConnections::kFaceLandmarksFaceOval + .data(), + face_landmarker::FaceLandmarksConnections::kFaceLandmarksFaceOval + .size()), + ColorType::WHITE), + graph); + auto lips = GetFaceLandmarksRenderData( + single_face_landmarks, + GetFaceLandmarksRenderDataOptions( + absl::Span>( + face_landmarker::FaceLandmarksConnections::kFaceLandmarksLips + .data(), + face_landmarker::FaceLandmarksConnections::kFaceLandmarksLips + .size()), + ColorType::WHITE), + graph); + auto left_eye = GetFaceLandmarksRenderData( + single_face_landmarks, + GetFaceLandmarksRenderDataOptions( + absl::Span>( + face_landmarker::FaceLandmarksConnections::kFaceLandmarksLeftEye + .data(), + face_landmarker::FaceLandmarksConnections::kFaceLandmarksLeftEye + .size()), + ColorType::GREEN), + graph); + auto left_eye_brow = GetFaceLandmarksRenderData( + single_face_landmarks, + GetFaceLandmarksRenderDataOptions( + absl::Span>( + face_landmarker::FaceLandmarksConnections:: + kFaceLandmarksLeftEyeBrow.data(), + face_landmarker::FaceLandmarksConnections:: + kFaceLandmarksLeftEyeBrow.size()), + ColorType::GREEN), + graph); + auto left_iris = GetFaceLandmarksRenderData( + single_face_landmarks, + GetFaceLandmarksRenderDataOptions( + absl::Span>( + face_landmarker::FaceLandmarksConnections::kFaceLandmarksLeftIris + .data(), + face_landmarker::FaceLandmarksConnections::kFaceLandmarksLeftIris + .size()), + ColorType::GREEN), + graph); + + auto right_eye = GetFaceLandmarksRenderData( + single_face_landmarks, + GetFaceLandmarksRenderDataOptions( + absl::Span>( + face_landmarker::FaceLandmarksConnections::kFaceLandmarksRightEye + .data(), + face_landmarker::FaceLandmarksConnections::kFaceLandmarksRightEye + .size()), + ColorType::BLUE), + graph); + auto right_eye_brow = GetFaceLandmarksRenderData( + single_face_landmarks, + GetFaceLandmarksRenderDataOptions( + absl::Span>( + face_landmarker::FaceLandmarksConnections:: + kFaceLandmarksRightEyeBrow.data(), + face_landmarker::FaceLandmarksConnections:: + kFaceLandmarksRightEyeBrow.size()), + ColorType::BLUE), + graph); + auto right_iris = GetFaceLandmarksRenderData( + single_face_landmarks, + GetFaceLandmarksRenderDataOptions( + absl::Span>( + face_landmarker::FaceLandmarksConnections::kFaceLandmarksRightIris + .data(), + face_landmarker::FaceLandmarksConnections::kFaceLandmarksRightIris + .size()), + ColorType::BLUE), + graph); + + // Create a black canvas image with same size as face image. + auto& flat_color = graph.AddNode("FlatColorImageCalculator"); + flat_color.GetOptions() + .mutable_color() + ->set_r(0); + face_image >> flat_color.In(kImageTag); + auto blank_canvas = flat_color.Out(kImageTag); + + // Draw render data on the canvas image. + auto& annotation_overlay = graph.AddNode("AnnotationOverlayCalculator"); + blank_canvas >> annotation_overlay.In(kUImageTag); + face_oval >> annotation_overlay.In(0); + lips >> annotation_overlay.In(1); + left_eye >> annotation_overlay.In(2); + left_eye_brow >> annotation_overlay.In(3); + left_iris >> annotation_overlay.In(4); + right_eye >> annotation_overlay.In(5); + right_eye_brow >> annotation_overlay.In(6); + right_iris >> annotation_overlay.In(7); + return annotation_overlay.Out(kUImageTag).Cast(); +} + +absl::StatusOr> GetDepthImage( + Source image, + const image_generator::proto::ConditionedImageGraphOptions:: + DepthConditionTypeOptions& depth_condition_type_options, + Graph& graph) { + auto& image_segmenter_graph = graph.AddNode( + "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"); + image_segmenter_graph + .GetOptions() + .CopyFrom(depth_condition_type_options.image_segmenter_graph_options()); + image >> image_segmenter_graph.In(kImageTag); + auto raw_depth_image = image_segmenter_graph.Out(kConfidenceMaskTag); + + auto& depth_postprocessing = graph.AddNode( + "mediapipe.tasks.vision.image_generator.internal." + "DepthImagePostprocessingCalculator"); + raw_depth_image >> depth_postprocessing.In(kImageTag); + return depth_postprocessing.Out(kImageTag).Cast(); +} + +absl::StatusOr> GetEdgeImage( + Source image, + const image_generator::proto::ConditionedImageGraphOptions:: + EdgeConditionTypeOptions& edge_condition_type_options, + Graph& graph) { + auto& edge_detector = graph.AddNode( + "mediapipe.tasks.vision.image_generator.internal." + "CannyEdgeCalculator"); + edge_detector + .GetOptions< + proto::ConditionedImageGraphOptions::EdgeConditionTypeOptions>() + .CopyFrom(edge_condition_type_options); + image >> edge_detector.In(kImageTag); + return edge_detector.Out(kImageTag).Cast(); +} + +} // namespace + +// A mediapipe.tasks.vision.image_generator.ConditionedImageGraph converts the +// input image to an image of condition type. The output image can be used as +// input for the diffusion model with control plugin. +// Inputs: +// IMAGE - Image +// Conditioned image to generate the image for diffusion plugin model. +// +// Outputs: +// IMAGE - Image +// The guidance image used as input for the diffusion plugin model. +class ConditionedImageGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + Graph graph; + auto& graph_options = + *sc->MutableOptions(); + Source conditioned_image = graph.In(kImageTag).Cast(); + // Configure the guidance graph and get the guidance image if guidance graph + // options is set. + switch (graph_options.condition_type_options_case()) { + case proto::ConditionedImageGraphOptions::CONDITION_TYPE_OPTIONS_NOT_SET: + return absl::InvalidArgumentError( + "Conditioned type options is not set."); + break; + case proto::ConditionedImageGraphOptions::kFaceConditionTypeOptions: { + ASSIGN_OR_RETURN( + auto face_landmarks_image, + GetFaceLandmarksImage(conditioned_image, + graph_options.face_condition_type_options(), + graph)); + face_landmarks_image >> graph.Out(kImageTag); + } break; + case proto::ConditionedImageGraphOptions::kDepthConditionTypeOptions: { + ASSIGN_OR_RETURN( + auto depth_image, + GetDepthImage(conditioned_image, + graph_options.depth_condition_type_options(), graph)); + depth_image >> graph.Out(kImageTag); + } break; + case proto::ConditionedImageGraphOptions::kEdgeConditionTypeOptions: { + ASSIGN_OR_RETURN( + auto edges_image, + GetEdgeImage(conditioned_image, + graph_options.edge_condition_type_options(), graph)); + edges_image >> graph.Out(kImageTag); + } break; + } + return graph.GetConfig(); + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::image_generator::ConditionedImageGraph); + +} // namespace image_generator +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_generator/conditioned_image_graph_test.cc b/mediapipe/tasks/cc/vision/image_generator/conditioned_image_graph_test.cc new file mode 100644 index 000000000..c67ae2fe9 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/conditioned_image_graph_test.cc @@ -0,0 +1,147 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace image_generator { + +namespace { + +using ::mediapipe::Image; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::core::TaskRunner; +using ::mediapipe::tasks::vision::DecodeImageFromFile; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kFaceLandmarkerModel[] = "face_landmarker_v2.task"; +constexpr char kDepthModel[] = + "mobilenetsweep_dptrigmqn384_unit_384_384_fp16quant_fp32input_opt.tflite"; +constexpr char kPortraitImage[] = "portrait.jpg"; +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageInStream[] = "image_in"; +constexpr char kImageOutStream[] = "image_out"; + +// Helper function to create a ConditionedImageGraphTaskRunner TaskRunner. +absl::StatusOr> +CreateConditionedImageGraphTaskRunner( + std::unique_ptr options) { + Graph graph; + auto& conditioned_image_graph = graph.AddNode( + "mediapipe.tasks.vision.image_generator.ConditionedImageGraph"); + conditioned_image_graph.GetOptions() + .Swap(options.get()); + graph.In(kImageTag).Cast().SetName(kImageInStream) >> + conditioned_image_graph.In(kImageTag); + conditioned_image_graph.Out(kImageTag).SetName(kImageOutStream) >> + graph.Out(kImageTag).Cast(); + return core::TaskRunner::Create( + graph.GetConfig(), + absl::make_unique()); +} + +TEST(ConditionedImageGraphTest, SucceedsFaceLandmarkerConditionType) { + auto options = std::make_unique(); + options->mutable_face_condition_type_options() + ->mutable_face_landmarker_graph_options() + ->mutable_base_options() + ->mutable_model_asset() + ->set_file_name( + file::JoinPath("./", kTestDataDirectory, kFaceLandmarkerModel)); + options->mutable_face_condition_type_options() + ->mutable_face_landmarker_graph_options() + ->mutable_face_detector_graph_options() + ->set_num_faces(1); + MP_ASSERT_OK_AND_ASSIGN( + auto runner, CreateConditionedImageGraphTaskRunner(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(file::JoinPath("./", kTestDataDirectory, + kPortraitImage))); + MP_ASSERT_OK_AND_ASSIGN( + auto output_packets, + runner->Process({{kImageInStream, MakePacket(std::move(image))}})); + const auto& output_image = output_packets[kImageOutStream].Get(); + MP_EXPECT_OK(SavePngTestOutput(*output_image.GetImageFrameSharedPtr(), + "face_landmarks_image")); +} + +TEST(ConditionedImageGraphTest, SucceedsDepthConditionType) { + auto options = std::make_unique(); + options->mutable_depth_condition_type_options() + ->mutable_image_segmenter_graph_options() + ->mutable_base_options() + ->mutable_model_asset() + ->set_file_name(file::JoinPath("./", kTestDataDirectory, kDepthModel)); + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(file::JoinPath("./", kTestDataDirectory, + kPortraitImage))); + MP_ASSERT_OK_AND_ASSIGN( + auto runner, CreateConditionedImageGraphTaskRunner(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN( + auto output_packets, + runner->Process({{kImageInStream, MakePacket(std::move(image))}})); + const auto& output_image = output_packets[kImageOutStream].Get(); + MP_EXPECT_OK( + SavePngTestOutput(*output_image.GetImageFrameSharedPtr(), "depth_image")); +} + +TEST(ConditionedImageGraphTest, SucceedsEdgeConditionType) { + auto options = std::make_unique(); + auto edge_condition_type_options = + options->mutable_edge_condition_type_options(); + edge_condition_type_options->set_threshold_1(100); + edge_condition_type_options->set_threshold_2(200); + edge_condition_type_options->set_aperture_size(3); + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(file::JoinPath("./", kTestDataDirectory, + kPortraitImage))); + MP_ASSERT_OK_AND_ASSIGN( + auto runner, CreateConditionedImageGraphTaskRunner(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN( + auto output_packets, + runner->Process({{kImageInStream, MakePacket(std::move(image))}})); + const auto& output_image = output_packets[kImageOutStream].Get(); + MP_EXPECT_OK( + SavePngTestOutput(*output_image.GetImageFrameSharedPtr(), "edges_image")); +} + +} // namespace +} // namespace image_generator +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_generator/diffuser/BUILD b/mediapipe/tasks/cc/vision/image_generator/diffuser/BUILD new file mode 100644 index 000000000..fe10affa1 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/diffuser/BUILD @@ -0,0 +1,70 @@ +# Copyright 2022 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +licenses(["notice"]) + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +cc_library( + name = "diffuser_gpu_header", + hdrs = ["diffuser_gpu.h"], + visibility = [ + "//mediapipe/tasks/cc/vision/image_generator/diffuser:__pkg__", + ], +) + +mediapipe_proto_library( + name = "stable_diffusion_iterate_calculator_proto", + srcs = ["stable_diffusion_iterate_calculator.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "stable_diffusion_iterate_calculator", + srcs = ["stable_diffusion_iterate_calculator.cc"], + deps = [ + ":diffuser_gpu_header", + ":stable_diffusion_iterate_calculator_cc_proto", + "//mediapipe/framework:calculator_context", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/deps:file_helpers", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:tensor", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + ], + alwayslink = 1, +) + +cc_library( + name = "diffusion_plugins_output_calculator", + srcs = ["diffusion_plugins_output_calculator.cc"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:tensor", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/vision/image_generator/diffuser/diffuser_gpu.h b/mediapipe/tasks/cc/vision/image_generator/diffuser/diffuser_gpu.h new file mode 100644 index 000000000..85738b80b --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/diffuser/diffuser_gpu.h @@ -0,0 +1,87 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_DIFFUSER_DIFFUSER_GPU_H_ +#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_DIFFUSER_DIFFUSER_GPU_H_ + +#include +#include + +#ifndef DG_EXPORT +#define DG_EXPORT __attribute__((visibility("default"))) +#endif // DG_EXPORT + +#ifdef __cplusplus +extern "C" { +#endif + +enum DiffuserModelType { + kDiffuserModelTypeSd1, + kDiffuserModelTypeGldm, + kDiffuserModelTypeDistilledGldm, + kDiffuserModelTypeSd2Base, + kDiffuserModelTypeTigo, +}; + +enum DiffuserPriorityHint { + kDiffuserPriorityHintHigh, + kDiffuserPriorityHintNormal, + kDiffuserPriorityHintLow, +}; + +enum DiffuserPerformanceHint { + kDiffuserPerformanceHintHigh, + kDiffuserPerformanceHintNormal, + kDiffuserPerformanceHintLow, +}; + +typedef struct { + DiffuserPriorityHint priority_hint; + DiffuserPerformanceHint performance_hint; +} DiffuserEnvironmentOptions; + +typedef struct { + DiffuserModelType model_type; + char model_dir[PATH_MAX]; + char lora_dir[PATH_MAX]; + const void* lora_weights_layer_mapping; + int lora_rank; + int seed; + int image_width; + int image_height; + int run_unet_with_plugins; + DiffuserEnvironmentOptions env_options; +} DiffuserConfig; + +typedef struct { + void* diffuser; +} DiffuserContext; + +typedef struct { + int shape[4]; + const float* data; +} DiffuserPluginTensor; + +DG_EXPORT DiffuserContext* DiffuserCreate(const DiffuserConfig*); // NOLINT +DG_EXPORT int DiffuserReset(DiffuserContext*, // NOLINT + const char*, int, int, float, const void*); +DG_EXPORT int DiffuserIterate(DiffuserContext*, int, int); // NOLINT +DG_EXPORT int DiffuserDecode(DiffuserContext*, uint8_t*); // NOLINT +DG_EXPORT void DiffuserDelete(DiffuserContext*); // NOLINT + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_DIFFUSER_DIFFUSER_GPU_H_ diff --git a/mediapipe/tasks/cc/vision/image_generator/diffuser/diffusion_plugins_output_calculator.cc b/mediapipe/tasks/cc/vision/image_generator/diffuser/diffusion_plugins_output_calculator.cc new file mode 100644 index 000000000..a2282b907 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/diffuser/diffusion_plugins_output_calculator.cc @@ -0,0 +1,66 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" + +namespace mediapipe { +namespace api2 { + +// In iteration mode, output the image guidance tensors at the current timestamp +// and advance the output stream timestamp bound by the number of steps. +// Otherwise, output the image guidance tensors at the current timestamp only. +class DiffusionPluginsOutputCalculator : public Node { + public: + static constexpr Input> kTensorsIn{"TENSORS"}; + static constexpr Input kStepsIn{"STEPS"}; + static constexpr Input::Optional kIterationIn{"ITERATION"}; + static constexpr Output> kTensorsOut{"TENSORS"}; + MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kStepsIn, kIterationIn, kTensorsOut); + + absl::Status Process(CalculatorContext* cc) override { + if (kTensorsIn(cc).IsEmpty()) { + return absl::OkStatus(); + } + // Consumes the tensor vector to avoid data copy. + absl::StatusOr>> status_or_tensor = + cc->Inputs().Tag("TENSORS").Value().Consume>(); + if (!status_or_tensor.ok()) { + return absl::InternalError("Input tensor vector is not consumable."); + } + if (kIterationIn(cc).IsConnected()) { + ABSL_CHECK_EQ(kIterationIn(cc).Get(), 0); + kTensorsOut(cc).Send(std::move(*status_or_tensor.value())); + kTensorsOut(cc).SetNextTimestampBound(cc->InputTimestamp() + + kStepsIn(cc).Get()); + } else { + kTensorsOut(cc).Send(std::move(*status_or_tensor.value())); + } + return absl::OkStatus(); + } +}; + +MEDIAPIPE_REGISTER_NODE(DiffusionPluginsOutputCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.cc b/mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.cc new file mode 100644 index 000000000..f7eb7c1b6 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.cc @@ -0,0 +1,299 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/absl_log.h" +#include "absl/status/status.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_context.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_helpers.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/tasks/cc/vision/image_generator/diffuser/diffuser_gpu.h" +#include "mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.pb.h" + +namespace mediapipe { +namespace api2 { +namespace { + +DiffuserPriorityHint ToDiffuserPriorityHint( + StableDiffusionIterateCalculatorOptions::ClPriorityHint priority) { + switch (priority) { + case StableDiffusionIterateCalculatorOptions::PRIORITY_HINT_LOW: + return kDiffuserPriorityHintLow; + case StableDiffusionIterateCalculatorOptions::PRIORITY_HINT_NORMAL: + return kDiffuserPriorityHintNormal; + case StableDiffusionIterateCalculatorOptions::PRIORITY_HINT_HIGH: + return kDiffuserPriorityHintHigh; + } + return kDiffuserPriorityHintNormal; +} + +DiffuserModelType ToDiffuserModelType( + StableDiffusionIterateCalculatorOptions::ModelType model_type) { + switch (model_type) { + case StableDiffusionIterateCalculatorOptions::DEFAULT: + case StableDiffusionIterateCalculatorOptions::SD_1: + return kDiffuserModelTypeSd1; + } + return kDiffuserModelTypeSd1; +} + +} // namespace + +// Runs diffusion models including, but not limited to, Stable Diffusion & gLDM. +// +// Inputs: +// PROMPT - std::string +// The prompt used to generate the image. +// STEPS - int +// The number of steps to run the UNet. +// ITERATION - int +// The iteration of the current run. +// PLUGIN_TENSORS - std::vector @Optional +// The output tensor vector of the diffusion plugins model. +// PLUGIN_STRENGTH - float @Optional +// The strength of the plugin tensors. +// SHOW_RESULT - bool @Optional +// Whether to show the diffusion result at the current step, regardless +// of what show_every_n_iteration is set to. +// +// Outputs: +// IMAGE - mediapipe::ImageFrame +// The image generated by the Stable Diffusion model from the input prompt. +// The output image is in RGB format. +// +// Example: +// node { +// calculator: "StableDiffusionIterateCalculator" +// input_stream: "PROMPT:prompt" +// input_stream: "STEPS:steps" +// output_stream: "IMAGE:result" +// options { +// [mediapipe.StableDiffusionIterateCalculatorOptions.ext] { +// base_seed: 0 +// model_type: SD_1 +// } +// } +// } +class StableDiffusionIterateCalculator : public Node { + public: + static constexpr Input kPromptIn{"PROMPT"}; + static constexpr Input kStepsIn{"STEPS"}; + static constexpr Input::Optional kIterationIn{"ITERATION"}; + static constexpr Input::Optional kRandSeedIn{"RAND_SEED"}; + static constexpr SideInput::Optional + kOptionsIn{"OPTIONS"}; + static constexpr Input>::Optional kPlugInTensorsIn{ + "PLUGIN_TENSORS"}; + static constexpr Input::Optional kPluginStrengthIn{"PLUGIN_STRENGTH"}; + static constexpr Input::Optional kShowResultIn{"SHOW_RESULT"}; + static constexpr Output kImageOut{"IMAGE"}; + MEDIAPIPE_NODE_CONTRACT(kPromptIn, kStepsIn, kIterationIn, kRandSeedIn, + kPlugInTensorsIn, kPluginStrengthIn, kShowResultIn, + kOptionsIn, kImageOut); + + ~StableDiffusionIterateCalculator() { + if (context_) DiffuserDelete(); + if (handle_) dlclose(handle_); + } + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + std::vector GetPluginTensors( + CalculatorContext* cc) const { + if (!kPlugInTensorsIn(cc).IsConnected()) return {}; + std::vector diffuser_tensors; + diffuser_tensors.reserve(kPlugInTensorsIn(cc)->size()); + for (const auto& mp_tensor : *kPlugInTensorsIn(cc)) { + DiffuserPluginTensor diffuser_tensor; + diffuser_tensor.shape[0] = mp_tensor.shape().dims[0]; + diffuser_tensor.shape[1] = mp_tensor.shape().dims[1]; + diffuser_tensor.shape[2] = mp_tensor.shape().dims[2]; + diffuser_tensor.shape[3] = mp_tensor.shape().dims[3]; + diffuser_tensor.data = mp_tensor.GetCpuReadView().buffer(); + diffuser_tensors.push_back(diffuser_tensor); + } + return diffuser_tensors; + } + + absl::Status LoadDiffuser() { + handle_ = dlopen("libimagegenerator_gpu.so", RTLD_NOW | RTLD_LOCAL); + RET_CHECK(handle_) << dlerror(); + create_ptr_ = reinterpret_cast( + dlsym(handle_, "DiffuserCreate")); + RET_CHECK(create_ptr_) << dlerror(); + reset_ptr_ = + reinterpret_cast(dlsym(handle_, "DiffuserReset")); + RET_CHECK(reset_ptr_) << dlerror(); + iterate_ptr_ = reinterpret_cast( + dlsym(handle_, "DiffuserIterate")); + RET_CHECK(iterate_ptr_) << dlerror(); + decode_ptr_ = reinterpret_cast( + dlsym(handle_, "DiffuserDecode")); + RET_CHECK(decode_ptr_) << dlerror(); + delete_ptr_ = reinterpret_cast( + dlsym(handle_, "DiffuserDelete")); + RET_CHECK(delete_ptr_) << dlerror(); + return absl::OkStatus(); + } + + DiffuserContext* DiffuserCreate(const DiffuserConfig* a) { + return (*create_ptr_)(a); + } + bool DiffuserReset(const char* a, int b, int c, float d, + const std::vector* e) { + return (*reset_ptr_)(context_, a, b, c, d, e); + } + bool DiffuserIterate(int a, int b) { return (*iterate_ptr_)(context_, a, b); } + bool DiffuserDecode(uint8_t* a) { return (*decode_ptr_)(context_, a); } + void DiffuserDelete() { (*delete_ptr_)(context_); } + + void* handle_ = nullptr; + DiffuserContext* context_ = nullptr; + DiffuserContext* (*create_ptr_)(const DiffuserConfig*); + int (*reset_ptr_)(DiffuserContext*, const char*, int, int, float, + const void*); + int (*iterate_ptr_)(DiffuserContext*, int, int); + int (*decode_ptr_)(DiffuserContext*, uint8_t*); + void (*delete_ptr_)(DiffuserContext*); + + int show_every_n_iteration_; + bool emit_empty_packet_; +}; + +absl::Status StableDiffusionIterateCalculator::Open(CalculatorContext* cc) { + StableDiffusionIterateCalculatorOptions options; + if (kOptionsIn(cc).IsEmpty()) { + options = cc->Options(); + } else { + options = kOptionsIn(cc).Get(); + } + show_every_n_iteration_ = options.show_every_n_iteration(); + emit_empty_packet_ = options.emit_empty_packet(); + + MP_RETURN_IF_ERROR(LoadDiffuser()); + + DiffuserConfig config; + config.model_type = ToDiffuserModelType(options.model_type()); + if (options.file_folder().empty()) { + std::strcpy(config.model_dir, "bins/"); // NOLINT + } else { + std::strcpy(config.model_dir, options.file_folder().c_str()); // NOLINT + } + MP_RETURN_IF_ERROR(mediapipe::file::Exists(config.model_dir)) + << config.model_dir; + RET_CHECK(options.lora_file_folder().empty() || + options.lora_weights_layer_mapping().empty()) + << "Can't set both lora_file_folder and lora_weights_layer_mapping."; + std::strcpy(config.lora_dir, options.lora_file_folder().c_str()); // NOLINT + std::map lora_weights_layer_mapping; + for (auto& layer_name_and_weights : options.lora_weights_layer_mapping()) { + lora_weights_layer_mapping[layer_name_and_weights.first] = + (char*)layer_name_and_weights.second; + } + config.lora_weights_layer_mapping = !lora_weights_layer_mapping.empty() + ? &lora_weights_layer_mapping + : nullptr; + config.lora_rank = options.lora_rank(); + config.seed = options.base_seed(); + config.image_width = options.output_image_width(); + config.image_height = options.output_image_height(); + config.run_unet_with_plugins = kPlugInTensorsIn(cc).IsConnected(); + config.env_options = { + .priority_hint = ToDiffuserPriorityHint(options.cl_priority_hint()), + .performance_hint = kDiffuserPerformanceHintHigh, + }; + RET_CHECK(options.plugins_strength() >= 0.0f || + options.plugins_strength() <= 1.0f) + << "The value of plugins_strength must be in the range of [0, 1]."; + context_ = DiffuserCreate(&config); + RET_CHECK(context_); + return absl::OkStatus(); +} + +absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) { + const auto& options = + cc->Options().GetExtension(StableDiffusionIterateCalculatorOptions::ext); + const std::string& prompt = *kPromptIn(cc); + const int steps = *kStepsIn(cc); + const int rand_seed = !kRandSeedIn(cc).IsEmpty() ? std::abs(*kRandSeedIn(cc)) + : options.base_seed(); + float plugins_strength = options.plugins_strength(); + if (kPluginStrengthIn(cc).IsConnected() && !kPluginStrengthIn(cc).IsEmpty()) { + plugins_strength = kPluginStrengthIn(cc).Get(); + RET_CHECK(plugins_strength >= 0.0f || plugins_strength <= 1.0f) + << "The value of plugins_strength must be in the range of [0, 1]."; + } + + if (kIterationIn(cc).IsEmpty()) { + const auto plugin_tensors = GetPluginTensors(cc); + RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed, plugins_strength, + &plugin_tensors)); + for (int i = 0; i < steps; i++) RET_CHECK(DiffuserIterate(steps, i)); + ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(), + options.output_image_height()); + RET_CHECK(DiffuserDecode(image_out.MutablePixelData())); + kImageOut(cc).Send(std::move(image_out)); + } else { + const int iteration = *kIterationIn(cc); + RET_CHECK_LT(iteration, steps); + + // Extract text embedding on first iteration. + if (iteration == 0) { + const auto plugin_tensors = GetPluginTensors(cc); + RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed, + plugins_strength, &plugin_tensors)); + } + + RET_CHECK(DiffuserIterate(steps, iteration)); + + bool force_show_result = kShowResultIn(cc).IsConnected() && + !kShowResultIn(cc).IsEmpty() && + kShowResultIn(cc).Get(); + bool show_result = force_show_result || + (iteration + 1) % show_every_n_iteration_ == 0 || + iteration == steps - 1; + // Decode the output and send out the image for visualization. + if (show_result) { + ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(), + options.output_image_height()); + RET_CHECK(DiffuserDecode(image_out.MutablePixelData())); + kImageOut(cc).Send(std::move(image_out)); + } else if (emit_empty_packet_) { + kImageOut(cc).Send(Packet()); + } + } + return absl::OkStatus(); +} + +MEDIAPIPE_REGISTER_NODE(StableDiffusionIterateCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.proto b/mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.proto new file mode 100644 index 000000000..48d7c1a65 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.proto @@ -0,0 +1,84 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +option java_package = "com.google.mediapipe.calculator.proto"; +option java_outer_classname = "StableDiffusionIterateCalculatorOptionsProto"; + +message StableDiffusionIterateCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional StableDiffusionIterateCalculatorOptions ext = 510855836; + } + + // The random seed that is fed into the calculator to control the randomness + // of the generated image. + optional uint32 base_seed = 1 [default = 0]; + + // The target output image size. Must be a multiple of 8 and larger than 384. + optional int32 output_image_width = 2 [default = 512]; + optional int32 output_image_height = 3 [default = 512]; + + // The folder name must end of '/'. + optional string file_folder = 4 [default = "bins/"]; + + // Note: only one of lora_file_folder and lora_weights_layer_mapping should be + // set. + // The LoRA file folder. The folder name must end of '/'. + optional string lora_file_folder = 9 [default = ""]; + + // The LoRA layer name mapping to the weight buffer position in the file. + map lora_weights_layer_mapping = 10; + + // The LoRA rank. + optional int32 lora_rank = 12 [default = 4]; + + // Determine when to run image decoding for how many every iterations. + // Setting this to 1 means we run the image decoding for every iteration for + // displaying the intermediate result, but it will also introduce much higher + // overall latency. + // Setting this to be the targeted number of iterations will only run the + // image decoding at the end, giving the best overall latency. + optional int32 show_every_n_iteration = 5 [default = 1]; + + // If set to be True, the calculator will perform a GPU-CPU sync and emit an + // empty packet. It is used to provide the signal of which iterations it is + // currently at, typically used to create a progress bar. Note that this also + // introduce overhead, but not significantly based on our experiments (~1ms). + optional bool emit_empty_packet = 6 [default = false]; + + enum ClPriorityHint { + PRIORITY_HINT_NORMAL = 0; // Default, must be first. + PRIORITY_HINT_LOW = 1; + PRIORITY_HINT_HIGH = 2; + } + + // OpenCL priority hint. Set this to LOW to yield to other GPU contexts. + // This lowers inference speed, but helps keeping the UI responsive. + optional ClPriorityHint cl_priority_hint = 7; + + enum ModelType { + DEFAULT = 0; + SD_1 = 1; // Stable Diffusion v1 models, including SD 1.4 and 1.5. + } + // Stable Diffusion model type. Default to Stable Diffusion v1. + optional ModelType model_type = 8 [default = SD_1]; + // The strength of the diffusion plugins inputs. + optional float plugins_strength = 11 [default = 1.0]; +} diff --git a/mediapipe/tasks/cc/vision/image_generator/image_generator.cc b/mediapipe/tasks/cc/vision/image_generator/image_generator.cc new file mode 100644 index 000000000..e4464d84d --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/image_generator.cc @@ -0,0 +1,397 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/image_generator/image_generator.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_generator/image_generator_result.h" +#include "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_generator/proto/image_generator_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace image_generator { +namespace { + +using ImageGeneratorGraphOptionsProto = ::mediapipe::tasks::vision:: + image_generator::proto::ImageGeneratorGraphOptions; +using ConditionedImageGraphOptionsProto = ::mediapipe::tasks::vision:: + image_generator::proto::ConditionedImageGraphOptions; +using ControlPluginGraphOptionsProto = ::mediapipe::tasks::vision:: + image_generator::proto::ControlPluginGraphOptions; +using FaceLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision:: + face_landmarker::proto::FaceLandmarkerGraphOptions; + +constexpr absl::string_view kImageTag = "IMAGE"; +constexpr absl::string_view kImageOutName = "image_out"; +constexpr absl::string_view kConditionImageTag = "CONDITION_IMAGE"; +constexpr absl::string_view kConditionImageName = "condition_image"; +constexpr absl::string_view kSourceConditionImageName = + "source_condition_image"; +constexpr absl::string_view kStepsTag = "STEPS"; +constexpr absl::string_view kStepsName = "steps"; +constexpr absl::string_view kIterationTag = "ITERATION"; +constexpr absl::string_view kIterationName = "iteration"; +constexpr absl::string_view kPromptTag = "PROMPT"; +constexpr absl::string_view kPromptName = "prompt"; +constexpr absl::string_view kRandSeedTag = "RAND_SEED"; +constexpr absl::string_view kRandSeedName = "rand_seed"; +constexpr absl::string_view kSelectTag = "SELECT"; +constexpr absl::string_view kSelectName = "select"; + +constexpr char kImageGeneratorGraphTypeName[] = + "mediapipe.tasks.vision.image_generator.ImageGeneratorGraph"; + +constexpr char kConditionedImageGraphContainerTypeName[] = + "mediapipe.tasks.vision.image_generator.ConditionedImageGraphContainer"; + +// Creates a MediaPipe graph config that contains a subgraph node of +// "mediapipe.tasks.vision.image_generator.ImageGeneratorGraph". +CalculatorGraphConfig CreateImageGeneratorGraphConfig( + std::unique_ptr options, + bool use_condition_image) { + api2::builder::Graph graph; + auto& subgraph = graph.AddNode(kImageGeneratorGraphTypeName); + subgraph.GetOptions().CopyFrom(*options); + graph.In(kStepsTag).SetName(kStepsName) >> subgraph.In(kStepsTag); + graph.In(kIterationTag).SetName(kIterationName) >> subgraph.In(kIterationTag); + graph.In(kPromptTag).SetName(kPromptName) >> subgraph.In(kPromptTag); + graph.In(kRandSeedTag).SetName(kRandSeedName) >> subgraph.In(kRandSeedTag); + if (use_condition_image) { + graph.In(kConditionImageTag).SetName(kConditionImageName) >> + subgraph.In(kConditionImageTag); + graph.In(kSelectTag).SetName(kSelectName) >> subgraph.In(kSelectTag); + } + subgraph.Out(kImageTag).SetName(kImageOutName) >> + graph[api2::Output::Optional(kImageTag)]; + return graph.GetConfig(); +} + +// Creates a MediaPipe graph config that contains a subgraph node of +// "mediapipe.tasks.vision.image_generator.ConditionedImageGraphContainer". +CalculatorGraphConfig CreateConditionedImageGraphContainerConfig( + std::unique_ptr options) { + api2::builder::Graph graph; + auto& subgraph = graph.AddNode(kConditionedImageGraphContainerTypeName); + subgraph.GetOptions().CopyFrom(*options); + graph.In(kImageTag).SetName(kSourceConditionImageName) >> + subgraph.In(kImageTag); + graph.In(kSelectTag).SetName(kSelectName) >> subgraph.In(kSelectTag); + subgraph.Out(kConditionImageTag).SetName(kConditionImageName) >> + graph.Out(kConditionImageTag).Cast(); + return graph.GetConfig(); +} + +absl::Status SetFaceConditionOptionsToProto( + FaceConditionOptions& face_condition_options, + ControlPluginGraphOptionsProto& options_proto) { + // Configure face plugin model. + auto plugin_base_options_proto = + std::make_unique( + tasks::core::ConvertBaseOptionsToProto( + &(face_condition_options.base_options))); + options_proto.mutable_base_options()->Swap(plugin_base_options_proto.get()); + + // Configure face landmarker graph. + auto& face_landmarker_options = + face_condition_options.face_landmarker_options; + auto& face_landmarker_options_proto = + *options_proto.mutable_conditioned_image_graph_options() + ->mutable_face_condition_type_options() + ->mutable_face_landmarker_graph_options(); + + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto( + &(face_landmarker_options.base_options))); + face_landmarker_options_proto.mutable_base_options()->Swap( + base_options_proto.get()); + face_landmarker_options_proto.mutable_base_options()->set_use_stream_mode( + false); + + // Configure face detector options. + auto* face_detector_graph_options = + face_landmarker_options_proto.mutable_face_detector_graph_options(); + face_detector_graph_options->set_num_faces(face_landmarker_options.num_faces); + face_detector_graph_options->set_min_detection_confidence( + face_landmarker_options.min_face_detection_confidence); + + // Configure face landmark detector options. + face_landmarker_options_proto.set_min_tracking_confidence( + face_landmarker_options.min_tracking_confidence); + auto* face_landmarks_detector_graph_options = + face_landmarker_options_proto + .mutable_face_landmarks_detector_graph_options(); + face_landmarks_detector_graph_options->set_min_detection_confidence( + face_landmarker_options.min_face_presence_confidence); + return absl::OkStatus(); +} + +absl::Status SetDepthConditionOptionsToProto( + DepthConditionOptions& depth_condition_options, + ControlPluginGraphOptionsProto& options_proto) { + // Configure face plugin model. + auto plugin_base_options_proto = + std::make_unique( + tasks::core::ConvertBaseOptionsToProto( + &(depth_condition_options.base_options))); + options_proto.mutable_base_options()->Swap(plugin_base_options_proto.get()); + + auto& image_segmenter_graph_options = + *options_proto.mutable_conditioned_image_graph_options() + ->mutable_depth_condition_type_options() + ->mutable_image_segmenter_graph_options(); + + auto depth_base_options_proto = + std::make_unique( + tasks::core::ConvertBaseOptionsToProto( + &(depth_condition_options.image_segmenter_options.base_options))); + image_segmenter_graph_options.mutable_base_options()->Swap( + depth_base_options_proto.get()); + image_segmenter_graph_options.mutable_base_options()->set_use_stream_mode( + false); + image_segmenter_graph_options.set_display_names_locale( + depth_condition_options.image_segmenter_options.display_names_locale); + return absl::OkStatus(); +} + +absl::Status SetEdgeConditionOptionsToProto( + EdgeConditionOptions& edge_condition_options, + ControlPluginGraphOptionsProto& options_proto) { + auto plugin_base_options_proto = + std::make_unique( + tasks::core::ConvertBaseOptionsToProto( + &(edge_condition_options.base_options))); + options_proto.mutable_base_options()->Swap(plugin_base_options_proto.get()); + + auto& edge_options_proto = + *options_proto.mutable_conditioned_image_graph_options() + ->mutable_edge_condition_type_options(); + edge_options_proto.set_threshold_1(edge_condition_options.threshold_1); + edge_options_proto.set_threshold_2(edge_condition_options.threshold_2); + edge_options_proto.set_aperture_size(edge_condition_options.aperture_size); + edge_options_proto.set_l2_gradient(edge_condition_options.l2_gradient); + return absl::OkStatus(); +} + +// Helper holder struct of image generator graph options and condition type +// index mapping. +struct ImageGeneratorOptionsProtoAndConditionTypeIndex { + std::unique_ptr options_proto; + std::unique_ptr> + condition_type_index; +}; + +// Converts the user-facing ImageGeneratorOptions struct to the internal +// ImageGeneratorOptions proto. +absl::StatusOr +ConvertImageGeneratorGraphOptionsProto( + ImageGeneratorOptions* image_generator_options, + ConditionOptions* condition_options) { + ImageGeneratorOptionsProtoAndConditionTypeIndex + options_proto_and_condition_index; + + // Configure base image generator options. + options_proto_and_condition_index.options_proto = + std::make_unique(); + auto& options_proto = *options_proto_and_condition_index.options_proto; + options_proto.set_text2image_model_directory( + image_generator_options->text2image_model_directory); + if (image_generator_options->lora_weights_file_path.has_value()) { + options_proto.mutable_lora_weights_file()->set_file_name( + *image_generator_options->lora_weights_file_path); + } + + // Configure optional condition type options. + if (condition_options != nullptr) { + options_proto_and_condition_index.condition_type_index = + std::make_unique>(); + auto& condition_type_index = + *options_proto_and_condition_index.condition_type_index; + if (condition_options->face_condition_options.has_value()) { + condition_type_index[ConditionOptions::FACE] = + condition_type_index.size(); + auto& face_plugin_graph_options = + *options_proto.add_control_plugin_graphs_options(); + RET_CHECK_OK(SetFaceConditionOptionsToProto( + *condition_options->face_condition_options, + face_plugin_graph_options)); + } + if (condition_options->depth_condition_options.has_value()) { + condition_type_index[ConditionOptions::DEPTH] = + condition_type_index.size(); + auto& depth_plugin_graph_options = + *options_proto.add_control_plugin_graphs_options(); + RET_CHECK_OK(SetDepthConditionOptionsToProto( + *condition_options->depth_condition_options, + depth_plugin_graph_options)); + } + if (condition_options->edge_condition_options.has_value()) { + condition_type_index[ConditionOptions::EDGE] = + condition_type_index.size(); + auto& edge_plugin_graph_options = + *options_proto.add_control_plugin_graphs_options(); + RET_CHECK_OK(SetEdgeConditionOptionsToProto( + *condition_options->edge_condition_options, + edge_plugin_graph_options)); + } + if (condition_type_index.empty()) { + return absl::InvalidArgumentError( + "At least one condition type must be set."); + } + } + return options_proto_and_condition_index; +} + +} // namespace + +absl::StatusOr> ImageGenerator::Create( + std::unique_ptr image_generator_options, + std::unique_ptr condition_options) { + bool use_condition_image = condition_options != nullptr; + ASSIGN_OR_RETURN(auto options_proto_and_condition_index, + ConvertImageGeneratorGraphOptionsProto( + image_generator_options.get(), condition_options.get())); + std::unique_ptr + options_proto_for_condition_image_graphs_container; + if (use_condition_image) { + options_proto_for_condition_image_graphs_container = + std::make_unique(); + options_proto_for_condition_image_graphs_container->CopyFrom( + *options_proto_and_condition_index.options_proto); + } + ASSIGN_OR_RETURN( + auto image_generator, + (core::VisionTaskApiFactory::Create( + CreateImageGeneratorGraphConfig( + std::move(options_proto_and_condition_index.options_proto), + use_condition_image), + std::make_unique(), + core::RunningMode::IMAGE, + /*result_callback=*/nullptr))); + image_generator->use_condition_image_ = use_condition_image; + if (use_condition_image) { + image_generator->condition_type_index_ = + std::move(options_proto_and_condition_index.condition_type_index); + ASSIGN_OR_RETURN( + image_generator->condition_image_graphs_container_task_runner_, + tasks::core::TaskRunner::Create( + CreateConditionedImageGraphContainerConfig( + std::move(options_proto_for_condition_image_graphs_container)), + absl::make_unique())); + } + image_generator->init_timestamp_ = absl::Now(); + return image_generator; +} + +absl::StatusOr ImageGenerator::CreateConditionImage( + Image source_condition_image, + ConditionOptions::ConditionType condition_type) { + if (condition_type_index_->find(condition_type) == + condition_type_index_->end()) { + return absl::InvalidArgumentError( + "The condition type is not created during initialization."); + } + ASSIGN_OR_RETURN( + auto output_packets, + condition_image_graphs_container_task_runner_->Process({ + {std::string(kSourceConditionImageName), + MakePacket(std::move(source_condition_image))}, + {std::string(kSelectName), + MakePacket(condition_type_index_->at(condition_type))}, + })); + return output_packets.at(std::string(kConditionImageName)).Get(); +} + +absl::StatusOr ImageGenerator::Generate( + const std::string& prompt, int iterations, int seed) { + if (use_condition_image_) { + return absl::InvalidArgumentError( + "ImageGenerator is created to use with conditioned image."); + } + return RunIterations(prompt, iterations, seed, std::nullopt); +} + +absl::StatusOr ImageGenerator::Generate( + const std::string& prompt, Image condition_image, + ConditionOptions::ConditionType condition_type, int iterations, int seed) { + if (!use_condition_image_) { + return absl::InvalidArgumentError( + "ImageGenerator is created to use without conditioned image."); + } + ASSIGN_OR_RETURN(auto plugin_model_image, + CreateConditionImage(condition_image, condition_type)); + return RunIterations( + prompt, iterations, seed, + ConditionInputs{plugin_model_image, + condition_type_index_->at(condition_type)}); +} + +absl::StatusOr ImageGenerator::RunIterations( + const std::string& prompt, int steps, int rand_seed, + std::optional condition_inputs) { + tasks::core::PacketMap output_packets; + ImageGeneratorResult result; + auto timestamp = (absl::Now() - init_timestamp_) / absl::Milliseconds(1); + for (int i = 0; i < steps; ++i) { + tasks::core::PacketMap input_packets; + if (i == 0 && condition_inputs.has_value()) { + input_packets[std::string(kConditionImageName)] = + MakePacket(condition_inputs->condition_image) + .At(Timestamp(timestamp)); + input_packets[std::string(kSelectName)] = + MakePacket(condition_inputs->select).At(Timestamp(timestamp)); + } + input_packets[std::string(kStepsName)] = + MakePacket(steps).At(Timestamp(timestamp)); + input_packets[std::string(kIterationName)] = + MakePacket(i).At(Timestamp(timestamp)); + input_packets[std::string(kPromptName)] = + MakePacket(prompt).At(Timestamp(timestamp)); + input_packets[std::string(kRandSeedName)] = + MakePacket(rand_seed).At(Timestamp(timestamp)); + ASSIGN_OR_RETURN(output_packets, ProcessImageData(input_packets)); + timestamp += 1; + } + result.generated_image = + output_packets.at(std::string(kImageOutName)).Get(); + if (condition_inputs.has_value()) { + result.condition_image = condition_inputs->condition_image; + } + return result; +} + +} // namespace image_generator +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_generator/image_generator.h b/mediapipe/tasks/cc/vision/image_generator/image_generator.h new file mode 100644 index 000000000..52599c02f --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/image_generator.h @@ -0,0 +1,157 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_H_ +#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.h" +#include "mediapipe/tasks/cc/vision/image_generator/image_generator_result.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace image_generator { + +// Options for drawing face landmarks image. +struct FaceConditionOptions { + // The base options for plugin model. + tasks::core::BaseOptions base_options; + + // Face landmarker options used to detect face landmarks in the condition + // image. + face_landmarker::FaceLandmarkerOptions face_landmarker_options; +}; + +// Options for detecting edges image. +struct EdgeConditionOptions { + // The base options for plugin model. + tasks::core::BaseOptions base_options; + + // These parameters are used to config Canny edge algorithm of OpenCV. + // See more details: + // https://docs.opencv.org/3.4/dd/d1a/group__imgproc__feature.html#ga04723e007ed888ddf11d9ba04e2232de + + // First threshold for the hysteresis procedure. + float threshold_1 = 100; + + // Second threshold for the hysteresis procedure. + float threshold_2 = 200; + + // Aperture size for the Sobel operator. Typical range is 3~7. + int aperture_size = 3; + + // A flag, indicating whether a more accurate L2 norm should be used to + // calculate the image gradient magnitude ( L2gradient=true ), or whether + // the default L1 norm is enough ( L2gradient=false ). + bool l2_gradient = false; +}; + +// Options for detecting depth image. +struct DepthConditionOptions { + // The base options for plugin model. + tasks::core::BaseOptions base_options; + + // Image segmenter options used to detect depth in the condition image. + image_segmenter::ImageSegmenterOptions image_segmenter_options; +}; + +struct ConditionOptions { + enum ConditionType { FACE, EDGE, DEPTH }; + std::optional face_condition_options; + std::optional edge_condition_options; + std::optional depth_condition_options; +}; + +// Note: The API is experimental and subject to change. +// The options for configuring a mediapipe image generator task. +struct ImageGeneratorOptions { + // The text to image model directory storing the model weights. + std::string text2image_model_directory; + + // The path to LoRA weights file. + std::optional lora_weights_file_path; +}; + +class ImageGenerator : tasks::vision::core::BaseVisionTaskApi { + public: + using BaseVisionTaskApi::BaseVisionTaskApi; + + // Creates an ImageGenerator from the provided options. + // image_generator_options: options to create the image generator. + // condition_options: optional options if plugin models are used to generate + // an image based on the condition image. + static absl::StatusOr> Create( + std::unique_ptr image_generator_options, + std::unique_ptr condition_options = nullptr); + + // Create the condition image of specified condition type from the source + // condition image. Currently support face landmarks, depth image and edge + // image as the condition image. + absl::StatusOr CreateConditionImage( + Image source_condition_image, + ConditionOptions::ConditionType condition_type); + + // Generates an image for iterations and the given random seed. Only valid + // when the ImageGenerator is created without condition options. + absl::StatusOr Generate(const std::string& prompt, + int iterations, int seed = 0); + + // Generates an image based on the condition image for iterations and the + // given random seed. + // A detailed introduction to the condition image: + // https://ai.googleblog.com/2023/06/on-device-diffusion-plugins-for.html + absl::StatusOr Generate( + const std::string& prompt, Image condition_image, + ConditionOptions::ConditionType condition_type, int iterations, + int seed = 0); + + private: + struct ConditionInputs { + Image condition_image; + int select; + }; + + bool use_condition_image_ = false; + + absl::Time init_timestamp_; + + std::unique_ptr + condition_image_graphs_container_task_runner_; + + std::unique_ptr> + condition_type_index_; + + absl::StatusOr RunIterations( + const std::string& prompt, int steps, int rand_seed, + std::optional condition_inputs); +}; + +} // namespace image_generator +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_H_ diff --git a/mediapipe/tasks/cc/vision/image_generator/image_generator_graph.cc b/mediapipe/tasks/cc/vision/image_generator/image_generator_graph.cc new file mode 100644 index 000000000..efbfd86e9 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/image_generator_graph.cc @@ -0,0 +1,385 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" +#include "mediapipe/calculators/tensor/inference_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/tool/switch_container.pb.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_generator/proto/image_generator_graph_options.pb.h" +#include "mediapipe/util/graph_builder_utils.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace image_generator { + +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; + +constexpr int kPluginsOutputSize = 512; +constexpr absl::string_view kTensorsTag = "TENSORS"; +constexpr absl::string_view kImageTag = "IMAGE"; +constexpr absl::string_view kImageCpuTag = "IMAGE_CPU"; +constexpr absl::string_view kStepsTag = "STEPS"; +constexpr absl::string_view kIterationTag = "ITERATION"; +constexpr absl::string_view kPromptTag = "PROMPT"; +constexpr absl::string_view kRandSeedTag = "RAND_SEED"; +constexpr absl::string_view kPluginTensorsTag = "PLUGIN_TENSORS"; +constexpr absl::string_view kConditionImageTag = "CONDITION_IMAGE"; +constexpr absl::string_view kSelectTag = "SELECT"; +constexpr absl::string_view kShowResultTag = "SHOW_RESULT"; +constexpr absl::string_view kMetadataFilename = "metadata"; +constexpr absl::string_view kLoraRankStr = "lora_rank"; + +struct ImageGeneratorInputs { + Source prompt; + Source steps; + Source iteration; + Source rand_seed; + std::optional> condition_image; + std::optional> select_condition_type; + std::optional> show_result; +}; + +struct ImageGeneratorOutputs { + Source generated_image; +}; + +} // namespace + +// A container graph containing several ConditionedImageGraph from which to +// choose specified condition type. +// Inputs: +// IMAGE - Image +// The source condition image, used to generate the condition image. +// SELECT - int +// The index of the selected conditioned image graph. +// Outputs: +// CONDITION_IMAGE - Image +// The condition image created from the specified condition type. +class ConditionedImageGraphContainer : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + Graph graph; + auto& graph_options = + *sc->MutableOptions(); + auto source_condition_image = graph.In(kImageTag).Cast(); + auto select_condition_type = graph.In(kSelectTag).Cast(); + auto& switch_container = graph.AddNode("SwitchContainer"); + auto& switch_options = + switch_container.GetOptions(); + for (auto& control_plugin_graph_options : + *graph_options.mutable_control_plugin_graphs_options()) { + auto& node = *switch_options.add_contained_node(); + node.set_calculator( + "mediapipe.tasks.vision.image_generator.ConditionedImageGraph"); + node.mutable_node_options()->Add()->PackFrom( + control_plugin_graph_options.conditioned_image_graph_options()); + } + source_condition_image >> switch_container.In(kImageTag); + select_condition_type >> switch_container.In(kSelectTag); + auto condition_image = switch_container.Out(kImageTag).Cast(); + condition_image >> graph.Out(kConditionImageTag); + return graph.GetConfig(); + } +}; + +// clang-format off +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::image_generator::ConditionedImageGraphContainer); // NOLINT +// clang-format on + +// A helper graph to convert condition image to Tensor using the control plugin +// model. +// Inputs: +// CONDITION_IMAGE - Image +// The condition image input to the control plugin model. +// Outputs: +// PLUGIN_TENSORS - std::vector +// The output tensors from the control plugin model. The tensors are used as +// inputs to the image generation model. +class ControlPluginGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + Graph graph; + auto& graph_options = + *sc->MutableOptions(); + + auto condition_image = graph.In(kConditionImageTag).Cast(); + + // Convert Image to ImageFrame. + auto& from_image = graph.AddNode("FromImageCalculator"); + condition_image >> from_image.In(kImageTag); + auto image_frame = from_image.Out(kImageCpuTag); + + // Convert ImageFrame to Tensor. + auto& image_to_tensor = graph.AddNode("ImageToTensorCalculator"); + auto& image_to_tensor_options = + image_to_tensor.GetOptions(); + image_to_tensor_options.set_output_tensor_width(kPluginsOutputSize); + image_to_tensor_options.set_output_tensor_height(kPluginsOutputSize); + image_to_tensor_options.mutable_output_tensor_float_range()->set_min(-1); + image_to_tensor_options.mutable_output_tensor_float_range()->set_max(1); + image_to_tensor_options.set_keep_aspect_ratio(true); + image_frame >> image_to_tensor.In(kImageTag); + + // Create the plugin model resource. + ASSIGN_OR_RETURN( + const core::ModelResources* plugin_model_resources, + CreateModelResources( + sc, + std::make_unique( + *graph_options.mutable_base_options()->mutable_model_asset()))); + + // Add control plugin model inference. + auto& plugins_inference = + AddInference(*plugin_model_resources, + graph_options.base_options().acceleration(), graph); + image_to_tensor.Out(kTensorsTag) >> plugins_inference.In(kTensorsTag); + // The plugins model is not runnable on OpenGL. Error message: + // TfLiteGpuDelegate Prepare: Batch size mismatch, expected 1 but got 64 + // Node number 67 (TfLiteGpuDelegate) failed to prepare. + plugins_inference.GetOptions() + .mutable_delegate() + ->mutable_xnnpack(); + plugins_inference.Out(kTensorsTag).Cast>() >> + graph.Out(kPluginTensorsTag); + return graph.GetConfig(); + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::image_generator::ControlPluginGraph); + +// A "mediapipe.tasks.vision.image_generator.ImageGeneratorGraph" performs image +// generation from a text prompt, and a optional condition image. +// +// Inputs: +// PROMPT - std::string +// The prompt describing the image to be generated. +// STEPS - int +// The total steps to generate the image. +// ITERATION - int +// The current iteration in the generating steps. Must be less than STEPS. +// RAND_SEED - int +// The randaom seed input to the image generation model. +// CONDITION_IMAGE - Image +// The condition image used as a guidance for the image generation. Only +// valid, if condtrol plugin graph options are set in the graph options. +// SELECT - int +// The index of the selected the control plugin graph. +// SHOW_RESULT - bool @Optional +// Whether to show the diffusion result at the current step. If this stream +// is not empty, regardless show_every_n_iteration in the options. +// +// Outputs: +// IMAGE - Image +// The generated image. +// STEPS - int @optional +// The total steps to generate the image. The same as STEPS input. +// ITERATION - int @optional +// The current iteration in the generating steps. The same as ITERATION +// input. +// SHOW_RESULT - bool @Optional +// Whether to show the diffusion result at the current step. The same as +// input SHOW_RESULT. +class ImageGeneratorGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + Graph graph; + auto* subgraph_options = + sc->MutableOptions(); + std::optional lora_resources; + // Create LoRA weights asset bundle resources. + if (subgraph_options->has_lora_weights_file()) { + auto external_file = std::make_unique(); + external_file->Swap(subgraph_options->mutable_lora_weights_file()); + ASSIGN_OR_RETURN(lora_resources, CreateModelAssetBundleResources( + sc, std::move(external_file))); + } + std::optional> condition_image; + std::optional> select_condition_type; + if (!subgraph_options->control_plugin_graphs_options().empty()) { + condition_image = graph.In(kConditionImageTag).Cast(); + select_condition_type = graph.In(kSelectTag).Cast(); + } + std::optional> show_result; + if (HasInput(sc->OriginalNode(), kShowResultTag)) { + show_result = graph.In(kShowResultTag).Cast(); + } + ASSIGN_OR_RETURN( + auto outputs, + BuildImageGeneratorGraph( + *sc->MutableOptions(), + lora_resources, + ImageGeneratorInputs{ + /*prompt=*/graph.In(kPromptTag).Cast(), + /*steps=*/graph.In(kStepsTag).Cast(), + /*iteration=*/graph.In(kIterationTag).Cast(), + /*rand_seed=*/graph.In(kRandSeedTag).Cast(), + /*condition_image*/ condition_image, + /*select_condition_type*/ select_condition_type, + /*show_result*/ show_result, + }, + graph)); + outputs.generated_image >> graph.Out(kImageTag).Cast(); + + // Optional outputs to provide the current iteration. + auto& pass_through = graph.AddNode("PassThroughCalculator"); + graph.In(kIterationTag) >> pass_through.In(0); + graph.In(kStepsTag) >> pass_through.In(1); + pass_through.Out(0) >> graph[Output::Optional(kIterationTag)]; + pass_through.Out(1) >> graph[Output::Optional(kStepsTag)]; + if (HasOutput(sc->OriginalNode(), kShowResultTag)) { + graph.In(kShowResultTag) >> pass_through.In(2); + pass_through.Out(2) >> graph[Output::Optional(kShowResultTag)]; + } + return graph.GetConfig(); + } + + absl::StatusOr BuildImageGeneratorGraph( + proto::ImageGeneratorGraphOptions& subgraph_options, + std::optional lora_resources, + ImageGeneratorInputs inputs, Graph& graph) { + auto& stable_diff = graph.AddNode("StableDiffusionIterateCalculator"); + if (inputs.condition_image.has_value()) { + // Add switch container for multiple control plugin graphs. + auto& switch_container = graph.AddNode("SwitchContainer"); + auto& switch_options = + switch_container.GetOptions(); + for (auto& control_plugin_graph_options : + *subgraph_options.mutable_control_plugin_graphs_options()) { + auto& node = *switch_options.add_contained_node(); + node.set_calculator( + "mediapipe.tasks.vision.image_generator.ControlPluginGraph"); + node.mutable_node_options()->Add()->PackFrom( + control_plugin_graph_options); + } + *inputs.condition_image >> switch_container.In(kConditionImageTag); + *inputs.select_condition_type >> switch_container.In(kSelectTag); + auto plugin_tensors = switch_container.Out(kPluginTensorsTag); + + // Additional diffusion plugins calculator to pass tensors to diffusion + // iterator. + auto& plugins_output = graph.AddNode("DiffusionPluginsOutputCalculator"); + plugin_tensors >> plugins_output.In(kTensorsTag); + inputs.steps >> plugins_output.In(kStepsTag); + inputs.iteration >> plugins_output.In(kIterationTag); + plugins_output.Out(kTensorsTag) >> stable_diff.In(kPluginTensorsTag); + } + + inputs.prompt >> stable_diff.In(kPromptTag); + inputs.steps >> stable_diff.In(kStepsTag); + inputs.iteration >> stable_diff.In(kIterationTag); + inputs.rand_seed >> stable_diff.In(kRandSeedTag); + if (inputs.show_result.has_value()) { + *inputs.show_result >> stable_diff.In(kShowResultTag); + } + mediapipe::StableDiffusionIterateCalculatorOptions& options = + stable_diff + .GetOptions(); + if (subgraph_options.has_stable_diffusion_iterate_options()) { + options = subgraph_options.stable_diffusion_iterate_options(); + } else { + options.set_base_seed(0); + options.set_output_image_height(kPluginsOutputSize); + options.set_output_image_width(kPluginsOutputSize); + options.set_file_folder(subgraph_options.text2image_model_directory()); + options.set_show_every_n_iteration(100); + options.set_emit_empty_packet(true); + } + if (lora_resources.has_value()) { + auto& lora_layer_weights_mapping = + *options.mutable_lora_weights_layer_mapping(); + for (const auto& file_path : (*lora_resources)->ListFiles()) { + auto basename = file::Basename(file_path); + ASSIGN_OR_RETURN(auto file_content, + (*lora_resources)->GetFile(std::string(file_path))); + if (file_path == kMetadataFilename) { + MP_RETURN_IF_ERROR( + ParseLoraMetadataAndConfigOptions(file_content, options)); + } else { + lora_layer_weights_mapping[basename] = + reinterpret_cast(file_content.data()); + } + } + } + + auto& to_image = graph.AddNode("ToImageCalculator"); + stable_diff.Out(kImageTag) >> to_image.In(kImageCpuTag); + + return {{to_image.Out(kImageTag).Cast()}}; + } + + private: + absl::Status ParseLoraMetadataAndConfigOptions( + absl::string_view contents, + mediapipe::StableDiffusionIterateCalculatorOptions& options) { + std::vector lines = + absl::StrSplit(contents, '\n', absl::SkipEmpty()); + for (const auto& line : lines) { + std::vector values = absl::StrSplit(line, ','); + if (values[0] == kLoraRankStr) { + int lora_rank; + if (values.size() != 2 || !absl::SimpleAtoi(values[1], &lora_rank)) { + return absl::InvalidArgumentError( + absl::StrCat("Error parsing LoRA weights metadata. ", line)); + } + options.set_lora_rank(lora_rank); + } + } + return absl::OkStatus(); + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::image_generator::ImageGeneratorGraph); + +} // namespace image_generator +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_generator/image_generator_result.h b/mediapipe/tasks/cc/vision/image_generator/image_generator_result.h new file mode 100644 index 000000000..7b7054d74 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/image_generator_result.h @@ -0,0 +1,41 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_RESULT_H_ +#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_RESULT_H_ + +#include "mediapipe/framework/formats/image.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace image_generator { + +// The result of ImageGenerator task. +struct ImageGeneratorResult { + // The generated image. + Image generated_image; + + // The condition_image used in the plugin model, only available if the + // condition type is set in ImageGeneratorOptions. + std::optional condition_image = std::nullopt; +}; + +} // namespace image_generator +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_RESULT_H_ diff --git a/mediapipe/tasks/cc/vision/image_generator/proto/BUILD b/mediapipe/tasks/cc/vision/image_generator/proto/BUILD new file mode 100644 index 000000000..971bb7f07 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/proto/BUILD @@ -0,0 +1,53 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = [ + "//mediapipe/tasks:internal", +]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "conditioned_image_graph_options_proto", + srcs = ["conditioned_image_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_proto", + ], +) + +mediapipe_proto_library( + name = "control_plugin_graph_options_proto", + srcs = ["control_plugin_graph_options.proto"], + deps = [ + ":conditioned_image_graph_options_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) + +mediapipe_proto_library( + name = "image_generator_graph_options_proto", + srcs = ["image_generator_graph_options.proto"], + deps = [ + ":control_plugin_graph_options_proto", + "//mediapipe/tasks/cc/core/proto:external_file_proto", + "//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_proto", + ], +) diff --git a/mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.proto b/mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.proto new file mode 100644 index 000000000..8d0798d76 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.proto @@ -0,0 +1,66 @@ + +/* Copyright 2023 The MediaPipe Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +syntax = "proto3"; + +package mediapipe.tasks.vision.image_generator.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.proto"; +import "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.imagegenerator.proto"; +option java_outer_classname = "ConditionedImageGraphOptionsProto"; + +message ConditionedImageGraphOptions { + // For conditioned image graph based on face landmarks. + message FaceConditionTypeOptions { + // Options for the face landmarker used in the face landmarks type graph. + face_landmarker.proto.FaceLandmarkerGraphOptions + face_landmarker_graph_options = 1; + } + + // For conditioned image graph base on edges detection. + message EdgeConditionTypeOptions { + // These parameters are used to config Canny edge algorithm of OpenCV. + // See more details: + // https://docs.opencv.org/3.4/dd/d1a/group__imgproc__feature.html#ga04723e007ed888ddf11d9ba04e2232de + + // First threshold for the hysteresis procedure. + float threshold_1 = 1; + + // Second threshold for the hysteresis procedure. + float threshold_2 = 2; + + // Aperture size for the Sobel operator. Typical range is 3~7. + int32 aperture_size = 3; + + // A flag, indicating whether a more accurate L2 norm should be used to + // calculate the image gradient magnitude ( L2gradient=true ), or whether + // the default L1 norm is enough ( L2gradient=false ). + bool l2_gradient = 4; + } + + // For conditioned image graph base on depth map. + message DepthConditionTypeOptions { + // Options for the image segmenter used in the depth condition type graph. + image_segmenter.proto.ImageSegmenterGraphOptions + image_segmenter_graph_options = 1; + } + + // The options for configuring the conditioned image graph. + oneof condition_type_options { + FaceConditionTypeOptions face_condition_type_options = 2; + EdgeConditionTypeOptions edge_condition_type_options = 3; + DepthConditionTypeOptions depth_condition_type_options = 4; + } +} diff --git a/mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.proto b/mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.proto new file mode 100644 index 000000000..52d94efb3 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.proto @@ -0,0 +1,34 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package mediapipe.tasks.vision.image_generator.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; +import "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.imagegenerator.proto"; +option java_outer_classname = "ControlPluginGraphOptionsProto"; + +message ControlPluginGraphOptions { + // The base options for the control plugin model. + core.proto.BaseOptions base_options = 1; + + // The options for the ConditionedImageGraphOptions to generate control plugin + // model input image. + proto.ConditionedImageGraphOptions conditioned_image_graph_options = 2; +} diff --git a/mediapipe/tasks/cc/vision/image_generator/proto/image_generator_graph_options.proto b/mediapipe/tasks/cc/vision/image_generator/proto/image_generator_graph_options.proto new file mode 100644 index 000000000..5bbf8de15 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/proto/image_generator_graph_options.proto @@ -0,0 +1,39 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package mediapipe.tasks.vision.image_generator.proto; + +import "mediapipe/tasks/cc/core/proto/external_file.proto"; +import "mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.proto"; +import "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.imagegenerator.proto"; +option java_outer_classname = "ImageGeneratorGraphOptionsProto"; + +message ImageGeneratorGraphOptions { + // The directory containing the models weight of the text to image model. + string text2image_model_directory = 1; + + // An optional LoRA weights file. If set, the diffusion model will be created + // with LoRA weights. + core.proto.ExternalFile lora_weights_file = 2; + + repeated proto.ControlPluginGraphOptions control_plugin_graphs_options = 3; + + mediapipe.StableDiffusionIterateCalculatorOptions + stable_diffusion_iterate_options = 4; +} diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index ee1cd3693..fa67d9af3 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,6 +37,7 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:status", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", @@ -63,6 +64,8 @@ cc_library( "//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/image:image_transformation_calculator_cc_proto", + "//mediapipe/calculators/image:set_alpha_calculator", + "//mediapipe/calculators/image:set_alpha_calculator_cc_proto", "//mediapipe/calculators/tensor:image_to_tensor_calculator", "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", "//mediapipe/calculators/tensor:inference_calculator", @@ -93,6 +96,7 @@ cc_library( "//mediapipe/util:graph_builder_utils", "//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_util", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD index c621016dc..b32c5d052 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -50,7 +50,12 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - ], + ] + select({ + "//conditions:default": [], + "//mediapipe:android": [ + ":segmentation_postprocessor_gl", + ], + }), alwayslink = 1, ) @@ -72,6 +77,29 @@ cc_library( "//mediapipe/tasks/cc/vision/utils:image_utils", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", + ] + select({ + "//conditions:default": [], + "//mediapipe:android": [ + "ssbo_to_texture_converter", + ], + }), +) + +cc_library( + name = "ssbo_to_texture_converter", + srcs = ["ssbo_to_texture_converter.cc"], + hdrs = ["ssbo_to_texture_converter.h"], + tags = [ + "nomac", + "notap", + ], + deps = [ + "//mediapipe/framework/formats:tensor", + "//mediapipe/gpu:gl_base", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_texture", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl/converters:util", ], ) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.cc index 5b212069f..b1791fc0a 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.cc @@ -5,6 +5,7 @@ #include #include +#include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "mediapipe/framework/port/status_macros.h" @@ -16,6 +17,14 @@ namespace mediapipe { namespace tasks { namespace { +// On most platforms, glGetUniformLocation returns -1 for an error status, but +// on web we'll see 0 instead. +#ifdef __EMSCRIPTEN__ +const GLint kUniformErrorStatus = 0; +#else +const GLint kUniformErrorStatus = -1; +#endif // __EMSCRIPTEN__ + using mediapipe::kBasicSquareVertices; using mediapipe::kBasicTextureVertices; using mediapipe::kBasicVertexShader; @@ -188,7 +197,7 @@ void main() { // Special argmax shader for N=1 classes. We don't need to worry about softmax // activation (it is assumed softmax requires N > 1 classes), but this should // occur after SIGMOID activation if specified. Instead of a true argmax, we -// simply use 0.5 as the cutoff, assigning 1 (foreground) or 0 (background) +// simply use 0.5 as the cutoff, assigning 0 (foreground) or 255 (background) // based on whether the confidence value reaches this cutoff or not, // respectively. static constexpr char kArgmaxOneClassShader[] = R"( @@ -199,12 +208,12 @@ uniform sampler2D input_texture; void main() { float input_val = texture2D(input_texture, sample_coordinate).x; // Category is just value rounded to nearest integer; then we map to either - // 0 or 1/255 accordingly. If the input has been activated properly, then the + // 0 or 1 accordingly. If the input has been activated properly, then the // values should always be in the range [0, 1]. But just in case it hasn't, to // avoid category overflow issues when the activation function is not properly // chosen, we add an extra clamp here, as performance hit is minimal. - float category = clamp(floor(input_val + 0.5), 0.0, 1.0); - gl_FragColor = vec4(category / 255.0, 0.0, 0.0, 1.0); + float category = clamp(floor(1.5 - input_val), 0.0, 1.0); + gl_FragColor = vec4(category, 0.0, 0.0, 1.0); })"; // Softmax is in 3 steps: @@ -341,7 +350,7 @@ absl::Status SegmentationPostprocessorGl::CreateBasicFragmentShaderProgram( for (const auto& uniform_name : uniform_names) { shader_struct_ptr->uniforms[uniform_name] = glGetUniformLocation(shader_struct_ptr->program, uniform_name.c_str()); - RET_CHECK(shader_struct_ptr->uniforms[uniform_name] > 0) + RET_CHECK(shader_struct_ptr->uniforms[uniform_name] > kUniformErrorStatus) << uniform_name << " uniform not found for " << program_name << " program"; } @@ -359,19 +368,20 @@ absl::Status SegmentationPostprocessorGl::GlInit( // TODO: We could skip this entirely if no confidence masks // are being produced AND num_classes > 1, but num_classes is only // known at runtime, so this would take a little extra refactoring. - LOG(INFO) << "SIGMOID activation function chosen on GPU"; + ABSL_LOG(INFO) << "SIGMOID activation function chosen on GPU"; activation_fn = "vec4 out_value = 1.0 / (exp(-in_value) + 1.0);"; break; case SegmenterOptions::SOFTMAX: if (produce_confidence_masks) { - LOG(INFO) << "SOFTMAX activation function chosen on GPU"; + ABSL_LOG(INFO) << "SOFTMAX activation function chosen on GPU"; } else { - LOG(INFO) << "SOFTMAX activation function chosen on GPU, but only " - << "category mask produced, so not applying."; + ABSL_LOG(INFO) + << "SOFTMAX activation function chosen on GPU, but only " + << "category mask produced, so not applying."; } break; case SegmenterOptions::NONE: - LOG(INFO) << "NONE activation function chosen on GPU"; + ABSL_LOG(INFO) << "NONE activation function chosen on GPU"; break; } @@ -427,10 +437,10 @@ absl::Status SegmentationPostprocessorGl::GlInit( // Get split program uniform locations. split_texture_uniform_ = glGetUniformLocation(split_program_, "input_texture"); - RET_CHECK(split_texture_uniform_ > 0) + RET_CHECK(split_texture_uniform_ > kUniformErrorStatus) << "split input_texture uniform not found."; split_x_offset_uniform_ = glGetUniformLocation(split_program_, "x_offset"); - RET_CHECK(split_x_offset_uniform_ > 0) + RET_CHECK(split_x_offset_uniform_ > kUniformErrorStatus) << "split x_offset uniform not found."; // TODO: If ES3.0+ only, switch to VAO for handling attributes. @@ -445,10 +455,24 @@ absl::Status SegmentationPostprocessorGl::GlInit( kBasicTextureVertices, GL_STATIC_DRAW); glBindBuffer(GL_ARRAY_BUFFER, 0); + +#ifdef TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING + MP_RETURN_IF_ERROR(ssbo_to_texture_converter_.Init()); +#endif // TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING + return absl::OkStatus(); }); } +// On Android, the extensions are prefixed by GL_, whereas on web they are not. +bool SegmentationPostprocessorGl::HasGlExtension(std::string const& extension) { +#ifdef __EMSCRIPTEN__ + return helper_.GetGlContext().HasGlExtension(extension); +#else + return helper_.GetGlContext().HasGlExtension("GL_" + extension); +#endif // __EMSCRIPTEN__ +} + std::vector> SegmentationPostprocessorGl::GetSegmentationResultGpu( const Shape& input_shape, const Shape& output_shape, const Tensor& tensor, @@ -459,18 +483,35 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu( produce_category_mask, &image_outputs]() -> absl::Status { // Get Tensor input and image output parameters + const int width = input_shape.width; // Slice width from shape + const int height = input_shape.height; // Slice height from chape + const int num_outputs = input_shape.channels; // One output per channel + const int num_chunks = (input_shape.channels + 3) / 4; // ceil(channels/4) + const int output_width = output_shape.width; // Final output width + const int output_height = output_shape.height; // Final output height int input_width, input_height; - if (!tensor.ready_as_opengl_texture_2d()) { - LOG(WARNING) << "Tensor wasn't ready on GPU; using slow workaround."; + if (!tensor.ready_on_gpu()) { + ABSL_LOG(WARNING) << "Tensor wasn't ready on GPU; using slow workaround."; (void)tensor.GetCpuReadView(); } +#ifdef TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING + // If our Tensor is an SSBO, then it's also linearized, so we convert to a + // kAligned 2d texture using a special converter and then proceed as before. + GLuint ssbo_tex_id; + ASSIGN_OR_RETURN(ssbo_tex_id, + ssbo_to_texture_converter_.ConvertTensorToGlTexture( + tensor, width, height, num_outputs)); + std::tie(input_width, input_height) = + ssbo_to_texture_converter_.GetTextureSize(); +#else const auto layout = tensor.GetOpenGlTexture2dReadView().GetLayoutDimensions( tensor.shape(), &input_width, &input_height); if (layout != Tensor::OpenGlTexture2dView::Layout::kAligned) { - LOG(ERROR) << "Tensor layout not kAligned! Cannot handle."; + ABSL_LOG(ERROR) << "Tensor layout not kAligned! Cannot handle."; } +#endif // TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING // Optimization: Only apply SOFTMAX when producing confidence masks, since // SOFTMAX errors out when num_classes = 1, so we don't have to worry about @@ -486,14 +527,12 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu( // (3) blending // Otherwise, we just try for F16. See b/277656755 for more information. // TODO: In the future, separate these 3 different restrictions. - // TODO: Also, we should extend this logic to non-web platforms. - static bool can_use_f32 = - helper_.GetGlContext().HasGlExtension("EXT_color_buffer_float") && - helper_.GetGlContext().HasGlExtension("OES_texture_float_linear") && - helper_.GetGlContext().HasGlExtension("EXT_float_blend"); + // TODO: Also, we should extend this logic to all platforms. + static bool can_use_f32 = HasGlExtension("EXT_color_buffer_float") && + HasGlExtension("OES_texture_float_linear") && + HasGlExtension("EXT_float_blend"); static bool can_use_f16_backup = - helper_.GetGlContext().HasGlExtension("EXT_color_buffer_half_float"); - + HasGlExtension("EXT_color_buffer_half_float"); RET_CHECK(can_use_f32 || can_use_f16_backup) << "Segmentation postprocessing error: GPU does not fully support " << "4-channel float32 or float16 formats."; @@ -510,15 +549,6 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu( const GpuBufferFormat final_output_format = can_use_f32 ? GpuBufferFormat::kGrayFloat32 : GpuBufferFormat::kGrayHalf16; - const Tensor::OpenGlTexture2dView read_view = - tensor.GetOpenGlTexture2dReadView(); - - const int width = input_shape.width; // Slice width from shape - const int height = input_shape.height; // Slice height from chape - const int num_outputs = input_shape.channels; // One output per channel - const int num_chunks = (input_shape.channels + 3) / 4; // ceil(channels/4) - const int output_width = output_shape.width; // Final output width - const int output_height = output_shape.height; // Final output height // We disable blending or else our alpha channel may destroy our other // channels' data. @@ -540,9 +570,16 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu( input_width, input_height, activation_output_format); helper_.BindFramebuffer(activated_texture); - // All our input source textures are just simple GL_TEXTURE_2D types. + // All our input source textures will be just simple GL_TEXTURE_2D types. glActiveTexture(GL_TEXTURE1); + +#ifdef TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING + glBindTexture(GL_TEXTURE_2D, ssbo_tex_id); +#else + const Tensor::OpenGlTexture2dView read_view = + tensor.GetOpenGlTexture2dReadView(); glBindTexture(GL_TEXTURE_2D, read_view.name()); +#endif // TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING // Render glClear(GL_COLOR_BUFFER_BIT); @@ -818,7 +855,7 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu( }); if (!status.ok()) { - LOG(ERROR) << "Error with rendering: " << status; + ABSL_LOG(ERROR) << "Error with rendering: " << status; } return image_outputs; @@ -841,6 +878,10 @@ SegmentationPostprocessorGl::~SegmentationPostprocessorGl() { glDeleteProgram(softmax_max_shader_.program); glDeleteProgram(softmax_transform_and_sum_shader_.program); glDeleteProgram(softmax_normalization_shader_.program); + +#ifdef TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING + ssbo_to_texture_converter_.Close(); +#endif // TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING }); } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.h b/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.h index 791f42037..2b98bbde6 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.h @@ -21,6 +21,14 @@ #include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" +// On Android with compute shaders we include the SSBO-to-texture converter +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 && \ + defined(__ANDROID__) +#define TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING 1 +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/ssbo_to_texture_converter.h" +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 && + // defined(__ANDROID__) + namespace mediapipe { namespace tasks { @@ -45,6 +53,7 @@ class SegmentationPostprocessorGl { }; absl::Status GlInit(const bool produce_confidence_masks); + bool HasGlExtension(std::string const& extension); absl::Status CreateBasicFragmentShaderProgram( std::string const& program_name, std::string const& fragment_shader_source, @@ -69,6 +78,10 @@ class SegmentationPostprocessorGl { GlShader softmax_max_shader_; GlShader softmax_transform_and_sum_shader_; GlShader softmax_normalization_shader_; + +#ifdef TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING + SsboToTextureConverter ssbo_to_texture_converter_; +#endif }; } // namespace tasks diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/ssbo_to_texture_converter.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/ssbo_to_texture_converter.cc new file mode 100644 index 000000000..d7a8bb249 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/ssbo_to_texture_converter.cc @@ -0,0 +1,162 @@ +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/ssbo_to_texture_converter.h" + +#include "tensorflow/lite/delegates/gpu/gl/converters/util.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" + +// Quick compile-time warning to ensure usage on the proper platform. +#if !(MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31) +#warning "SsboToTextureConverter should be used with OpenGL ES 3.1 or above" +#endif + +namespace mediapipe { +namespace tasks { +namespace { + +using ::tflite::gpu::gl::GlProgram; +using ::tflite::gpu::gl::GlShader; + +constexpr int kWorkgroupSize = 8; // Block size for GPU shader. +const tflite::gpu::uint3 workgroup_size = {kWorkgroupSize, kWorkgroupSize, 1}; + +// "Delinearization" shader: +// Example data using n=5 channels: 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14 --> +// 0,1,2,3 | 4,X,X,X | 5,6,7,8 | 9,X,X,X | 10,11,12,13 | 14,X,X,X +const char delinearization_shader_source[] = R"( +precision highp float; +layout(rgba32f, binding = 0) writeonly uniform highp image2D output_texture; + +uniform ivec2 out_size; +uniform int num_channels; +uniform int num_channels_padded; // ^ rounded up to nearest multiple of 4 + +layout(std430, binding = 2) readonly buffer B0 { + float elements[]; +} input_data; // data tensor + +void main() { + int out_width = out_size.x; + int out_height = out_size.y; + + ivec2 gid = ivec2(gl_GlobalInvocationID.xy); + if (gid.x >= out_width || gid.y >= out_height) { return; } + int linear_index_pixels = gid.y * out_width + gid.x; + int linear_index = linear_index_pixels * 4; + + int num_completed_chunks = linear_index / num_channels_padded; + int offset = linear_index % num_channels_padded; + int data_index = num_completed_chunks * num_channels + offset; + + // Early exit if fully outside buffer + int data_size = input_data.elements.length(); + if (data_index >= data_size) return; + + // We add some extra logic here just to ensure we don't overrun buffer and get + // undefined behavior. TODO: Come up with nicer way around this if + // we end up needing this sort of patch more frequently. + float x = input_data.elements[data_index]; + float y = 0.0; + float z = 0.0; + float w = 0.0; + if (data_index + 3 < data_size) { + w = input_data.elements[data_index + 3]; + z = input_data.elements[data_index + 2]; + y = input_data.elements[data_index + 1]; + } else if (data_index + 2 < data_size) { + z = input_data.elements[data_index + 2]; + y = input_data.elements[data_index + 1]; + } else if (data_index + 1 < data_size) { + y = input_data.elements[data_index + 1]; + } + + ivec2 output_coordinate = ivec2(gid.x, gid.y); + vec4 out_value = vec4(x, y, z, w); + imageStore(output_texture, output_coordinate, out_value); +})"; + +// Commonly used to compute the number of blocks to launch in a kernel. +int NumGroups(const int size, const int group_size) { // NOLINT + return (size + group_size - 1) / group_size; +} + +} // namespace + +absl::Status SsboToTextureConverter::Init() { + GlShader delinearization_shader; + std::string delinearization_shader_source_with_headers = + absl::StrCat(tflite::gpu::gl::GetShaderHeader(workgroup_size), + delinearization_shader_source); + MP_RETURN_IF_ERROR(GlShader::CompileShader( + GL_COMPUTE_SHADER, delinearization_shader_source_with_headers, + &delinearization_shader)); + delinearization_program_ = absl::make_unique(); + MP_RETURN_IF_ERROR(GlProgram::CreateWithShader( + delinearization_shader, delinearization_program_.get())); + return absl::OkStatus(); +} + +void SsboToTextureConverter::Close() { delinearization_program_.reset(); } + +std::pair +SsboToTextureConverter::GetTextureSize() { + return std::make_pair(texture_width_, texture_height_); +} + +absl::StatusOr SsboToTextureConverter::ConvertTensorToGlTexture( + const Tensor& tensor, const uint32_t width, const uint32_t height, + const uint32_t channels) { + // The tflite::gpu:: namespace looks like it's much simpler and older-- it + // doesn't tap into any memory pools, and doesn't allow linearF32 filtering + // where available, for example. The key difference is that it uses + // glTexStorage2D for allocation instead of glTexImage2D, which is necessary + // in order to create an immutable format (as required by glBindImageTexture). + // MP will automatically use this for RGBA16F but not RGBA32F textures + // currently, oddly enough. So options are: + // (1) extend MP to similarly handle RGBA32F + // (2) just make our own texture here and keep reusing, recreating if the size + // changes, which should generally not happen. (This is ok because we use + // the texture immediately and never output it from the calculator). + // (3) Change glBindImageTexture call to alternative so we can just use + // existing MP glTexImage2D storage creation? This seems less than + // ideal since it's rather nice to keep the above program in compute + // shader format. + // TODO: To be safe for this initial implementation, we go with + // option #2, as it's simplest/easiest, but this should be cleaned up later. + const uint32_t num_pixels_per_element = ((channels + 3) / 4); + const uint32_t padded_channels = 4 * num_pixels_per_element; + const uint32_t texture_width = width * num_pixels_per_element; + const uint32_t texture_height = height; + if (texture_width != texture_width_ || texture_height != texture_height_) { + // tflite::gpu::gl::GlTexture autoreleases, so we don't have to worry about + // freeing memory. + MP_RETURN_IF_ERROR(CreateReadWriteRgbaImageTexture( + tflite::gpu::DataType::FLOAT32, {texture_width, texture_height}, + &out_texture_)); + texture_width_ = texture_width; + texture_height_ = texture_height; + } + + glBindImageTexture(0 /* output index */, out_texture_.id(), 0, GL_FALSE, 0, + GL_WRITE_ONLY, GL_RGBA32F); + auto read_view = tensor.GetOpenGlBufferReadView(); + glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 2 /* input index */, + read_view.name()); + + glUseProgram(delinearization_program_->id()); + glUniform2i(glGetUniformLocation(delinearization_program_->id(), "out_size"), + texture_width, texture_height); + glUniform1i( + glGetUniformLocation(delinearization_program_->id(), "num_channels"), + channels); + glUniform1i(glGetUniformLocation(delinearization_program_->id(), + "num_channels_padded"), + padded_channels); + + const tflite::gpu::uint3 workgroups = { + NumGroups(texture_width, kWorkgroupSize), + NumGroups(texture_height, kWorkgroupSize), 1}; + MP_RETURN_IF_ERROR(delinearization_program_->Dispatch(workgroups)); + return out_texture_.id(); +} + +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/ssbo_to_texture_converter.h b/mediapipe/tasks/cc/vision/image_segmenter/calculators/ssbo_to_texture_converter.h new file mode 100644 index 000000000..401928912 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/ssbo_to_texture_converter.h @@ -0,0 +1,55 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_CALCULATORS_SSBO_TO_TEXTURE_CONVERTER_H_ +#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_CALCULATORS_SSBO_TO_TEXTURE_CONVERTER_H_ + +#include + +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/gpu/gl_base.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_texture.h" + +namespace mediapipe { +namespace tasks { + +// Helper class for converting Android and Linux Tensors from OpenGL ES >=3.1 +// SSBO objects into OpenGL ES <=3.0 2D textures. Cannot be used with other +// Tensor backends. +class SsboToTextureConverter { + public: + SsboToTextureConverter() = default; + ~SsboToTextureConverter() = default; + absl::Status Init(); + void Close(); + absl::StatusOr ConvertTensorToGlTexture(const Tensor& tensor, + const uint32_t width, + const uint32_t height, + const uint32_t channels); + + // Should only be called after ConvertTensorToGlTexture + std::pair GetTextureSize(); + + private: + uint32_t texture_width_; + uint32_t texture_height_; + tflite::gpu::gl::GlTexture out_texture_; + std::unique_ptr delinearization_program_; +}; + +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_CALCULATORS_SSBO_TO_TEXTURE_CONVERTER_H_ diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc index e6d9ec8af..810acd93a 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -43,9 +43,18 @@ limitations under the License. #include "mediapipe/util/label_map.pb.h" #ifdef __EMSCRIPTEN__ -#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.h" +#define TASK_SEGMENTATION_USE_GL_POSTPROCESSING 1 +#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 && \ + !MEDIAPIPE_USING_SWIFTSHADER && defined(MEDIAPIPE_ANDROID) +#define TASK_SEGMENTATION_USE_GL_POSTPROCESSING 1 +#else +#undef TASK_SEGMENTATION_USE_GL_POSTPROCESSING #endif // __EMSCRIPTEN__ +#ifdef TASK_SEGMENTATION_USE_GL_POSTPROCESSING +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.h" +#endif // TASK_SEGMENTATION_USE_GL_POSTPROCESSING + // TODO: consolidate TensorToSegmentationCalculator. namespace mediapipe { namespace tasks { @@ -61,6 +70,8 @@ using ::mediapipe::tasks::vision::GetImageLikeTensorShape; using ::mediapipe::tasks::vision::Shape; using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; +constexpr uint8_t kUnLabeledPixelValue = 255; + void StableSoftmax(absl::Span values, absl::Span activated_values) { float max_value = *std::max_element(values.begin(), values.end()); @@ -153,9 +164,11 @@ Image ProcessForCategoryMaskCpu(const Shape& input_shape, } if (input_channels == 1) { // if the input tensor is a single mask, it is assumed to be a binary - // foreground segmentation mask. For such a mask, we make foreground - // category 1, and background category 0. - pixel = static_cast(confidence_scores[0] > 0.5f); + // foreground segmentation mask. For such a mask, instead of a true + // argmax, we simply use 0.5 as the cutoff, assigning 0 (foreground) or + // 255 (background) based on whether the confidence value reaches this + // cutoff or not, respectively. + pixel = confidence_scores[0] > 0.5f ? 0 : kUnLabeledPixelValue; } else { const int maximum_category_idx = std::max_element(confidence_scores.begin(), confidence_scores.end()) - @@ -287,8 +300,11 @@ class TensorsToSegmentationCalculator : public Node { static constexpr Output::Multiple kConfidenceMaskOut{ "CONFIDENCE_MASK"}; static constexpr Output::Optional kCategoryMaskOut{"CATEGORY_MASK"}; + static constexpr Output>::Optional kQualityScoresOut{ + "QUALITY_SCORES"}; MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut, - kConfidenceMaskOut, kCategoryMaskOut); + kConfidenceMaskOut, kCategoryMaskOut, + kQualityScoresOut); static absl::Status UpdateContract(CalculatorContract* cc); @@ -301,19 +317,19 @@ class TensorsToSegmentationCalculator : public Node { const float* tensors_buffer); TensorsToSegmentationCalculatorOptions options_; -#ifdef __EMSCRIPTEN__ +#ifdef TASK_SEGMENTATION_USE_GL_POSTPROCESSING SegmentationPostprocessorGl postprocessor_; -#endif // __EMSCRIPTEN__ +#endif // TASK_SEGMENTATION_USE_GL_POSTPROCESSING }; // static absl::Status TensorsToSegmentationCalculator::UpdateContract( CalculatorContract* cc) { -#ifdef __EMSCRIPTEN__ +#ifdef TASK_SEGMENTATION_USE_GL_POSTPROCESSING return SegmentationPostprocessorGl::UpdateContract(cc); #else return absl::OkStatus(); -#endif // __EMSCRIPTEN__ +#endif // TASK_SEGMENTATION_USE_GL_POSTPROCESSING } absl::Status TensorsToSegmentationCalculator::Open( @@ -333,20 +349,41 @@ absl::Status TensorsToSegmentationCalculator::Open( "connected."); } } -#ifdef __EMSCRIPTEN__ +#ifdef TASK_SEGMENTATION_USE_GL_POSTPROCESSING MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_)); -#endif // __EMSCRIPTEN__ +#endif // TASK_SEGMENTATION_USE_GL_POSTPROCESSING return absl::OkStatus(); } absl::Status TensorsToSegmentationCalculator::Process( mediapipe::CalculatorContext* cc) { - RET_CHECK_EQ(kTensorsIn(cc).Get().size(), 1) - << "Expect a vector of single Tensor."; - const auto& input_tensor = kTensorsIn(cc).Get()[0]; + const auto& input_tensors = kTensorsIn(cc).Get(); + if (input_tensors.size() != 1 && input_tensors.size() != 2) { + return absl::InvalidArgumentError( + "Expect input tensor vector of size 1 or 2."); + } + const auto& input_tensor = *input_tensors.rbegin(); ASSIGN_OR_RETURN(const Shape input_shape, GetImageLikeTensorShape(input_tensor)); + // TODO: should use tensor signature to get the correct output + // tensor. + if (input_tensors.size() == 2) { + const auto& quality_tensor = input_tensors[0]; + const float* quality_score_buffer = + quality_tensor.GetCpuReadView().buffer(); + const std::vector quality_scores( + quality_score_buffer, + quality_score_buffer + + (quality_tensor.bytes() / quality_tensor.element_size())); + kQualityScoresOut(cc).Send(quality_scores); + } else { + // If the input_tensors don't contain quality scores, send the default + // quality scores as 1. + const std::vector quality_scores(input_shape.channels, 1.0f); + kQualityScoresOut(cc).Send(quality_scores); + } + // Category mask does not require activation function. if (options_.segmenter_options().output_type() == SegmenterOptions::CONFIDENCE_MASK && @@ -362,11 +399,11 @@ absl::Status TensorsToSegmentationCalculator::Process( } // Use GPU postprocessing on web when Tensor is there already. -#ifdef __EMSCRIPTEN__ +#ifdef TASK_SEGMENTATION_USE_GL_POSTPROCESSING Shape output_shape = {/* height= */ output_height, /* width= */ output_width, /* channels= */ input_shape.channels}; - if (input_tensor.ready_as_opengl_texture_2d()) { + if (input_tensor.ready_on_gpu()) { bool produce_category_mask = options_.segmenter_options().output_type() == SegmenterOptions::CATEGORY_MASK || cc->Outputs().HasTag("CATEGORY_MASK"); @@ -400,7 +437,7 @@ absl::Status TensorsToSegmentationCalculator::Process( } return absl::OkStatus(); } -#endif // __EMSCRIPTEN__ +#endif // TASK_SEGMENTATION_USE_GL_POSTPROCESSING // Otherwise, use CPU postprocessing. const float* tensors_buffer = input_tensor.GetCpuReadView().buffer(); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto index dbaf34db0..b728f5046 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc index d6a2f3fd9..7b8fb9c9e 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index 33c868e05..74d8047de 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,12 +16,14 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h" #include +#include #include "absl/strings/str_format.h" #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" @@ -41,11 +43,15 @@ constexpr char kConfidenceMasksTag[] = "CONFIDENCE_MASKS"; constexpr char kConfidenceMasksStreamName[] = "confidence_masks"; constexpr char kCategoryMaskTag[] = "CATEGORY_MASK"; constexpr char kCategoryMaskStreamName[] = "category_mask"; +constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; +constexpr char kOutputSizeStreamName[] = "output_size"; constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageTag[] = "IMAGE"; constexpr char kNormRectStreamName[] = "norm_rect_in"; constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kQualityScoresStreamName[] = "quality_scores"; +constexpr char kQualityScoresTag[] = "QUALITY_SCORES"; constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; @@ -68,6 +74,7 @@ CalculatorGraphConfig CreateGraphConfig( options.get()); graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName); + graph.In(kOutputSizeTag).SetName(kOutputSizeStreamName); if (output_confidence_masks) { task_subgraph.Out(kConfidenceMasksTag) .SetName(kConfidenceMasksStreamName) >> @@ -77,14 +84,18 @@ CalculatorGraphConfig CreateGraphConfig( task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >> graph.Out(kCategoryMaskTag); } + task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >> + graph.Out(kQualityScoresTag); task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); if (enable_flow_limiting) { return tasks::core::AddFlowLimiterCalculator( - graph, task_subgraph, {kImageTag, kNormRectTag}, kConfidenceMasksTag); + graph, task_subgraph, {kImageTag, kNormRectTag, kOutputSizeTag}, + kConfidenceMasksTag); } graph.In(kImageTag) >> task_subgraph.In(kImageTag); graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag); + graph.In(kOutputSizeTag) >> task_subgraph.In(kOutputSizeTag); return graph.GetConfig(); } @@ -172,9 +183,13 @@ absl::StatusOr> ImageSegmenter::Create( category_mask = status_or_packets.value()[kCategoryMaskStreamName].Get(); } + const std::vector& quality_scores = + status_or_packets.value()[kQualityScoresStreamName] + .Get>(); Packet image_packet = status_or_packets.value()[kImageOutStreamName]; result_callback( - {{confidence_masks, category_mask}}, image_packet.Get(), + {{confidence_masks, category_mask, quality_scores}}, + image_packet.Get(), image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); }; } @@ -203,6 +218,16 @@ absl::StatusOr> ImageSegmenter::Create( absl::StatusOr ImageSegmenter::Segment( mediapipe::Image image, std::optional image_processing_options) { + return Segment(image, { + /*output_width=*/image.width(), + /*output_height=*/image.height(), + std::move(image_processing_options), + }); +} + +absl::StatusOr ImageSegmenter::Segment( + mediapipe::Image image, SegmentationOptions segmentation_options) { + MP_RETURN_IF_ERROR(ValidateSegmentationOptions(segmentation_options)); if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, @@ -210,14 +235,19 @@ absl::StatusOr ImageSegmenter::Segment( MediaPipeTasksStatus::kRunnerUnexpectedInputError); } ASSIGN_OR_RETURN(NormalizedRect norm_rect, - ConvertToNormalizedRect(image_processing_options, image, - /*roi_allowed=*/false)); + ConvertToNormalizedRect( + segmentation_options.image_processing_options, image, + /*roi_allowed=*/false)); ASSIGN_OR_RETURN( auto output_packets, ProcessImageData( {{kImageInStreamName, mediapipe::MakePacket(std::move(image))}, {kNormRectStreamName, - MakePacket(std::move(norm_rect))}})); + MakePacket(std::move(norm_rect))}, + {kOutputSizeStreamName, + MakePacket>( + std::make_pair(segmentation_options.output_width, + segmentation_options.output_height))}})); std::optional> confidence_masks; if (output_confidence_masks_) { confidence_masks = @@ -227,12 +257,26 @@ absl::StatusOr ImageSegmenter::Segment( if (output_category_mask_) { category_mask = output_packets[kCategoryMaskStreamName].Get(); } - return {{confidence_masks, category_mask}}; + const std::vector& quality_scores = + output_packets[kQualityScoresStreamName].Get>(); + return {{confidence_masks, category_mask, quality_scores}}; } absl::StatusOr ImageSegmenter::SegmentForVideo( mediapipe::Image image, int64_t timestamp_ms, std::optional image_processing_options) { + return SegmentForVideo(image, timestamp_ms, + { + /*output_width=*/image.width(), + /*output_height=*/image.height(), + std::move(image_processing_options), + }); +} + +absl::StatusOr ImageSegmenter::SegmentForVideo( + mediapipe::Image image, int64_t timestamp_ms, + SegmentationOptions segmentation_options) { + MP_RETURN_IF_ERROR(ValidateSegmentationOptions(segmentation_options)); if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, @@ -240,8 +284,9 @@ absl::StatusOr ImageSegmenter::SegmentForVideo( MediaPipeTasksStatus::kRunnerUnexpectedInputError); } ASSIGN_OR_RETURN(NormalizedRect norm_rect, - ConvertToNormalizedRect(image_processing_options, image, - /*roi_allowed=*/false)); + ConvertToNormalizedRect( + segmentation_options.image_processing_options, image, + /*roi_allowed=*/false)); ASSIGN_OR_RETURN( auto output_packets, ProcessVideoData( @@ -250,6 +295,11 @@ absl::StatusOr ImageSegmenter::SegmentForVideo( .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, {kNormRectStreamName, MakePacket(std::move(norm_rect)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kOutputSizeStreamName, + MakePacket>( + std::make_pair(segmentation_options.output_width, + segmentation_options.output_height)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); std::optional> confidence_masks; if (output_confidence_masks_) { @@ -260,12 +310,26 @@ absl::StatusOr ImageSegmenter::SegmentForVideo( if (output_category_mask_) { category_mask = output_packets[kCategoryMaskStreamName].Get(); } - return {{confidence_masks, category_mask}}; + const std::vector& quality_scores = + output_packets[kQualityScoresStreamName].Get>(); + return {{confidence_masks, category_mask, quality_scores}}; } absl::Status ImageSegmenter::SegmentAsync( Image image, int64_t timestamp_ms, std::optional image_processing_options) { + return SegmentAsync(image, timestamp_ms, + { + /*output_width=*/image.width(), + /*output_height=*/image.height(), + std::move(image_processing_options), + }); +} + +absl::Status ImageSegmenter::SegmentAsync( + Image image, int64_t timestamp_ms, + SegmentationOptions segmentation_options) { + MP_RETURN_IF_ERROR(ValidateSegmentationOptions(segmentation_options)); if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, @@ -273,14 +337,20 @@ absl::Status ImageSegmenter::SegmentAsync( MediaPipeTasksStatus::kRunnerUnexpectedInputError); } ASSIGN_OR_RETURN(NormalizedRect norm_rect, - ConvertToNormalizedRect(image_processing_options, image, - /*roi_allowed=*/false)); + ConvertToNormalizedRect( + segmentation_options.image_processing_options, image, + /*roi_allowed=*/false)); return SendLiveStreamData( {{kImageInStreamName, MakePacket(std::move(image)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, {kNormRectStreamName, MakePacket(std::move(norm_rect)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kOutputSizeStreamName, + MakePacket>( + std::make_pair(segmentation_options.output_width, + segmentation_options.output_height)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index 352d6b273..82bb3a3a6 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -67,6 +67,22 @@ struct ImageSegmenterOptions { result_callback = nullptr; }; +// Options for configuring runtime behavior of ImageSegmenter. +struct SegmentationOptions { + // The width of the output segmentation masks. + int output_width; + + // The height of the output segmentation masks. + int output_height; + + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing segmentation, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + std::optional image_processing_options; +}; + // Performs segmentation on images. // // The API expects a TFLite model with mandatory TFLite Model Metadata. @@ -102,17 +118,46 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // // The image can be of any size with format RGB or RGBA. // + // The output size is the same as the input image size. + // // The optional 'image_processing_options' parameter can be used to specify // the rotation to apply to the image before performing segmentation, by // setting its 'rotation_degrees' field. Note that specifying a // region-of-interest using the 'region_of_interest' field is NOT supported // and will result in an invalid argument error being returned. - absl::StatusOr Segment( mediapipe::Image image, std::optional image_processing_options = std::nullopt); + // Performs image segmentation on the provided single image. + // Only use this method when the ImageSegmenter is created with the image + // running mode. + // + // The image can be of any size with format RGB or RGBA. + absl::StatusOr Segment( + mediapipe::Image image, SegmentationOptions segmentation_options); + + // Performs image segmentation on the provided video frame. + // Only use this method when the ImageSegmenter is created with the video + // running mode. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide the video frame's timestamp (in milliseconds). The input timestamps + // must be monotonically increasing. + // + // The output size is the same as the input image size. + // + // The optional 'image_processing_options' parameter can be used + // to specify the rotation to apply to the image before performing + // segmentation, by setting its 'rotation_degrees' field. Note that specifying + // a region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + absl::StatusOr SegmentForVideo( + mediapipe::Image image, int64_t timestamp_ms, + std::optional image_processing_options = + std::nullopt); + // Performs image segmentation on the provided video frame. // Only use this method when the ImageSegmenter is created with the video // running mode. @@ -120,16 +165,9 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // The image can be of any size with format RGB or RGBA. It's required to // provide the video frame's timestamp (in milliseconds). The input timestamps // must be monotonically increasing. - // - // The optional 'image_processing_options' parameter can be used to specify - // the rotation to apply to the image before performing segmentation, by - // setting its 'rotation_degrees' field. Note that specifying a - // region-of-interest using the 'region_of_interest' field is NOT supported - // and will result in an invalid argument error being returned. absl::StatusOr SegmentForVideo( mediapipe::Image image, int64_t timestamp_ms, - std::optional image_processing_options = - std::nullopt); + SegmentationOptions segmentation_options); // Sends live image data to perform image segmentation, and the results will // be available via the "result_callback" provided in the @@ -141,13 +179,15 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // sent to the image segmenter. The input timestamps must be monotonically // increasing. // + // The output size is the same as the input image size. + // // The optional 'image_processing_options' parameter can be used to specify // the rotation to apply to the image before performing segmentation, by // setting its 'rotation_degrees' field. Note that specifying a // region-of-interest using the 'region_of_interest' field is NOT supported // and will result in an invalid argument error being returned. // - // The "result_callback" prvoides + // The "result_callback" provides // - An ImageSegmenterResult. // - The const reference to the corresponding input image that the image // segmentation runs on. Note that the const reference to the image will @@ -158,6 +198,26 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { std::optional image_processing_options = std::nullopt); + // Sends live image data to perform image segmentation, and the results will + // be available via the "result_callback" provided in the + // ImageSegmenterOptions. Only use this method when the ImageSegmenter is + // created with the live stream running mode. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide a timestamp (in milliseconds) to indicate when the input image is + // sent to the image segmenter. The input timestamps must be monotonically + // increasing. + // + // The "result_callback" provides + // - An ImageSegmenterResult. + // - The const reference to the corresponding input image that the image + // segmentation runs on. Note that the const reference to the image will + // no longer be valid when the callback returns. To access the image data + // outside of the callback, callers need to make a copy of the image. + // - The input timestamp in milliseconds. + absl::Status SegmentAsync(mediapipe::Image image, int64_t timestamp_ms, + SegmentationOptions segmentation_options); + // Shuts down the ImageSegmenter when all works are done. absl::Status Close() { return runner_->Close(); } @@ -174,6 +234,14 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { std::vector labels_; bool output_confidence_masks_; bool output_category_mask_; + + absl::Status ValidateSegmentationOptions(const SegmentationOptions& options) { + if (options.output_width <= 0 || options.output_height <= 0) { + return absl::InvalidArgumentError( + "Both output_width and output_height must be larger than 0."); + } + return absl::OkStatus(); + } }; } // namespace image_segmenter diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 840e7933a..b49f22ca0 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,16 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include #include +#include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "mediapipe/calculators/image/image_clone_calculator.pb.h" #include "mediapipe/calculators/image/image_transformation_calculator.pb.h" +#include "mediapipe/calculators/image/set_alpha_calculator.pb.h" #include "mediapipe/calculators/tensor/tensor_converter_calculator.pb.h" #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/port.h" @@ -80,6 +83,8 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU"; constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; +constexpr char kSizeTag[] = "SIZE"; +constexpr char kQualityScoresTag[] = "QUALITY_SCORES"; constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA"; // Struct holding the different output streams produced by the image segmenter @@ -89,6 +94,7 @@ struct ImageSegmenterOutputs { std::optional>> confidence_masks; std::optional> category_mask; // The same as the input image, mainly used for live stream mode. + std::optional>> quality_scores; Source image; }; @@ -179,7 +185,7 @@ absl::Status ConfigureTensorsToSegmentationCalculator( } } if (!found_activation_in_metadata) { - LOG(WARNING) + ABSL_LOG(WARNING) << "No activation type is found in model metadata. Use NONE for " "ImageSegmenterGraph."; } @@ -190,19 +196,12 @@ absl::Status ConfigureTensorsToSegmentationCalculator( "Segmentation tflite models are assumed to have a single subgraph.", MediaPipeTasksStatus::kInvalidArgumentError); } - const auto* primary_subgraph = (*model.subgraphs())[0]; - if (primary_subgraph->outputs()->size() != 1) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "Segmentation tflite models are assumed to have a single output.", - MediaPipeTasksStatus::kInvalidArgumentError); - } - ASSIGN_OR_RETURN( *options->mutable_label_items(), - GetLabelItemsIfAny(*metadata_extractor, - *metadata_extractor->GetOutputTensorMetadata()->Get(0), - segmenter_option.display_names_locale())); + GetLabelItemsIfAny( + *metadata_extractor, + **metadata_extractor->GetOutputTensorMetadata()->crbegin(), + segmenter_option.display_names_locale())); return absl::OkStatus(); } @@ -212,10 +211,16 @@ absl::StatusOr GetOutputTensor( const tflite::Model& model = *model_resources.GetTfLiteModel(); const auto* primary_subgraph = (*model.subgraphs())[0]; const auto* output_tensor = - (*primary_subgraph->tensors())[(*primary_subgraph->outputs())[0]]; + (*primary_subgraph->tensors())[*(*primary_subgraph->outputs()).rbegin()]; return output_tensor; } +uint32_t GetOutputTensorsSize(const core::ModelResources& model_resources) { + const tflite::Model& model = *model_resources.GetTfLiteModel(); + const auto* primary_subgraph = (*model.subgraphs())[0]; + return primary_subgraph->outputs()->size(); +} + // Get the input tensor from the tflite model of given model resources. absl::StatusOr GetInputTensor( const core::ModelResources& model_resources) { @@ -249,7 +254,8 @@ void ConfigureTensorConverterCalculator( // the tflite model. absl::StatusOr ConvertImageToTensors( Source image_in, Source norm_rect_in, bool use_gpu, - const core::ModelResources& model_resources, Graph& graph) { + bool is_hair_segmentation, const core::ModelResources& model_resources, + Graph& graph) { ASSIGN_OR_RETURN(const tflite::Tensor* tflite_input_tensor, GetInputTensor(model_resources)); if (tflite_input_tensor->shape()->size() != 4) { @@ -294,9 +300,17 @@ absl::StatusOr ConvertImageToTensors( // Convert from Image to legacy ImageFrame or GpuBuffer. auto& from_image = graph.AddNode("FromImageCalculator"); image_on_device >> from_image.In(kImageTag); - auto image_cpu_or_gpu = + Source image_cpu_or_gpu = from_image.Out(use_gpu ? kImageGpuTag : kImageCpuTag); + if (is_hair_segmentation) { + auto& set_alpha = graph.AddNode("SetAlphaCalculator"); + set_alpha.GetOptions() + .set_alpha_value(0); + image_cpu_or_gpu >> set_alpha.In(use_gpu ? kImageGpuTag : kImageTag); + image_cpu_or_gpu = set_alpha.Out(use_gpu ? kImageGpuTag : kImageTag); + } + // Resize the input image to the model input size. auto& image_transformation = graph.AddNode("ImageTransformationCalculator"); ConfigureImageTransformationCalculator( @@ -344,6 +358,9 @@ absl::StatusOr ConvertImageToTensors( // Describes image rotation and region of image to perform detection // on. // @Optional: rect covering the whole image is used if not specified. +// OUTPUT_SIZE - std::pair @Optional +// The output size of the mask, in width and height. If not specified, the +// output size of the input image is used. // // Outputs: // CONFIDENCE_MASK - mediapipe::Image @Multiple @@ -388,11 +405,16 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { if (!options.segmenter_options().has_output_type()) { MP_RETURN_IF_ERROR(SanityCheck(sc)); } + std::optional>> output_size; + if (HasInput(sc->OriginalNode(), kOutputSizeTag)) { + output_size = graph.In(kOutputSizeTag).Cast>(); + } ASSIGN_OR_RETURN( auto output_streams, BuildSegmentationTask( options, *model_resources, graph[Input(kImageTag)], - graph[Input::Optional(kNormRectTag)], graph)); + graph[Input::Optional(kNormRectTag)], output_size, + graph)); // TODO: remove deprecated output type support. if (options.segmenter_options().has_output_type()) { @@ -423,6 +445,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { *output_streams.category_mask >> graph[Output(kCategoryMaskTag)]; } } + if (output_streams.quality_scores) { + *output_streams.quality_scores >> + graph[Output>::Optional(kQualityScoresTag)]; + } output_streams.image >> graph[Output(kImageTag)]; return graph.GetConfig(); } @@ -453,7 +479,8 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { absl::StatusOr BuildSegmentationTask( const ImageSegmenterGraphOptions& task_options, const core::ModelResources& model_resources, Source image_in, - Source norm_rect_in, Graph& graph) { + Source norm_rect_in, + std::optional>> output_size, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); // Adds preprocessing calculators and connects them to the graph input image @@ -461,28 +488,51 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { bool use_gpu = components::processors::DetermineImagePreprocessingGpuBackend( task_options.base_options().acceleration()); - ASSIGN_OR_RETURN(auto image_and_tensors, - ConvertImageToTensors(image_in, norm_rect_in, use_gpu, - model_resources, graph)); - // Adds inference subgraph and connects its input stream to the output - // tensors produced by the ImageToTensorCalculator. - auto& inference = AddInference( - model_resources, task_options.base_options().acceleration(), graph); - image_and_tensors.tensors >> inference.In(kTensorsTag); - // Adds segmentation calculators for output streams. + // Adds segmentation calculators for output streams. Add this calculator + // first to get the labels. auto& tensor_to_images = graph.AddNode("mediapipe.tasks.TensorsToSegmentationCalculator"); RET_CHECK_OK(ConfigureTensorsToSegmentationCalculator( task_options, model_resources, &tensor_to_images .GetOptions())); + const auto& tensor_to_images_options = + tensor_to_images.GetOptions(); + + // TODO: remove special logic for hair segmentation model. + // The alpha channel of hair segmentation model indicates the interested + // area. The model was designed for live stream mode, so that the mask of + // previous frame is used as the indicator for the next frame. For the first + // frame, it expects the alpha channel to be empty. To consolidate IMAGE, + // VIDEO and LIVE_STREAM mode in mediapipe tasks, here we forcely set the + // alpha channel to be empty if we find the model is the hair segmentation + // model. + bool is_hair_segmentation = false; + if (tensor_to_images_options.label_items_size() == 2 && + tensor_to_images_options.label_items().at(1).name() == "hair") { + is_hair_segmentation = true; + } + + ASSIGN_OR_RETURN( + auto image_and_tensors, + ConvertImageToTensors(image_in, norm_rect_in, use_gpu, + is_hair_segmentation, model_resources, graph)); + // Adds inference subgraph and connects its input stream to the output + // tensors produced by the ImageToTensorCalculator. + auto& inference = AddInference( + model_resources, task_options.base_options().acceleration(), graph); + image_and_tensors.tensors >> inference.In(kTensorsTag); inference.Out(kTensorsTag) >> tensor_to_images.In(kTensorsTag); - // Adds image property calculator for output size. - auto& image_properties = graph.AddNode("ImagePropertiesCalculator"); - image_in >> image_properties.In("IMAGE"); - image_properties.Out("SIZE") >> tensor_to_images.In(kOutputSizeTag); + if (output_size.has_value()) { + *output_size >> tensor_to_images.In(kOutputSizeTag); + } else { + // Adds image property calculator for output size. + auto& image_properties = graph.AddNode("ImagePropertiesCalculator"); + image_in >> image_properties.In(kImageTag); + image_properties.Out(kSizeTag) >> tensor_to_images.In(kOutputSizeTag); + } // Exports multiple segmented masks. // TODO: remove deprecated output type support. @@ -501,9 +551,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { tensor_to_images[Output::Multiple(kSegmentationTag)][i])); } } + auto quality_scores = + tensor_to_images[Output>(kQualityScoresTag)]; return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks, /*confidence_masks=*/std::nullopt, /*category_mask=*/std::nullopt, + /*quality_scores=*/quality_scores, /*image=*/image_and_tensors.image}; } else { std::optional>> confidence_masks; @@ -523,9 +576,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { if (output_category_mask_) { category_mask = tensor_to_images[Output(kCategoryMaskTag)]; } + auto quality_scores = + tensor_to_images[Output>(kQualityScoresTag)]; return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt, /*confidence_masks=*/confidence_masks, /*category_mask=*/category_mask, + /*quality_scores=*/quality_scores, /*image=*/image_and_tensors.image}; } } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h index f14ee7a90..7f159cc39 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,6 +33,10 @@ struct ImageSegmenterResult { // A category mask of uint8 image in GRAY8 format where each pixel represents // the class which the pixel in the original image was predicted to belong to. std::optional category_mask; + // The quality scores of the result masks, in the range of [0, 1]. Defaults to + // `1` if the model doesn't output quality scores. Each element corresponds to + // the score of the category in the model outputs. + std::vector quality_scores; }; } // namespace image_segmenter diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index 21f73e103..656ed0715 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -30,6 +30,7 @@ limitations under the License. #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/test_util.h" #include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" @@ -425,6 +426,28 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } +TEST_F(ImageModeTest, SucceedsSelfieSegmentationSingleLabel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kSelfieSegmentation); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + ASSERT_EQ(segmenter->GetLabels().size(), 1); + EXPECT_EQ(segmenter->GetLabels()[0], "selfie"); + MP_ASSERT_OK(segmenter->Close()); +} + +TEST_F(ImageModeTest, SucceedsSelfieSegmentationLandscapeSingleLabel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kSelfieSegmentationLandscape); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + ASSERT_EQ(segmenter->GetLabels().size(), 1); + EXPECT_EQ(segmenter->GetLabels()[0], "selfie"); + MP_ASSERT_OK(segmenter->Close()); +} + TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) { Image image = GetSRGBImage(JoinPath("./", kTestDataDirectory, "portrait.jpg")); @@ -464,6 +487,9 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) { EXPECT_TRUE(result.category_mask.has_value()); MP_ASSERT_OK(segmenter->Close()); + MP_EXPECT_OK( + SavePngTestOutput(*result.category_mask->GetImageFrameSharedPtr(), + "portrait_selfie_segmentation_expected_category_mask")); cv::Mat selfie_mask = mediapipe::formats::MatView( result.category_mask->GetImageFrameSharedPtr().get()); cv::Mat expected_mask = cv::imread( @@ -471,7 +497,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) { "portrait_selfie_segmentation_expected_category_mask.jpg"), cv::IMREAD_GRAYSCALE); EXPECT_THAT(selfie_mask, - SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 255)); + SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 1)); } TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) { @@ -487,6 +513,9 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) { EXPECT_TRUE(result.category_mask.has_value()); MP_ASSERT_OK(segmenter->Close()); + MP_EXPECT_OK(SavePngTestOutput( + *result.category_mask->GetImageFrameSharedPtr(), + "portrait_selfie_segmentation_landscape_expected_category_mask")); cv::Mat selfie_mask = mediapipe::formats::MatView( result.category_mask->GetImageFrameSharedPtr().get()); cv::Mat expected_mask = cv::imread( @@ -495,7 +524,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) { "portrait_selfie_segmentation_landscape_expected_category_mask.jpg"), cv::IMREAD_GRAYSCALE); EXPECT_THAT(selfie_mask, - SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 255)); + SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 1)); } TEST_F(ImageModeTest, SucceedsHairSegmentation) { diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD index 9523dd679..54dc399d3 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto index 5c7d2ec71..8866b3951 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto index b1ec529d0..86b9b39ad 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/BUILD b/mediapipe/tasks/cc/vision/interactive_segmenter/BUILD index 8552383ac..177cbf43a 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ cc_library( name = "interactive_segmenter", srcs = ["interactive_segmenter.cc"], hdrs = ["interactive_segmenter.h"], + visibility = ["//visibility:public"], deps = [ ":interactive_segmenter_graph", "//mediapipe/framework:calculator_cc_proto", @@ -51,6 +52,7 @@ cc_library( name = "interactive_segmenter_graph", srcs = ["interactive_segmenter_graph.cc"], deps = [ + "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/image:set_alpha_calculator", "//mediapipe/calculators/util:annotation_overlay_calculator", "//mediapipe/calculators/util:flat_color_image_calculator", @@ -59,6 +61,7 @@ cc_library( "//mediapipe/calculators/util:to_image_calculator", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:node", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc index 9d7111e75..38bbf3baf 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/keypoint.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" @@ -50,21 +51,24 @@ constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageOutStreamName[] = "image_out"; constexpr char kRoiStreamName[] = "roi_in"; constexpr char kNormRectStreamName[] = "norm_rect_in"; +constexpr char kQualityScoresStreamName[] = "quality_scores"; constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"}; constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"}; constexpr absl::string_view kImageTag{"IMAGE"}; constexpr absl::string_view kRoiTag{"ROI"}; constexpr absl::string_view kNormRectTag{"NORM_RECT"}; +constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"}; constexpr absl::string_view kSubgraphTypeName{ "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"}; +using components::containers::NormalizedKeypoint; + using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::Image; using ::mediapipe::NormalizedRect; using ::mediapipe::tasks::vision::image_segmenter::ImageSegmenterResult; -using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: image_segmenter::proto::ImageSegmenterGraphOptions; @@ -89,6 +93,8 @@ CalculatorGraphConfig CreateGraphConfig( task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >> graph.Out(kCategoryMaskTag); } + task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >> + graph.Out(kQualityScoresTag); task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); graph.In(kImageTag) >> task_subgraph.In(kImageTag); @@ -116,7 +122,7 @@ absl::StatusOr ConvertRoiToRenderData(const RegionOfInterest& roi) { case RegionOfInterest::Format::kUnspecified: return absl::InvalidArgumentError( "RegionOfInterest format not specified"); - case RegionOfInterest::Format::kKeyPoint: + case RegionOfInterest::Format::kKeyPoint: { RET_CHECK(roi.keypoint.has_value()); auto* annotation = result.add_render_annotations(); annotation->mutable_color()->set_r(255); @@ -125,6 +131,19 @@ absl::StatusOr ConvertRoiToRenderData(const RegionOfInterest& roi) { point->set_x(roi.keypoint->x); point->set_y(roi.keypoint->y); return result; + } + case RegionOfInterest::Format::kScribble: { + RET_CHECK(roi.scribble.has_value()); + auto* annotation = result.add_render_annotations(); + annotation->mutable_color()->set_r(255); + for (const NormalizedKeypoint& keypoint : *(roi.scribble)) { + auto* point = annotation->mutable_scribble()->add_point(); + point->set_normalized(true); + point->set_x(keypoint.x); + point->set_y(keypoint.y); + } + return result; + } } return absl::UnimplementedError("Unrecognized format"); } @@ -186,7 +205,9 @@ absl::StatusOr InteractiveSegmenter::Segment( if (output_category_mask_) { category_mask = output_packets[kCategoryMaskStreamName].Get(); } - return {{confidence_masks, category_mask}}; + const std::vector& quality_scores = + output_packets[kQualityScoresStreamName].Get>(); + return {{confidence_masks, category_mask, quality_scores}}; } } // namespace interactive_segmenter diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h index 350777f31..ad8a558df 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -53,6 +53,7 @@ struct RegionOfInterest { enum class Format { kUnspecified = 0, // Format not specified. kKeyPoint = 1, // Using keypoint to represent ROI. + kScribble = 2, // Using scribble to represent ROI. }; // Specifies the format used to specify the region-of-interest. Note that @@ -61,8 +62,13 @@ struct RegionOfInterest { Format format = Format::kUnspecified; // Represents the ROI in keypoint format, this should be non-nullopt if - // `format` is `KEYPOINT`. + // `format` is `kKeyPoint`. std::optional keypoint; + + // Represents the ROI in scribble format, this should be non-nullopt if + // `format` is `kScribble`. + std::optional> + scribble; }; // Performs interactive segmentation on images. diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc index b907e2156..5ae2792fe 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,12 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "mediapipe/calculators/util/flat_color_image_calculator.pb.h" #include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image.h" @@ -35,6 +37,51 @@ namespace mediapipe { namespace tasks { namespace vision { namespace interactive_segmenter { +namespace internal { + +// A calculator to add thickness to the render data according to the image size, +// so that the render data is scale invariant to the image size. If the render +// data already has thickness, it will be kept as is. +class AddThicknessToRenderDataCalculator : public api2::Node { + public: + static constexpr api2::Input kImageIn{"IMAGE"}; + static constexpr api2::Input kRenderDataIn{ + "RENDER_DATA"}; + static constexpr api2::Output kRenderDataOut{ + "RENDER_DATA"}; + + static constexpr int kModelInputTensorWidth = 512; + static constexpr int kModelInputTensorHeight = 512; + + MEDIAPIPE_NODE_CONTRACT(kImageIn, kRenderDataIn, kRenderDataOut); + + absl::Status Process(CalculatorContext* cc) final { + mediapipe::RenderData render_data = kRenderDataIn(cc).Get(); + Image image = kImageIn(cc).Get(); + double thickness = std::max( + std::max(image.width() / static_cast(kModelInputTensorWidth), + image.height() / static_cast(kModelInputTensorHeight)), + 1.0); + + for (auto& annotation : *render_data.mutable_render_annotations()) { + if (!annotation.has_thickness()) { + annotation.set_thickness(thickness); + } + } + kRenderDataOut(cc).Send(render_data); + return absl::OkStatus(); + } +}; + +// NOLINTBEGIN: Node registration doesn't work when part of calculator name is +// moved to next line. +// clang-format off +MEDIAPIPE_REGISTER_NODE( + ::mediapipe::tasks::vision::interactive_segmenter::internal::AddThicknessToRenderDataCalculator); +// clang-format on +// NOLINTEND + +} // namespace internal namespace { @@ -58,6 +105,8 @@ constexpr absl::string_view kAlphaTag{"ALPHA"}; constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"}; constexpr absl::string_view kNormRectTag{"NORM_RECT"}; constexpr absl::string_view kRoiTag{"ROI"}; +constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"}; +constexpr absl::string_view kRenderDataTag{"RENDER_DATA"}; // Updates the graph to return `roi` stream which has same dimension as // `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is @@ -68,14 +117,23 @@ Source<> RoiToAlpha(Source image, Source roi, bool use_gpu, const absl::string_view image_tag_with_suffix = use_gpu ? kImageGpuTag : kImageCpuTag; + // Adds thickness to the render data so that the render data is scale + // invariant to the input image size. + auto& add_thickness = graph.AddNode( + "mediapipe::tasks::vision::interactive_segmenter::internal::" + "AddThicknessToRenderDataCalculator"); + image >> add_thickness.In(kImageTag); + roi >> add_thickness.In(kRenderDataTag); + auto roi_with_thickness = add_thickness.Out(kRenderDataTag); + // Generates a blank canvas with same size as input image. auto& flat_color = graph.AddNode("FlatColorImageCalculator"); auto& flat_color_options = flat_color.GetOptions(); // SetAlphaCalculator only takes 1st channel. flat_color_options.mutable_color()->set_r(0); - image >> flat_color.In(kImageTag)[0]; - auto blank_canvas = flat_color.Out(kImageTag)[0]; + image >> flat_color.In(kImageTag); + auto blank_canvas = flat_color.Out(kImageTag); auto& from_mp_image = graph.AddNode("FromImageCalculator"); blank_canvas >> from_mp_image.In(kImageTag); @@ -84,7 +142,7 @@ Source<> RoiToAlpha(Source image, Source roi, bool use_gpu, auto& roi_to_alpha = graph.AddNode("AnnotationOverlayCalculator"); blank_canvas_in_cpu_or_gpu >> roi_to_alpha.In(use_gpu ? kImageGpuTag : kImageTag); - roi >> roi_to_alpha.In(0); + roi_with_thickness >> roi_to_alpha.In(0); auto alpha = roi_to_alpha.Out(use_gpu ? kImageGpuTag : kImageTag); return alpha; @@ -162,6 +220,7 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph { image >> from_mp_image.In(kImageTag); auto image_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix); + // Creates an RGBA image with model input tensor size. auto alpha_in_cpu_or_gpu = RoiToAlpha(image, roi, use_gpu, graph); auto& set_alpha = graph.AddNode("SetAlphaCalculator"); @@ -200,6 +259,8 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph { graph[Output(kCategoryMaskTag)]; } } + image_segmenter.Out(kQualityScoresTag) >> + graph[Output>::Optional(kQualityScoresTag)]; image_segmenter.Out(kImageTag) >> graph[Output(kImageTag)]; return graph.GetConfig(); diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc index c761678d0..2bb06428e 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,9 +18,12 @@ limitations under the License. #include #include #include +#include +#include #include "absl/flags/flag.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image.h" @@ -31,6 +34,7 @@ limitations under the License. #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/test_util.h" #include "mediapipe/tasks/cc/components/containers/keypoint.h" #include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" @@ -67,6 +71,10 @@ constexpr absl::string_view kCatsAndDogsJpg{"cats_and_dogs.jpg"}; // Golden mask for the dogs in cats_and_dogs.jpg. constexpr absl::string_view kCatsAndDogsMaskDog1{"cats_and_dogs_mask_dog1.png"}; constexpr absl::string_view kCatsAndDogsMaskDog2{"cats_and_dogs_mask_dog2.png"}; +constexpr absl::string_view kPenguinsLarge{"penguins_large.jpg"}; +constexpr absl::string_view kPenguinsSmall{"penguins_small.jpg"}; +constexpr absl::string_view kPenguinsSmallMask{"penguins_small_mask.png"}; +constexpr absl::string_view kPenguinsLargeMask{"penguins_large_mask.png"}; constexpr float kGoldenMaskSimilarity = 0.97; @@ -179,22 +187,47 @@ TEST_F(CreateFromOptionsTest, FailsWithNeitherOutputSet) { struct InteractiveSegmenterTestParams { std::string test_name; RegionOfInterest::Format format; - NormalizedKeypoint roi; + std::variant> roi; + absl::string_view input_image_file; absl::string_view golden_mask_file; float similarity_threshold; }; -using SucceedSegmentationWithRoi = - ::testing::TestWithParam; +class SucceedSegmentationWithRoi + : public ::testing::TestWithParam { + public: + absl::StatusOr TestParamsToTaskOptions() { + const InteractiveSegmenterTestParams& params = GetParam(); + + RegionOfInterest interaction_roi; + interaction_roi.format = params.format; + switch (params.format) { + case (RegionOfInterest::Format::kKeyPoint): { + interaction_roi.keypoint = std::get(params.roi); + break; + } + case (RegionOfInterest::Format::kScribble): { + interaction_roi.scribble = + std::get>(params.roi); + break; + } + default: { + return absl::InvalidArgumentError("Unknown ROI format"); + } + } + + return interaction_roi; + } +}; TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) { + MP_ASSERT_OK_AND_ASSIGN(RegionOfInterest interaction_roi, + TestParamsToTaskOptions()); const InteractiveSegmenterTestParams& params = GetParam(); + MP_ASSERT_OK_AND_ASSIGN( - Image image, - DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); - RegionOfInterest interaction_roi; - interaction_roi.format = params.format; - interaction_roi.keypoint = params.roi; + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + params.input_image_file))); auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kPtmModel); @@ -217,16 +250,25 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) { EXPECT_THAT(actual_mask, SimilarToUint8Mask(expected_mask, params.similarity_threshold, kGoldenMaskMagnificationFactor)); + + cv::Mat visualized_mask; + actual_mask.convertTo(visualized_mask, CV_8UC1, /*alpha=*/255); + ImageFrame visualized_image(mediapipe::ImageFormat::GRAY8, + visualized_mask.cols, visualized_mask.rows, + visualized_mask.step, visualized_mask.data, + [visualized_mask](uint8_t[]) {}); + MP_EXPECT_OK(SavePngTestOutput( + visualized_image, absl::StrFormat("%s_category_mask", params.test_name))); } TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) { - const auto& params = GetParam(); + MP_ASSERT_OK_AND_ASSIGN(RegionOfInterest interaction_roi, + TestParamsToTaskOptions()); + const InteractiveSegmenterTestParams& params = GetParam(); + MP_ASSERT_OK_AND_ASSIGN( - Image image, - DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); - RegionOfInterest interaction_roi; - interaction_roi.format = params.format; - interaction_roi.keypoint = params.roi; + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + params.input_image_file))); auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kPtmModel); @@ -248,16 +290,44 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) { result.confidence_masks->at(1).GetImageFrameSharedPtr().get()); EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float, params.similarity_threshold)); + cv::Mat visualized_mask; + actual_mask.convertTo(visualized_mask, CV_8UC1, /*alpha=*/255); + ImageFrame visualized_image(mediapipe::ImageFormat::GRAY8, + visualized_mask.cols, visualized_mask.rows, + visualized_mask.step, visualized_mask.data, + [visualized_mask](uint8_t[]) {}); + MP_EXPECT_OK(SavePngTestOutput( + visualized_image, + absl::StrFormat("%s_confidence_mask", params.test_name))); } INSTANTIATE_TEST_SUITE_P( SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi, ::testing::ValuesIn( - {{"PointToDog1", RegionOfInterest::Format::kKeyPoint, - NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f}, + {// Keypoint input. + {"PointToDog1", RegionOfInterest::Format::kKeyPoint, + NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsJpg, kCatsAndDogsMaskDog1, + 0.84f}, {"PointToDog2", RegionOfInterest::Format::kKeyPoint, - NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2, - kGoldenMaskSimilarity}}), + NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsJpg, kCatsAndDogsMaskDog2, + kGoldenMaskSimilarity}, + {"PenguinsSmall", RegionOfInterest::Format::kKeyPoint, + NormalizedKeypoint{0.329, 0.545}, kPenguinsSmall, kPenguinsSmallMask, + 0.9f}, + {"PenguinsLarge", RegionOfInterest::Format::kKeyPoint, + NormalizedKeypoint{0.329, 0.545}, kPenguinsLarge, kPenguinsLargeMask, + 0.9f}, + // Scribble input. + {"ScribbleToDog1", RegionOfInterest::Format::kScribble, + std::vector{NormalizedKeypoint{0.44, 0.70}, + NormalizedKeypoint{0.44, 0.71}, + NormalizedKeypoint{0.44, 0.72}}, + kCatsAndDogsJpg, kCatsAndDogsMaskDog1, 0.84f}, + {"ScribbleToDog2", RegionOfInterest::Format::kScribble, + std::vector{NormalizedKeypoint{0.66, 0.66}, + NormalizedKeypoint{0.66, 0.67}, + NormalizedKeypoint{0.66, 0.68}}, + kCatsAndDogsJpg, kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}), [](const ::testing::TestParamInfo& info) { return info.param.test_name; }); diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 0238449c7..40d7ab50b 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -54,12 +54,7 @@ cc_library( name = "object_detector_graph", srcs = ["object_detector_graph.cc"], deps = [ - "//mediapipe/calculators/core:split_vector_calculator_cc_proto", "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/calculators/tensor:tensors_to_detections_calculator", - "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto", - "//mediapipe/calculators/util:detection_label_id_to_text_calculator", - "//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto", "//mediapipe/calculators/util:detection_projection_calculator", "//mediapipe/calculators/util:detection_transformation_calculator", "//mediapipe/calculators/util:detections_deduplicate_calculator", @@ -71,19 +66,15 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", - "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", + "//mediapipe/tasks/cc/components/processors:detection_postprocessing_graph", "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:detection_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:detector_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/core:utils", - "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto", "//mediapipe/tasks/metadata:metadata_schema_cc", - "//mediapipe/util:label_map_cc_proto", - "//mediapipe/util:label_map_util", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc index fe0651e1e..152ee3273 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -129,9 +129,17 @@ absl::StatusOr> ObjectDetector::Create( if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { return; } + Packet image_packet = status_or_packets.value()[kImageOutStreamName]; Packet detections_packet = status_or_packets.value()[kDetectionsOutStreamName]; - Packet image_packet = status_or_packets.value()[kImageOutStreamName]; + if (detections_packet.IsEmpty()) { + Packet empty_packet = + status_or_packets.value()[kDetectionsOutStreamName]; + result_callback( + {ConvertToDetectionResult({})}, image_packet.Get(), + empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); + return; + } result_callback(ConvertToDetectionResult( detections_packet.Get>()), image_packet.Get(), @@ -165,6 +173,9 @@ absl::StatusOr ObjectDetector::Detect( ProcessImageData( {{kImageInStreamName, MakePacket(std::move(image))}, {kNormRectName, MakePacket(std::move(norm_rect))}})); + if (output_packets[kDetectionsOutStreamName].IsEmpty()) { + return {ConvertToDetectionResult({})}; + } return ConvertToDetectionResult( output_packets[kDetectionsOutStreamName].Get>()); } @@ -190,6 +201,9 @@ absl::StatusOr ObjectDetector::DetectForVideo( {kNormRectName, MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); + if (output_packets[kDetectionsOutStreamName].IsEmpty()) { + return {ConvertToDetectionResult({})}; + } return ConvertToDetectionResult( output_packets[kDetectionsOutStreamName].Get>()); } diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.h b/mediapipe/tasks/cc/vision/object_detector/object_detector.h index 6113ed105..de2c0dbaf 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.h +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -99,7 +99,20 @@ struct ObjectDetectorOptions { // - only RGB inputs are supported (`channels` is required to be 3). // - if type is kTfLiteFloat32, NormalizationOptions are required to be // attached to the metadata for input normalization. -// Output tensors must be the 4 outputs of a `DetectionPostProcess` op, i.e: +// Output tensors could be 2 output tensors or 4 output tensors. +// The 2 output tensors must represent locations and scores, respectively. +// (kTfLiteFloat32) +// - locations tensor of size `[num_results x num_coords]`. The num_coords is +// the number of coordinates a location result represent. Usually in the +// form: [4 + 2 * keypoint_num], where 4 location values encode the bounding +// box (y_center, x_center, height, width) and the additional keypoints are in +// (y, x) order. +// (kTfLiteFloat32) +// - scores tensor of size `[num_results x num_classes]`. The values of a +// result represent the classification probability belonging to the class at +// the index, which is denoted in the label file of corresponding tensor +// metadata in the model file. +// The 4 output tensors must come from `DetectionPostProcess` op, i.e: // (kTfLiteFloat32) // - locations tensor of size `[num_results x 4]`, the inner array // representing bounding boxes in the form [top, left, right, bottom]. diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index e0eb32fdd..e2b374970 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,16 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include -#include #include #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "mediapipe/calculators/core/split_vector_calculator.pb.h" -#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" -#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h" #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/calculator.pb.h" @@ -31,19 +25,15 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" -#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" +#include "mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/detector_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" -#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" -#include "mediapipe/tasks/cc/core/utils.h" -#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.pb.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" -#include "mediapipe/util/label_map.pb.h" -#include "mediapipe/util/label_map_util.h" namespace mediapipe { namespace tasks { @@ -56,42 +46,18 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::metadata::ModelMetadataExtractor; -using ::tflite::BoundingBoxProperties; -using ::tflite::ContentProperties; -using ::tflite::ContentProperties_BoundingBoxProperties; -using ::tflite::EnumNameContentProperties; -using ::tflite::ProcessUnit; -using ::tflite::ProcessUnitOptions_ScoreThresholdingOptions; -using ::tflite::TensorMetadata; -using LabelItems = mediapipe::proto_ns::Map; using ObjectDetectorOptionsProto = object_detector::proto::ObjectDetectorOptions; using TensorsSource = mediapipe::api2::builder::Source>; -constexpr int kDefaultLocationsIndex = 0; -constexpr int kDefaultCategoriesIndex = 1; -constexpr int kDefaultScoresIndex = 2; -constexpr int kDefaultNumResultsIndex = 3; - -constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); - -constexpr char kLocationTensorName[] = "location"; -constexpr char kCategoryTensorName[] = "category"; -constexpr char kScoreTensorName[] = "score"; -constexpr char kNumberOfDetectionsTensorName[] = "number of detections"; - -constexpr char kCalibratedScoresTag[] = "CALIBRATED_SCORES"; constexpr char kDetectionsTag[] = "DETECTIONS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kImageTag[] = "IMAGE"; -constexpr char kIndicesTag[] = "INDICES"; constexpr char kMatrixTag[] = "MATRIX"; constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS"; constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX"; -constexpr char kScoresTag[] = "SCORES"; constexpr char kTensorTag[] = "TENSORS"; // Struct holding the different output streams produced by the object detection @@ -101,34 +67,6 @@ struct ObjectDetectionOutputStreams { Source image; }; -// Parameters used for configuring the post-processing calculators. -struct PostProcessingSpecs { - // The maximum number of detection results to return. - int max_results; - // Indices of the output tensors to match the output tensors to the correct - // index order of the output tensors: [location, categories, scores, - // num_detections]. - std::vector output_tensor_indices; - // For each pack of 4 coordinates returned by the model, this denotes the - // order in which to get the left, top, right and bottom coordinates. - std::vector bounding_box_corners_order; - // This is populated by reading the label files from the TFLite Model - // Metadata: if no such files are available, this is left empty and the - // ObjectDetector will only be able to populate the `index` field of the - // detection results. - LabelItems label_items; - // Score threshold. Detections with a confidence below this value are - // discarded. If none is provided via metadata or options, -FLT_MAX is set as - // default value. - float score_threshold; - // Set of category indices to be allowed/denied. - absl::flat_hash_set allow_or_deny_categories; - // Indicates `allow_or_deny_categories` is an allowlist or a denylist. - bool is_allowlist; - // Score calibration options, if any. - std::optional score_calibration_options; -}; - absl::Status SanityCheckOptions(const ObjectDetectorOptionsProto& options) { if (options.max_results() == 0) { return CreateStatusWithPayload( @@ -147,310 +85,6 @@ absl::Status SanityCheckOptions(const ObjectDetectorOptionsProto& options) { return absl::OkStatus(); } -absl::StatusOr GetBoundingBoxProperties( - const TensorMetadata& tensor_metadata) { - if (tensor_metadata.content() == nullptr || - tensor_metadata.content()->content_properties() == nullptr) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrFormat( - "Expected BoundingBoxProperties for tensor %s, found none.", - tensor_metadata.name() ? tensor_metadata.name()->str() : "#0"), - MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError); - } - - ContentProperties type = tensor_metadata.content()->content_properties_type(); - if (type != ContentProperties_BoundingBoxProperties) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrFormat( - "Expected BoundingBoxProperties for tensor %s, found %s.", - tensor_metadata.name() ? tensor_metadata.name()->str() : "#0", - EnumNameContentProperties(type)), - MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError); - } - - const BoundingBoxProperties* properties = - tensor_metadata.content()->content_properties_as_BoundingBoxProperties(); - - // Mobile SSD only supports "BOUNDARIES" bounding box type. - if (properties->type() != tflite::BoundingBoxType_BOUNDARIES) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrFormat( - "Mobile SSD only supports BoundingBoxType BOUNDARIES, found %s", - tflite::EnumNameBoundingBoxType(properties->type())), - MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError); - } - - // Mobile SSD only supports "RATIO" coordinates type. - if (properties->coordinate_type() != tflite::CoordinateType_RATIO) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrFormat( - "Mobile SSD only supports CoordinateType RATIO, found %s", - tflite::EnumNameCoordinateType(properties->coordinate_type())), - MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError); - } - - // Index is optional, but must contain 4 values if present. - if (properties->index() != nullptr && properties->index()->size() != 4) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrFormat( - "Expected BoundingBoxProperties index to contain 4 values, found " - "%d", - properties->index()->size()), - MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError); - } - - return properties; -} - -absl::StatusOr GetLabelItemsIfAny( - const ModelMetadataExtractor& metadata_extractor, - const TensorMetadata& tensor_metadata, absl::string_view locale) { - const std::string labels_filename = - ModelMetadataExtractor::FindFirstAssociatedFileName( - tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS); - if (labels_filename.empty()) { - LabelItems empty_label_items; - return empty_label_items; - } - ASSIGN_OR_RETURN(absl::string_view labels_file, - metadata_extractor.GetAssociatedFile(labels_filename)); - const std::string display_names_filename = - ModelMetadataExtractor::FindFirstAssociatedFileName( - tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS, - locale); - absl::string_view display_names_file; - if (!display_names_filename.empty()) { - ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile( - display_names_filename)); - } - return mediapipe::BuildLabelMapFromFiles(labels_file, display_names_file); -} - -absl::StatusOr GetScoreThreshold( - const ModelMetadataExtractor& metadata_extractor, - const TensorMetadata& tensor_metadata) { - ASSIGN_OR_RETURN( - const ProcessUnit* score_thresholding_process_unit, - metadata_extractor.FindFirstProcessUnit( - tensor_metadata, ProcessUnitOptions_ScoreThresholdingOptions)); - if (score_thresholding_process_unit == nullptr) { - return kDefaultScoreThreshold; - } - return score_thresholding_process_unit->options_as_ScoreThresholdingOptions() - ->global_score_threshold(); -} - -absl::StatusOr> GetAllowOrDenyCategoryIndicesIfAny( - const ObjectDetectorOptionsProto& config, const LabelItems& label_items) { - absl::flat_hash_set category_indices; - // Exit early if no denylist/allowlist. - if (config.category_denylist_size() == 0 && - config.category_allowlist_size() == 0) { - return category_indices; - } - if (label_items.empty()) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "Using `category_allowlist` or `category_denylist` requires " - "labels to be present in the TFLite Model Metadata but none was found.", - MediaPipeTasksStatus::kMetadataMissingLabelsError); - } - const auto& category_list = config.category_allowlist_size() > 0 - ? config.category_allowlist() - : config.category_denylist(); - for (const auto& category_name : category_list) { - int index = -1; - for (int i = 0; i < label_items.size(); ++i) { - if (label_items.at(i).name() == category_name) { - index = i; - break; - } - } - // Ignores duplicate or unknown categories. - if (index < 0) { - continue; - } - category_indices.insert(index); - } - return category_indices; -} - -absl::StatusOr> -GetScoreCalibrationOptionsIfAny( - const ModelMetadataExtractor& metadata_extractor, - const TensorMetadata& tensor_metadata) { - // Get ScoreCalibrationOptions, if any. - ASSIGN_OR_RETURN( - const ProcessUnit* score_calibration_process_unit, - metadata_extractor.FindFirstProcessUnit( - tensor_metadata, tflite::ProcessUnitOptions_ScoreCalibrationOptions)); - if (score_calibration_process_unit == nullptr) { - return std::nullopt; - } - auto* score_calibration_options = - score_calibration_process_unit->options_as_ScoreCalibrationOptions(); - // Get corresponding AssociatedFile. - auto score_calibration_filename = - metadata_extractor.FindFirstAssociatedFileName( - tensor_metadata, - tflite::AssociatedFileType_TENSOR_AXIS_SCORE_CALIBRATION); - if (score_calibration_filename.empty()) { - return CreateStatusWithPayload( - absl::StatusCode::kNotFound, - "Found ScoreCalibrationOptions but missing required associated " - "parameters file with type TENSOR_AXIS_SCORE_CALIBRATION.", - MediaPipeTasksStatus::kMetadataAssociatedFileNotFoundError); - } - ASSIGN_OR_RETURN( - absl::string_view score_calibration_file, - metadata_extractor.GetAssociatedFile(score_calibration_filename)); - ScoreCalibrationCalculatorOptions score_calibration_calculator_options; - MP_RETURN_IF_ERROR(ConfigureScoreCalibration( - score_calibration_options->score_transformation(), - score_calibration_options->default_score(), score_calibration_file, - &score_calibration_calculator_options)); - return score_calibration_calculator_options; -} - -std::vector GetOutputTensorIndices( - const flatbuffers::Vector>* - tensor_metadatas) { - std::vector output_indices = { - core::FindTensorIndexByMetadataName(tensor_metadatas, - kLocationTensorName), - core::FindTensorIndexByMetadataName(tensor_metadatas, - kCategoryTensorName), - core::FindTensorIndexByMetadataName(tensor_metadatas, kScoreTensorName), - core::FindTensorIndexByMetadataName(tensor_metadatas, - kNumberOfDetectionsTensorName)}; - // locations, categories, scores, and number of detections - for (int i = 0; i < 4; i++) { - int output_index = output_indices[i]; - // If tensor name is not found, set the default output indices. - if (output_index == -1) { - LOG(WARNING) << absl::StrFormat( - "You don't seem to be matching tensor names in metadata list. The " - "tensor name \"%s\" at index %d in the model metadata doesn't " - "match " - "the available output names: [\"%s\", \"%s\", \"%s\", \"%s\"].", - tensor_metadatas->Get(i)->name()->c_str(), i, kLocationTensorName, - kCategoryTensorName, kScoreTensorName, kNumberOfDetectionsTensorName); - output_indices = {kDefaultLocationsIndex, kDefaultCategoriesIndex, - kDefaultScoresIndex, kDefaultNumResultsIndex}; - return output_indices; - } - } - return output_indices; -} - -// Builds PostProcessingSpecs from ObjectDetectorOptionsProto and model metadata -// for configuring the post-processing calculators. -absl::StatusOr BuildPostProcessingSpecs( - const ObjectDetectorOptionsProto& options, - const ModelMetadataExtractor* metadata_extractor) { - // Checks output tensor metadata is present and consistent with model. - auto* output_tensors_metadata = metadata_extractor->GetOutputTensorMetadata(); - if (output_tensors_metadata == nullptr || - output_tensors_metadata->size() != 4) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrFormat("Mismatch between number of output tensors (4) and " - "output tensors metadata (%d).", - output_tensors_metadata == nullptr - ? 0 - : output_tensors_metadata->size()), - MediaPipeTasksStatus::kMetadataInconsistencyError); - } - PostProcessingSpecs specs; - specs.max_results = options.max_results(); - specs.output_tensor_indices = GetOutputTensorIndices(output_tensors_metadata); - // Extracts mandatory BoundingBoxProperties and performs sanity checks on the - // fly. - ASSIGN_OR_RETURN(const BoundingBoxProperties* bounding_box_properties, - GetBoundingBoxProperties(*output_tensors_metadata->Get( - specs.output_tensor_indices[0]))); - if (bounding_box_properties->index() == nullptr) { - specs.bounding_box_corners_order = {0, 1, 2, 3}; - } else { - auto bounding_box_index = bounding_box_properties->index(); - specs.bounding_box_corners_order = { - bounding_box_index->Get(0), - bounding_box_index->Get(1), - bounding_box_index->Get(2), - bounding_box_index->Get(3), - }; - } - // Builds label map (if available) from metadata. - ASSIGN_OR_RETURN(specs.label_items, - GetLabelItemsIfAny(*metadata_extractor, - *output_tensors_metadata->Get( - specs.output_tensor_indices[1]), - options.display_names_locale())); - // Obtains allow/deny categories. - specs.is_allowlist = !options.category_allowlist().empty(); - ASSIGN_OR_RETURN( - specs.allow_or_deny_categories, - GetAllowOrDenyCategoryIndicesIfAny(options, specs.label_items)); - // Sets score threshold. - if (options.has_score_threshold()) { - specs.score_threshold = options.score_threshold(); - } else { - ASSIGN_OR_RETURN(specs.score_threshold, - GetScoreThreshold(*metadata_extractor, - *output_tensors_metadata->Get( - specs.output_tensor_indices[2]))); - } - // Builds score calibration options (if available) from metadata. - ASSIGN_OR_RETURN( - specs.score_calibration_options, - GetScoreCalibrationOptionsIfAny( - *metadata_extractor, - *output_tensors_metadata->Get(specs.output_tensor_indices[2]))); - return specs; -} - -// Fills in the TensorsToDetectionsCalculatorOptions based on -// PostProcessingSpecs. -void ConfigureTensorsToDetectionsCalculator( - const PostProcessingSpecs& specs, - mediapipe::TensorsToDetectionsCalculatorOptions* options) { - options->set_num_classes(specs.label_items.size()); - options->set_num_coords(4); - options->set_min_score_thresh(specs.score_threshold); - if (specs.max_results != -1) { - options->set_max_results(specs.max_results); - } - if (specs.is_allowlist) { - options->mutable_allow_classes()->Assign( - specs.allow_or_deny_categories.begin(), - specs.allow_or_deny_categories.end()); - } else { - options->mutable_ignore_classes()->Assign( - specs.allow_or_deny_categories.begin(), - specs.allow_or_deny_categories.end()); - } - - const auto& output_indices = specs.output_tensor_indices; - // Assigns indices to each the model output tensor. - auto* tensor_mapping = options->mutable_tensor_mapping(); - tensor_mapping->set_detections_tensor_index(output_indices[0]); - tensor_mapping->set_classes_tensor_index(output_indices[1]); - tensor_mapping->set_scores_tensor_index(output_indices[2]); - tensor_mapping->set_num_detections_tensor_index(output_indices[3]); - - // Assigns the bounding box corner order. - auto box_boundaries_indices = options->mutable_box_boundaries_indices(); - box_boundaries_indices->set_xmin(specs.bounding_box_corners_order[0]); - box_boundaries_indices->set_ymin(specs.bounding_box_corners_order[1]); - box_boundaries_indices->set_xmax(specs.bounding_box_corners_order[2]); - box_boundaries_indices->set_ymax(specs.bounding_box_corners_order[3]); -} - } // namespace // A "mediapipe.tasks.vision.ObjectDetectorGraph" performs object detection. @@ -530,7 +164,6 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { const core::ModelResources& model_resources, Source image_in, Source norm_rect_in, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); - // Checks that the model has 4 outputs. auto& model = *model_resources.GetTfLiteModel(); if (model.subgraphs()->size() != 1) { return CreateStatusWithPayload( @@ -539,13 +172,6 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { model.subgraphs()->size()), MediaPipeTasksStatus::kInvalidArgumentError); } - if (model.subgraphs()->Get(0)->outputs()->size() != 4) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrFormat("Expected a model with 4 output tensors, found %d.", - model.subgraphs()->Get(0)->outputs()->size()), - MediaPipeTasksStatus::kInvalidArgumentError); - } // Checks that metadata is available. auto* metadata_extractor = model_resources.GetMetadataExtractor(); if (metadata_extractor->GetModelMetadata() == nullptr || @@ -577,70 +203,36 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { auto& inference = AddInference( model_resources, task_options.base_options().acceleration(), graph); preprocessing.Out(kTensorTag) >> inference.In(kTensorTag); - - // Adds post processing calculators. - ASSIGN_OR_RETURN( - auto post_processing_specs, - BuildPostProcessingSpecs(task_options, metadata_extractor)); - // Calculators to perform score calibration, if specified in the metadata. - TensorsSource calibrated_tensors = + TensorsSource model_output_tensors = inference.Out(kTensorTag).Cast>(); - if (post_processing_specs.score_calibration_options.has_value()) { - // Split tensors. - auto* split_tensor_vector_node = - &graph.AddNode("SplitTensorVectorCalculator"); - auto& split_tensor_vector_options = - split_tensor_vector_node - ->GetOptions(); - for (int i = 0; i < 4; ++i) { - auto* range = split_tensor_vector_options.add_ranges(); - range->set_begin(i); - range->set_end(i + 1); - } - calibrated_tensors >> split_tensor_vector_node->In(0); - // Add score calibration calculator. - auto* score_calibration_node = - &graph.AddNode("ScoreCalibrationCalculator"); - score_calibration_node->GetOptions() - .CopyFrom(*post_processing_specs.score_calibration_options); - split_tensor_vector_node->Out( - post_processing_specs.output_tensor_indices[1]) >> - score_calibration_node->In(kIndicesTag); - split_tensor_vector_node->Out( - post_processing_specs.output_tensor_indices[2]) >> - score_calibration_node->In(kScoresTag); - - // Re-concatenate tensors. - auto* concatenate_tensor_vector_node = - &graph.AddNode("ConcatenateTensorVectorCalculator"); - for (int i = 0; i < 4; ++i) { - if (i == post_processing_specs.output_tensor_indices[2]) { - score_calibration_node->Out(kCalibratedScoresTag) >> - concatenate_tensor_vector_node->In(i); - } else { - split_tensor_vector_node->Out(i) >> - concatenate_tensor_vector_node->In(i); - } - } - calibrated_tensors = - concatenate_tensor_vector_node->Out(0).Cast>(); - } - // Calculator to convert output tensors to a detection proto vector. - // Connects TensorsToDetectionsCalculator's input stream to the output - // tensors produced by the inference subgraph. - auto& tensors_to_detections = - graph.AddNode("TensorsToDetectionsCalculator"); - ConfigureTensorsToDetectionsCalculator( - post_processing_specs, - &tensors_to_detections - .GetOptions()); - calibrated_tensors >> tensors_to_detections.In(kTensorTag); + // Add Detection postprocessing graph to convert tensors to detections. + auto& postprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.DetectionPostprocessingGraph"); + components::processors::proto::DetectorOptions detector_options; + detector_options.set_max_results(task_options.max_results()); + detector_options.set_score_threshold(task_options.score_threshold()); + detector_options.set_display_names_locale( + task_options.display_names_locale()); + detector_options.mutable_category_allowlist()->CopyFrom( + task_options.category_allowlist()); + detector_options.mutable_category_denylist()->CopyFrom( + task_options.category_denylist()); + // TODO: expose min suppression threshold in + // ObjectDetectorOptions. + detector_options.set_min_suppression_threshold(0.3); + MP_RETURN_IF_ERROR( + components::processors::ConfigureDetectionPostprocessingGraph( + model_resources, detector_options, + postprocessing + .GetOptions())); + model_output_tensors >> postprocessing.In(kTensorTag); + auto detections = postprocessing.Out(kDetectionsTag); // Calculator to projects detections back to the original coordinate system. auto& detection_projection = graph.AddNode("DetectionProjectionCalculator"); - tensors_to_detections.Out(kDetectionsTag) >> - detection_projection.In(kDetectionsTag); + detections >> detection_projection.In(kDetectionsTag); preprocessing.Out(kMatrixTag) >> detection_projection.In(kProjectionMatrixTag); @@ -652,22 +244,13 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { detection_transformation.In(kDetectionsTag); preprocessing.Out(kImageSizeTag) >> detection_transformation.In(kImageSizeTag); - - // Calculator to assign detection labels. - auto& detection_label_id_to_text = - graph.AddNode("DetectionLabelIdToTextCalculator"); - auto& detection_label_id_to_text_opts = - detection_label_id_to_text - .GetOptions(); - *detection_label_id_to_text_opts.mutable_label_items() = - std::move(post_processing_specs.label_items); - detection_transformation.Out(kPixelDetectionsTag) >> - detection_label_id_to_text.In(""); + auto detections_in_pixel = + detection_transformation.Out(kPixelDetectionsTag); // Deduplicate Detections with same bounding box coordinates. auto& detections_deduplicate = graph.AddNode("DetectionsDeduplicateCalculator"); - detection_label_id_to_text.Out("") >> detections_deduplicate.In(""); + detections_in_pixel >> detections_deduplicate.In(""); // Outputs the labeled detections and the processed image as the subgraph // output streams. diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc index c992cf67e..e66fc19bb 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -76,15 +76,18 @@ using ::testing::HasSubstr; using ::testing::Optional; using DetectionProto = mediapipe::Detection; -constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; -constexpr char kMobileSsdWithMetadata[] = - "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"; -constexpr char kMobileSsdWithDummyScoreCalibration[] = +constexpr absl::string_view kTestDataDirectory{ + "/mediapipe/tasks/testdata/vision/"}; +constexpr absl::string_view kMobileSsdWithMetadata{ + "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"}; +constexpr absl::string_view kMobileSsdWithDummyScoreCalibration{ "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration." - "tflite"; + "tflite"}; // The model has different output tensor order. -constexpr char kEfficientDetWithMetadata[] = - "coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite"; +constexpr absl::string_view kEfficientDetWithMetadata{ + "coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite"}; +constexpr absl::string_view kEfficientDetWithoutNms{ + "efficientdet_lite0_fp16_no_nms.tflite"}; // Checks that the two provided `Detection` proto vectors are equal, with a // tolerancy on floating-point scores to account for numerical instabilities. @@ -451,6 +454,67 @@ TEST_F(ImageModeTest, SucceedsEfficientDetModel) { })pb")})); } +TEST_F(ImageModeTest, SucceedsEfficientDetNoNmsModel) { + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "cats_and_dogs.jpg"))); + auto options = std::make_unique(); + options->max_results = 4; + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kEfficientDetWithoutNms); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, + ObjectDetector::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); + MP_ASSERT_OK(object_detector->Close()); + ExpectApproximatelyEqual( + results, + ConvertToDetectionResult( + {ParseTextProtoOrDie(R"pb( + label: "dog" + score: 0.733542 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 636 ymin: 160 width: 282 height: 451 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.699751 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 870 ymin: 411 width: 208 height: 187 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "dog" + score: 0.682425 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 386 ymin: 216 width: 256 height: 376 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.646585 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 83 ymin: 399 width: 347 height: 198 } + })pb")})); +} + +TEST_F(ImageModeTest, SucceedsNoObjectDetected) { + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "cats_and_dogs.jpg"))); + auto options = std::make_unique(); + options->max_results = 4; + options->score_threshold = 1.0f; + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kEfficientDetWithoutNms); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, + ObjectDetector::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); + MP_ASSERT_OK(object_detector->Close()); + EXPECT_THAT(results.detections, testing::IsEmpty()); +} + TEST_F(ImageModeTest, SucceedsWithoutImageResizing) { MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath( "./", kTestDataDirectory, diff --git a/mediapipe/tasks/cc/vision/object_detector/proto/BUILD b/mediapipe/tasks/cc/vision/object_detector/proto/BUILD index edcaff52f..863d69df0 100644 --- a/mediapipe/tasks/cc/vision/object_detector/proto/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/proto/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto b/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto index 3f6932f8f..471d8d527 100644 --- a/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto +++ b/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/pose_detector/BUILD b/mediapipe/tasks/cc/vision/pose_detector/BUILD index 1fc4b5ba2..4f9bc5944 100644 --- a/mediapipe/tasks/cc/vision/pose_detector/BUILD +++ b/mediapipe/tasks/cc/vision/pose_detector/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph.cc b/mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph.cc index 32c125ce2..f1554f8df 100644 --- a/mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph_test.cc b/mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph_test.cc index 800c05846..4d15583af 100644 --- a/mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "absl/flags/flag.h" +#include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" @@ -114,16 +115,18 @@ absl::StatusOr> CreateTaskRunner( Detection GetExpectedPoseDetectionResult(absl::string_view file_name) { Detection detection; - CHECK_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, file_name), - &detection, Defaults())) + ABSL_CHECK_OK( + GetTextProto(file::JoinPath("./", kTestDataDirectory, file_name), + &detection, Defaults())) << "Expected pose detection result does not exist."; return detection; } NormalizedRect GetExpectedExpandedPoseRect(absl::string_view file_name) { NormalizedRect expanded_rect; - CHECK_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, file_name), - &expanded_rect, Defaults())) + ABSL_CHECK_OK( + GetTextProto(file::JoinPath("./", kTestDataDirectory, file_name), + &expanded_rect, Defaults())) << "Expected expanded pose rect does not exist."; return expanded_rect; } diff --git a/mediapipe/tasks/cc/vision/pose_detector/proto/BUILD b/mediapipe/tasks/cc/vision/pose_detector/proto/BUILD index 287ed0183..53e7d5a55 100644 --- a/mediapipe/tasks/cc/vision/pose_detector/proto/BUILD +++ b/mediapipe/tasks/cc/vision/pose_detector/proto/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.proto b/mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.proto index 693f95262..531bcc7e9 100644 --- a/mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.proto +++ b/mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/BUILD b/mediapipe/tasks/cc/vision/pose_landmarker/BUILD index 19f546257..f9bdb5613 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/pose_landmarker/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -155,3 +155,13 @@ cc_library( "//mediapipe/tasks/cc/components/containers:landmark", ], ) + +cc_library( + name = "pose_landmarks_connections", + hdrs = ["pose_landmarks_connections.h"], +) + +cc_library( + name = "pose_landmark", + hdrs = ["pose_landmark.h"], +) diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmark.h b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmark.h new file mode 100644 index 000000000..36c628145 --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmark.h @@ -0,0 +1,68 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_POSE_LANDMARKER_POSE_LANDMARK_H_ +#define MEDIAPIPE_TASKS_CC_VISION_POSE_LANDMARKER_POSE_LANDMARK_H_ + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace pose_landmarker { + +static constexpr int kNumPoseLandmarks = 33; + +// BlazePose 33 landmark names. +enum class PoseLandmark { + kNose = 0, + kLeftEyeInner, + kLeftEye, + kLeftEyeOuter, + kRightEyeInner, + kRightEye, + kRightEyeOuter, + kLeftEar, + kRightEar, + kMouthLeft, + kMouthRight, + kLeftShoulder, + kRightShoulder, + kLeftElbow, + kRightElbow, + kLeftWrist, + kRightWrist, + kLeftPinky1, + kRightPinky1, + kLeftIndex1, + kRightIndex1, + kLeftThumb2, + kRightThumb2, + kLeftHip, + kRightHip, + kLeftKnee, + kRightKnee, + kLeftAnkle, + kRightAnkle, + kLeftHeel, + kRightHeel, + kLeftFootIndex, + kRightFootIndex, +}; + +} // namespace pose_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_POSE_LANDMARKER_POSE_LANDMARK_H_ diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.cc index 01c86c122..797e71488 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -63,8 +63,6 @@ constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS"; constexpr char kNormLandmarksStreamName[] = "norm_landmarks"; constexpr char kPoseWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kPoseWorldLandmarksStreamName[] = "world_landmarks"; -constexpr char kPoseAuxiliaryLandmarksTag[] = "AUXILIARY_LANDMARKS"; -constexpr char kPoseAuxiliaryLandmarksStreamName[] = "auxiliary_landmarks"; constexpr int kMicroSecondsPerMilliSecond = 1000; // Creates a MediaPipe graph config that contains a subgraph node of @@ -83,9 +81,6 @@ CalculatorGraphConfig CreateGraphConfig( graph.Out(kNormLandmarksTag); subgraph.Out(kPoseWorldLandmarksTag).SetName(kPoseWorldLandmarksStreamName) >> graph.Out(kPoseWorldLandmarksTag); - subgraph.Out(kPoseAuxiliaryLandmarksTag) - .SetName(kPoseAuxiliaryLandmarksStreamName) >> - graph.Out(kPoseAuxiliaryLandmarksTag); subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); if (output_segmentation_masks) { subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >> @@ -163,8 +158,6 @@ absl::StatusOr> PoseLandmarker::Create( status_or_packets.value()[kNormLandmarksStreamName]; Packet pose_world_landmarks_packet = status_or_packets.value()[kPoseWorldLandmarksStreamName]; - Packet pose_auxiliary_landmarks_packet = - status_or_packets.value()[kPoseAuxiliaryLandmarksStreamName]; std::optional> segmentation_mask = std::nullopt; if (output_segmentation_masks) { segmentation_mask = segmentation_mask_packet.Get>(); @@ -175,9 +168,7 @@ absl::StatusOr> PoseLandmarker::Create( /* pose_landmarks= */ pose_landmarks_packet.Get>(), /* pose_world_landmarks= */ - pose_world_landmarks_packet.Get>(), - pose_auxiliary_landmarks_packet - .Get>()), + pose_world_landmarks_packet.Get>()), image_packet.Get(), pose_landmarks_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); @@ -234,10 +225,7 @@ absl::StatusOr PoseLandmarker::Detect( .Get>(), /* pose_world_landmarks */ output_packets[kPoseWorldLandmarksStreamName] - .Get>(), - /*pose_auxiliary_landmarks= */ - output_packets[kPoseAuxiliaryLandmarksStreamName] - .Get>()); + .Get>()); } absl::StatusOr PoseLandmarker::DetectForVideo( @@ -277,10 +265,7 @@ absl::StatusOr PoseLandmarker::DetectForVideo( .Get>(), /* pose_world_landmarks */ output_packets[kPoseWorldLandmarksStreamName] - .Get>(), - /* pose_auxiliary_landmarks= */ - output_packets[kPoseAuxiliaryLandmarksStreamName] - .Get>()); + .Get>()); } absl::Status PoseLandmarker::DetectAsync( diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.h b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.h index 058ab0b1e..314356aa0 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.h +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc index 456a6efd1..7889212e8 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -108,9 +108,18 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, ->mutable_model_asset(), is_copy); } - pose_detector_graph_options->mutable_base_options() - ->mutable_acceleration() - ->CopyFrom(options->base_options().acceleration()); + if (options->base_options().acceleration().has_gpu()) { + core::proto::Acceleration gpu_accel; + gpu_accel.mutable_gpu()->set_use_advanced_gpu_api(true); + pose_detector_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(gpu_accel); + + } else { + pose_detector_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + } pose_detector_graph_options->mutable_base_options()->set_use_stream_mode( options->base_options().use_stream_mode()); auto* pose_landmarks_detector_graph_options = diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph_test.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph_test.cc index 6a4cc93b6..a2206c336 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph_test.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include "absl/flags/flag.h" #include "absl/status/statusor.h" @@ -24,12 +25,15 @@ limitations under the License. #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/tool/test_util.h" #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" @@ -76,8 +80,11 @@ constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormRectName[] = "norm_rect"; constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS"; constexpr char kNormLandmarksName[] = "norm_landmarks"; +constexpr char kSegmentationMaskTag[] = "SEGMENTATION_MASK"; +constexpr char kSegmentationMaskName[] = "segmentation_mask"; constexpr float kLiteModelFractionDiff = 0.05; // percentage +constexpr float kGoldenMaskSimilarity = .98; template ProtoT GetExpectedProto(absl::string_view filename) { @@ -125,6 +132,9 @@ absl::StatusOr> CreatePoseLandmarkerGraphTaskRunner( pose_landmarker.Out(kNormLandmarksTag).SetName(kNormLandmarksName) >> graph[Output>(kNormLandmarksTag)]; + pose_landmarker.Out(kSegmentationMaskTag).SetName(kSegmentationMaskName) >> + graph[Output>(kSegmentationMaskTag)]; + return TaskRunner::Create( graph.GetConfig(), absl::make_unique()); @@ -145,6 +155,21 @@ NormalizedRect MakeNormRect(float x_center, float y_center, float width, class PoseLandmarkerGraphTest : public testing::TestWithParam {}; +// Convert pixels from float range [0,1] to uint8 range [0,255]. +ImageFrame CreateUint8ImageFrame(const Image& image) { + auto* image_frame_ptr = image.GetImageFrameSharedPtr().get(); + ImageFrame output_image_frame(ImageFormat::GRAY8, image_frame_ptr->Width(), + image_frame_ptr->Height(), 1); + float* pixelData = + reinterpret_cast(image_frame_ptr->MutablePixelData()); + uint8_t* uint8PixelData = output_image_frame.MutablePixelData(); + const int total_pixels = image_frame_ptr->Width() * image_frame_ptr->Height(); + for (int i = 0; i < total_pixels; ++i) { + uint8PixelData[i] = static_cast(pixelData[i] * 255.0f); + } + return output_image_frame; +} + TEST_P(PoseLandmarkerGraphTest, Succeeds) { MP_ASSERT_OK_AND_ASSIGN( Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, @@ -167,6 +192,53 @@ TEST_P(PoseLandmarkerGraphTest, Succeeds) { GetParam().landmarks_diff_threshold), *GetParam().expected_landmarks_list)); } + + const std::vector& segmentation_masks = + (*output_packets)[kSegmentationMaskName].Get>(); + + EXPECT_EQ(segmentation_masks.size(), 1); + + const Image& segmentation_mask = segmentation_masks[0]; + const ImageFrame segmentation_mask_image_frame = + CreateUint8ImageFrame(segmentation_mask); + + auto expected_image_frame = LoadTestPng( + JoinPath("./", kTestDataDirectory, "pose_segmentation_mask_golden.png"), + ImageFormat::GRAY8); + + ASSERT_EQ(segmentation_mask_image_frame.Width(), + expected_image_frame->Width()); + ASSERT_EQ(segmentation_mask_image_frame.Height(), + expected_image_frame->Height()); + ASSERT_EQ(segmentation_mask_image_frame.Format(), + expected_image_frame->Format()); + ASSERT_EQ(segmentation_mask_image_frame.NumberOfChannels(), + expected_image_frame->NumberOfChannels()); + ASSERT_EQ(segmentation_mask_image_frame.ByteDepth(), + expected_image_frame->ByteDepth()); + ASSERT_EQ(segmentation_mask_image_frame.NumberOfChannels(), 1); + ASSERT_EQ(segmentation_mask_image_frame.ByteDepth(), 1); + int consistent_pixels = 0; + int num_pixels = segmentation_mask_image_frame.Width() * + segmentation_mask_image_frame.Height(); + for (int i = 0; i < segmentation_mask_image_frame.Height(); ++i) { + for (int j = 0; j < segmentation_mask_image_frame.Width(); ++j) { + consistent_pixels += + (segmentation_mask_image_frame + .PixelData()[segmentation_mask_image_frame.WidthStep() * i + + j] == + expected_image_frame + ->PixelData()[expected_image_frame->WidthStep() * i + j]); + } + } + + EXPECT_GE(static_cast(consistent_pixels) / num_pixels, + kGoldenMaskSimilarity); + + // For visual comparison of segmentation mask output. + MP_ASSERT_OK_AND_ASSIGN(auto output_path, + SavePngTestOutput(segmentation_mask_image_frame, + "segmentation_mask_output")); } INSTANTIATE_TEST_SUITE_P( diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.cc index 6222bbd68..da4c630b3 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,15 +27,12 @@ namespace pose_landmarker { PoseLandmarkerResult ConvertToPoseLandmarkerResult( std::optional> segmentation_masks, const std::vector& pose_landmarks_proto, - const std::vector& pose_world_landmarks_proto, - const std::vector& - pose_auxiliary_landmarks_proto) { + const std::vector& pose_world_landmarks_proto) { PoseLandmarkerResult result; result.segmentation_masks = segmentation_masks; result.pose_landmarks.resize(pose_landmarks_proto.size()); result.pose_world_landmarks.resize(pose_world_landmarks_proto.size()); - result.pose_auxiliary_landmarks.resize(pose_auxiliary_landmarks_proto.size()); std::transform(pose_landmarks_proto.begin(), pose_landmarks_proto.end(), result.pose_landmarks.begin(), components::containers::ConvertToNormalizedLandmarks); @@ -43,10 +40,6 @@ PoseLandmarkerResult ConvertToPoseLandmarkerResult( pose_world_landmarks_proto.end(), result.pose_world_landmarks.begin(), components::containers::ConvertToLandmarks); - std::transform(pose_auxiliary_landmarks_proto.begin(), - pose_auxiliary_landmarks_proto.end(), - result.pose_auxiliary_landmarks.begin(), - components::containers::ConvertToNormalizedLandmarks); return result; } diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.h b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.h index 07adb87f5..27314b6c6 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.h +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_ -#define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_ +#ifndef MEDIAPIPE_TASKS_CC_VISION_POSE_LANDMARKER_POSE_LANDMARKER_RESULT_H_ +#define MEDIAPIPE_TASKS_CC_VISION_POSE_LANDMARKER_POSE_LANDMARKER_RESULT_H_ #include @@ -37,21 +37,16 @@ struct PoseLandmarkerResult { std::vector pose_landmarks; // Detected pose landmarks in world coordinates. std::vector pose_world_landmarks; - // Detected auxiliary landmarks, used for deriving ROI for next frame. - std::vector - pose_auxiliary_landmarks; }; PoseLandmarkerResult ConvertToPoseLandmarkerResult( std::optional> segmentation_mask, const std::vector& pose_landmarks_proto, - const std::vector& pose_world_landmarks_proto, - const std::vector& - pose_auxiliary_landmarks_proto); + const std::vector& pose_world_landmarks_proto); } // namespace pose_landmarker } // namespace vision } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_ +#endif // MEDIAPIPE_TASKS_CC_VISION_POSE_LANDMARKER_POSE_LANDMARKER_RESULT_H_ diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result_test.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result_test.cc index 10e0d61a3..05e83b655 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result_test.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -47,13 +47,6 @@ TEST(ConvertFromProto, Succeeds) { landmark_proto.set_y(5.2); landmark_proto.set_z(4.3); - mediapipe::NormalizedLandmarkList auxiliary_landmark_list_proto; - mediapipe::NormalizedLandmark& auxiliary_landmark_proto = - *auxiliary_landmark_list_proto.add_landmark(); - auxiliary_landmark_proto.set_x(0.5); - auxiliary_landmark_proto.set_y(0.5); - auxiliary_landmark_proto.set_z(0.5); - std::vector segmentation_masks_lists = {segmentation_mask}; std::vector normalized_landmarks_lists = { @@ -62,12 +55,9 @@ TEST(ConvertFromProto, Succeeds) { std::vector world_landmarks_lists = { world_landmark_list_proto}; - std::vector auxiliary_landmarks_lists = { - auxiliary_landmark_list_proto}; - PoseLandmarkerResult pose_landmarker_result = ConvertToPoseLandmarkerResult( segmentation_masks_lists, normalized_landmarks_lists, - world_landmarks_lists, auxiliary_landmarks_lists); + world_landmarks_lists); EXPECT_EQ(pose_landmarker_result.pose_landmarks.size(), 1); EXPECT_EQ(pose_landmarker_result.pose_landmarks[0].landmarks.size(), 1); @@ -82,14 +72,6 @@ TEST(ConvertFromProto, Succeeds) { testing::FieldsAre(testing::FloatEq(3.1), testing::FloatEq(5.2), testing::FloatEq(4.3), std::nullopt, std::nullopt, std::nullopt)); - - EXPECT_EQ(pose_landmarker_result.pose_auxiliary_landmarks.size(), 1); - EXPECT_EQ(pose_landmarker_result.pose_auxiliary_landmarks[0].landmarks.size(), - 1); - EXPECT_THAT(pose_landmarker_result.pose_auxiliary_landmarks[0].landmarks[0], - testing::FieldsAre(testing::FloatEq(0.5), testing::FloatEq(0.5), - testing::FloatEq(0.5), std::nullopt, - std::nullopt, std::nullopt)); } } // namespace pose_landmarker diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_test.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_test.cc index 062d0746d..239851b5f 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_test.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/flags/flag.h" +#include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -38,7 +39,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" #include "util/tuple/dump_vars.h" namespace mediapipe { @@ -105,17 +106,17 @@ MATCHER_P2(LandmarksMatches, expected_landmarks, toleration, "") { for (int i = 0; i < arg.size(); i++) { for (int j = 0; j < arg[i].landmarks.size(); j++) { if (arg[i].landmarks.size() != expected_landmarks[i].landmarks.size()) { - LOG(INFO) << "sizes not equal"; + ABSL_LOG(INFO) << "sizes not equal"; return false; } if (std::abs(arg[i].landmarks[j].x - expected_landmarks[i].landmarks[j].x) > toleration || std::abs(arg[i].landmarks[j].y - expected_landmarks[i].landmarks[j].y) > toleration) { - LOG(INFO) << DUMP_VARS(arg[i].landmarks[j].x, - expected_landmarks[i].landmarks[j].x); - LOG(INFO) << DUMP_VARS(arg[i].landmarks[j].y, - expected_landmarks[i].landmarks[j].y); + ABSL_LOG(INFO) << DUMP_VARS(arg[i].landmarks[j].x, + expected_landmarks[i].landmarks[j].x); + ABSL_LOG(INFO) << DUMP_VARS(arg[i].landmarks[j].y, + expected_landmarks[i].landmarks[j].y); return false; } } @@ -316,7 +317,7 @@ TEST_P(VideoModeTest, Succeeds) { MP_ASSERT_OK_AND_ASSIGN(pose_landmarker_results, pose_landmarker->DetectForVideo(image, i)); } - LOG(INFO) << i; + ABSL_LOG(INFO) << i; ExpectPoseLandmarkerResultsCorrect( pose_landmarker_results, expected_results, kLandmarksOnVideoAbsMargin); } diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_connections.h b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_connections.h new file mode 100644 index 000000000..4b79215a4 --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_connections.h @@ -0,0 +1,39 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_POSE_LANDMARKER_POSE_LANDMARKS_CONNECTIONS_H_ +#define MEDIAPIPE_TASKS_CC_VISION_POSE_LANDMARKER_POSE_LANDMARKS_CONNECTIONS_H_ + +#include + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace pose_landmarker { + +static constexpr std::array, 34> kPoseLandmarksConnections{{ + {1, 2}, {0, 1}, {2, 3}, {3, 7}, {0, 4}, {4, 5}, {5, 6}, + {6, 8}, {9, 10}, {11, 12}, {11, 13}, {13, 15}, {15, 17}, {15, 19}, + {15, 21}, {17, 19}, {12, 14}, {14, 16}, {16, 18}, {16, 20}, {16, 22}, + {18, 20}, {11, 23}, {12, 24}, {23, 24}, {23, 25}, {24, 26}, {25, 27}, + {26, 28}, {27, 29}, {28, 30}, {29, 31}, {30, 32}, {27, 31}, +}}; + +} // namespace pose_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_POSE_LANDMARKER_POSE_LANDMARKS_CONNECTIONS_H_ diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph.cc index f8488db02..e8397192b 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph_test.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph_test.cc index d5108decf..08842b211 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/proto/BUILD b/mediapipe/tasks/cc/vision/pose_landmarker/proto/BUILD index a2ad7b0b1..869a1ea60 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/proto/BUILD +++ b/mediapipe/tasks/cc/vision/pose_landmarker/proto/BUILD @@ -1,4 +1,4 @@ -# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# Copyright 2023 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.proto b/mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.proto index bde314bad..d11395669 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.proto +++ b/mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.proto b/mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.proto index 3b88491bb..9eb835d6a 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.proto +++ b/mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2023 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/utils/BUILD b/mediapipe/tasks/cc/vision/utils/BUILD index 7e5a4dc8c..bb84cf3f1 100644 --- a/mediapipe/tasks/cc/vision/utils/BUILD +++ b/mediapipe/tasks/cc/vision/utils/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,6 +33,7 @@ cc_library_with_tflite( "//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/tasks/metadata:metadata_schema_cc", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -60,6 +61,7 @@ cc_test_with_tflite( "//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/tasks/metadata:metadata_schema_cc", "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -110,3 +112,23 @@ cc_test( "//mediapipe/tasks/cc/components/containers:rect", ], ) + +cc_library( + name = "data_renderer", + srcs = ["data_renderer.cc"], + hdrs = ["data_renderer.h"], + deps = [ + "//mediapipe/calculators/util:annotation_overlay_calculator", + "//mediapipe/calculators/util:landmarks_to_render_data_calculator", + "//mediapipe/calculators/util:landmarks_to_render_data_calculator_cc_proto", + "//mediapipe/calculators/util:rect_to_render_data_calculator_cc_proto", + "//mediapipe/calculators/util:rect_to_render_scale_calculator", + "//mediapipe/calculators/util:rect_to_render_scale_calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/util:render_data_cc_proto", + "@com_google_absl//absl/types:span", + ], +) diff --git a/mediapipe/tasks/cc/vision/utils/data_renderer.cc b/mediapipe/tasks/cc/vision/utils/data_renderer.cc new file mode 100644 index 000000000..aeefbba2f --- /dev/null +++ b/mediapipe/tasks/cc/vision/utils/data_renderer.cc @@ -0,0 +1,88 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/utils/data_renderer.h" + +#include +#include +#include + +#include "absl/types/span.h" +#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h" +#include "mediapipe/calculators/util/rect_to_render_data_calculator.pb.h" +#include "mediapipe/calculators/util/rect_to_render_scale_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/util/render_data.pb.h" + +namespace mediapipe::tasks::vision::utils { + +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Stream; + +Stream Render(Stream image, + absl::Span> render_data_list, + Graph& graph) { + auto& annotation_overlay = graph.AddNode("AnnotationOverlayCalculator"); + image >> annotation_overlay.In("UIMAGE"); + for (int i = 0; i < render_data_list.size(); ++i) { + render_data_list[i] >> annotation_overlay.In(i); + } + return annotation_overlay.Out("UIMAGE").Cast(); +} + +Stream RenderLandmarks( + Stream landmarks, + std::optional> render_scale, + const mediapipe::LandmarksToRenderDataCalculatorOptions& renderer_options, + Graph& graph) { + auto& landmarks_render = graph.AddNode("LandmarksToRenderDataCalculator"); + landmarks_render + .GetOptions() + .CopyFrom(renderer_options); + landmarks >> landmarks_render.In("NORM_LANDMARKS"); + if (render_scale.has_value()) { + *render_scale >> landmarks_render.In("RENDER_SCALE"); + } + auto render_data = landmarks_render.Out("RENDER_DATA"); + return render_data.Cast(); +} + +Stream GetRenderScale(Stream> image_size, + Stream roi, float multiplier, + Graph& graph) { + auto& to_render_scale = graph.AddNode("RectToRenderScaleCalculator"); + to_render_scale.GetOptions() + .set_multiplier(multiplier); + roi >> to_render_scale.In("NORM_RECT"); + image_size >> to_render_scale.In("IMAGE_SIZE"); + return to_render_scale.Out("RENDER_SCALE").Cast(); +} + +Stream RenderRect( + Stream rect, + const mediapipe::RectToRenderDataCalculatorOptions& renderer_options, + Graph& graph) { + auto& rect_render = graph.AddNode("RectToRenderDataCalculator"); + rect_render.GetOptions() + .CopyFrom(renderer_options); + rect >> rect_render.In("NORM_RECT"); + auto render_data = rect_render.Out("RENDER_DATA"); + return render_data.Cast(); +} + +} // namespace mediapipe::tasks::vision::utils diff --git a/mediapipe/tasks/cc/vision/utils/data_renderer.h b/mediapipe/tasks/cc/vision/utils/data_renderer.h new file mode 100644 index 000000000..f58f94ee8 --- /dev/null +++ b/mediapipe/tasks/cc/vision/utils/data_renderer.h @@ -0,0 +1,69 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_UTILS_DATA_RENDERER_H_ +#define MEDIAPIPE_TASKS_CC_VISION_UTILS_DATA_RENDERER_H_ + +#include +#include + +#include "absl/types/span.h" +#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h" +#include "mediapipe/calculators/util/rect_to_render_data_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/util/render_data.pb.h" + +namespace mediapipe::tasks::vision::utils { + +// Adds a node to the provided graph that renders the render_data_list on the +// given image, and returns the rendered image. +api2::builder::Stream Render( + api2::builder::Stream image, + absl::Span> render_data_list, + api2::builder::Graph& graph); + +// Adds a node to the provided graph that infers the render scale from the image +// size and the object RoI. It will give you bigger rendered primitives for +// bigger/closer objects and smaller primitives for smaller/far objects. The +// primitives scale is proportional to `roi_size * multiplier`. +// +// See more details in +// mediapipe/calculators/util/rect_to_render_scale_calculator.cc +api2::builder::Stream GetRenderScale( + api2::builder::Stream> image_size, + api2::builder::Stream roi, float multiplier, + api2::builder::Graph& graph); + +// Adds a node to the provided graph that gets the landmarks render data +// according to the renderer_options. +api2::builder::Stream RenderLandmarks( + api2::builder::Stream landmarks, + std::optional> render_scale, + const mediapipe::LandmarksToRenderDataCalculatorOptions& renderer_options, + api2::builder::Graph& graph); + +// Adds a node to the provided graph that gets the rect render data according to +// the renderer_options. +api2::builder::Stream RenderRect( + api2::builder::Stream rect, + const mediapipe::RectToRenderDataCalculatorOptions& renderer_options, + api2::builder::Graph& graph); + +} // namespace mediapipe::tasks::vision::utils + +#endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_DATA_RENDERER_H_ diff --git a/mediapipe/tasks/cc/vision/utils/data_renderer_test.cc b/mediapipe/tasks/cc/vision/utils/data_renderer_test.cc new file mode 100644 index 000000000..b42c335b2 --- /dev/null +++ b/mediapipe/tasks/cc/vision/utils/data_renderer_test.cc @@ -0,0 +1,133 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/utils/data_renderer.h" + +#include +#include + +#include "absl/types/span.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/util/render_data.pb.h" + +namespace mediapipe::tasks::vision::utils { +namespace { + +using ::mediapipe::CalculatorGraphConfig; +using ::mediapipe::EqualsProto; +using ::mediapipe::NormalizedRect; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Stream; + +TEST(DataRenderer, Render) { + Graph graph; + Stream image_in = graph.In("IMAGE").Cast(); + Stream render_data_in = + graph.In("RENDER_DATA").Cast(); + std::vector> render_data_list = {render_data_in}; + Stream image_out = + Render(image_in, absl::Span>(render_data_list), graph); + image_out.SetName("image_out"); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "AnnotationOverlayCalculator" + input_stream: "__stream_1" + input_stream: "UIMAGE:__stream_0" + output_stream: "UIMAGE:image_out" + } + input_stream: "IMAGE:__stream_0" + input_stream: "RENDER_DATA:__stream_1" + )pb"))); +} + +TEST(DataRenderer, RenderLandmarks) { + Graph graph; + Stream rect = + graph.In("NORM_LANDMARKS").Cast(); + Stream render_data = + RenderLandmarks(rect, std::nullopt, {}, graph); + render_data.SetName("render_data"); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "LandmarksToRenderDataCalculator" + input_stream: "NORM_LANDMARKS:__stream_0" + output_stream: "RENDER_DATA:render_data" + options { + [mediapipe.LandmarksToRenderDataCalculatorOptions.ext] {} + } + } + input_stream: "NORM_LANDMARKS:__stream_0" + )pb"))); +} + +TEST(DataRenderer, GetRenderScale) { + Graph graph; + Stream> image_size = + graph.In("IMAGE_SIZE").Cast>(); + Stream roi = graph.In("ROI").Cast(); + Stream render_scale = GetRenderScale(image_size, roi, 0.0001, graph); + render_scale.SetName("render_scale"); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "RectToRenderScaleCalculator" + input_stream: "IMAGE_SIZE:__stream_0" + input_stream: "NORM_RECT:__stream_1" + output_stream: "RENDER_SCALE:render_scale" + options { + [mediapipe.RectToRenderScaleCalculatorOptions.ext] { + multiplier: 0.0001 + } + } + } + input_stream: "IMAGE_SIZE:__stream_0" + input_stream: "ROI:__stream_1" + )pb"))); +} + +TEST(DataRenderer, RenderRect) { + Graph graph; + Stream rect = graph.In("NORM_RECT").Cast(); + Stream render_data = RenderRect(rect, {}, graph); + render_data.SetName("render_data"); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "RectToRenderDataCalculator" + input_stream: "NORM_RECT:__stream_0" + output_stream: "RENDER_DATA:render_data" + options { + [mediapipe.RectToRenderDataCalculatorOptions.ext] {} + } + } + input_stream: "NORM_RECT:__stream_0" + )pb"))); +} + +} // namespace +} // namespace mediapipe::tasks::vision::utils diff --git a/mediapipe/tasks/cc/vision/utils/image_tensor_specs.cc b/mediapipe/tasks/cc/vision/utils/image_tensor_specs.cc index 3f0425a69..690cd6e5c 100644 --- a/mediapipe/tasks/cc/vision/utils/image_tensor_specs.cc +++ b/mediapipe/tasks/cc/vision/utils/image_tensor_specs.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -241,11 +242,12 @@ absl::StatusOr BuildInputImageTensorSpecs( absl::StatusOr BuildInputImageTensorSpecs( const core::ModelResources& model_resources) { const tflite::Model& model = *model_resources.GetTfLiteModel(); + // TODO: Investigate if there is any better solutions support + // running inference with multiple subgraphs. if (model.subgraphs()->size() != 1) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "Image tflite models are assumed to have a single subgraph.", - MediaPipeTasksStatus::kInvalidArgumentError); + ABSL_LOG(WARNING) + << "TFLite model has more than 1 subgraphs. Use subrgaph 0 as " + "the primary subgraph for inference"; } const auto* primary_subgraph = (*model.subgraphs())[0]; if (primary_subgraph->inputs()->size() != 1) { diff --git a/mediapipe/tasks/cc/vision/utils/image_tensor_specs.h b/mediapipe/tasks/cc/vision/utils/image_tensor_specs.h index f1ad2a807..d60b8a1da 100644 --- a/mediapipe/tasks/cc/vision/utils/image_tensor_specs.h +++ b/mediapipe/tasks/cc/vision/utils/image_tensor_specs.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/utils/image_tensor_specs_test.cc b/mediapipe/tasks/cc/vision/utils/image_tensor_specs_test.cc index 8c7b7d595..a10d1281c 100644 --- a/mediapipe/tasks/cc/vision/utils/image_tensor_specs_test.cc +++ b/mediapipe/tasks/cc/vision/utils/image_tensor_specs_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/flags/flag.h" +#include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" @@ -179,7 +180,7 @@ TEST_F(ImageTensorSpecsTest, BuildInputImageTensorSpecsFromModelResources) { core::ModelResources::Create(kTestModelResourcesTag, std::move(model_file))); const tflite::Model* model = model_resources->GetTfLiteModel(); - CHECK(model != nullptr); + ABSL_CHECK(model != nullptr); absl::StatusOr input_specs_or = BuildInputImageTensorSpecs(*model_resources); MP_ASSERT_OK(input_specs_or); diff --git a/mediapipe/tasks/cc/vision/utils/image_utils.cc b/mediapipe/tasks/cc/vision/utils/image_utils.cc index 4dc169f12..596d4da25 100644 --- a/mediapipe/tasks/cc/vision/utils/image_utils.cc +++ b/mediapipe/tasks/cc/vision/utils/image_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/utils/image_utils.h b/mediapipe/tasks/cc/vision/utils/image_utils.h index 23833de20..9d67125df 100644 --- a/mediapipe/tasks/cc/vision/utils/image_utils.h +++ b/mediapipe/tasks/cc/vision/utils/image_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h b/mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h index e1632e6f0..ac5627e62 100644 --- a/mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h +++ b/mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc b/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc index fe4e63824..6db046156 100644 --- a/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils.h b/mediapipe/tasks/cc/vision/utils/landmarks_utils.h index 4d1fac62f..59bc52773 100644 --- a/mediapipe/tasks/cc/vision/utils/landmarks_utils.h +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils_test.cc b/mediapipe/tasks/cc/vision/utils/landmarks_utils_test.cc index c30a5225b..55aef554d 100644 --- a/mediapipe/tasks/cc/vision/utils/landmarks_utils_test.cc +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. +/* Copyright 2022 The MediaPipe Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/AndroidManifest.xml b/mediapipe/tasks/examples/android/objectdetector/src/main/AndroidManifest.xml deleted file mode 100644 index 5c53dc269..000000000 --- a/mediapipe/tasks/examples/android/objectdetector/src/main/AndroidManifest.xml +++ /dev/null @@ -1,37 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD b/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD deleted file mode 100644 index 89c1edcb3..000000000 --- a/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -licenses(["notice"]) - -package(default_visibility = ["//visibility:private"]) - -android_binary( - name = "objectdetector", - srcs = glob(["**/*.java"]), - assets = [ - "//mediapipe/tasks/testdata/vision:test_models", - ], - assets_dir = "", - custom_package = "com.google.mediapipe.tasks.examples.objectdetector", - manifest = "AndroidManifest.xml", - manifest_values = { - "applicationId": "com.google.mediapipe.tasks.examples.objectdetector", - }, - multidex = "native", - resource_files = ["//mediapipe/tasks/examples/android:resource_files"], - deps = [ - "//mediapipe/java/com/google/mediapipe/framework:android_framework", - "//mediapipe/java/com/google/mediapipe/framework/image", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision:core", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision:objectdetector", - "//third_party:androidx_appcompat", - "//third_party:androidx_constraint_layout", - "//third_party:opencv", - "@maven//:androidx_activity_activity", - "@maven//:androidx_concurrent_concurrent_futures", - "@maven//:androidx_exifinterface_exifinterface", - "@maven//:androidx_fragment_fragment", - "@maven//:com_google_guava_guava", - ], -) diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java deleted file mode 100644 index 18c010a00..000000000 --- a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java +++ /dev/null @@ -1,239 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.mediapipe.tasks.examples.objectdetector; - -import android.content.Intent; -import android.graphics.Bitmap; -import android.media.MediaMetadataRetriever; -import android.os.Bundle; -import android.provider.MediaStore; -import androidx.appcompat.app.AppCompatActivity; -import android.util.Log; -import android.view.View; -import android.widget.Button; -import android.widget.FrameLayout; -import androidx.activity.result.ActivityResultLauncher; -import androidx.activity.result.contract.ActivityResultContracts; -import androidx.exifinterface.media.ExifInterface; -// ContentResolver dependency -import com.google.mediapipe.framework.MediaPipeException; -import com.google.mediapipe.framework.image.BitmapImageBuilder; -import com.google.mediapipe.framework.image.MPImage; -import com.google.mediapipe.tasks.core.BaseOptions; -import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; -import com.google.mediapipe.tasks.vision.core.RunningMode; -import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult; -import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector; -import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector.ObjectDetectorOptions; -import java.io.IOException; -import java.io.InputStream; - -/** Main activity of MediaPipe Task Object Detector reference app. */ -public class MainActivity extends AppCompatActivity { - private static final String TAG = "MainActivity"; - private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"; - - private ObjectDetector objectDetector; - - private enum InputSource { - UNKNOWN, - IMAGE, - VIDEO, - CAMERA, - } - - private InputSource inputSource = InputSource.UNKNOWN; - - // Image mode demo component. - private ActivityResultLauncher imageGetter; - // Video mode demo component. - private ActivityResultLauncher videoGetter; - private ObjectDetectionResultImageView imageView; - - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - setContentView(R.layout.activity_main); - setupImageModeDemo(); - setupVideoModeDemo(); - // TODO: Adds live camera demo. - } - - /** Sets up the image mode demo. */ - private void setupImageModeDemo() { - imageView = new ObjectDetectionResultImageView(this); - // The Intent to access gallery and read images as bitmap. - imageGetter = - registerForActivityResult( - new ActivityResultContracts.StartActivityForResult(), - result -> { - Intent resultIntent = result.getData(); - if (resultIntent != null) { - if (result.getResultCode() == RESULT_OK) { - Bitmap bitmap = null; - int rotation = 0; - try { - bitmap = - downscaleBitmap( - MediaStore.Images.Media.getBitmap( - this.getContentResolver(), resultIntent.getData())); - } catch (IOException e) { - Log.e(TAG, "Bitmap reading error:" + e); - } - try { - InputStream imageData = - this.getContentResolver().openInputStream(resultIntent.getData()); - rotation = getImageRotation(imageData); - } catch (IOException | MediaPipeException e) { - Log.e(TAG, "Bitmap rotation error:" + e); - } - if (bitmap != null) { - MPImage image = new BitmapImageBuilder(bitmap).build(); - ObjectDetectionResult detectionResult = - objectDetector.detect( - image, - ImageProcessingOptions.builder().setRotationDegrees(rotation).build()); - imageView.setData(image, detectionResult); - runOnUiThread(() -> imageView.update()); - } - } - } - }); - Button loadImageButton = findViewById(R.id.button_load_picture); - loadImageButton.setOnClickListener( - v -> { - if (inputSource != InputSource.IMAGE) { - createObjectDetector(RunningMode.IMAGE); - this.inputSource = InputSource.IMAGE; - updateLayout(); - } - // Reads images from gallery. - Intent pickImageIntent = new Intent(Intent.ACTION_PICK); - pickImageIntent.setDataAndType(MediaStore.Images.Media.INTERNAL_CONTENT_URI, "image/*"); - imageGetter.launch(pickImageIntent); - }); - } - - /** Sets up the video mode demo. */ - private void setupVideoModeDemo() { - imageView = new ObjectDetectionResultImageView(this); - // The Intent to access gallery and read a video file. - videoGetter = - registerForActivityResult( - new ActivityResultContracts.StartActivityForResult(), - result -> { - Intent resultIntent = result.getData(); - if (resultIntent != null) { - if (result.getResultCode() == RESULT_OK) { - MediaMetadataRetriever metaRetriever = new MediaMetadataRetriever(); - metaRetriever.setDataSource(this, resultIntent.getData()); - long duration = - Long.parseLong( - metaRetriever.extractMetadata( - MediaMetadataRetriever.METADATA_KEY_DURATION)); - int numFrames = - Integer.parseInt( - metaRetriever.extractMetadata( - MediaMetadataRetriever.METADATA_KEY_VIDEO_FRAME_COUNT)); - long frameIntervalMs = duration / numFrames; - for (int i = 0; i < numFrames; ++i) { - MPImage image = - new BitmapImageBuilder(metaRetriever.getFrameAtIndex(i)).build(); - ObjectDetectionResult detectionResult = - objectDetector.detectForVideo(image, frameIntervalMs * i); - // Currently only annotates the detection result on the first video frame and - // display it to verify the correctness. - // TODO: Annotates the detection result on every frame, save the - // annotated frames as a video file, and play back the video afterwards. - if (i == 0) { - imageView.setData(image, detectionResult); - runOnUiThread(() -> imageView.update()); - } - } - } - } - }); - Button loadVideoButton = findViewById(R.id.button_load_video); - loadVideoButton.setOnClickListener( - v -> { - createObjectDetector(RunningMode.VIDEO); - updateLayout(); - this.inputSource = InputSource.VIDEO; - - // Reads a video from gallery. - Intent pickVideoIntent = new Intent(Intent.ACTION_PICK); - pickVideoIntent.setDataAndType(MediaStore.Video.Media.INTERNAL_CONTENT_URI, "video/*"); - videoGetter.launch(pickVideoIntent); - }); - } - - private void createObjectDetector(RunningMode mode) { - if (objectDetector != null) { - objectDetector.close(); - } - // Initializes a new MediaPipe ObjectDetector instance - ObjectDetectorOptions options = - ObjectDetectorOptions.builder() - .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) - .setScoreThreshold(0.5f) - .setMaxResults(5) - .setRunningMode(mode) - .build(); - objectDetector = ObjectDetector.createFromOptions(this, options); - } - - private void updateLayout() { - // Updates the preview layout. - FrameLayout frameLayout = findViewById(R.id.preview_display_layout); - frameLayout.removeAllViewsInLayout(); - imageView.setImageDrawable(null); - frameLayout.addView(imageView); - imageView.setVisibility(View.VISIBLE); - } - - private Bitmap downscaleBitmap(Bitmap originalBitmap) { - double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight(); - int width = imageView.getWidth(); - int height = imageView.getHeight(); - if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) { - width = (int) (height * aspectRatio); - } else { - height = (int) (width / aspectRatio); - } - return Bitmap.createScaledBitmap(originalBitmap, width, height, false); - } - - private int getImageRotation(InputStream imageData) throws IOException, MediaPipeException { - int orientation = - new ExifInterface(imageData) - .getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL); - switch (orientation) { - case ExifInterface.ORIENTATION_NORMAL: - return 0; - case ExifInterface.ORIENTATION_ROTATE_90: - return 90; - case ExifInterface.ORIENTATION_ROTATE_180: - return 180; - case ExifInterface.ORIENTATION_ROTATE_270: - return 270; - default: - // TODO: use getRotationDegrees() and isFlipped() instead of switch once flip - // is supported. - throw new MediaPipeException( - MediaPipeException.StatusCode.UNIMPLEMENTED.ordinal(), - "Flipped images are not supported yet."); - } - } -} diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java deleted file mode 100644 index 283e48857..000000000 --- a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.mediapipe.tasks.examples.objectdetector; - -import android.content.Context; -import android.graphics.Bitmap; -import android.graphics.Canvas; -import android.graphics.Color; -import android.graphics.Matrix; -import android.graphics.Paint; -import androidx.appcompat.widget.AppCompatImageView; -import com.google.mediapipe.framework.image.BitmapExtractor; -import com.google.mediapipe.framework.image.MPImage; -import com.google.mediapipe.tasks.components.containers.Detection; -import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult; - -/** An ImageView implementation for displaying {@link ObjectDetectionResult}. */ -public class ObjectDetectionResultImageView extends AppCompatImageView { - private static final String TAG = "ObjectDetectionResultImageView"; - - private static final int BBOX_COLOR = Color.GREEN; - private static final int BBOX_THICKNESS = 5; // Pixels - private Bitmap latest; - - public ObjectDetectionResultImageView(Context context) { - super(context); - setScaleType(AppCompatImageView.ScaleType.FIT_CENTER); - } - - /** - * Sets a {@link MPImage} and an {@link ObjectDetectionResult} to render. - * - * @param image a {@link MPImage} object for annotation. - * @param result an {@link ObjectDetectionResult} object that contains the detection result. - */ - public void setData(MPImage image, ObjectDetectionResult result) { - if (image == null || result == null) { - return; - } - latest = BitmapExtractor.extract(image); - Canvas canvas = new Canvas(latest); - canvas.drawBitmap(latest, new Matrix(), null); - for (int i = 0; i < result.detections().size(); ++i) { - drawDetectionOnCanvas(result.detections().get(i), canvas); - } - } - - /** Updates the image view with the latest {@link ObjectDetectionResult}. */ - public void update() { - postInvalidate(); - if (latest != null) { - setImageBitmap(latest); - } - } - - private void drawDetectionOnCanvas(Detection detection, Canvas canvas) { - // TODO: Draws the category and the score per bounding box. - // Draws bounding box. - Paint bboxPaint = new Paint(); - bboxPaint.setColor(BBOX_COLOR); - bboxPaint.setStyle(Paint.Style.STROKE); - bboxPaint.setStrokeWidth(BBOX_THICKNESS); - canvas.drawRect(detection.boundingBox(), bboxPaint); - } -} diff --git a/mediapipe/tasks/examples/android/res/drawable-v24/ic_launcher_foreground.xml b/mediapipe/tasks/examples/android/res/drawable-v24/ic_launcher_foreground.xml deleted file mode 100644 index c7bd21dbd..000000000 --- a/mediapipe/tasks/examples/android/res/drawable-v24/ic_launcher_foreground.xml +++ /dev/null @@ -1,34 +0,0 @@ - - - - - - - - - - - diff --git a/mediapipe/tasks/examples/android/res/drawable/ic_launcher_background.xml b/mediapipe/tasks/examples/android/res/drawable/ic_launcher_background.xml deleted file mode 100644 index 01f0af0ad..000000000 --- a/mediapipe/tasks/examples/android/res/drawable/ic_launcher_background.xml +++ /dev/null @@ -1,74 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/mediapipe/tasks/examples/android/res/layout/activity_main.xml b/mediapipe/tasks/examples/android/res/layout/activity_main.xml deleted file mode 100644 index 834e9a3e6..000000000 --- a/mediapipe/tasks/examples/android/res/layout/activity_main.xml +++ /dev/null @@ -1,40 +0,0 @@ - - - -