Merge branch 'google:master' into image-classification-python-impl

This commit is contained in:
Kinar R 2022-10-05 17:14:25 +05:30 committed by GitHub
commit 58091aa2c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
62 changed files with 2803 additions and 105 deletions

View File

@ -53,7 +53,7 @@ RUN pip3 install wheel
RUN pip3 install future
RUN pip3 install absl-py numpy opencv-contrib-python protobuf==3.20.1
RUN pip3 install six==1.14.0
RUN pip3 install tensorflow==2.2.0
RUN pip3 install tensorflow
RUN pip3 install tf_slim
RUN ln -s /usr/bin/python3 /usr/bin/python

View File

@ -18,7 +18,7 @@ import android.content.ClipDescription;
import android.content.Context;
import android.net.Uri;
import android.os.Bundle;
import androidx.appcompat.widget.AppCompatEditText;
import android.support.v7.widget.AppCompatEditText;
import android.util.AttributeSet;
import android.util.Log;
import android.view.inputmethod.EditorInfo;

View File

@ -371,7 +371,7 @@ void* Tensor::MapAhwbToCpuRead() const {
if ((valid_ & kValidOpenGlBuffer) && ssbo_written_ == -1) {
// EGLSync is failed. Use another synchronization method.
// TODO: Use tflite::gpu::GlBufferSync and GlActiveSync.
glFinish();
gl_context_->Run([]() { glFinish(); });
} else if (valid_ & kValidAHardwareBuffer) {
CHECK(ahwb_written_) << "Ahwb-to-Cpu synchronization requires the "
"completion function to be set";

View File

@ -89,10 +89,6 @@ def mediapipe_aar(
calculators = calculators,
)
_mediapipe_proto(
name = name + "_proto",
)
native.genrule(
name = name + "_aar_manifest_generator",
outs = ["AndroidManifest.xml"],
@ -115,19 +111,10 @@ EOF
"//mediapipe/java/com/google/mediapipe/components:java_src",
"//mediapipe/java/com/google/mediapipe/framework:java_src",
"//mediapipe/java/com/google/mediapipe/glutil:java_src",
"com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
"com/google/mediapipe/formats/proto/ClassificationProto.java",
"com/google/mediapipe/formats/proto/DetectionProto.java",
"com/google/mediapipe/formats/proto/LandmarkProto.java",
"com/google/mediapipe/formats/proto/LocationDataProto.java",
"com/google/mediapipe/proto/CalculatorProto.java",
] +
] + mediapipe_java_proto_srcs() +
select({
"//conditions:default": [],
"enable_stats_logging": [
"com/google/mediapipe/proto/MediaPipeLoggingProto.java",
"com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
],
"enable_stats_logging": mediapipe_logging_java_proto_srcs(),
}),
manifest = "AndroidManifest.xml",
proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"],
@ -179,93 +166,6 @@ EOF
_aar_with_jni(name, name + "_android_lib")
def _mediapipe_proto(name):
"""Generates MediaPipe java proto libraries.
Args:
name: the name of the target.
"""
_proto_java_src_generator(
name = "mediapipe_log_extension_proto",
proto_src = "mediapipe/util/analytics/mediapipe_log_extension.proto",
java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java",
srcs = ["//mediapipe/util/analytics:protos_src"],
)
_proto_java_src_generator(
name = "mediapipe_logging_enums_proto",
proto_src = "mediapipe/util/analytics/mediapipe_logging_enums.proto",
java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
srcs = ["//mediapipe/util/analytics:protos_src"],
)
_proto_java_src_generator(
name = "calculator_proto",
proto_src = "mediapipe/framework/calculator.proto",
java_lite_out = "com/google/mediapipe/proto/CalculatorProto.java",
srcs = ["//mediapipe/framework:protos_src"],
)
_proto_java_src_generator(
name = "landmark_proto",
proto_src = "mediapipe/framework/formats/landmark.proto",
java_lite_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
srcs = ["//mediapipe/framework/formats:protos_src"],
)
_proto_java_src_generator(
name = "rasterization_proto",
proto_src = "mediapipe/framework/formats/annotation/rasterization.proto",
java_lite_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
srcs = ["//mediapipe/framework/formats/annotation:protos_src"],
)
_proto_java_src_generator(
name = "location_data_proto",
proto_src = "mediapipe/framework/formats/location_data.proto",
java_lite_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
srcs = [
"//mediapipe/framework/formats:protos_src",
"//mediapipe/framework/formats/annotation:protos_src",
],
)
_proto_java_src_generator(
name = "detection_proto",
proto_src = "mediapipe/framework/formats/detection.proto",
java_lite_out = "com/google/mediapipe/formats/proto/DetectionProto.java",
srcs = [
"//mediapipe/framework/formats:protos_src",
"//mediapipe/framework/formats/annotation:protos_src",
],
)
_proto_java_src_generator(
name = "classification_proto",
proto_src = "mediapipe/framework/formats/classification.proto",
java_lite_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
srcs = [
"//mediapipe/framework/formats:protos_src",
],
)
def _proto_java_src_generator(name, proto_src, java_lite_out, srcs = []):
native.genrule(
name = name + "_proto_java_src_generator",
srcs = srcs + [
"@com_google_protobuf//:lite_well_known_protos",
],
outs = [java_lite_out],
cmd = "$(location @com_google_protobuf//:protoc) " +
"--proto_path=. --proto_path=$(GENDIR) " +
"--proto_path=$$(pwd)/external/com_google_protobuf/src " +
"--java_out=lite:$(GENDIR) " + proto_src + " && " +
"mv $(GENDIR)/" + java_lite_out + " $$(dirname $(location " + java_lite_out + "))",
tools = [
"@com_google_protobuf//:protoc",
],
)
def _mediapipe_jni(name, gen_libmediapipe, calculators = []):
"""Generates MediaPipe jni library.
@ -345,3 +245,93 @@ cp -r lib jni
zip -r $$origdir/$(location :{}.aar) jni/*/*.so
""".format(android_library, name, name, name, name),
)
def mediapipe_java_proto_src_extractor(target, src_out, name = ""):
"""Extracts the generated MediaPipe java proto source code from the target.
Args:
target: The java proto lite target to be built and extracted.
src_out: The output java proto src code path.
name: The optional bazel target name.
Returns:
The output java proto src code path.
"""
if not name:
name = target.split(":")[-1] + "_proto_java_src_extractor"
src_jar = target.replace("_java_proto_lite", "_proto-lite-src.jar").replace(":", "/").replace("//", "")
native.genrule(
name = name + "_proto_java_src_extractor",
srcs = [target],
outs = [src_out],
cmd = "unzip $(GENDIR)/" + src_jar + " -d $(GENDIR) && mv $(GENDIR)/" +
src_out + " $$(dirname $(location " + src_out + "))",
)
return src_out
def mediapipe_java_proto_srcs(name = ""):
"""Extracts the generated MediaPipe framework java proto source code.
Args:
name: The optional bazel target name.
Returns:
The list of the extrated MediaPipe java proto source code.
"""
proto_src_list = []
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:calculator_java_proto_lite",
src_out = "com/google/mediapipe/proto/CalculatorProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:landmark_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite",
src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:location_data_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:detection_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/DetectionProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:classification_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
))
return proto_src_list
def mediapipe_logging_java_proto_srcs(name = ""):
"""Extracts the generated logging-related MediaPipe java proto source code.
Args:
name: The optional bazel target name.
Returns:
The list of the extrated MediaPipe logging-related java proto source code.
"""
proto_src_list = []
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util/analytics:mediapipe_log_extension_java_proto_lite",
src_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util/analytics:mediapipe_logging_enums_java_proto_lite",
src_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
))
return proto_src_list

View File

