Merge branch 'google:master' into face-stylizer-python-add-tests
This commit is contained in:
		
						commit
						27d8a1cb4f
					
				
							
								
								
									
										3
									
								
								.bazelrc
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								.bazelrc
									
									
									
									
									
								
							| 
						 | 
					@ -87,6 +87,9 @@ build:ios_fat --config=ios
 | 
				
			||||||
build:ios_fat --ios_multi_cpus=armv7,arm64
 | 
					build:ios_fat --ios_multi_cpus=armv7,arm64
 | 
				
			||||||
build:ios_fat --watchos_cpus=armv7k
 | 
					build:ios_fat --watchos_cpus=armv7k
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					build:ios_sim_fat --config=ios
 | 
				
			||||||
 | 
					build:ios_sim_fat --ios_multi_cpus=x86_64,sim_arm64
 | 
				
			||||||
 | 
					
 | 
				
			||||||
build:darwin_x86_64 --apple_platform_type=macos
 | 
					build:darwin_x86_64 --apple_platform_type=macos
 | 
				
			||||||
build:darwin_x86_64 --macos_minimum_os=10.12
 | 
					build:darwin_x86_64 --macos_minimum_os=10.12
 | 
				
			||||||
build:darwin_x86_64 --cpu=darwin_x86_64
 | 
					build:darwin_x86_64 --cpu=darwin_x86_64
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										34
									
								
								.github/stale.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										34
									
								
								.github/stale.yml
									
									
									
									
										vendored
									
									
								
							| 
						 | 
					@ -1,34 +0,0 @@
 | 
				
			||||||
# Copyright 2021 The MediaPipe Authors.
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					 | 
				
			||||||
# You may obtain a copy of the License at
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# Unless required by applicable law or agreed to in writing, software
 | 
					 | 
				
			||||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
					 | 
				
			||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
					 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					 | 
				
			||||||
# limitations under the License.
 | 
					 | 
				
			||||||
# ============================================================================
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# This file was assembled from multiple pieces, whose use is documented
 | 
					 | 
				
			||||||
# throughout. Please refer to the TensorFlow dockerfiles documentation
 | 
					 | 
				
			||||||
# for more information.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Number of days of inactivity before an Issue or Pull Request becomes stale
 | 
					 | 
				
			||||||
daysUntilStale: 7
 | 
					 | 
				
			||||||
# Number of days of inactivity before a stale Issue or Pull Request is closed
 | 
					 | 
				
			||||||
daysUntilClose: 7
 | 
					 | 
				
			||||||
# Only issues or pull requests with all of these labels are checked if stale. Defaults to `[]` (disabled)
 | 
					 | 
				
			||||||
onlyLabels:
 | 
					 | 
				
			||||||
 - stat:awaiting response
 | 
					 | 
				
			||||||
# Comment to post when marking as stale. Set to `false` to disable
 | 
					 | 
				
			||||||
markComment: >
 | 
					 | 
				
			||||||
  This issue has been automatically marked as stale because it has not had
 | 
					 | 
				
			||||||
  recent activity. It will be closed if no further activity occurs. Thank you.
 | 
					 | 
				
			||||||
# Comment to post when removing the stale label. Set to `false` to disable
 | 
					 | 
				
			||||||
unmarkComment: false
 | 
					 | 
				
			||||||
