Merge branch 'google:master' into language-detector-python

This commit is contained in:
Kinar R 2023-04-28 09:28:54 +05:30 committed by GitHub
commit 76c8251faf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
950 changed files with 25490 additions and 2085 deletions

34
.github/stale.yml vendored
View File

@ -1,34 +0,0 @@
# Copyright 2021 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
#
# This file was assembled from multiple pieces, whose use is documented
# throughout. Please refer to the TensorFlow dockerfiles documentation
# for more information.
# Number of days of inactivity before an Issue or Pull Request becomes stale
daysUntilStale: 7
# Number of days of inactivity before a stale Issue or Pull Request is closed
daysUntilClose: 7
# Only issues or pull requests with all of these labels are checked if stale. Defaults to `[]` (disabled)
onlyLabels:
- stat:awaiting response
# Comment to post when marking as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Thank you.
# Comment to post when removing the stale label. Set to `false` to disable
unmarkComment: false
closeComment: >
Closing as stale. Please reopen if you'd like to work on this further.

66
.github/workflows/stale.yaml vendored Normal file
View File

@ -0,0 +1,66 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# This workflow alerts and then closes the stale issues/PRs after specific time
# You can adjust the behavior by modifying this file.
# For more information, see:
# https://github.com/actions/stale
name: 'Close stale issues and PRs'
"on":
schedule:
- cron: "30 1 * * *"
permissions:
contents: read
issues: write
pull-requests: write
jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: 'actions/stale@v7'
with:
# Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale.
exempt-issue-labels: 'override-stale'
# Comma separated list of labels that can be assigned to PRs to exclude them from being marked as stale.
exempt-pr-labels: "override-stale"
# Limit the No. of API calls in one run default value is 30.
operations-per-run: 500
# Prevent to remove stale label when PRs or issues are updated.
remove-stale-when-updated: false
# comment on issue if not active for more then 7 days.
stale-issue-message: 'This issue has been marked stale because it has no recent activity since 7 days. It will be closed if no further activity occurs. Thank you.'
# comment on PR if not active for more then 14 days.
stale-pr-message: 'This PR has been marked stale because it has no recent activity since 14 days. It will be closed if no further activity occurs. Thank you.'
# comment on issue if stale for more then 7 days.
close-issue-message: This issue was closed due to lack of activity after being marked stale for past 7 days.
# comment on PR if stale for more then 14 days.
close-pr-message: This PR was closed due to lack of activity after being marked stale for past 14 days.
# Number of days of inactivity before an Issue Request becomes stale
days-before-issue-stale: 7
# Number of days of inactivity before a stale Issue is closed
days-before-issue-close: 7
# reason for closed the issue default value is not_planned
close-issue-reason: completed
# Number of days of inactivity before a stale PR is closed
days-before-pr-close: 14
# Number of days of inactivity before an PR Request becomes stale
days-before-pr-stale: 14
# Check for label to stale or close the issue/PR
any-of-labels: 'stat:awaiting response'
# override stale to stalled for PR
stale-pr-label: 'stale'
# override stale to stalled for Issue
stale-issue-label: "stale"

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -433,9 +433,9 @@ absl::Status SpectrogramCalculator::ProcessVectorToOutput(
absl::Status SpectrogramCalculator::ProcessVector(const Matrix& input_stream,
CalculatorContext* cc) {
switch (output_type_) {
// These blocks deliberately ignore clang-format to preserve the
// "silhouette" of the different cases.
// clang-format off
// These blocks deliberately ignore clang-format to preserve the
// "silhouette" of the different cases.
// clang-format off
case SpectrogramCalculatorOptions::COMPLEX: {
return ProcessVectorToOutput(
input_stream,

View File

@ -1,4 +1,4 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2022 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2022 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2022 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2022 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2022 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -18,6 +18,9 @@ package mediapipe;
import "mediapipe/framework/calculator.proto";
option java_package = "com.google.mediapipe.calculator.proto";
option java_outer_classname = "LogicCalculatorOptionsProto";
message LogicCalculatorOptions {
extend CalculatorOptions {
optional LogicCalculatorOptions ext = 338731246;

View File

@ -467,6 +467,11 @@ class SideFallbackT : public Base {
// CalculatorContext (e.g. kOut(cc)), and provides a type-safe interface to
// OutputStreamShard. Like that class, this class will not be usually named in
// calculator code, but used as a temporary object (e.g. kOut(cc).Send(...)).
//
// If not connected (!IsConnected()) SetNextTimestampBound is safe to call and
// does nothing.
// All the sub-classes that define Send should implement it to be safe to to
// call if not connected and do nothing in such case.
class OutputShardAccessBase {
public:
OutputShardAccessBase(const CalculatorContext& cc, OutputStreamShard* output)

View File

@ -1,4 +1,4 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -60,7 +60,7 @@ std::string GetUnusedSidePacketName(
}
std::string candidate = input_side_packet_name_base;
int iter = 2;
while (mediapipe::ContainsKey(input_side_packets, candidate)) {
while (input_side_packets.contains(candidate)) {
candidate = absl::StrCat(input_side_packet_name_base, "_",
absl::StrFormat("%02d", iter));
++iter;

View File

@ -101,7 +101,7 @@ void TestSuccessTagMap(const std::vector<std::string>& tag_index_names,
EXPECT_EQ(tags.size(), tag_map->Mapping().size())
<< "Parameters: in " << tag_map->DebugString();
for (int i = 0; i < tags.size(); ++i) {
EXPECT_TRUE(mediapipe::ContainsKey(tag_map->Mapping(), tags[i]))
EXPECT_TRUE(tag_map->Mapping().contains(tags[i]))
<< "Parameters: Trying to find \"" << tags[i] << "\" in\n"
<< tag_map->DebugString();
}

View File

@ -1,4 +1,4 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -14,8 +14,10 @@
package com.google.mediapipe.components;
import javax.annotation.Nullable;
/** Lightweight abstraction for an object that can produce audio data. */
public interface AudioDataProducer {
/** Set the consumer that receives the audio data from this producer. */
void setAudioConsumer(AudioDataConsumer consumer);
void setAudioConsumer(@Nullable AudioDataConsumer consumer);
}

View File

@ -71,7 +71,10 @@ android_library(
"AudioDataProducer.java",
],
visibility = ["//visibility:public"],
deps = ["@maven//:com_google_guava_guava"],
deps = [
"@maven//:com_google_code_findbugs_jsr305",
"@maven//:com_google_guava_guava",
],
)
# MicrophoneHelper that provides access to audio data from a microphone

View File

@ -231,7 +231,7 @@ public class GlSurfaceViewRenderer implements GLSurfaceView.Renderer {
}
/** Returns the texture left, right, bottom, and top visible boundaries. */
protected float[] calculateTextureBoundary() {
public float[] calculateTextureBoundary() {
// TODO: compute scale from surfaceTexture size.
float scaleWidth = frameWidth > 0 ? (float) surfaceWidth / (float) frameWidth : 1.0f;
float scaleHeight = frameHeight > 0 ? (float) surfaceHeight / (float) frameHeight : 1.0f;

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2022 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2022 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2022 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2022 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2022 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2022 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2022 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2022 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2022 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2022 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2022 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2022 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2022 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
/* Copyright 2022 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2019-2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2019-2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -67,11 +67,18 @@ py_library(
name = "loss_functions",
srcs = ["loss_functions.py"],
srcs_version = "PY3",
deps = [
":file_util",
":model_util",
],
)
py_test(
name = "loss_functions_test",
srcs = ["loss_functions_test.py"],
tags = [
"requires-net:external",
],
deps = [":loss_functions"],
)

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,10 +13,21 @@
# limitations under the License.
"""Loss function utility library."""
from typing import Optional, Sequence
import abc
from typing import Mapping, Sequence
import dataclasses
from typing import Any, Optional
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import file_util
from mediapipe.model_maker.python.core.utils import model_util
from official.modeling import tf_utils
_VGG_IMAGENET_PERCEPTUAL_MODEL_URL = 'https://storage.googleapis.com/mediapipe-assets/vgg_feature_extractor.tar.gz'
class FocalLoss(tf.keras.losses.Loss):
"""Implementation of focal loss (https://arxiv.org/pdf/1708.02002.pdf).
@ -45,7 +56,6 @@ class FocalLoss(tf.keras.losses.Loss):
```python
model.compile(optimizer='sgd', loss=FocalLoss(gamma))
```
"""
def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
@ -103,3 +113,252 @@ class FocalLoss(tf.keras.losses.Loss):
# By default, this function uses "sum_over_batch_size" reduction for the
# loss per batch.
return tf.reduce_sum(losses) / batch_size
@dataclasses.dataclass
class PerceptualLossWeight:
"""The weight for each perceptual loss.
Attributes:
l1: weight for L1 loss.
content: weight for content loss.
style: weight for style loss.
"""
l1: float = 1.0
content: float = 1.0
style: float = 1.0
class ImagePerceptualQualityLoss(tf.keras.losses.Loss):
"""Image perceptual quality loss.
It obtains a weighted loss between the VGGPerceptualLoss and L1 loss.
"""
def __init__(
self,
loss_weight: Optional[PerceptualLossWeight] = None,
reduction: tf.keras.losses.Reduction = tf.keras.losses.Reduction.NONE,
):
"""Initializes ImagePerceptualQualityLoss."""
self._loss_weight = loss_weight
self._losses = {}
self._vgg_loss = VGGPerceptualLoss(self._loss_weight)
self._reduction = reduction
def _l1_loss(
self,
reduction: tf.keras.losses.Reduction = tf.keras.losses.Reduction.NONE,
) -> Any:
"""Calculates L1 loss."""
return tf.keras.losses.MeanAbsoluteError(reduction)
def __call__(
self,
img1: tf.Tensor,
img2: tf.Tensor,
) -> tf.Tensor:
"""Computes image perceptual quality loss."""
loss_value = []
if self._loss_weight is None:
self._loss_weight = PerceptualLossWeight()
if self._loss_weight.content > 0 or self._loss_weight.style > 0:
vgg_loss = self._vgg_loss(img1, img2)
vgg_loss_value = tf.math.add_n(vgg_loss.values())
loss_value.append(vgg_loss_value)
if self._loss_weight.l1 > 0:
l1_loss = self._l1_loss(reduction=self._reduction)(img1, img2)
l1_loss_value = tf_utils.safe_mean(l1_loss * self._loss_weight.l1)
loss_value.append(l1_loss_value)
total_loss = tf.math.add_n(loss_value)
return total_loss
class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta):
"""Base class for perceptual loss model."""
def __init__(
self,
feature_weight: Optional[Sequence[float]] = None,
loss_weight: Optional[PerceptualLossWeight] = None,
):
"""Instantiates perceptual loss.
Args:
feature_weight: The weight coeffcients of multiple model extracted
features used for calculating the perceptual loss.
loss_weight: The weight coefficients between `style_loss` and
`content_loss`.
"""
super().__init__()
self._loss_op = lambda x, y: tf.math.reduce_mean(tf.abs(x - y))
self._loss_style = tf.constant(0.0)
self._loss_content = tf.constant(0.0)
self._feature_weight = feature_weight
self._loss_weight = loss_weight
def __call__(
self,
img1: tf.Tensor,
img2: tf.Tensor,
) -> Mapping[str, tf.Tensor]:
"""Computes perceptual loss between two images.
Args:
img1: First batch of images. The pixel values should be normalized to [-1,
1].
img2: Second batch of images. The pixel values should be normalized to
[-1, 1].
Returns:
A mapping between loss name and loss tensors.
"""
x_features = self._compute_features(img1)
y_features = self._compute_features(img2)
if self._loss_weight is None:
self._loss_weight = PerceptualLossWeight()
# If the _feature_weight is not initialized, then initialize it as a list of
# all the element equals to 1.0.
if self._feature_weight is None:
self._feature_weight = [1.0] * len(x_features)
# If the length of _feature_weight smallert than the length of the feature,
# raise a ValueError. Otherwise, only use the first len(x_features) weight
# for computing the loss.
if len(self._feature_weight) < len(x_features):
raise ValueError(
f'Input feature weight length {len(self._feature_weight)} is smaller'
f' than feature length {len(x_features)}'
)
if self._loss_weight.style > 0.0:
self._loss_style = tf_utils.safe_mean(
self._loss_weight.style
* self._get_style_loss(x_feats=x_features, y_feats=y_features)
)
if self._loss_weight.content > 0.0:
self._loss_content = tf_utils.safe_mean(
self._loss_weight.content
* self._get_content_loss(x_feats=x_features, y_feats=y_features)
)
return {'style_loss': self._loss_style, 'content_loss': self._loss_content}
@abc.abstractmethod
def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]:
"""Computes features from the given image tensor.
Args:
img: Image tensor.
Returns:
A list of multi-scale feature maps.
"""
def _get_content_loss(
self, x_feats: Sequence[tf.Tensor], y_feats: Sequence[tf.Tensor]
) -> tf.Tensor:
"""Gets weighted multi-scale content loss.
Args:
x_feats: Reconstructed face image.
y_feats: Target face image.
Returns:
A scalar tensor for the content loss.
"""
content_losses = []
for coef, x_feat, y_feat in zip(self._feature_weight, x_feats, y_feats):
content_loss = self._loss_op(x_feat, y_feat) * coef
content_losses.append(content_loss)
return tf.math.reduce_sum(content_losses)
def _get_style_loss(
self, x_feats: Sequence[tf.Tensor], y_feats: Sequence[tf.Tensor]
) -> tf.Tensor:
"""Gets weighted multi-scale style loss.
Args:
x_feats: Reconstructed face image.
y_feats: Target face image.
Returns:
A scalar tensor for the style loss.
"""
style_losses = []
i = 0
for coef, x_feat, y_feat in zip(self._feature_weight, x_feats, y_feats):
x_feat_g = _compute_gram_matrix(x_feat)
y_feat_g = _compute_gram_matrix(y_feat)
style_loss = self._loss_op(x_feat_g, y_feat_g) * coef
style_losses.append(style_loss)
i = i + 1
return tf.math.reduce_sum(style_loss)
class VGGPerceptualLoss(PerceptualLoss):
"""Perceptual loss based on VGG19 pretrained on the ImageNet dataset.
Reference:
- [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](
https://arxiv.org/abs/1603.08155) (ECCV 2016)
Perceptual loss measures high-level perceptual and semantic differences
between images.
"""
def __init__(
self,
loss_weight: Optional[PerceptualLossWeight] = None,
):
"""Initializes image quality loss essentials.
Args:
loss_weight: Loss weight coefficients.
"""
super().__init__(
feature_weight=np.array([0.1, 0.1, 1.0, 1.0, 1.0]),
loss_weight=loss_weight,
)
rgb_mean = tf.constant([0.485, 0.456, 0.406])
rgb_std = tf.constant([0.229, 0.224, 0.225])
self._rgb_mean = tf.reshape(rgb_mean, (1, 1, 1, 3))
self._rgb_std = tf.reshape(rgb_std, (1, 1, 1, 3))
model_path = file_util.DownloadedFiles(
'vgg_feature_extractor',
_VGG_IMAGENET_PERCEPTUAL_MODEL_URL,
is_folder=True,
)
self._vgg19 = model_util.load_keras_model(model_path.get_path())
def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]:
"""Computes VGG19 features."""
img = (img + 1) / 2.0
norm_img = (img - self._rgb_mean) / self._rgb_std
# no grad, as it only serves as a frozen feature extractor.
return self._vgg19(norm_img)
def _compute_gram_matrix(feature: tf.Tensor) -> tf.Tensor:
"""Computes gram matrix for the feature map.
Args:
feature: [B, H, W, C] feature map.
Returns:
[B, C, C] gram matrix.
"""
h, w, c = feature.shape[1:].as_list()
feat_reshaped = tf.reshape(feature, shape=(-1, h * w, c))
feat_gram = tf.matmul(
tf.transpose(feat_reshaped, perm=[0, 2, 1]), feat_reshaped
)
return feat_gram / (c * h * w)

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,7 +13,9 @@
# limitations under the License.
import math
from typing import Optional
import tempfile
from typing import Dict, Optional, Sequence
from unittest import mock as unittest_mock
from absl.testing import parameterized
import tensorflow as tf
@ -21,7 +23,7 @@ import tensorflow as tf
from mediapipe.model_maker.python.core.utils import loss_functions
class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase):
class FocalLossTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
dict(testcase_name='no_sample_weight', sample_weight=None),
@ -99,5 +101,228 @@ class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase):
self.assertNear(loss, expected_loss, 1e-4)
class MockPerceptualLoss(loss_functions.PerceptualLoss):
"""A mock class with implementation of abstract methods for testing."""
def __init__(
self,
use_mock_loss_op: bool = False,
feature_weight: Optional[Sequence[float]] = None,
loss_weight: Optional[loss_functions.PerceptualLossWeight] = None,
):
super().__init__(feature_weight=feature_weight, loss_weight=loss_weight)
if use_mock_loss_op:
self._loss_op = lambda x, y: tf.math.reduce_mean(x - y)
def _compute_features(self, img: tf.Tensor) -> Sequence[tf.Tensor]:
return [tf.random.normal(shape=(1, 8, 8, 3))] * 5
class PerceptualLossTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self._img1 = tf.fill(dims=(8, 8), value=0.2)
self._img2 = tf.fill(dims=(8, 8), value=0.8)
def test_invalid_feature_weight_raise_value_error(self):
with self.assertRaisesRegex(
ValueError,
'Input feature weight length 2 is smaller than feature length 5',
):
MockPerceptualLoss(feature_weight=[1.0, 2.0])(
img1=self._img1, img2=self._img2
)
@parameterized.named_parameters(
dict(
testcase_name='default_loss_weight_and_loss_op',
use_mock_loss_op=False,
feature_weight=None,
loss_weight=None,
loss_values={
'style_loss': 0.032839,
'content_loss': 5.639870,
},
),
dict(
testcase_name='style_loss_weight_is_0_default_loss_op',
use_mock_loss_op=False,
feature_weight=None,
loss_weight=loss_functions.PerceptualLossWeight(style=0),
loss_values={
'style_loss': 0,
'content_loss': 5.639870,
},
),
dict(
testcase_name='content_loss_weight_is_0_default_loss_op',
use_mock_loss_op=False,
feature_weight=None,
loss_weight=loss_functions.PerceptualLossWeight(content=0),
loss_values={
'style_loss': 0.032839,
'content_loss': 0,
},
),
dict(
testcase_name='customized_loss_weight_default_loss_op',
use_mock_loss_op=False,
feature_weight=None,
loss_weight=loss_functions.PerceptualLossWeight(
style=1.0, content=2.0
),
loss_values={'style_loss': 0.032839, 'content_loss': 11.279739},
),
dict(
testcase_name=(
'customized_feature_weight_and_loss_weight_default_loss_op'
),
use_mock_loss_op=False,
feature_weight=[1.0, 2.0, 3.0, 4.0, 5.0],
loss_weight=loss_functions.PerceptualLossWeight(
style=1.0, content=2.0
),
loss_values={'style_loss': 0.164193, 'content_loss': 33.839218},
),
dict(
testcase_name='no_loss_change_if_extra_feature_weight_provided',
use_mock_loss_op=False,
feature_weight=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
loss_weight=loss_functions.PerceptualLossWeight(
style=1.0, content=2.0
),
loss_values={
'style_loss': 0.164193,
'content_loss': 33.839218,
},
),
dict(
testcase_name='customized_loss_weight_custom_loss_op',
use_mock_loss_op=True,
feature_weight=None,
loss_weight=loss_functions.PerceptualLossWeight(
style=1.0, content=2.0
),
loss_values={'style_loss': 0.000395, 'content_loss': -1.533469},
),
)
def test_weighted_perceptul_loss(
self,
use_mock_loss_op: bool,
feature_weight: Sequence[float],
loss_weight: loss_functions.PerceptualLossWeight,
loss_values: Dict[str, float],
):
perceptual_loss = MockPerceptualLoss(
use_mock_loss_op=use_mock_loss_op,
feature_weight=feature_weight,
loss_weight=loss_weight,
)
loss = perceptual_loss(img1=self._img1, img2=self._img2)
self.assertEqual(list(loss.keys()), ['style_loss', 'content_loss'])
self.assertNear(loss['style_loss'], loss_values['style_loss'], 1e-4)
self.assertNear(loss['content_loss'], loss_values['content_loss'], 1e-4)
class VGGPerceptualLossTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
# Mock tempfile.gettempdir() to be unique for each test to avoid race
# condition when downloading model since these tests may run in parallel.
mock_gettempdir = unittest_mock.patch.object(
tempfile,
'gettempdir',
return_value=self.create_tempdir(),
autospec=True,
)
self.mock_gettempdir = mock_gettempdir.start()
self.addCleanup(mock_gettempdir.stop)
self._img1 = tf.fill(dims=(1, 256, 256, 3), value=0.1)
self._img2 = tf.fill(dims=(1, 256, 256, 3), value=0.9)
@parameterized.named_parameters(
dict(
testcase_name='default_loss_weight',
loss_weight=None,
loss_values={
'style_loss': 5.8363257e-06,
'content_loss': 1.7016045,
},
),
dict(
testcase_name='customized_loss_weight',
loss_weight=loss_functions.PerceptualLossWeight(
style=10.0, content=20.0
),
loss_values={
'style_loss': 5.8363257e-05,
'content_loss': 34.03208,
},
),
)
def test_vgg_perceptual_loss(self, loss_weight, loss_values):
vgg_loss = loss_functions.VGGPerceptualLoss(loss_weight=loss_weight)
loss = vgg_loss(img1=self._img1, img2=self._img2)
self.assertEqual(list(loss.keys()), ['style_loss', 'content_loss'])
self.assertNear(
loss['style_loss'],
loss_values['style_loss'],
loss_values['style_loss'] / 1e5,
)
self.assertNear(
loss['content_loss'],
loss_values['content_loss'],
loss_values['content_loss'] / 1e5,
)
class ImagePerceptualQualityLossTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
# Mock tempfile.gettempdir() to be unique for each test to avoid race
# condition when downloading model since these tests may run in parallel.
mock_gettempdir = unittest_mock.patch.object(
tempfile,
'gettempdir',
return_value=self.create_tempdir(),
autospec=True,
)
self.mock_gettempdir = mock_gettempdir.start()
self.addCleanup(mock_gettempdir.stop)
self._img1 = tf.fill(dims=(1, 256, 256, 3), value=0.1)
self._img2 = tf.fill(dims=(1, 256, 256, 3), value=0.9)
@parameterized.named_parameters(
dict(
testcase_name='default_loss_weight',
loss_weight=None,
loss_value=2.501612,
),
dict(
testcase_name='customized_loss_weight_zero_l1',
loss_weight=loss_functions.PerceptualLossWeight(
l1=0.0, style=10.0, content=20.0
),
loss_value=34.032139,
),
dict(
testcase_name='customized_loss_weight_nonzero_l1',
loss_weight=loss_functions.PerceptualLossWeight(
l1=10.0, style=10.0, content=20.0
),
loss_value=42.032139,
),
)
def test_image_perceptual_quality_loss(self, loss_weight, loss_value):
image_quality_loss = loss_functions.ImagePerceptualQualityLoss(
loss_weight=loss_weight
)
loss = image_quality_loss(img1=self._img1, img2=self._img2)
self.assertNear(loss, loss_value, 1e-4)
if __name__ == '__main__':
tf.test.main()

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

Some files were not shown because too many files have changed in this diff Show More