@ -0,0 +1,22 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//visibility:public"])
package_group(
name = "internal",
packages = [
"//mediapipe/model_maker/...",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,22 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//visibility:public"])
package_group(
name = "internal",
packages = [
"//mediapipe/model_maker/...",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,19 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
licenses(["notice"])

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,68 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
licenses(["notice"])
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
py_library(
name = "data_util",
srcs = ["data_util.py"],
srcs_version = "PY3",
)
py_test(
name = "data_util_test",
srcs = ["data_util_test.py"],
data = ["//mediapipe/model_maker/python/core/data/testdata"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":data_util"],
)
py_library(
name = "dataset",
srcs = ["dataset.py"],
srcs_version = "PY3",
)
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":dataset",
"//mediapipe/model_maker/python/core/utils:test_util",
],
)
py_library(
name = "classification_dataset",
srcs = ["classification_dataset.py"],
srcs_version = "PY3",
deps = [":dataset"],
)
py_test(
name = "classification_dataset_test",
srcs = ["classification_dataset_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":classification_dataset"],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,47 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common classification dataset library."""
from typing import Any, Tuple
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset as ds
class ClassificationDataset(ds.Dataset):
"""DataLoader for classification models."""
def __init__(self, dataset: tf.data.Dataset, size: int, index_to_label: Any):
super().__init__(dataset, size)
self.index_to_label = index_to_label
@property
def num_classes(self: ds._DatasetT) -> int:
return len(self.index_to_label)
def split(self: ds._DatasetT,
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
"""Splits dataset into two sub-datasets with the given fraction.
Primarily used for splitting the data set into training and testing sets.
Args:
fraction: float, demonstrates the fraction of the first returned
subdataset in the original data.
Returns:
The splitted two sub datasets.
"""
return self._split(fraction, self.index_to_label)

View File

@ -0,0 +1,68 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Dependency imports
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset
class ClassificationDataLoaderTest(tf.test.TestCase):
def test_split(self):
class MagicClassificationDataLoader(
classification_dataset.ClassificationDataset):
def __init__(self, dataset, size, index_to_label, value):
super(MagicClassificationDataLoader,
self).__init__(dataset, size, index_to_label)
self.value = value
def split(self, fraction):
return self._split(fraction, self.index_to_label, self.value)
# Some dummy inputs.
magic_value = 42
num_classes = 2
index_to_label = (False, True)
# Create data loader from sample data.
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = MagicClassificationDataLoader(ds, len(ds), index_to_label,
magic_value)
# Train/Test data split.
fraction = .25
train_data, test_data = data.split(fraction)
# `split` should return instances of child DataLoader.
self.assertIsInstance(train_data, MagicClassificationDataLoader)
self.assertIsInstance(test_data, MagicClassificationDataLoader)
# Make sure number of entries are right.
self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data))
self.assertLen(train_data, fraction * len(ds))
self.assertLen(test_data, len(ds) - len(train_data))
# Make sure attributes propagated correctly.
self.assertEqual(train_data.num_classes, num_classes)
self.assertEqual(test_data.index_to_label, index_to_label)
self.assertEqual(train_data.value, magic_value)
self.assertEqual(test_data.value, magic_value)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,35 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data utility library."""
import cv2
import numpy as np
import tensorflow as tf
def load_image(path: str) -> np.ndarray:
"""Loads an image as an RGB numpy array.
Args:
path: input image file absolute path.
Returns:
An RGB image in numpy.ndarray.
"""
tf.compat.v1.logging.info('Loading RGB image %s', path)
# TODO Replace the OpenCV image load and conversion library by
# MediaPipe image utility library once it is ready.
image = cv2.imread(path)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

View File

@ -0,0 +1,44 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# Dependency imports
from absl import flags
import tensorflow as tf
from mediapipe.model_maker.python.core.data import data_util
_WORKSPACE = "mediapipe"
_TEST_DATA_DIR = os.path.join(
_WORKSPACE, 'mediapipe/model_maker/python/core/data/testdata')
FLAGS = flags.FLAGS
class DataUtilTest(tf.test.TestCase):
def test_load_rgb_image(self):
image_path = os.path.join(FLAGS.test_srcdir, _TEST_DATA_DIR, 'test.jpg')
image_data = data_util.load_image(image_path)
self.assertEqual(image_data.shape, (5184, 3456, 3))
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,164 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common dataset for model training and evaluation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from typing import Callable, Optional, Tuple, TypeVar
# Dependency imports
import tensorflow as tf
_DatasetT = TypeVar('_DatasetT', bound='Dataset')
class Dataset(object):
"""A generic dataset class for loading model training and evaluation dataset.
For each ML task, such as image classification, text classification etc., a
subclass can be derived from this class to provide task-specific data loading
utilities.
"""
def __init__(self, tf_dataset: tf.data.Dataset, size: Optional[int] = None):
"""Initializes Dataset class.
To build dataset from raw data, consider using the task specific utilities,
e.g. from_folder().
Args:
tf_dataset: A tf.data.Dataset object that contains a potentially large set
of elements, where each element is a pair of (input_data, target). The
`input_data` means the raw input data, like an image, a text etc., while
the `target` means the ground truth of the raw input data, e.g. the
classification label of the image etc.
size: The size of the dataset. tf.data.Dataset donesn't support a function
to get the length directly since it's lazy-loaded and may be infinite.
"""
self._dataset = tf_dataset
self._size = size
@property
def size(self) -> Optional[int]:
"""Returns the size of the dataset.
Note that this function may return None becuase the exact size of the
dataset isn't a necessary parameter to create an instance of this class,
and tf.data.Dataset donesn't support a function to get the length directly
since it's lazy-loaded and may be infinite.
In most cases, however, when an instance of this class is created by helper
functions like 'from_folder', the size of the dataset will be preprocessed,
and this function can return an int representing the size of the dataset.
"""
return self._size
def gen_tf_dataset(self,
batch_size: int = 1,
is_training: bool = False,
shuffle: bool = False,
preprocess: Optional[Callable[..., bool]] = None,
drop_remainder: bool = False) -> tf.data.Dataset:
"""Generates a batched tf.data.Dataset for training/evaluation.
Args:
batch_size: An integer, the returned dataset will be batched by this size.
is_training: A boolean, when True, the returned dataset will be optionally
shuffled and repeated as an endless dataset.
shuffle: A boolean, when True, the returned dataset will be shuffled to
create randomness during model training.
preprocess: A function taking three arguments in order, feature, label and
boolean is_training.
drop_remainder: boolean, whether the finaly batch drops remainder.
Returns:
A TF dataset ready to be consumed by Keras model.
"""
dataset = self._dataset
if preprocess:
preprocess = functools.partial(preprocess, is_training=is_training)
dataset = dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
if is_training:
if shuffle:
# Shuffle size should be bigger than the batch_size. Otherwise it's only
# shuffling within the batch, which equals to not having shuffle.
buffer_size = 3 * batch_size
# But since we are doing shuffle before repeat, it doesn't make sense to
# shuffle more than total available entries.
# TODO: Investigate if shuffling before / after repeat
# dataset can get a better performance?
# Shuffle after repeat will give a more randomized dataset and mix the
# epoch boundary: https://www.tensorflow.org/guide/data
if self._size:
buffer_size = min(self._size, buffer_size)
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
# TODO: Consider converting dataset to distributed dataset
# here.
return dataset
def __len__(self):
"""Returns the number of element of the dataset."""
if self._size is not None:
return self._size
else:
return len(self._dataset)
def split(self: _DatasetT, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
"""Splits dataset into two sub-datasets with the given fraction.
Primarily used for splitting the data set into training and testing sets.
Args:
fraction: A float value defines the fraction of the first returned
subdataset in the original data.
Returns:
The splitted two sub datasets.
"""
return self._split(fraction)
def _split(self: _DatasetT, fraction: float,
*args) -> Tuple[_DatasetT, _DatasetT]:
"""Implementation for `split` method and returns sub-class instances.
Child DataLoader classes, if requires additional constructor arguments,
should implement their own `split` method by calling `_split` with all
arguments to the constructor.
Args:
fraction: A float value defines the fraction of the first returned
subdataset in the original data.
*args: additional arguments passed to the sub-class constructor.
Returns:
The splitted two sub datasets.
"""
assert (fraction > 0 and fraction < 1)
dataset = self._dataset
train_size = int(self._size * fraction)
trainset = self.__class__(dataset.take(train_size), train_size, *args)
test_size = self._size - train_size
testset = self.__class__(dataset.skip(train_size), test_size, *args)
return trainset, testset

View File

@ -0,0 +1,78 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset as ds
from mediapipe.model_maker.python.core.utils import test_util
class DatasetTest(tf.test.TestCase):
def test_split(self):
dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
[1, 0]])
data = ds.Dataset(dataset, 4)
train_data, test_data = data.split(0.5)
self.assertLen(train_data, 2)
self.assertIsInstance(train_data, ds.Dataset)
self.assertIsInstance(test_data, ds.Dataset)
for i, elem in enumerate(train_data.gen_tf_dataset()):
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
self.assertLen(test_data, 2)
for i, elem in enumerate(test_data.gen_tf_dataset()):
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
def test_len(self):
size = 4
dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
[1, 0]])
data = ds.Dataset(dataset, size)
self.assertLen(data, size)
def test_gen_tf_dataset(self):
input_dim = 8
data = test_util.create_dataset(
data_size=2, input_shape=[input_dim], num_classes=2)
dataset = data.gen_tf_dataset()
self.assertLen(dataset, 2)
for (feature, label) in dataset:
self.assertTrue((tf.shape(feature).numpy() == np.array([1, 8])).all())
self.assertTrue((tf.shape(label).numpy() == np.array([1])).all())
dataset2 = data.gen_tf_dataset(batch_size=2)
self.assertLen(dataset2, 1)
for (feature, label) in dataset2:
self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all())
self.assertTrue((tf.shape(label).numpy() == np.array([2])).all())
dataset3 = data.gen_tf_dataset(batch_size=2, is_training=True, shuffle=True)
self.assertEqual(dataset3.cardinality(), 1)
for (feature, label) in dataset3.take(10):
self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all())
self.assertTrue((tf.shape(label).numpy() == np.array([2])).all())
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,30 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//mediapipe/framework/tool:mediapipe_files.bzl",
"mediapipe_files",
)
package(
default_visibility = ["//mediapipe/model_maker/python/core/data:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
mediapipe_files(srcs = ["test.jpg"])
filegroup(
name = "testdata",
srcs = ["test.jpg"],
)

View File

@ -0,0 +1,100 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
licenses(["notice"])
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
py_library(
name = "test_util",
testonly = 1,
srcs = ["test_util.py"],
srcs_version = "PY3",
deps = [
":model_util",
"//mediapipe/model_maker/python/core/data:dataset",
],
)
py_library(
name = "image_preprocessing",
srcs = ["image_preprocessing.py"],
srcs_version = "PY3",
)
py_test(
name = "image_preprocessing_test",
srcs = ["image_preprocessing_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":image_preprocessing"],
)
py_library(
name = "model_util",
srcs = ["model_util.py"],
srcs_version = "PY3",
deps = [
":quantization",
"//mediapipe/model_maker/python/core/data:dataset",
],
)
py_test(
name = "model_util_test",
srcs = ["model_util_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":model_util",
":quantization",
":test_util",
],
)
py_library(
name = "loss_functions",
srcs = ["loss_functions.py"],
srcs_version = "PY3",
)
py_test(
name = "loss_functions_test",
srcs = ["loss_functions_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":loss_functions"],
)
py_library(
name = "quantization",
srcs = ["quantization.py"],
srcs_version = "PY3",
deps = ["//mediapipe/model_maker/python/core/data:dataset"],
)
py_test(
name = "quantization_test",
srcs = ["quantization_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":quantization",
":test_util",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,228 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ImageNet preprocessing."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
IMAGE_SIZE = 224
CROP_PADDING = 32
class Preprocessor(object):
"""Preprocessor for image classification."""
def __init__(self,
input_shape,
num_classes,
mean_rgb,
stddev_rgb,
use_augmentation=False):
self.input_shape = input_shape
self.num_classes = num_classes
self.mean_rgb = mean_rgb
self.stddev_rgb = stddev_rgb
self.use_augmentation = use_augmentation
def __call__(self, image, label, is_training=True):
if self.use_augmentation:
return self._preprocess_with_augmentation(image, label, is_training)
return self._preprocess_without_augmentation(image, label)
def _preprocess_with_augmentation(self, image, label, is_training):
"""Image preprocessing method with data augmentation."""
image_size = self.input_shape[0]
if is_training:
image = preprocess_for_train(image, image_size)
else:
image = preprocess_for_eval(image, image_size)
image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
label = tf.one_hot(label, depth=self.num_classes)
return image, label
# TODO: Changes to preprocess to support batch input.
def _preprocess_without_augmentation(self, image, label):
"""Image preprocessing method without data augmentation."""
image = tf.cast(image, tf.float32)
image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
image = tf.compat.v1.image.resize(image, self.input_shape)
label = tf.one_hot(label, depth=self.num_classes)
return image, label
def _distorted_bounding_box_crop(image,
bbox,
min_object_covered=0.1,
aspect_ratio_range=(0.75, 1.33),
area_range=(0.05, 1.0),
max_attempts=100):
"""Generates cropped_image using one of the bboxes randomly distorted.
See `tf.image.sample_distorted_bounding_box` for more documentation.
Args:
image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of
shape [height, width, channels].
bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` where
each coordinate is [0, 1) and the coordinates are arranged as `[ymin,
xmin, ymax, xmax]`. If num_boxes is 0 then use the whole image.
min_object_covered: An optional `float`. Defaults to `0.1`. The cropped area
of the image must contain at least this fraction of any bounding box
supplied.
aspect_ratio_range: An optional list of `float`s. The cropped area of the
image must have an aspect ratio = width / height within this range.
area_range: An optional list of `float`s. The cropped area of the image must
contain a fraction of the supplied image within in this range.
max_attempts: An optional `int`. Number of attempts at generating a cropped
region of the image of the specified constraints. After `max_attempts`
failures, return the entire image.
Returns:
A cropped image `Tensor`
"""
with tf.name_scope('distorted_bounding_box_crop'):
shape = tf.shape(image)
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
shape,
bounding_boxes=bbox,
min_object_covered=min_object_covered,
aspect_ratio_range=aspect_ratio_range,
area_range=area_range,
max_attempts=max_attempts,
use_image_if_no_bounding_boxes=True)
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
# Crop the image to the specified bounding box.
offset_y, offset_x, _ = tf.unstack(bbox_begin)
target_height, target_width, _ = tf.unstack(bbox_size)
image = tf.image.crop_to_bounding_box(image, offset_y, offset_x,
target_height, target_width)
return image
def _at_least_x_are_equal(a, b, x):
"""At least `x` of `a` and `b` `Tensors` are equal."""
match = tf.equal(a, b)
match = tf.cast(match, tf.int32)
return tf.greater_equal(tf.reduce_sum(match), x)
def _resize_image(image, image_size, method=None):
if method is not None:
tf.compat.v1.logging.info('Use customized resize method {}'.format(method))
return tf.compat.v1.image.resize([image], [image_size, image_size],
method)[0]
tf.compat.v1.logging.info('Use default resize_bicubic.')
return tf.compat.v1.image.resize_bicubic([image], [image_size, image_size])[0]
def _decode_and_random_crop(original_image, image_size, resize_method=None):
"""Makes a random crop of image_size."""
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
image = _distorted_bounding_box_crop(
original_image,
bbox,
min_object_covered=0.1,
aspect_ratio_range=(3. / 4, 4. / 3.),
area_range=(0.08, 1.0),
max_attempts=10)
original_shape = tf.shape(original_image)
bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)
image = tf.cond(bad,
lambda: _decode_and_center_crop(original_image, image_size),
lambda: _resize_image(image, image_size, resize_method))
return image
def _decode_and_center_crop(image, image_size, resize_method=None):
"""Crops to center of image with padding then scales image_size."""
shape = tf.shape(image)
image_height = shape[0]
image_width = shape[1]
padded_center_crop_size = tf.cast(
((image_size / (image_size + CROP_PADDING)) *
tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32)
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
image = tf.image.crop_to_bounding_box(image, offset_height, offset_width,
padded_center_crop_size,
padded_center_crop_size)
image = _resize_image(image, image_size, resize_method)
return image
def _flip(image):
"""Random horizontal image flip."""
image = tf.image.random_flip_left_right(image)
return image
def preprocess_for_train(
image: tf.Tensor,
image_size: int = IMAGE_SIZE,
resize_method: str = tf.image.ResizeMethod.BILINEAR) -> tf.Tensor:
"""Preprocesses the given image for evaluation.
Args:
image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of
shape [height, width, channels].
image_size: image size.
resize_method: resize method. If none, use bicubic.
Returns:
A preprocessed image `Tensor`.
"""
image = _decode_and_random_crop(image, image_size, resize_method)
image = _flip(image)
image = tf.reshape(image, [image_size, image_size, 3])
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
return image
def preprocess_for_eval(
image: tf.Tensor,
image_size: int = IMAGE_SIZE,
resize_method: str = tf.image.ResizeMethod.BILINEAR) -> tf.Tensor:
"""Preprocesses the given image for evaluation.
Args:
image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of
shape [height, width, channels].
image_size: image size.
resize_method: if None, use bicubic.
Returns:
A preprocessed image `Tensor`.
"""
image = _decode_and_center_crop(image, image_size, resize_method)
image = tf.reshape(image, [image_size, image_size, 3])
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
return image

View File

@ -0,0 +1,85 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import image_preprocessing
def _get_preprocessed_image(preprocessor, is_training=False):
image_placeholder = tf.compat.v1.placeholder(tf.uint8, [24, 24, 3])
label_placeholder = tf.compat.v1.placeholder(tf.int32, [1])
image_tensor, _ = preprocessor(image_placeholder, label_placeholder,
is_training)
with tf.compat.v1.Session() as sess:
input_image = np.arange(24 * 24 * 3, dtype=np.uint8).reshape([24, 24, 3])
image = sess.run(
image_tensor,
feed_dict={
image_placeholder: input_image,
label_placeholder: [0]
})
return image
class PreprocessorTest(tf.test.TestCase):
def test_preprocess_without_augmentation(self):
preprocessor = image_preprocessing.Preprocessor(input_shape=[2, 2],
num_classes=2,
mean_rgb=[0.0],
stddev_rgb=[255.0],
use_augmentation=False)
actual_image = np.array([[[0., 0.00392157, 0.00784314],
[0.14117648, 0.14509805, 0.14901961]],
[[0.37647063, 0.3803922, 0.38431376],
[0.5176471, 0.52156866, 0.5254902]]])
image = _get_preprocessed_image(preprocessor)
self.assertTrue(np.allclose(image, actual_image, atol=1e-05))
def test_preprocess_with_augmentation(self):
image_preprocessing.CROP_PADDING = 1
preprocessor = image_preprocessing.Preprocessor(input_shape=[2, 2],
num_classes=2,
mean_rgb=[0.0],
stddev_rgb=[255.0],
use_augmentation=True)
# Tests validation image.
actual_eval_image = np.array([[[0.17254902, 0.1764706, 0.18039216],
[0.26666668, 0.27058825, 0.27450982]],
[[0.42352945, 0.427451, 0.43137258],
[0.5176471, 0.52156866, 0.5254902]]])
image = _get_preprocessed_image(preprocessor, is_training=False)
self.assertTrue(np.allclose(image, actual_eval_image, atol=1e-05))
# Tests training image.
image1 = _get_preprocessed_image(preprocessor, is_training=True)
image2 = _get_preprocessed_image(preprocessor, is_training=True)
self.assertFalse(np.allclose(image1, image2, atol=1e-05))
self.assertEqual(image1.shape, (2, 2, 3))
self.assertEqual(image2.shape, (2, 2, 3))
if __name__ == '__main__':
tf.compat.v1.disable_eager_execution()
tf.test.main()

View File

@ -0,0 +1,105 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Loss function utility library."""
from typing import Optional, Sequence
import tensorflow as tf
class FocalLoss(tf.keras.losses.Loss):
"""Implementation of focal loss (https://arxiv.org/pdf/1708.02002.pdf).
This class computes the focal loss between labels and prediction. Focal loss
is a weighted loss function that modulates the standard cross-entropy loss
based on how well the neural network performs on a specific example of a
class. The labels should be provided in a `one_hot` vector representation.
There should be `#classes` floating point values per prediction.
The loss is reduced across all samples using 'sum_over_batch_size' reduction
(see https://www.tensorflow.org/api_docs/python/tf/keras/losses/Reduction).
Example usage:
>>> y_true = [[0, 1, 0], [0, 0, 1]]
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
>>> gamma = 2
>>> focal_loss = FocalLoss(gamma)
>>> focal_loss(y_true, y_pred).numpy()
0.9326
>>> # Calling with 'sample_weight'.
>>> focal_loss(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()
0.6528
Usage with the `compile()` API:
```python
model.compile(optimizer='sgd', loss=FocalLoss(gamma))
```
"""
def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
"""Constructor.
Args:
gamma: Focal loss gamma, as described in class docs.
class_weight: A weight to apply to the loss, one for each class. The
weight is applied for each input where the ground truth label matches.
"""
super(tf.keras.losses.Loss, self).__init__()
# Used for clipping min/max values of probability values in y_pred to avoid
# NaNs and Infs in computation.
self._epsilon = 1e-7
# This is a tunable "focusing parameter"; should be >= 0.
# When gamma = 0, the loss returned is the standard categorical
# cross-entropy loss.
self._gamma = gamma
self._class_weight = class_weight
# tf.keras.losses.Loss class implementation requires a Reduction specified
# in self.reduction. To use this reduction, we should use tensorflow's
# compute_weighted_loss function however it is only compatible with v1 of
# Tensorflow: https://www.tensorflow.org/api_docs/python/tf/compat/v1/losses/compute_weighted_loss?hl=en. pylint: disable=line-too-long
# So even though it is specified here, we don't use self.reduction in the
# loss function call.
self.reduction = tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
def __call__(self,
y_true: tf.Tensor,
y_pred: tf.Tensor,
sample_weight: Optional[tf.Tensor] = None) -> tf.Tensor:
if self._class_weight:
class_weight = tf.convert_to_tensor(self._class_weight, dtype=tf.float32)
label = tf.argmax(y_true, axis=1)
loss_weight = tf.gather(class_weight, label)
else:
loss_weight = tf.ones(tf.shape(y_true)[0])
y_true = tf.cast(y_true, y_pred.dtype)
y_pred = tf.clip_by_value(y_pred, self._epsilon, 1 - self._epsilon)
batch_size = tf.cast(tf.shape(y_pred)[0], y_pred.dtype)
if sample_weight is None:
sample_weight = tf.constant(1.0)
weight_shape = sample_weight.shape
weight_rank = weight_shape.ndims
y_pred_rank = y_pred.shape.ndims
if y_pred_rank - weight_rank == 1:
sample_weight = tf.expand_dims(sample_weight, [-1])
elif weight_rank != 0:
raise ValueError(f'Unexpected sample_weights, should be either a scalar'
f'or a vector of batch_size:{batch_size.numpy()}')
ce = -tf.math.log(y_pred)
modulating_factor = tf.math.pow(1 - y_pred, self._gamma)
losses = y_true * modulating_factor * ce * sample_weight
losses = losses * loss_weight[:, tf.newaxis]
# By default, this function uses "sum_over_batch_size" reduction for the
# loss per batch.
return tf.reduce_sum(losses) / batch_size

View File

@ -0,0 +1,103 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import loss_functions
class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
dict(testcase_name='no_sample_weight', sample_weight=None),
dict(
testcase_name='with_sample_weight',
sample_weight=tf.constant([0.2, 0.2, 0.3, 0.1, 0.2])))
def test_focal_loss_gamma_0_is_cross_entropy(
self, sample_weight: Optional[tf.Tensor]):
y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1,
0]])
y_pred = tf.constant([[0.7, 0.1, 0.2], [0.6, 0.3, 0.1], [0.1, 0.5, 0.4],
[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
tf_cce = tf.keras.losses.CategoricalCrossentropy(
from_logits=False,
reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
focal_loss = loss_functions.FocalLoss(gamma=0)
self.assertAllClose(
tf_cce(y_true, y_pred, sample_weight=sample_weight),
focal_loss(y_true, y_pred, sample_weight=sample_weight), 1e-4)
def test_focal_loss_with_sample_weight(self):
y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1,
0]])
y_pred = tf.constant([[0.7, 0.1, 0.2], [0.6, 0.3, 0.1], [0.1, 0.5, 0.4],
[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
focal_loss = loss_functions.FocalLoss(gamma=0)
sample_weight = tf.constant([0.2, 0.2, 0.3, 0.1, 0.2])
self.assertGreater(
focal_loss(y_true=y_true, y_pred=y_pred),
focal_loss(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight))
@parameterized.named_parameters(
dict(testcase_name='gt_0.1', y_pred=tf.constant([0.1, 0.9])),
dict(testcase_name='gt_0.3', y_pred=tf.constant([0.3, 0.7])),
dict(testcase_name='gt_0.5', y_pred=tf.constant([0.5, 0.5])),
dict(testcase_name='gt_0.7', y_pred=tf.constant([0.7, 0.3])),
dict(testcase_name='gt_0.9', y_pred=tf.constant([0.9, 0.1])),
)
def test_focal_loss_decreases_with_increasing_gamma(self, y_pred: tf.Tensor):
y_true = tf.constant([[1, 0]])
focal_loss_gamma_0 = loss_functions.FocalLoss(gamma=0)
loss_gamma_0 = focal_loss_gamma_0(y_true, y_pred)
focal_loss_gamma_0p5 = loss_functions.FocalLoss(gamma=0.5)
loss_gamma_0p5 = focal_loss_gamma_0p5(y_true, y_pred)
focal_loss_gamma_1 = loss_functions.FocalLoss(gamma=1)
loss_gamma_1 = focal_loss_gamma_1(y_true, y_pred)
focal_loss_gamma_2 = loss_functions.FocalLoss(gamma=2)
loss_gamma_2 = focal_loss_gamma_2(y_true, y_pred)
focal_loss_gamma_5 = loss_functions.FocalLoss(gamma=5)
loss_gamma_5 = focal_loss_gamma_5(y_true, y_pred)
self.assertGreater(loss_gamma_0, loss_gamma_0p5)
self.assertGreater(loss_gamma_0p5, loss_gamma_1)
self.assertGreater(loss_gamma_1, loss_gamma_2)
self.assertGreater(loss_gamma_2, loss_gamma_5)
@parameterized.named_parameters(
dict(testcase_name='index_0', true_class=0),
dict(testcase_name='index_1', true_class=1),
dict(testcase_name='index_2', true_class=2),
)
def test_focal_loss_class_weight_is_applied(self, true_class: int):
class_weight = [1.0, 3.0, 10.0]
y_pred = tf.constant([[1.0, 1.0, 1.0]]) / 3.0
y_true = tf.one_hot(true_class, depth=3)[tf.newaxis, :]
expected_loss = -math.log(1.0 / 3.0) * class_weight[true_class]
loss_fn = loss_functions.FocalLoss(gamma=0, class_weight=class_weight)
loss = loss_fn(y_true, y_pred)
self.assertNear(loss, expected_loss, 1e-4)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,241 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for keras models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union
# Dependency imports
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset
from mediapipe.model_maker.python.core.utils import quantization
DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
ESTIMITED_STEPS_PER_EPOCH = 1000
def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
batch_size: Optional[int] = None,
train_data: Optional[dataset.Dataset] = None) -> int:
"""Gets the estimated training steps per epoch.
1. If `steps_per_epoch` is set, returns `steps_per_epoch` directly.
2. Else if we can get the length of training data successfully, returns
`train_data_length // batch_size`.
Args:
steps_per_epoch: int, training steps per epoch.
batch_size: int, batch size.
train_data: training data.
Returns:
Estimated training steps per epoch.
Raises:
ValueError: if both steps_per_epoch and train_data are not set.
"""
if steps_per_epoch is not None:
# steps_per_epoch is set by users manually.
return steps_per_epoch
else:
if train_data is None:
raise ValueError('Input train_data cannot be None.')
# Gets the steps by the length of the training data.
return len(train_data) // batch_size
def export_tflite(
model: tf.keras.Model,
tflite_filepath: str,
quantization_config: Optional[quantization.QuantizationConfig] = None,
supported_ops: Tuple[tf.lite.OpsSet,
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,)):
"""Converts the model to tflite format and saves it.
Args:
model: model to be converted to tflite.
tflite_filepath: File path to save tflite model.
quantization_config: Configuration for post-training quantization.
supported_ops: A list of supported ops in the converted TFLite file.
"""
if tflite_filepath is None:
raise ValueError(
"TFLite filepath couldn't be None when exporting to tflite.")
with tempfile.TemporaryDirectory() as temp_dir:
save_path = os.path.join(temp_dir, 'saved_model')
model.save(save_path, include_optimizer=False, save_format='tf')
converter = tf.lite.TFLiteConverter.from_saved_model(save_path)
if quantization_config:
converter = quantization_config.set_converter_with_quantization(converter)
converter.target_spec.supported_ops = supported_ops
tflite_model = converter.convert()
with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
f.write(tflite_model)
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Applies a warmup schedule on a given learning rate decay schedule."""
def __init__(self,
initial_learning_rate: float,
decay_schedule_fn: Callable[[Any], Any],
warmup_steps: int,
name: Optional[str] = None):
"""Initializes a new instance of the `WarmUp` class.
Args:
initial_learning_rate: learning rate after the warmup.
decay_schedule_fn: A function maps step to learning rate. Will be applied
for values of step larger than 'warmup_steps'.
warmup_steps: Number of steps to do warmup for.
name: TF namescope under which to perform the learning rate calculation.
"""
super(WarmUp, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.warmup_steps = warmup_steps
self.decay_schedule_fn = decay_schedule_fn
self.name = name
def __call__(self, step: Union[int, tf.Tensor]) -> tf.Tensor:
with tf.name_scope(self.name or 'WarmUp') as name:
# Implements linear warmup. i.e., if global_step < warmup_steps, the
# learning rate will be `global_step/num_warmup_steps * init_lr`.
global_step_float = tf.cast(step, tf.float32)
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
warmup_percent_done = global_step_float / warmup_steps_float
warmup_learning_rate = self.initial_learning_rate * warmup_percent_done
return tf.cond(
global_step_float < warmup_steps_float,
lambda: warmup_learning_rate,
lambda: self.decay_schedule_fn(step),
name=name)
def get_config(self) -> Dict[Text, Any]:
return {
'initial_learning_rate': self.initial_learning_rate,
'decay_schedule_fn': self.decay_schedule_fn,
'warmup_steps': self.warmup_steps,
'name': self.name
}
class LiteRunner(object):
"""A runner to do inference with the TFLite model."""
def __init__(self, tflite_filepath: str):
"""Initializes Lite runner with tflite model file.
Args:
tflite_filepath: File path to the TFLite model.
"""
with tf.io.gfile.GFile(tflite_filepath, 'rb') as f:
tflite_model = f.read()
self.interpreter = tf.lite.Interpreter(model_content=tflite_model)
self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
def run(
self, input_tensors: Union[List[tf.Tensor], Dict[str, tf.Tensor]]
) -> Union[List[tf.Tensor], tf.Tensor]:
"""Runs inference with the TFLite model.
Args:
input_tensors: List / Dict of the input tensors of the TFLite model. The
order should be the same as the keras model if it's a list. It also
accepts tensor directly if the model has only 1 input.
Returns:
List of the output tensors for multi-output models, otherwise just
the output tensor. The order should be the same as the keras model.
"""
if not isinstance(input_tensors, list) and not isinstance(
input_tensors, dict):
input_tensors = [input_tensors]
interpreter = self.interpreter
# Reshape inputs
for i, input_detail in enumerate(self.input_details):
input_tensor = _get_input_tensor(
input_tensors=input_tensors,
input_details=self.input_details,
index=i)
interpreter.resize_tensor_input(
input_index=input_detail['index'], tensor_size=input_tensor.shape)
interpreter.allocate_tensors()
# Feed input to the interpreter
for i, input_detail in enumerate(self.input_details):
input_tensor = _get_input_tensor(
input_tensors=input_tensors,
input_details=self.input_details,
index=i)
if input_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT):
# Quantize the input
scale, zero_point = input_detail['quantization']
input_tensor = input_tensor / scale + zero_point
input_tensor = np.array(input_tensor, dtype=input_detail['dtype'])
interpreter.set_tensor(input_detail['index'], input_tensor)
interpreter.invoke()
output_tensors = []
for output_detail in self.output_details:
output_tensor = interpreter.get_tensor(output_detail['index'])
if output_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT):
# Dequantize the output
scale, zero_point = output_detail['quantization']
output_tensor = output_tensor.astype(np.float32)
output_tensor = (output_tensor - zero_point) * scale
output_tensors.append(output_tensor)
if len(output_tensors) == 1:
return output_tensors[0]
return output_tensors
def get_lite_runner(tflite_filepath: str) -> 'LiteRunner':
"""Returns a `LiteRunner` from file path to TFLite model."""
lite_runner = LiteRunner(tflite_filepath)
return lite_runner
def _get_input_tensor(input_tensors: Union[List[tf.Tensor], Dict[str,
tf.Tensor]],
input_details: Dict[str, Any], index: int) -> tf.Tensor:
"""Returns input tensor in `input_tensors` that maps `input_detail[i]`."""
if isinstance(input_tensors, dict):
# Gets the mapped input tensor.
input_detail = input_details
for input_tensor_name, input_tensor in input_tensors.items():
if input_tensor_name in input_detail['name']:
return input_tensor
raise ValueError('Input tensors don\'t contains a tensor that mapped the '
'input detail %s' % str(input_detail))
else:
return input_tensors[index]

View File

@ -0,0 +1,137 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.core.utils import test_util
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
dict(
testcase_name='input_only_steps_per_epoch',
steps_per_epoch=1000,
batch_size=None,
train_data=None,
expected_steps_per_epoch=1000),
dict(
testcase_name='input_steps_per_epoch_and_batch_size',
steps_per_epoch=1000,
batch_size=32,
train_data=None,
expected_steps_per_epoch=1000),
dict(
testcase_name='input_steps_per_epoch_batch_size_and_train_data',
steps_per_epoch=1000,
batch_size=32,
train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
[1, 0]]),
expected_steps_per_epoch=1000),
dict(
testcase_name='input_batch_size_and_train_data',
steps_per_epoch=None,
batch_size=2,
train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
[1, 0]]),
expected_steps_per_epoch=2))
def test_get_steps_per_epoch(self, steps_per_epoch, batch_size, train_data,
expected_steps_per_epoch):
estimated_steps_per_epoch = model_util.get_steps_per_epoch(
steps_per_epoch=steps_per_epoch,
batch_size=batch_size,
train_data=train_data)
self.assertEqual(estimated_steps_per_epoch, expected_steps_per_epoch)
def test_get_steps_per_epoch_raise_value_error(self):
with self.assertRaises(ValueError):
model_util.get_steps_per_epoch(
steps_per_epoch=None, batch_size=16, train_data=None)
def test_warmup(self):
init_lr = 0.1
warmup_steps = 1000
num_decay_steps = 100
learning_rate_fn = tf.keras.experimental.CosineDecay(
initial_learning_rate=init_lr, decay_steps=num_decay_steps)
warmup_object = model_util.WarmUp(
initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate_fn,
warmup_steps=1000,
name='test')
self.assertEqual(
warmup_object.get_config(), {
'initial_learning_rate': init_lr,
'decay_schedule_fn': learning_rate_fn,
'warmup_steps': warmup_steps,
'name': 'test'
})
def test_export_tflite(self):
input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
model_util.export_tflite(model, tflite_file)
self._test_tflite(model, tflite_file, input_dim)
@parameterized.named_parameters(
dict(
testcase_name='dynamic_quantize',
config=quantization.QuantizationConfig.for_dynamic(),
model_size=1288),
dict(
testcase_name='int8_quantize',
config=quantization.QuantizationConfig.for_int8(
representative_data=test_util.create_dataset(
data_size=10, input_shape=[16], num_classes=3)),
model_size=1832),
dict(
testcase_name='float16_quantize',
config=quantization.QuantizationConfig.for_float16(),
model_size=1468))
def test_export_tflite_quantized(self, config, model_size):
input_dim = 16
num_classes = 2
max_input_value = 5
model = test_util.build_model([input_dim], num_classes)
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
model_util.export_tflite(model, tflite_file, config)
self._test_tflite(
model, tflite_file, input_dim, max_input_value, atol=1e-00)
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
def _test_tflite(self,
keras_model: tf.keras.Model,
tflite_model_file: str,
input_dim: int,
max_input_value: int = 1000,
atol: float = 1e-04):
np.random.seed(0)
random_input = np.random.uniform(
low=0, high=max_input_value, size=(1, input_dim)).astype(np.float32)
self.assertTrue(
test_util.is_same_output(
tflite_model_file, keras_model, random_input, atol=atol))
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,213 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Libraries for post-training quantization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Any, Callable, List, Optional, Union
# Dependency imports
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset as ds
DEFAULT_QUANTIZATION_STEPS = 500
def _get_representative_dataset_generator(dataset: tf.data.Dataset,
num_steps: int) -> Callable[[], Any]:
"""Gets a representative dataset generator for post-training quantization.
The generator is to provide a small dataset to calibrate or estimate the
range, i.e, (min, max) of all floating-point arrays in the model for
quantization. Usually, this is a small subset of a few hundred samples
randomly chosen, in no particular order, from the training or evaluation
dataset. See tf.lite.RepresentativeDataset for more details.
Args:
dataset: Input dataset for extracting representative sub dataset.
num_steps: The number of quantization steps which also reflects the size of
the representative dataset.
Returns:
A representative dataset generator.
"""
def representative_dataset_gen():
"""Generates representative dataset for quantization."""
for data, _ in dataset.take(num_steps):
yield [data]
return representative_dataset_gen
class QuantizationConfig(object):
"""Configuration for post-training quantization.
Refer to
https://www.tensorflow.org/lite/performance/post_training_quantization
for different post-training quantization options.
"""
def __init__(
self,
optimizations: Optional[Union[tf.lite.Optimize,
List[tf.lite.Optimize]]] = None,
representative_data: Optional[ds.Dataset] = None,
quantization_steps: Optional[int] = None,
inference_input_type: Optional[tf.dtypes.DType] = None,
inference_output_type: Optional[tf.dtypes.DType] = None,
supported_ops: Optional[Union[tf.lite.OpsSet,
List[tf.lite.OpsSet]]] = None,
supported_types: Optional[Union[tf.dtypes.DType,
List[tf.dtypes.DType]]] = None,
experimental_new_quantizer: bool = False,
):
"""Constructs QuantizationConfig.
Args:
optimizations: A list of optimizations to apply when converting the model.
If not set, use `[Optimize.DEFAULT]` by default.
representative_data: A representative ds.Dataset for post-training
quantization.
quantization_steps: Number of post-training quantization calibration steps
to run (default to DEFAULT_QUANTIZATION_STEPS).
inference_input_type: Target data type of real-number input arrays. Allows
for a different type for input arrays. Defaults to None. If set, must be
be `{tf.float32, tf.uint8, tf.int8}`.
inference_output_type: Target data type of real-number output arrays.
Allows for a different type for output arrays. Defaults to None. If set,
must be `{tf.float32, tf.uint8, tf.int8}`.
supported_ops: Set of OpsSet options supported by the device. Used to Set
converter.target_spec.supported_ops.
supported_types: List of types for constant values on the target device.
Supported values are types exported by lite.constants. Frequently, an
optimization choice is driven by the most compact (i.e. smallest) type
in this list (default [constants.FLOAT]).
experimental_new_quantizer: Whether to enable experimental new quantizer.
Raises:
ValueError: if inference_input_type or inference_output_type are set but
not in {tf.float32, tf.uint8, tf.int8}.
"""
if inference_input_type is not None and inference_input_type not in {
tf.float32, tf.uint8, tf.int8
}:
raise ValueError('Unsupported inference_input_type %s' %
inference_input_type)
if inference_output_type is not None and inference_output_type not in {
tf.float32, tf.uint8, tf.int8
}:
raise ValueError('Unsupported inference_output_type %s' %
inference_output_type)
if optimizations is None:
optimizations = [tf.lite.Optimize.DEFAULT]
if not isinstance(optimizations, list):
optimizations = [optimizations]
self.optimizations = optimizations
self.representative_data = representative_data
if self.representative_data is not None and quantization_steps is None:
quantization_steps = DEFAULT_QUANTIZATION_STEPS
self.quantization_steps = quantization_steps
self.inference_input_type = inference_input_type
self.inference_output_type = inference_output_type
if supported_ops is not None and not isinstance(supported_ops, list):
supported_ops = [supported_ops]
self.supported_ops = supported_ops
if supported_types is not None and not isinstance(supported_types, list):
supported_types = [supported_types]
self.supported_types = supported_types
self.experimental_new_quantizer = experimental_new_quantizer
@classmethod
def for_dynamic(cls) -> 'QuantizationConfig':
"""Creates configuration for dynamic range quantization."""
return QuantizationConfig()
@classmethod
def for_int8(
cls,
representative_data: ds.Dataset,
quantization_steps: int = DEFAULT_QUANTIZATION_STEPS,
inference_input_type: tf.dtypes.DType = tf.uint8,
inference_output_type: tf.dtypes.DType = tf.uint8,
supported_ops: tf.lite.OpsSet = tf.lite.OpsSet.TFLITE_BUILTINS_INT8
) -> 'QuantizationConfig':
"""Creates configuration for full integer quantization.
Args:
representative_data: Representative data used for post-training
quantization.
quantization_steps: Number of post-training quantization calibration steps
to run.
inference_input_type: Target data type of real-number input arrays.
inference_output_type: Target data type of real-number output arrays.
supported_ops: Set of `tf.lite.OpsSet` options, where each option
represents a set of operators supported by the target device.
Returns:
QuantizationConfig.
"""
return QuantizationConfig(
representative_data=representative_data,
quantization_steps=quantization_steps,
inference_input_type=inference_input_type,
inference_output_type=inference_output_type,
supported_ops=supported_ops)
@classmethod
def for_float16(cls) -> 'QuantizationConfig':
"""Creates configuration for float16 quantization."""
return QuantizationConfig(supported_types=[tf.float16])
def set_converter_with_quantization(self, converter: tf.lite.TFLiteConverter,
**kwargs: Any) -> tf.lite.TFLiteConverter:
"""Sets input TFLite converter with quantization configurations.
Args:
converter: input tf.lite.TFLiteConverter.
**kwargs: arguments used by ds.Dataset.gen_tf_dataset.
Returns:
tf.lite.TFLiteConverter with quantization configurations.
"""
converter.optimizations = self.optimizations
if self.representative_data is not None:
tf_ds = self.representative_data.gen_tf_dataset(
batch_size=1, is_training=False, **kwargs)
converter.representative_dataset = tf.lite.RepresentativeDataset(
_get_representative_dataset_generator(tf_ds, self.quantization_steps))
if self.inference_input_type:
converter.inference_input_type = self.inference_input_type
if self.inference_output_type:
converter.inference_output_type = self.inference_output_type
if self.supported_ops:
converter.target_spec.supported_ops = self.supported_ops
if self.supported_types:
converter.target_spec.supported_types = self.supported_types
if self.experimental_new_quantizer is not None:
converter.experimental_new_quantizer = self.experimental_new_quantizer
return converter