closeComment: >
 | 
					 | 
				
			||||||
  Closing as stale. Please reopen if you'd like to work on this further.
 | 
					 | 
				
			||||||
							
								
								
									
										66
									
								
								.github/workflows/stale.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								.github/workflows/stale.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,66 @@
 | 
				
			||||||
 | 
					# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					# ==============================================================================
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# This workflow alerts and then closes the stale issues/PRs after specific time
 | 
				
			||||||
 | 
					# You can adjust the behavior by modifying this file.
 | 
				
			||||||
 | 
					# For more information, see:
 | 
				
			||||||
 | 
					# https://github.com/actions/stale
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					name: 'Close stale issues and PRs'
 | 
				
			||||||
 | 
					"on":
 | 
				
			||||||
 | 
					  schedule:
 | 
				
			||||||
 | 
					  - cron: "30 1 * * *"
 | 
				
			||||||
 | 
					permissions:
 | 
				
			||||||
 | 
					  contents: read
 | 
				
			||||||
 | 
					  issues: write
 | 
				
			||||||
 | 
					  pull-requests: write
 | 
				
			||||||
 | 
					jobs:
 | 
				
			||||||
 | 
					  stale:
 | 
				
			||||||
 | 
					    runs-on: ubuntu-latest
 | 
				
			||||||
 | 
					    steps:
 | 
				
			||||||
 | 
					    - uses: 'actions/stale@v7'
 | 
				
			||||||
 | 
					      with:
 | 
				
			||||||
 | 
					        # Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale.
 | 
				
			||||||
 | 
					        exempt-issue-labels: 'override-stale'
 | 
				
			||||||
 | 
					        # Comma separated list of labels that can be assigned to PRs to exclude them from being marked as stale.
 | 
				
			||||||
 | 
					        exempt-pr-labels: "override-stale"
 | 
				
			||||||
 | 
					        # Limit the No. of API calls in one run default value is 30.
 | 
				
			||||||
 | 
					        operations-per-run: 500
 | 
				
			||||||
 | 
					        # Prevent to remove stale label when PRs or issues are updated.
 | 
				
			||||||
 | 
					        remove-stale-when-updated: false
 | 
				
			||||||
 | 
					        # comment on issue if not active for more then 7 days.
 | 
				
			||||||
 | 
					        stale-issue-message: 'This issue has been marked stale because it has no recent activity since 7 days. It will be closed if no further activity occurs. Thank you.'
 | 
				
			||||||
 | 
					        # comment on PR if not active for more then 14 days.
 | 
				
			||||||
 | 
					        stale-pr-message: 'This PR has been marked stale because it has no recent activity since 14 days. It will be closed if no further activity occurs. Thank you.'
 | 
				
			||||||
 | 
					        # comment on issue if stale for more then 7 days.
 | 
				
			||||||
 | 
					        close-issue-message: This issue was closed due to lack of activity after being marked stale for past 7 days.
 | 
				
			||||||
 | 
					        # comment on PR if stale for more then 14 days.
 | 
				
			||||||
 | 
					        close-pr-message: This PR was closed due to lack of activity after being marked stale for past 14 days.
 | 
				
			||||||
 | 
					        # Number of days of inactivity before an Issue Request becomes stale
 | 
				
			||||||
 | 
					        days-before-issue-stale: 7
 | 
				
			||||||
 | 
					        # Number of days of inactivity before a stale Issue is closed
 | 
				
			||||||
 | 
					        days-before-issue-close: 7
 | 
				
			||||||
 | 
					        # reason for closed the issue default value is not_planned
 | 
				
			||||||
 | 
					        close-issue-reason: completed
 | 
				
			||||||
 | 
					        # Number of days of inactivity before a stale PR is closed
 | 
				
			||||||
 | 
					        days-before-pr-close: 14
 | 
				
			||||||
 | 
					        # Number of days of inactivity before an PR Request becomes stale
 | 
				
			||||||
 | 
					        days-before-pr-stale: 14
 | 
				
			||||||
 | 
					        # Check for label to stale or close the issue/PR
 | 
				
			||||||
 | 
					        any-of-labels: 'stat:awaiting response'
 | 
				
			||||||
 | 
					        # override stale to stalled for PR
 | 
				
			||||||
 | 
					        stale-pr-label: 'stale'
 | 
				
			||||||
 | 
					        # override stale to stalled for Issue
 | 
				
			||||||
 | 
					        stale-issue-label: "stale"
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -433,9 +433,9 @@ absl::Status SpectrogramCalculator::ProcessVectorToOutput(
 | 
				
			||||||
absl::Status SpectrogramCalculator::ProcessVector(const Matrix& input_stream,
 | 
					absl::Status SpectrogramCalculator::ProcessVector(const Matrix& input_stream,
 | 
				
			||||||
                                                  CalculatorContext* cc) {
 | 
					                                                  CalculatorContext* cc) {
 | 
				
			||||||
  switch (output_type_) {
 | 
					  switch (output_type_) {
 | 
				
			||||||
    // These blocks deliberately ignore clang-format to preserve the
 | 
					      // These blocks deliberately ignore clang-format to preserve the
 | 
				
			||||||
    // "silhouette" of the different cases.
 | 
					      // "silhouette" of the different cases.
 | 
				
			||||||
    // clang-format off
 | 
					      // clang-format off
 | 
				
			||||||
    case SpectrogramCalculatorOptions::COMPLEX: {
 | 
					    case SpectrogramCalculatorOptions::COMPLEX: {
 | 
				
			||||||
      return ProcessVectorToOutput(
 | 
					      return ProcessVectorToOutput(
 | 
				
			||||||
          input_stream,
 | 
					          input_stream,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -400,6 +400,16 @@ cc_library(
 | 
				
			||||||
# compile your binary with the flag TENSORFLOW_PROTOS=lite.
 | 
					# compile your binary with the flag TENSORFLOW_PROTOS=lite.
 | 
				
			||||||
cc_library(
 | 
					cc_library(
 | 
				
			||||||
    name = "tensorflow_inference_calculator_no_envelope_loader",
 | 
					    name = "tensorflow_inference_calculator_no_envelope_loader",
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":tensorflow_inference_calculator_for_boq",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    alwayslink = 1,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# This dependency removed tensorflow_jellyfish_deps and xprofilez_with_server because they failed
 | 
				
			||||||
 | 
					# Boq conformance test. Weigh your use case to see if this will work for you.
 | 
				
			||||||
 | 
					cc_library(
 | 
				
			||||||
 | 
					    name = "tensorflow_inference_calculator_for_boq",
 | 
				
			||||||
    srcs = ["tensorflow_inference_calculator.cc"],
 | 
					    srcs = ["tensorflow_inference_calculator.cc"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
        ":tensorflow_inference_calculator_cc_proto",
 | 
					        ":tensorflow_inference_calculator_cc_proto",
 | 
				
			||||||
| 
						 | 
					@ -585,6 +595,24 @@ cc_library(
 | 
				
			||||||
# See yaqs/1092546221614039040
 | 
					# See yaqs/1092546221614039040
 | 
				
			||||||
cc_library(
 | 
					cc_library(
 | 
				
			||||||
    name = "tensorflow_session_from_saved_model_generator_no_envelope_loader",
 | 
					    name = "tensorflow_session_from_saved_model_generator_no_envelope_loader",
 | 
				
			||||||
 | 
					    defines = select({
 | 
				
			||||||
 | 
					        "//mediapipe:android": ["__ANDROID__"],
 | 
				
			||||||
 | 
					        "//conditions:default": [],
 | 
				
			||||||
 | 
					    }),
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":tensorflow_session_from_saved_model_generator_for_boq",
 | 
				
			||||||
 | 
					    ] + select({
 | 
				
			||||||
 | 
					        "//conditions:default": [
 | 
				
			||||||
 | 
					            "//learning/brain/frameworks/uptc/public:uptc_session_no_envelope_loader",
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					    }),
 | 
				
			||||||
 | 
					    alwayslink = 1,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Same library as tensorflow_session_from_saved_model_generator without uptc_session,
 | 
				
			||||||
 | 
					# envelop_loader and remote_session dependencies.
 | 
				
			||||||
 | 
					cc_library(
 | 
				
			||||||
 | 
					    name = "tensorflow_session_from_saved_model_generator_for_boq",
 | 
				
			||||||
    srcs = ["tensorflow_session_from_saved_model_generator.cc"],
 | 
					    srcs = ["tensorflow_session_from_saved_model_generator.cc"],
 | 
				
			||||||
    defines = select({
 | 
					    defines = select({
 | 
				
			||||||
        "//mediapipe:android": ["__ANDROID__"],
 | 
					        "//mediapipe:android": ["__ANDROID__"],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,6 +18,9 @@ package mediapipe;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import "mediapipe/framework/calculator.proto";
 | 
					import "mediapipe/framework/calculator.proto";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					option java_package = "com.google.mediapipe.calculator.proto";
 | 
				
			||||||
 | 
					option java_outer_classname = "LogicCalculatorOptionsProto";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
message LogicCalculatorOptions {
 | 
					message LogicCalculatorOptions {
 | 
				
			||||||
  extend CalculatorOptions {
 | 
					  extend CalculatorOptions {
 | 
				
			||||||
    optional LogicCalculatorOptions ext = 338731246;
 | 
					    optional LogicCalculatorOptions ext = 338731246;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -223,6 +223,16 @@ class SourceImpl {
 | 
				
			||||||
    return !(*this == other);
 | 
					    return !(*this == other);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Src& SetName(const char* name) {
 | 
				
			||||||
 | 
					    base_->name_ = std::string(name);
 | 
				
			||||||
 | 
					    return *this;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Src& SetName(absl::string_view name) {
 | 
				
			||||||
 | 
					    base_->name_ = std::string(name);
 | 
				
			||||||
 | 
					    return *this;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  Src& SetName(std::string name) {
 | 
					  Src& SetName(std::string name) {
 | 
				
			||||||
    base_->name_ = std::move(name);
 | 
					    base_->name_ = std::move(name);
 | 
				
			||||||
    return *this;
 | 
					    return *this;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -467,6 +467,11 @@ class SideFallbackT : public Base {
 | 
				
			||||||
// CalculatorContext (e.g. kOut(cc)), and provides a type-safe interface to
 | 
					// CalculatorContext (e.g. kOut(cc)), and provides a type-safe interface to
 | 
				
			||||||
// OutputStreamShard. Like that class, this class will not be usually named in
 | 
					// OutputStreamShard. Like that class, this class will not be usually named in
 | 
				
			||||||
// calculator code, but used as a temporary object (e.g. kOut(cc).Send(...)).
 | 
					// calculator code, but used as a temporary object (e.g. kOut(cc).Send(...)).
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// If not connected (!IsConnected()) SetNextTimestampBound is safe to call and
 | 
				
			||||||
 | 
					// does nothing.
 | 
				
			||||||
 | 
					// All the sub-classes that define Send should implement it to be safe to to
 | 
				
			||||||
 | 
					// call if not connected and do nothing in such case.
 | 
				
			||||||
class OutputShardAccessBase {
 | 
					class OutputShardAccessBase {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  OutputShardAccessBase(const CalculatorContext& cc, OutputStreamShard* output)
 | 
					  OutputShardAccessBase(const CalculatorContext& cc, OutputStreamShard* output)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										53
									
								
								mediapipe/framework/tool/ios.bzl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								mediapipe/framework/tool/ios.bzl
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,53 @@
 | 
				
			||||||
 | 
					# Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""MediaPipe Task Library Helper Rules for iOS"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					MPP_TASK_MINIMUM_OS_VERSION = "11.0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# When the static framework is built with bazel, the all header files are moved
 | 
				
			||||||
 | 
					# to the "Headers" directory with no header path prefixes. This auxiliary rule
 | 
				
			||||||
 | 
					# is used for stripping the path prefix to the C/iOS API header files included by
 | 
				
			||||||
 | 
					# other C/iOS API header files.
 | 
				
			||||||
 | 
					# In case of C header files, includes start with a keyword of "#include'.
 | 
				
			||||||
 | 
					# Imports in iOS header files start with a keyword of '#import'.
 | 
				
			||||||
 | 
					def strip_api_include_path_prefix(name, hdr_labels, prefix = ""):
 | 
				
			||||||
 | 
					    """Create modified header files with the import path stripped out.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      name: The name to be used as a prefix to the generated genrules.
 | 
				
			||||||
 | 
					      hdr_labels: List of header labels to strip out the include path. Each
 | 
				
			||||||
 | 
					          label must end with a colon followed by the header file name.
 | 
				
			||||||
 | 
					      prefix: Optional prefix path to prepend to the header inclusion path.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    for hdr_label in hdr_labels:
 | 
				
			||||||
 | 
					        hdr_filename = hdr_label.split(":")[-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # The last path component of iOS header files is sources/some_file.h
 | 
				
			||||||
 | 
					        # Hence it wiill contain a '/'. So the string can be split at '/' to get
 | 
				
			||||||
 | 
					        # the header file name.
 | 
				
			||||||
 | 
					        if "/" in hdr_filename:
 | 
				
			||||||
 | 
					            hdr_filename = hdr_filename.split("/")[-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hdr_basename = hdr_filename.split(".")[0]
 | 
				
			||||||
 | 
					        native.genrule(
 | 
				
			||||||
 | 
					            name = "{}_{}".format(name, hdr_basename),
 | 
				
			||||||
 | 
					            srcs = [hdr_label],
 | 
				
			||||||
 | 
					            outs = [hdr_filename],
 | 
				
			||||||
 | 
					            cmd = """
 | 
				
			||||||
 | 
					            sed 's|#\\([a-z]*\\) ".*/\\([^/]\\{{1,\\}}\\.h\\)"|#\\1 "{}\\2"|'\
 | 
				
			||||||
 | 
					            "$(location {})"\
 | 
				
			||||||
 | 
					            > "$@"
 | 
				
			||||||
 | 
					            """.format(prefix, hdr_label),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
| 
						 | 
					@ -60,7 +60,7 @@ std::string GetUnusedSidePacketName(
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  std::string candidate = input_side_packet_name_base;
 | 
					  std::string candidate = input_side_packet_name_base;
 | 
				
			||||||
  int iter = 2;
 | 
					  int iter = 2;
 | 
				
			||||||
  while (mediapipe::ContainsKey(input_side_packets, candidate)) {
 | 
					  while (input_side_packets.contains(candidate)) {
 | 
				
			||||||
    candidate = absl::StrCat(input_side_packet_name_base, "_",
 | 
					    candidate = absl::StrCat(input_side_packet_name_base, "_",
 | 
				
			||||||
                             absl::StrFormat("%02d", iter));
 | 
					                             absl::StrFormat("%02d", iter));
 | 
				
			||||||
    ++iter;
 | 
					    ++iter;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -101,7 +101,7 @@ void TestSuccessTagMap(const std::vector<std::string>& tag_index_names,
 | 
				
			||||||
  EXPECT_EQ(tags.size(), tag_map->Mapping().size())
 | 
					  EXPECT_EQ(tags.size(), tag_map->Mapping().size())
 | 
				
			||||||
      << "Parameters: in " << tag_map->DebugString();
 | 
					      << "Parameters: in " << tag_map->DebugString();
 | 
				
			||||||
  for (int i = 0; i < tags.size(); ++i) {
 | 
					  for (int i = 0; i < tags.size(); ++i) {
 | 
				
			||||||
    EXPECT_TRUE(mediapipe::ContainsKey(tag_map->Mapping(), tags[i]))
 | 
					    EXPECT_TRUE(tag_map->Mapping().contains(tags[i]))
 | 
				
			||||||
        << "Parameters: Trying to find \"" << tags[i] << "\" in\n"
 | 
					        << "Parameters: Trying to find \"" << tags[i] << "\" in\n"
 | 
				
			||||||
        << tag_map->DebugString();
 | 
					        << tag_map->DebugString();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -14,8 +14,10 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
package com.google.mediapipe.components;
 | 
					package com.google.mediapipe.components;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import javax.annotation.Nullable;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/** Lightweight abstraction for an object that can produce audio data. */
 | 
					/** Lightweight abstraction for an object that can produce audio data. */
 | 
				
			||||||
public interface AudioDataProducer {
 | 
					public interface AudioDataProducer {
 | 
				
			||||||
  /** Set the consumer that receives the audio data from this producer. */
 | 
					  /** Set the consumer that receives the audio data from this producer. */
 | 
				
			||||||
  void setAudioConsumer(AudioDataConsumer consumer);
 | 
					  void setAudioConsumer(@Nullable AudioDataConsumer consumer);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -71,7 +71,10 @@ android_library(
 | 
				
			||||||
        "AudioDataProducer.java",
 | 
					        "AudioDataProducer.java",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
    visibility = ["//visibility:public"],
 | 
					    visibility = ["//visibility:public"],
 | 
				
			||||||
    deps = ["@maven//:com_google_guava_guava"],
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "@maven//:com_google_code_findbugs_jsr305",
 | 
				
			||||||
 | 
					        "@maven//:com_google_guava_guava",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# MicrophoneHelper that provides access to audio data from a microphone
 | 
					# MicrophoneHelper that provides access to audio data from a microphone
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -231,7 +231,7 @@ public class GlSurfaceViewRenderer implements GLSurfaceView.Renderer {
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /** Returns the texture left, right, bottom, and top visible boundaries. */
 | 
					  /** Returns the texture left, right, bottom, and top visible boundaries. */
 | 
				
			||||||
  protected float[] calculateTextureBoundary() {
 | 
					  public float[] calculateTextureBoundary() {
 | 
				
			||||||
    // TODO: compute scale from surfaceTexture size.
 | 
					    // TODO: compute scale from surfaceTexture size.
 | 
				
			||||||
    float scaleWidth = frameWidth > 0 ? (float) surfaceWidth / (float) frameWidth : 1.0f;
 | 
					    float scaleWidth = frameWidth > 0 ? (float) surfaceWidth / (float) frameWidth : 1.0f;
 | 
				
			||||||
    float scaleHeight = frameHeight > 0 ? (float) surfaceHeight / (float) frameHeight : 1.0f;
 | 
					    float scaleHeight = frameHeight > 0 ? (float) surfaceHeight / (float) frameHeight : 1.0f;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					/* Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
you may not use this file except in compliance with the License.
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2019-2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2019-2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the 'License');
 | 
					# Licensed under the Apache License, Version 2.0 (the 'License');
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the 'License');
 | 
					# Licensed under the Apache License, Version 2.0 (the 'License');
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					@ -67,11 +67,18 @@ py_library(
 | 
				
			||||||
    name = "loss_functions",
 | 
					    name = "loss_functions",
 | 
				
			||||||
    srcs = ["loss_functions.py"],
 | 
					    srcs = ["loss_functions.py"],
 | 
				
			||||||
    srcs_version = "PY3",
 | 
					    srcs_version = "PY3",
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":file_util",
 | 
				
			||||||
 | 
					        ":model_util",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
py_test(
 | 
					py_test(
 | 
				
			||||||
    name = "loss_functions_test",
 | 
					    name = "loss_functions_test",
 | 
				
			||||||
    srcs = ["loss_functions_test.py"],
 | 
					    srcs = ["loss_functions_test.py"],
 | 
				
			||||||
 | 
					    tags = [
 | 
				
			||||||
 | 
					        "requires-net:external",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
    deps = [":loss_functions"],
 | 
					    deps = [":loss_functions"],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					@ -13,10 +13,21 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
"""Loss function utility library."""
 | 
					"""Loss function utility library."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Optional, Sequence
 | 
					import abc
 | 
				
			||||||
 | 
					from typing import Mapping, Sequence
 | 
				
			||||||
 | 
					import dataclasses
 | 
				
			||||||
 | 
					from typing import Any, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
import tensorflow as tf
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.core.utils import file_util
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.core.utils import model_util
 | 
				
			||||||
 | 
					from official.modeling import tf_utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_VGG_IMAGENET_PERCEPTUAL_MODEL_URL = 'https://storage.googleapis.com/mediapipe-assets/vgg_feature_extractor.tar.gz'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class FocalLoss(tf.keras.losses.Loss):
 | 
					class FocalLoss(tf.keras.losses.Loss):
 | 
				
			||||||
  """Implementation of focal loss (https://arxiv.org/pdf/1708.02002.pdf).
 | 
					  """Implementation of focal loss (https://arxiv.org/pdf/1708.02002.pdf).
 | 
				
			||||||
| 
						 | 
					@ -45,7 +56,6 @@ class FocalLoss(tf.keras.losses.Loss):
 | 
				
			||||||
  ```python
 | 
					  ```python
 | 
				
			||||||
  model.compile(optimizer='sgd', loss=FocalLoss(gamma))
 | 
					  model.compile(optimizer='sgd', loss=FocalLoss(gamma))
 | 
				
			||||||
  ```
 | 
					  ```
 | 
				
			||||||
 | 
					 | 
				
			||||||
  """
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
 | 
					  def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
 | 
				
			||||||
| 
						 | 
					@ -103,3 +113,252 @@ class FocalLoss(tf.keras.losses.Loss):
 | 
				
			||||||
    # By default, this function uses "sum_over_batch_size" reduction for the
 | 
					    # By default, this function uses "sum_over_batch_size" reduction for the
 | 
				
			||||||
    # loss per batch.
 | 
					    # loss per batch.
 | 
				
			||||||
    return tf.reduce_sum(losses) / batch_size
 | 
					    return tf.reduce_sum(losses) / batch_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclasses.dataclass
 | 
				
			||||||
 | 
					class PerceptualLossWeight:
 | 
				
			||||||
 | 
					  """The weight for each perceptual loss.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Attributes:
 | 
				
			||||||
 | 
					    l1: weight for L1 loss.
 | 
				
			||||||
 | 
					    content: weight for content loss.
 | 
				
			||||||
 | 
					    style: weight for style loss.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  l1: float = 1.0
 | 
				
			||||||
 | 
					  content: float = 1.0
 | 
				
			||||||
 | 
					  style: float = 1.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ImagePerceptualQualityLoss(tf.keras.losses.Loss):
 | 
				
			||||||
 | 
					  """Image perceptual quality loss.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  It obtains a weighted loss between the VGGPerceptualLoss and L1 loss.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __init__(
 | 
				
			||||||
 | 
					      self,
 | 
				
			||||||
 | 
					      loss_weight: Optional[PerceptualLossWeight] = None,
 | 
				
			||||||
 | 
					      reduction: tf.keras.losses.Reduction = tf.keras.losses.Reduction.NONE,
 | 
				
			||||||
 | 
					  ):
 | 
				
			||||||
 | 
					    """Initializes ImagePerceptualQualityLoss."""
 | 
				
			||||||
 | 
					    self._loss_weight = loss_weight
 | 
				
			||||||
 | 
					    self._losses = {}
 | 
				
			||||||
 | 
					    self._vgg_loss = VGGPerceptualLoss(self._loss_weight)
 | 
				
			||||||
 | 
					    self._reduction = reduction
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _l1_loss(
 | 
				
			||||||
 | 
					      self,
 | 
				
			||||||
 | 
					      reduction: tf.keras.losses.Reduction = tf.keras.losses.Reduction.NONE,
 | 
				
			||||||
 | 
					  ) -> Any:
 | 
				
			||||||
 | 
					    """Calculates L1 loss."""
 | 
				
			||||||
 | 
					    return tf.keras.losses.MeanAbsoluteError(reduction)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __call__(
 | 
				
			||||||
 | 
					      self,
 | 
				
			||||||
 | 
					      img1: tf.Tensor,
 | 
				
			||||||
 | 
					      img2: tf.Tensor,
 | 
				
			||||||
 | 
					  ) -> tf.Tensor:
 | 
				
			||||||
 | 
					    """Computes image perceptual quality loss."""
 | 
				
			||||||
 | 
					    loss_value = []
 | 
				
			||||||
 | 
					    if self._loss_weight is None:
 | 
				
			||||||
 | 
					      self._loss_weight = PerceptualLossWeight()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self._loss_weight.content > 0 or self._loss_weight.style > 0:
 | 
				
			||||||
 | 
					      vgg_loss = self._vgg_loss(img1, img2)
 | 
				
			||||||
 | 
					      vgg_loss_value = tf.math.add_n(vgg_loss.values())
 | 
				
			||||||
 | 
					      loss_value.append(vgg_loss_value)
 | 
				
			||||||
 | 
					    if self._loss_weight.l1 > 0:
 | 
				
			||||||
 | 
					      l1_loss = self._l1_loss(reduction=self._reduction)(img1, img2)
 | 
				
			||||||
 | 
					      l1_loss_value = tf_utils.safe_mean(l1_loss * self._loss_weight.l1)
 | 
				
			||||||
 | 
					      loss_value.append(l1_loss_value)
 | 
				
			||||||
 | 
					    total_loss = tf.math.add_n(loss_value)
 | 
				
			||||||
 | 
					    return total_loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta):
 | 
				
			||||||
 | 
					  """Base class for perceptual loss model."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __init__(
 | 
				
			||||||
 | 
					      self,
 | 
				
			||||||
 | 
					      feature_weight: Optional[Sequence[float]] = None,
 | 
				
			||||||
 | 
					      loss_weight: Optional[PerceptualLossWeight] = None,
 | 
				
			||||||
 | 
					  ):
 | 
				
			||||||
 | 
					    """Instantiates perceptual loss.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      feature_weight: The weight coeffcients of multiple model extracted
 | 
				
			||||||
 | 
					        features used for calculating the perceptual loss.
 | 
				
			||||||
 | 
					      loss_weight: The weight coefficients between `style_loss` and
 | 
				
			||||||
 | 
					        `content_loss`.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    super().__init__()
 | 
				
			||||||
 | 
					    self._loss_op = lambda x, y: tf.math.reduce_mean(tf.abs(x - y))
 | 
				
			||||||
 | 
					    self._loss_style = tf.constant(0.0)
 | 
				
			||||||
 | 
					    self._loss_content = tf.constant(0.0)
 | 
				
			||||||
 | 
					    self._feature_weight = feature_weight
 | 
				
			||||||
 | 
					    self._loss_weight = loss_weight
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __call__(
 | 
				
			||||||
 | 
					      self,
 | 
				
			||||||
 | 
					      img1: tf.Tensor,
 | 
				
			||||||
 | 
					      img2: tf.Tensor,
 | 
				
			||||||
 | 
					  ) -> Mapping[str, tf.Tensor]:
 | 
				
			||||||
 | 
					    """Computes perceptual loss between two images.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      img1: First batch of images. The pixel values should be normalized to [-1,
 | 
				
			||||||
 | 
					        1].
 | 
				
			||||||
 | 
					      img2: Second batch of images. The pixel values should be normalized to
 | 
				
			||||||
 | 
					        [-1, 1].
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      A mapping between loss name and loss tensors.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    x_features = self._compute_features(img1)
 | 
				
			||||||
 | 
					    y_features = self._compute_features(img2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self._loss_weight is None:
 | 
				
			||||||
 | 
					      self._loss_weight = PerceptualLossWeight()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # If the _feature_weight is not initialized, then initialize it as a list of
 | 
				
			||||||
 | 
					    # all the element equals to 1.0.
 | 
				
			||||||
 | 
					    if self._feature_weight is None:
 | 
				
			||||||
 | 
					      self._feature_weight = [1.0] * len(x_features)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # If the length of _feature_weight smallert than the length of the feature,
 | 
				
			||||||
 | 
					    # raise a ValueError. Otherwise, only use the first len(x_features) weight
 | 
				
			||||||
 | 
					    # for computing the loss.
 | 
				
			||||||
 | 
					    if len(self._feature_weight) < len(x_features):
 | 
				
			||||||
 | 
					      raise ValueError(
 | 
				
			||||||
 | 
					          f'Input feature weight length {len(self._feature_weight)} is smaller'
 | 
				
			||||||
 | 
					          f' than feature length {len(x_features)}'
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self._loss_weight.style > 0.0:
 | 
				
			||||||
 | 
					      self._loss_style = tf_utils.safe_mean(
 | 
				
			||||||
 | 
					          self._loss_weight.style
 | 
				
			||||||
 | 
					          * self._get_style_loss(x_feats=x_features, y_feats=y_features)
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					    if self._loss_weight.content > 0.0:
 | 
				
			||||||
 | 
					      self._loss_content = tf_utils.safe_mean(
 | 
				
			||||||
 | 
					          self._loss_weight.content
 | 
				
			||||||
 | 
					          * self._get_content_loss(x_feats=x_features, y_feats=y_features)
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return {'style_loss': self._loss_style, 'content_loss': self._loss_content}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @abc.abstractmethod
 | 
				
			||||||
 | 
					  def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]:
 | 
				
			||||||
 | 
					    """Computes features from the given image tensor.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      img: Image tensor.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      A list of multi-scale feature maps.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _get_content_loss(
 | 
				
			||||||
 | 
					      self, x_feats: Sequence[tf.Tensor], y_feats: Sequence[tf.Tensor]
 | 
				
			||||||
 | 
					  ) -> tf.Tensor:
 | 
				
			||||||
 | 
					    """Gets weighted multi-scale content loss.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      x_feats: Reconstructed face image.
 | 
				
			||||||
 | 
					      y_feats: Target face image.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      A scalar tensor for the content loss.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    content_losses = []
 | 
				
			||||||
 | 
					    for coef, x_feat, y_feat in zip(self._feature_weight, x_feats, y_feats):
 | 
				
			||||||
 | 
					      content_loss = self._loss_op(x_feat, y_feat) * coef
 | 
				
			||||||
 | 
					      content_losses.append(content_loss)
 | 
				
			||||||
 | 
					    return tf.math.reduce_sum(content_losses)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _get_style_loss(
 | 
				
			||||||
 | 
					      self, x_feats: Sequence[tf.Tensor], y_feats: Sequence[tf.Tensor]
 | 
				
			||||||
 | 
					  ) -> tf.Tensor:
 | 
				
			||||||
 | 
					    """Gets weighted multi-scale style loss.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      x_feats: Reconstructed face image.
 | 
				
			||||||
 | 
					      y_feats: Target face image.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      A scalar tensor for the style loss.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    style_losses = []
 | 
				
			||||||
 | 
					    i = 0
 | 
				
			||||||
 | 
					    for coef, x_feat, y_feat in zip(self._feature_weight, x_feats, y_feats):
 | 
				
			||||||
 | 
					      x_feat_g = _compute_gram_matrix(x_feat)
 | 
				
			||||||
 | 
					      y_feat_g = _compute_gram_matrix(y_feat)
 | 
				
			||||||
 | 
					      style_loss = self._loss_op(x_feat_g, y_feat_g) * coef
 | 
				
			||||||
 | 
					      style_losses.append(style_loss)
 | 
				
			||||||
 | 
					      i = i + 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return tf.math.reduce_sum(style_loss)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VGGPerceptualLoss(PerceptualLoss):
 | 
				
			||||||
 | 
					  """Perceptual loss based on VGG19 pretrained on the ImageNet dataset.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Reference:
 | 
				
			||||||
 | 
					  - [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](
 | 
				
			||||||
 | 
					      https://arxiv.org/abs/1603.08155) (ECCV 2016)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Perceptual loss measures high-level perceptual and semantic differences
 | 
				
			||||||
 | 
					  between images.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __init__(
 | 
				
			||||||
 | 
					      self,
 | 
				
			||||||
 | 
					      loss_weight: Optional[PerceptualLossWeight] = None,
 | 
				
			||||||
 | 
					  ):
 | 
				
			||||||
 | 
					    """Initializes image quality loss essentials.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      loss_weight: Loss weight coefficients.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    super().__init__(
 | 
				
			||||||
 | 
					        feature_weight=np.array([0.1, 0.1, 1.0, 1.0, 1.0]),
 | 
				
			||||||
 | 
					        loss_weight=loss_weight,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rgb_mean = tf.constant([0.485, 0.456, 0.406])
 | 
				
			||||||
 | 
					    rgb_std = tf.constant([0.229, 0.224, 0.225])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    self._rgb_mean = tf.reshape(rgb_mean, (1, 1, 1, 3))
 | 
				
			||||||
 | 
					    self._rgb_std = tf.reshape(rgb_std, (1, 1, 1, 3))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_path = file_util.DownloadedFiles(
 | 
				
			||||||
 | 
					        'vgg_feature_extractor',
 | 
				
			||||||
 | 
					        _VGG_IMAGENET_PERCEPTUAL_MODEL_URL,
 | 
				
			||||||
 | 
					        is_folder=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    self._vgg19 = model_util.load_keras_model(model_path.get_path())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]:
 | 
				
			||||||
 | 
					    """Computes VGG19 features."""
 | 
				
			||||||
 | 
					    img = (img + 1) / 2.0
 | 
				
			||||||
 | 
					    norm_img = (img - self._rgb_mean) / self._rgb_std
 | 
				
			||||||
 | 
					    # no grad, as it only serves as a frozen feature extractor.
 | 
				
			||||||
 | 
					    return self._vgg19(norm_img)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _compute_gram_matrix(feature: tf.Tensor) -> tf.Tensor:
 | 
				
			||||||
 | 
					  """Computes gram matrix for the feature map.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Args:
 | 
				
			||||||
 | 
					    feature: [B, H, W, C] feature map.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Returns:
 | 
				
			||||||
 | 
					    [B, C, C] gram matrix.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					  h, w, c = feature.shape[1:].as_list()
 | 
				
			||||||
 | 
					  feat_reshaped = tf.reshape(feature, shape=(-1, h * w, c))
 | 
				
			||||||
 | 
					  feat_gram = tf.matmul(
 | 
				
			||||||
 | 
					      tf.transpose(feat_reshaped, perm=[0, 2, 1]), feat_reshaped
 | 
				
			||||||
 | 
					  )
 | 
				
			||||||
 | 
					  return feat_gram / (c * h * w)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					@ -13,7 +13,9 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
from typing import Optional
 | 
					import tempfile
 | 
				
			||||||
 | 
					from typing import Dict, Optional, Sequence
 | 
				
			||||||
 | 
					from unittest import mock as unittest_mock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from absl.testing import parameterized
 | 
					from absl.testing import parameterized
 | 
				
			||||||
import tensorflow as tf
 | 
					import tensorflow as tf
 | 
				
			||||||
| 
						 | 
					@ -21,7 +23,7 @@ import tensorflow as tf
 | 
				
			||||||
from mediapipe.model_maker.python.core.utils import loss_functions
 | 
					from mediapipe.model_maker.python.core.utils import loss_functions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase):
 | 
					class FocalLossTest(tf.test.TestCase, parameterized.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @parameterized.named_parameters(
 | 
					  @parameterized.named_parameters(
 | 
				
			||||||
      dict(testcase_name='no_sample_weight', sample_weight=None),
 | 
					      dict(testcase_name='no_sample_weight', sample_weight=None),
 | 
				
			||||||
| 
						 | 
					@ -99,5 +101,228 @@ class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase):
 | 
				
			||||||
    self.assertNear(loss, expected_loss, 1e-4)
 | 
					    self.assertNear(loss, expected_loss, 1e-4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MockPerceptualLoss(loss_functions.PerceptualLoss):
 | 
				
			||||||
 | 
					  """A mock class with implementation of abstract methods for testing."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __init__(
 | 
				
			||||||
 | 
					      self,
 | 
				
			||||||
 | 
					      use_mock_loss_op: bool = False,
 | 
				
			||||||
 | 
					      feature_weight: Optional[Sequence[float]] = None,
 | 
				
			||||||
 | 
					      loss_weight: Optional[loss_functions.PerceptualLossWeight] = None,
 | 
				
			||||||
 | 
					  ):
 | 
				
			||||||
 | 
					    super().__init__(feature_weight=feature_weight, loss_weight=loss_weight)
 | 
				
			||||||
 | 
					    if use_mock_loss_op:
 | 
				
			||||||
 | 
					      self._loss_op = lambda x, y: tf.math.reduce_mean(x - y)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]:
 | 
				
			||||||
 | 
					    return [tf.random.normal(shape=(1, 8, 8, 3))] * 5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PerceptualLossTest(tf.test.TestCase, parameterized.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def setUp(self):
 | 
				
			||||||
 | 
					    super().setUp()
 | 
				
			||||||
 | 
					    self._img1 = tf.fill(dims=(8, 8), value=0.2)
 | 
				
			||||||
 | 
					    self._img2 = tf.fill(dims=(8, 8), value=0.8)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_invalid_feature_weight_raise_value_error(self):
 | 
				
			||||||
 | 
					    with self.assertRaisesRegex(
 | 
				
			||||||
 | 
					        ValueError,
 | 
				
			||||||
 | 
					        'Input feature weight length 2 is smaller than feature length 5',
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					      MockPerceptualLoss(feature_weight=[1.0, 2.0])(
 | 
				
			||||||
 | 
					          img1=self._img1, img2=self._img2
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @parameterized.named_parameters(
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name='default_loss_weight_and_loss_op',
 | 
				
			||||||
 | 
					          use_mock_loss_op=False,
 | 
				
			||||||
 | 
					          feature_weight=None,
 | 
				
			||||||
 | 
					          loss_weight=None,
 | 
				
			||||||
 | 
					          loss_values={
 | 
				
			||||||
 | 
					              'style_loss': 0.032839,
 | 
				
			||||||
 | 
					              'content_loss': 5.639870,
 | 
				
			||||||
 | 
					          },
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name='style_loss_weight_is_0_default_loss_op',
 | 
				
			||||||
 | 
					          use_mock_loss_op=False,
 | 
				
			||||||
 | 
					          feature_weight=None,
 | 
				
			||||||
 | 
					          loss_weight=loss_functions.PerceptualLossWeight(style=0),
 | 
				
			||||||
 | 
					          loss_values={
 | 
				
			||||||
 | 
					              'style_loss': 0,
 | 
				
			||||||
 | 
					              'content_loss': 5.639870,
 | 
				
			||||||
 | 
					          },
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name='content_loss_weight_is_0_default_loss_op',
 | 
				
			||||||
 | 
					          use_mock_loss_op=False,
 | 
				
			||||||
 | 
					          feature_weight=None,
 | 
				
			||||||
 | 
					          loss_weight=loss_functions.PerceptualLossWeight(content=0),
 | 
				
			||||||
 | 
					          loss_values={
 | 
				
			||||||
 | 
					              'style_loss': 0.032839,
 | 
				
			||||||
 | 
					              'content_loss': 0,
 | 
				
			||||||
 | 
					          },
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name='customized_loss_weight_default_loss_op',
 | 
				
			||||||
 | 
					          use_mock_loss_op=False,
 | 
				
			||||||
 | 
					          feature_weight=None,
 | 
				
			||||||
 | 
					          loss_weight=loss_functions.PerceptualLossWeight(
 | 
				
			||||||
 | 
					              style=1.0, content=2.0
 | 
				
			||||||
 | 
					          ),
 | 
				
			||||||
 | 
					          loss_values={'style_loss': 0.032839, 'content_loss': 11.279739},
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name=(
 | 
				
			||||||
 | 
					              'customized_feature_weight_and_loss_weight_default_loss_op'
 | 
				
			||||||
 | 
					          ),
 | 
				
			||||||
 | 
					          use_mock_loss_op=False,
 | 
				
			||||||
 | 
					          feature_weight=[1.0, 2.0, 3.0, 4.0, 5.0],
 | 
				
			||||||
 | 
					          loss_weight=loss_functions.PerceptualLossWeight(
 | 
				
			||||||
 | 
					              style=1.0, content=2.0
 | 
				
			||||||
 | 
					          ),
 | 
				
			||||||
 | 
					          loss_values={'style_loss': 0.164193, 'content_loss': 33.839218},
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name='no_loss_change_if_extra_feature_weight_provided',
 | 
				
			||||||
 | 
					          use_mock_loss_op=False,
 | 
				
			||||||
 | 
					          feature_weight=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
 | 
				
			||||||
 | 
					          loss_weight=loss_functions.PerceptualLossWeight(
 | 
				
			||||||
 | 
					              style=1.0, content=2.0
 | 
				
			||||||
 | 
					          ),
 | 
				
			||||||
 | 
					          loss_values={
 | 
				
			||||||
 | 
					              'style_loss': 0.164193,
 | 
				
			||||||
 | 
					              'content_loss': 33.839218,
 | 
				
			||||||
 | 
					          },
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name='customized_loss_weight_custom_loss_op',
 | 
				
			||||||
 | 
					          use_mock_loss_op=True,
 | 
				
			||||||
 | 
					          feature_weight=None,
 | 
				
			||||||
 | 
					          loss_weight=loss_functions.PerceptualLossWeight(
 | 
				
			||||||
 | 
					              style=1.0, content=2.0
 | 
				
			||||||
 | 
					          ),
 | 
				
			||||||
 | 
					          loss_values={'style_loss': 0.000395, 'content_loss': -1.533469},
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					  )
 | 
				
			||||||
 | 
					  def test_weighted_perceptul_loss(
 | 
				
			||||||
 | 
					      self,
 | 
				
			||||||
 | 
					      use_mock_loss_op: bool,
 | 
				
			||||||
 | 
					      feature_weight: Sequence[float],
 | 
				
			||||||
 | 
					      loss_weight: loss_functions.PerceptualLossWeight,
 | 
				
			||||||
 | 
					      loss_values: Dict[str, float],
 | 
				
			||||||
 | 
					  ):
 | 
				
			||||||
 | 
					    perceptual_loss = MockPerceptualLoss(
 | 
				
			||||||
 | 
					        use_mock_loss_op=use_mock_loss_op,
 | 
				
			||||||
 | 
					        feature_weight=feature_weight,
 | 
				
			||||||
 | 
					        loss_weight=loss_weight,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    loss = perceptual_loss(img1=self._img1, img2=self._img2)
 | 
				
			||||||
 | 
					    self.assertEqual(list(loss.keys()), ['style_loss', 'content_loss'])
 | 
				
			||||||
 | 
					    self.assertNear(loss['style_loss'], loss_values['style_loss'], 1e-4)
 | 
				
			||||||
 | 
					    self.assertNear(loss['content_loss'], loss_values['content_loss'], 1e-4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VGGPerceptualLossTest(tf.test.TestCase, parameterized.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def setUp(self):
 | 
				
			||||||
 | 
					    super().setUp()
 | 
				
			||||||
 | 
					    # Mock tempfile.gettempdir() to be unique for each test to avoid race
 | 
				
			||||||
 | 
					    # condition when downloading model since these tests may run in parallel.
 | 
				
			||||||
 | 
					    mock_gettempdir = unittest_mock.patch.object(
 | 
				
			||||||
 | 
					        tempfile,
 | 
				
			||||||
 | 
					        'gettempdir',
 | 
				
			||||||
 | 
					        return_value=self.create_tempdir(),
 | 
				
			||||||
 | 
					        autospec=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    self.mock_gettempdir = mock_gettempdir.start()
 | 
				
			||||||
 | 
					    self.addCleanup(mock_gettempdir.stop)
 | 
				
			||||||
 | 
					    self._img1 = tf.fill(dims=(1, 256, 256, 3), value=0.1)
 | 
				
			||||||
 | 
					    self._img2 = tf.fill(dims=(1, 256, 256, 3), value=0.9)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @parameterized.named_parameters(
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name='default_loss_weight',
 | 
				
			||||||
 | 
					          loss_weight=None,
 | 
				
			||||||
 | 
					          loss_values={
 | 
				
			||||||
 | 
					              'style_loss': 5.8363257e-06,
 | 
				
			||||||
 | 
					              'content_loss': 1.7016045,
 | 
				
			||||||
 | 
					          },
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name='customized_loss_weight',
 | 
				
			||||||
 | 
					          loss_weight=loss_functions.PerceptualLossWeight(
 | 
				
			||||||
 | 
					              style=10.0, content=20.0
 | 
				
			||||||
 | 
					          ),
 | 
				
			||||||
 | 
					          loss_values={
 | 
				
			||||||
 | 
					              'style_loss': 5.8363257e-05,
 | 
				
			||||||
 | 
					              'content_loss': 34.03208,
 | 
				
			||||||
 | 
					          },
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					  )
 | 
				
			||||||
 | 
					  def test_vgg_perceptual_loss(self, loss_weight, loss_values):
 | 
				
			||||||
 | 
					    vgg_loss = loss_functions.VGGPerceptualLoss(loss_weight=loss_weight)
 | 
				
			||||||
 | 
					    loss = vgg_loss(img1=self._img1, img2=self._img2)
 | 
				
			||||||
 | 
					    self.assertEqual(list(loss.keys()), ['style_loss', 'content_loss'])
 | 
				
			||||||
 | 
					    self.assertNear(
 | 
				
			||||||
 | 
					        loss['style_loss'],
 | 
				
			||||||
 | 
					        loss_values['style_loss'],
 | 
				
			||||||
 | 
					        loss_values['style_loss'] / 1e5,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    self.assertNear(
 | 
				
			||||||
 | 
					        loss['content_loss'],
 | 
				
			||||||
 | 
					        loss_values['content_loss'],
 | 
				
			||||||
 | 
					        loss_values['content_loss'] / 1e5,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ImagePerceptualQualityLossTest(tf.test.TestCase, parameterized.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def setUp(self):
 | 
				
			||||||
 | 
					    super().setUp()
 | 
				
			||||||
 | 
					    # Mock tempfile.gettempdir() to be unique for each test to avoid race
 | 
				
			||||||
 | 
					    # condition when downloading model since these tests may run in parallel.
 | 
				
			||||||
 | 
					    mock_gettempdir = unittest_mock.patch.object(
 | 
				
			||||||
 | 
					        tempfile,
 | 
				
			||||||
 | 
					        'gettempdir',
 | 
				
			||||||
 | 
					        return_value=self.create_tempdir(),
 | 
				
			||||||
 | 
					        autospec=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    self.mock_gettempdir = mock_gettempdir.start()
 | 
				
			||||||
 | 
					    self.addCleanup(mock_gettempdir.stop)
 | 
				
			||||||
 | 
					    self._img1 = tf.fill(dims=(1, 256, 256, 3), value=0.1)
 | 
				
			||||||
 | 
					    self._img2 = tf.fill(dims=(1, 256, 256, 3), value=0.9)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @parameterized.named_parameters(
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name='default_loss_weight',
 | 
				
			||||||
 | 
					          loss_weight=None,
 | 
				
			||||||
 | 
					          loss_value=2.501612,
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name='customized_loss_weight_zero_l1',
 | 
				
			||||||
 | 
					          loss_weight=loss_functions.PerceptualLossWeight(
 | 
				
			||||||
 | 
					              l1=0.0, style=10.0, content=20.0
 | 
				
			||||||
 | 
					          ),
 | 
				
			||||||
 | 
					          loss_value=34.032139,
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					      dict(
 | 
				
			||||||
 | 
					          testcase_name='customized_loss_weight_nonzero_l1',
 | 
				
			||||||
 | 
					          loss_weight=loss_functions.PerceptualLossWeight(
 | 
				
			||||||
 | 
					              l1=10.0, style=10.0, content=20.0
 | 
				
			||||||
 | 
					          ),
 | 
				
			||||||
 | 
					          loss_value=42.032139,
 | 
				
			||||||
 | 
					      ),
 | 
				
			||||||
 | 
					  )
 | 
				
			||||||
 | 
					  def test_image_perceptual_quality_loss(self, loss_weight, loss_value):
 | 
				
			||||||
 | 
					    image_quality_loss = loss_functions.ImagePerceptualQualityLoss(
 | 
				
			||||||
 | 
					        loss_weight=loss_weight
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    loss = image_quality_loss(img1=self._img1, img2=self._img2)
 | 
				
			||||||
 | 
					    self.assertNear(loss, loss_value, 1e-4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
  tf.test.main()
 | 
					  tf.test.main()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					@ -16,7 +16,8 @@ from __future__ import absolute_import
 | 
				
			||||||
from __future__ import division
 | 
					from __future__ import division
 | 
				
			||||||
from __future__ import print_function
 | 
					from __future__ import print_function
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import List, Union
 | 
					from typing import Sequence
 | 
				
			||||||
 | 
					from typing import Dict, List, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Dependency imports
 | 
					# Dependency imports
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -94,6 +95,17 @@ def is_same_output(tflite_model: bytearray,
 | 
				
			||||||
  return np.allclose(lite_output, keras_output, atol=atol)
 | 
					  return np.allclose(lite_output, keras_output, atol=atol)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def run_tflite(
 | 
				
			||||||
 | 
					    tflite_filename: str,
 | 
				
			||||||
 | 
					    input_tensors: Union[List[tf.Tensor], Dict[str, tf.Tensor]],
 | 
				
			||||||
 | 
					) -> Union[Sequence[tf.Tensor], tf.Tensor]:
 | 
				
			||||||
 | 
					  """Runs TFLite model inference."""
 | 
				
			||||||
 | 
					  with tf.io.gfile.GFile(tflite_filename, "rb") as f:
 | 
				
			||||||
 | 
					    tflite_model = f.read()
 | 
				
			||||||
 | 
					  lite_runner = model_util.get_lite_runner(tflite_model)
 | 
				
			||||||
 | 
					  return lite_runner.run(input_tensors)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_tflite(keras_model: tf.keras.Model,
 | 
					def test_tflite(keras_model: tf.keras.Model,
 | 
				
			||||||
                tflite_model: bytearray,
 | 
					                tflite_model: bytearray,
 | 
				
			||||||
                size: Union[int, List[int]],
 | 
					                size: Union[int, List[int]],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					# Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
		Reference in New Issue
	
	Block a user