View File

@ -0,0 +1,108 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.core.utils import test_util
class QuantizationTest(tf.test.TestCase, parameterized.TestCase):
def test_create_dynamic_quantization_config(self):
config = quantization.QuantizationConfig.for_dynamic()
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
self.assertIsNone(config.representative_data)
self.assertIsNone(config.inference_input_type)
self.assertIsNone(config.inference_output_type)
self.assertIsNone(config.supported_ops)
self.assertIsNone(config.supported_types)
self.assertFalse(config.experimental_new_quantizer)
def test_create_int8_quantization_config(self):
representative_data = test_util.create_dataset(
data_size=10, input_shape=[4], num_classes=3)
config = quantization.QuantizationConfig.for_int8(
representative_data=representative_data)
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
self.assertEqual(config.inference_input_type, tf.uint8)
self.assertEqual(config.inference_output_type, tf.uint8)
self.assertEqual(config.supported_ops,
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8])
self.assertFalse(config.experimental_new_quantizer)
def test_set_converter_with_quantization_from_int8_config(self):
representative_data = test_util.create_dataset(
data_size=10, input_shape=[4], num_classes=3)
config = quantization.QuantizationConfig.for_int8(
representative_data=representative_data)
model = test_util.build_model(input_shape=[4], num_classes=3)
saved_model_dir = self.get_temp_dir()
model.save(saved_model_dir)
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter = config.set_converter_with_quantization(converter=converter)
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
self.assertEqual(config.inference_input_type, tf.uint8)
self.assertEqual(config.inference_output_type, tf.uint8)
self.assertEqual(config.supported_ops,
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8])
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
self.assertEqual(interpreter.get_input_details()[0]['dtype'], tf.uint8)
self.assertEqual(interpreter.get_output_details()[0]['dtype'], tf.uint8)
def test_create_float16_quantization_config(self):
config = quantization.QuantizationConfig.for_float16()
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
self.assertIsNone(config.representative_data)
self.assertIsNone(config.inference_input_type)
self.assertIsNone(config.inference_output_type)
self.assertIsNone(config.supported_ops)
self.assertEqual(config.supported_types, [tf.float16])
self.assertFalse(config.experimental_new_quantizer)
def test_set_converter_with_quantization_from_float16_config(self):
config = quantization.QuantizationConfig.for_float16()
model = test_util.build_model(input_shape=[4], num_classes=3)
saved_model_dir = self.get_temp_dir()
model.save(saved_model_dir)
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter = config.set_converter_with_quantization(converter=converter)
self.assertEqual(config.supported_types, [tf.float16])
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
# The input and output are expected to be set to float32 by default.
self.assertEqual(interpreter.get_input_details()[0]['dtype'], tf.float32)
self.assertEqual(interpreter.get_output_details()[0]['dtype'], tf.float32)
@parameterized.named_parameters(
dict(
testcase_name='invalid_inference_input_type',
inference_input_type=tf.uint8,
inference_output_type=tf.int64),
dict(
testcase_name='invalid_inference_output_type',
inference_input_type=tf.int64,
inference_output_type=tf.float32))
def test_create_quantization_config_failure(self, inference_input_type,
inference_output_type):
with self.assertRaises(ValueError):
_ = quantization.QuantizationConfig(
inference_input_type=inference_input_type,
inference_output_type=inference_output_type)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,76 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test utilities for model maker."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import List, Union
# Dependency imports
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset as ds
from mediapipe.model_maker.python.core.utils import model_util
def create_dataset(data_size: int,
input_shape: List[int],
num_classes: int,
max_input_value: int = 1000) -> ds.Dataset:
"""Creates and returns a simple `Dataset` object for test."""
features = tf.random.uniform(
shape=[data_size] + input_shape,
minval=0,
maxval=max_input_value,
dtype=tf.float32)
labels = tf.random.uniform(
shape=[data_size], minval=0, maxval=num_classes, dtype=tf.int32)
tf_dataset = tf.data.Dataset.from_tensor_slices((features, labels))
dataset = ds.Dataset(tf_dataset, data_size)
return dataset
def build_model(input_shape: List[int], num_classes: int) -> tf.keras.Model:
"""Builds a simple Keras model for test."""
inputs = tf.keras.layers.Input(shape=input_shape)
if len(input_shape) == 3: # Image inputs.
outputs = tf.keras.layers.GlobalAveragePooling2D()(inputs)
outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(outputs)
elif len(input_shape) == 1: # Text inputs.
outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(inputs)
else:
raise ValueError("Model inputs should be 2D tensor or 4D tensor.")
model = tf.keras.Model(inputs=inputs, outputs=outputs)
return model
def is_same_output(tflite_file: str,
keras_model: tf.keras.Model,
input_tensors: Union[List[tf.Tensor], tf.Tensor],
atol: float = 1e-04) -> bool:
"""Returns if the output of TFLite model and keras model are identical."""
# Gets output from lite model.
lite_runner = model_util.get_lite_runner(tflite_file)
lite_output = lite_runner.run(input_tensors)
# Gets output from keras model.
keras_output = keras_model.predict_on_batch(input_tensors)
return np.allclose(lite_output, keras_output, atol=atol)

View File

@ -0,0 +1,4 @@
absl-py
numpy
opencv-contrib-python
tensorflow

View File

@ -20,6 +20,16 @@ package(default_visibility = [
licenses(["notice"])
mediapipe_proto_library(
name = "gesture_embedder_graph_options_proto",
srcs = ["gesture_embedder_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)
mediapipe_proto_library(
name = "hand_gesture_recognizer_graph_options_proto",
srcs = ["hand_gesture_recognizer_graph_options.proto"],

View File

@ -0,0 +1,30 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe.tasks.vision.gesture_recognizer.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
message GestureEmbedderGraphOptions {
extend mediapipe.CalculatorOptions {
optional GestureEmbedderGraphOptions ext = 478825422;
}
// Base options for configuring hand gesture recognition subgraph, such as
// specifying the TfLite model file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1;
}

View File

@ -0,0 +1,21 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
filegroup(
name = "resource_files",
srcs = glob(["res/**"]),
visibility = ["//mediapipe/tasks/examples/android:__subpackages__"],
)

View File

@ -0,0 +1,37 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.examples.objectdetector">
<uses-sdk
android:minSdkVersion="28"
android:targetSdkVersion="30" />
<!-- For loading images from gallery -->
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
<!-- For using the camera -->
<uses-permission android:name="android.permission.CAMERA" />
<uses-feature android:name="android.hardware.camera" />
<!-- For logging solution events -->
<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
<application
android:allowBackup="true"
android:icon="@mipmap/ic_launcher"
android:label="MediaPipe Tasks Object Detector"
android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true"
android:theme="@style/AppTheme"
android:exported="false">
<activity android:name=".MainActivity"
android:screenOrientation="portrait"
android:exported="true">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
</manifest>

View File

@ -0,0 +1,48 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//visibility:private"])
android_binary(
name = "objectdetector",
srcs = glob(["**/*.java"]),
assets = [
"//mediapipe/tasks/testdata/vision:test_models",
],
assets_dir = "",
custom_package = "com.google.mediapipe.tasks.examples.objectdetector",
manifest = "AndroidManifest.xml",
manifest_values = {
"applicationId": "com.google.mediapipe.tasks.examples.objectdetector",
},
multidex = "native",
resource_files = ["//mediapipe/tasks/examples/android:resource_files"],
deps = [
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector",
"//third_party:androidx_appcompat",
"//third_party:androidx_constraint_layout",
"//third_party:opencv",
"@maven//:androidx_activity_activity",
"@maven//:androidx_concurrent_concurrent_futures",
"@maven//:androidx_exifinterface_exifinterface",
"@maven//:androidx_fragment_fragment",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,236 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.examples.objectdetector;
import android.content.Intent;
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.media.MediaMetadataRetriever;
import android.os.Bundle;
import android.provider.MediaStore;
import androidx.appcompat.app.AppCompatActivity;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.FrameLayout;
import androidx.activity.result.ActivityResultLauncher;
import androidx.activity.result.contract.ActivityResultContracts;
import androidx.exifinterface.media.ExifInterface;
// ContentResolver dependency
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.Image;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult;
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector;
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector.ObjectDetectorOptions;
import java.io.IOException;
import java.io.InputStream;
/** Main activity of MediaPipe Task Object Detector reference app. */
public class MainActivity extends AppCompatActivity {
private static final String TAG = "MainActivity";
private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
private ObjectDetector objectDetector;
private enum InputSource {
UNKNOWN,
IMAGE,
VIDEO,
CAMERA,
}
private InputSource inputSource = InputSource.UNKNOWN;
// Image mode demo component.
private ActivityResultLauncher<Intent> imageGetter;
// Video mode demo component.
private ActivityResultLauncher<Intent> videoGetter;
private ObjectDetectionResultImageView imageView;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
setupImageModeDemo();
setupVideoModeDemo();
// TODO: Adds live camera demo.
}
/** Sets up the image mode demo. */
private void setupImageModeDemo() {
imageView = new ObjectDetectionResultImageView(this);
// The Intent to access gallery and read images as bitmap.
imageGetter =
registerForActivityResult(
new ActivityResultContracts.StartActivityForResult(),
result -> {
Intent resultIntent = result.getData();
if (resultIntent != null) {
if (result.getResultCode() == RESULT_OK) {
Bitmap bitmap = null;
try {
bitmap =
downscaleBitmap(
MediaStore.Images.Media.getBitmap(
this.getContentResolver(), resultIntent.getData()));
} catch (IOException e) {
Log.e(TAG, "Bitmap reading error:" + e);
}
try {
InputStream imageData =
this.getContentResolver().openInputStream(resultIntent.getData());
bitmap = rotateBitmap(bitmap, imageData);
} catch (IOException e) {
Log.e(TAG, "Bitmap rotation error:" + e);
}
if (bitmap != null) {
Image image = new BitmapImageBuilder(bitmap).build();
ObjectDetectionResult detectionResult = objectDetector.detect(image);
imageView.setData(image, detectionResult);
runOnUiThread(() -> imageView.update());
}
}
}
});
Button loadImageButton = findViewById(R.id.button_load_picture);
loadImageButton.setOnClickListener(
v -> {
if (inputSource != InputSource.IMAGE) {
createObjectDetector(RunningMode.IMAGE);
this.inputSource = InputSource.IMAGE;
updateLayout();
}
// Reads images from gallery.
Intent pickImageIntent = new Intent(Intent.ACTION_PICK);
pickImageIntent.setDataAndType(MediaStore.Images.Media.INTERNAL_CONTENT_URI, "image/*");
imageGetter.launch(pickImageIntent);
});
}
/** Sets up the video mode demo. */
private void setupVideoModeDemo() {
imageView = new ObjectDetectionResultImageView(this);
// The Intent to access gallery and read a video file.
videoGetter =
registerForActivityResult(
new ActivityResultContracts.StartActivityForResult(),
result -> {
Intent resultIntent = result.getData();
if (resultIntent != null) {
if (result.getResultCode() == RESULT_OK) {
MediaMetadataRetriever metaRetriever = new MediaMetadataRetriever();
metaRetriever.setDataSource(this, resultIntent.getData());
long duration =
Long.parseLong(
metaRetriever.extractMetadata(
MediaMetadataRetriever.METADATA_KEY_DURATION));
int numFrames =
Integer.parseInt(
metaRetriever.extractMetadata(
MediaMetadataRetriever.METADATA_KEY_VIDEO_FRAME_COUNT));
long frameIntervalMs = duration / numFrames;
for (int i = 0; i < numFrames; ++i) {
Image image = new BitmapImageBuilder(metaRetriever.getFrameAtIndex(i)).build();
ObjectDetectionResult detectionResult =
objectDetector.detectForVideo(image, frameIntervalMs * i);
// Currently only annotates the detection result on the first video frame and
// display it to verify the correctness.
// TODO: Annotates the detection result on every frame, save the
// annotated frames as a video file, and play back the video afterwards.
if (i == 0) {
imageView.setData(image, detectionResult);
runOnUiThread(() -> imageView.update());
}
}
}
}
});
Button loadVideoButton = findViewById(R.id.button_load_video);
loadVideoButton.setOnClickListener(
v -> {
createObjectDetector(RunningMode.VIDEO);
updateLayout();
this.inputSource = InputSource.VIDEO;
// Reads a video from gallery.
Intent pickVideoIntent = new Intent(Intent.ACTION_PICK);
pickVideoIntent.setDataAndType(MediaStore.Video.Media.INTERNAL_CONTENT_URI, "video/*");
videoGetter.launch(pickVideoIntent);
});
}
private void createObjectDetector(RunningMode mode) {
if (objectDetector != null) {
objectDetector.close();
}
// Initializes a new MediaPipe ObjectDetector instance
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setScoreThreshold(0.5f)
.setMaxResults(5)
.setRunningMode(mode)
.build();
objectDetector = ObjectDetector.createFromOptions(this, options);
}
private void updateLayout() {
// Updates the preview layout.
FrameLayout frameLayout = findViewById(R.id.preview_display_layout);
frameLayout.removeAllViewsInLayout();
imageView.setImageDrawable(null);
frameLayout.addView(imageView);
imageView.setVisibility(View.VISIBLE);
}
private Bitmap downscaleBitmap(Bitmap originalBitmap) {
double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight();
int width = imageView.getWidth();
int height = imageView.getHeight();
if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) {
width = (int) (height * aspectRatio);
} else {
height = (int) (width / aspectRatio);
}
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
}
private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException {
int orientation =
new ExifInterface(imageData)
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
if (orientation == ExifInterface.ORIENTATION_NORMAL) {
return inputBitmap;
}
Matrix matrix = new Matrix();
switch (orientation) {
case ExifInterface.ORIENTATION_ROTATE_90:
matrix.postRotate(90);
break;
case ExifInterface.ORIENTATION_ROTATE_180:
matrix.postRotate(180);
break;
case ExifInterface.ORIENTATION_ROTATE_270:
matrix.postRotate(270);
break;
default:
matrix.postRotate(0);
}
return Bitmap.createBitmap(
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
}
}

View File

@ -0,0 +1,77 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.examples.objectdetector;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Matrix;
import android.graphics.Paint;
import androidx.appcompat.widget.AppCompatImageView;
import com.google.mediapipe.framework.image.BitmapExtractor;
import com.google.mediapipe.framework.image.Image;
import com.google.mediapipe.tasks.components.containers.Detection;
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult;
/** An ImageView implementation for displaying {@link ObjectDetectionResult}. */
public class ObjectDetectionResultImageView extends AppCompatImageView {
private static final String TAG = "ObjectDetectionResultImageView";
private static final int BBOX_COLOR = Color.GREEN;
private static final int BBOX_THICKNESS = 5; // Pixels
private Bitmap latest;
public ObjectDetectionResultImageView(Context context) {
super(context);
setScaleType(AppCompatImageView.ScaleType.FIT_CENTER);
}
/**
* Sets an {@link Image} and an {@link ObjectDetectionResult} to render.
*
* @param image an {@link Image} object for annotation.
* @param result an {@link ObjectDetectionResult} object that contains the detection result.
*/
public void setData(Image image, ObjectDetectionResult result) {
if (image == null || result == null) {
return;
}
latest = BitmapExtractor.extract(image);
Canvas canvas = new Canvas(latest);
canvas.drawBitmap(latest, new Matrix(), null);
for (int i = 0; i < result.detections().size(); ++i) {
drawDetectionOnCanvas(result.detections().get(i), canvas);
}
}
/** Updates the image view with the latest {@link ObjectDetectionResult}. */
public void update() {
postInvalidate();
if (latest != null) {
setImageBitmap(latest);
}
}
private void drawDetectionOnCanvas(Detection detection, Canvas canvas) {
// TODO: Draws the category and the score per bounding box.
// Draws bounding box.
Paint bboxPaint = new Paint();
bboxPaint.setColor(BBOX_COLOR);
bboxPaint.setStyle(Paint.Style.STROKE);
bboxPaint.setStrokeWidth(BBOX_THICKNESS);
canvas.drawRect(detection.boundingBox(), bboxPaint);
}
}

View File

@ -0,0 +1,34 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:aapt="http://schemas.android.com/aapt"
android:width="108dp"
android:height="108dp"
android:viewportHeight="108"
android:viewportWidth="108">
<path
android:fillType="evenOdd"
android:pathData="M32,64C32,64 38.39,52.99 44.13,50.95C51.37,48.37 70.14,49.57 70.14,49.57L108.26,87.69L108,109.01L75.97,107.97L32,64Z"
android:strokeColor="#00000000"
android:strokeWidth="1">
<aapt:attr name="android:fillColor">
<gradient
android:endX="78.5885"
android:endY="90.9159"
android:startX="48.7653"
android:startY="61.0927"
android:type="linear">
<item
android:color="#44000000"
android:offset="0.0" />
<item
android:color="#00000000"
android:offset="1.0" />
</gradient>
</aapt:attr>
</path>
<path
android:fillColor="#FFFFFF"
android:fillType="nonZero"
android:pathData="M66.94,46.02L66.94,46.02C72.44,50.07 76,56.61 76,64L32,64C32,56.61 35.56,50.11 40.98,46.06L36.18,41.19C35.45,40.45 35.45,39.3 36.18,38.56C36.91,37.81 38.05,37.81 38.78,38.56L44.25,44.05C47.18,42.57 50.48,41.71 54,41.71C57.48,41.71 60.78,42.57 63.68,44.05L69.11,38.56C69.84,37.81 70.98,37.81 71.71,38.56C72.44,39.3 72.44,40.45 71.71,41.19L66.94,46.02ZM62.94,56.92C64.08,56.92 65,56.01 65,54.88C65,53.76 64.08,52.85 62.94,52.85C61.8,52.85 60.88,53.76 60.88,54.88C60.88,56.01 61.8,56.92 62.94,56.92ZM45.06,56.92C46.2,56.92 47.13,56.01 47.13,54.88C47.13,53.76 46.2,52.85 45.06,52.85C43.92,52.85 43,53.76 43,54.88C43,56.01 43.92,56.92 45.06,56.92Z"
android:strokeColor="#00000000"
android:strokeWidth="1" />
</vector>

View File

@ -0,0 +1,74 @@
<?xml version="1.0" encoding="utf-8"?>
<vector
android:height="108dp"
android:width="108dp"
android:viewportHeight="108"
android:viewportWidth="108"
xmlns:android="http://schemas.android.com/apk/res/android">
<path android:fillColor="#26A69A"
android:pathData="M0,0h108v108h-108z"/>
<path android:fillColor="#00000000" android:pathData="M9,0L9,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,0L19,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M29,0L29,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M39,0L39,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M49,0L49,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M59,0L59,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M69,0L69,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M79,0L79,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M89,0L89,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M99,0L99,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,9L108,9"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,19L108,19"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,29L108,29"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,39L108,39"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,49L108,49"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,59L108,59"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,69L108,69"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,79L108,79"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,89L108,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,99L108,99"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,29L89,29"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,39L89,39"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,49L89,49"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,59L89,59"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,69L89,69"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,79L89,79"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M29,19L29,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M39,19L39,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M49,19L49,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M59,19L59,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M69,19L69,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M79,19L79,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
</vector>

View File

@ -0,0 +1,40 @@
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout
xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical">
<LinearLayout
android:id="@+id/buttons"
android:layout_width="match_parent"
android:layout_height="wrap_content"
style="?android:attr/buttonBarStyle" android:gravity="center"
android:orientation="horizontal">
<Button
android:id="@+id/button_load_picture"
android:layout_width="wrap_content"
style="?android:attr/buttonBarButtonStyle" android:layout_height="wrap_content"
android:text="@string/load_picture" />
<Button
android:id="@+id/button_load_video"
android:layout_width="wrap_content"
style="?android:attr/buttonBarButtonStyle" android:layout_height="wrap_content"
android:text="@string/load_video" />
<Button
android:id="@+id/button_start_camera"
android:layout_width="wrap_content"
style="?android:attr/buttonBarButtonStyle" android:layout_height="wrap_content"
android:text="@string/start_camera" />
</LinearLayout>
<FrameLayout
android:id="@+id/preview_display_layout"
android:layout_width="match_parent"
android:layout_height="match_parent">
<TextView
android:id="@+id/no_view"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:gravity="center"
android:text="@string/instruction" />
</FrameLayout>
</LinearLayout>

View File

@ -0,0 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
<background android:drawable="@drawable/ic_launcher_background"/>
<foreground android:drawable="@mipmap/ic_launcher_foreground"/>
</adaptive-icon>

View File

@ -0,0 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
<background android:drawable="@drawable/ic_launcher_background"/>
<foreground android:drawable="@mipmap/ic_launcher_foreground"/>
</adaptive-icon>

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 959 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 900 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<resources>
<color name="colorPrimary">#008577</color>
<color name="colorPrimaryDark">#00574B</color>
<color name="colorAccent">#D81B60</color>
</resources>

View File

@ -0,0 +1,6 @@
<resources>
<string name="load_picture" translatable="false">Load Picture</string>
<string name="load_video" translatable="false">Load Video</string>
<string name="start_camera" translatable="false">Start Camera</string>
<string name="instruction" translatable="false">Please press any button above to start</string>
</resources>

View File

@ -0,0 +1,11 @@
<resources>
<!-- Base application theme. -->
<style name="AppTheme" parent="Theme.AppCompat.Light.DarkActionBar">
<!-- Customize your theme here. -->
<item name="colorPrimary">@color/colorPrimary</item>
<item name="colorPrimaryDark">@color/colorPrimaryDark</item>
<item name="colorAccent">@color/colorAccent</item>
</style>
</resources>

View File

@ -550,6 +550,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/ssd_mobilenet_v1.tflite?generation=1661875947436302"],
)
http_file(
name = "com_google_mediapipe_test_jpg",
sha256 = "798a12a466933842528d8438f553320eebe5137f02650f12dd68706a2f94fb4f",
urls = ["https://storage.googleapis.com/mediapipe-assets/test.jpg?generation=1664672140191116"],
)
http_file(
name = "com_google_mediapipe_test_model_add_op_tflite",
sha256 = "298300ca8a9193b80ada1dca39d36f20bffeebde09e85385049b3bfe7be2272f",