Merge branch 'image-classification-python-impl' of https://github.com/kinaryml/mediapipe into image-classification-python-impl

This commit is contained in:
kinaryml 2022-10-05 04:44:37 -07:00
commit ef31e3a482
89 changed files with 3668 additions and 357 deletions

View File

@ -53,7 +53,7 @@ RUN pip3 install wheel
RUN pip3 install future
RUN pip3 install absl-py numpy opencv-contrib-python protobuf==3.20.1
RUN pip3 install six==1.14.0
RUN pip3 install tensorflow==2.2.0
RUN pip3 install tensorflow
RUN pip3 install tf_slim
RUN ln -s /usr/bin/python3 /usr/bin/python

View File

@ -143,6 +143,98 @@ Below is an example of how to create a subgraph named `TwoPassThroughSubgraph`.
}
```
## Graph Options
It is possible to specify a "graph options" protobuf for a MediaPipe graph
similar to the [`Calculator Options`](calculators.md#calculator-options)
protobuf specified for a MediaPipe calculator. These "graph options" can be
specified where a graph is invoked, and used to populate calculator options and
subgraph options within the graph.
In a CalculatorGraphConfig, graph options can be specified for a subgraph
exactly like calculator options, as shown below:
```
node {
calculator: "FlowLimiterCalculator"
input_stream: "image"
output_stream: "throttled_image"
node_options: {
[type.googleapis.com/mediapipe.FlowLimiterCalculatorOptions] {
max_in_flight: 1
}
}
}
node {
calculator: "FaceDetectionSubgraph"
input_stream: "IMAGE:throttled_image"
node_options: {
[type.googleapis.com/mediapipe.FaceDetectionOptions] {
tensor_width: 192
tensor_height: 192
}
}
}
```
In a CalculatorGraphConfig, graph options can be accepted and used to populate
calculator options, as shown below:
```
graph_options: {
[type.googleapis.com/mediapipe.FaceDetectionOptions] {}
}
node: {
calculator: "ImageToTensorCalculator"
input_stream: "IMAGE:multi_backend_image"
node_options: {
[type.googleapis.com/mediapipe.ImageToTensorCalculatorOptions] {
keep_aspect_ratio: true
border_mode: BORDER_ZERO
}
}
option_value: "output_tensor_width:options/tensor_width"
option_value: "output_tensor_height:options/tensor_height"
}
node {
calculator: "InferenceCalculator"
node_options: {
[type.googleapis.com/mediapipe.InferenceCalculatorOptions] {}
}
option_value: "delegate:options/delegate"
option_value: "model_path:options/model_path"
}
```
In this example, the `FaceDetectionSubgraph` accepts graph option protobuf
`FaceDetectionOptions`. The `FaceDetectionOptions` is used to define some field
values in the calculator options `ImageToTensorCalculatorOptions` and some field
values in the subgraph options `InferenceCalculatorOptions`. The field values
are defined using the `option_value:` syntax.
In the `CalculatorGraphConfig::Node` protobuf, the fields `node_options:` and
`option_value:` together define the option values for a calculator such as
`ImageToTensorCalculator`. The `node_options:` field defines a set of literal
constant values using the text protobuf syntax. Each `option_value:` field
defines the value for one protobuf field using information from the enclosing
graph, specifically from field values of the graph options of the enclosing
graph. In the example above, the `option_value:`
`"output_tensor_width:options/tensor_width"` defines the field
`ImageToTensorCalculatorOptions.output_tensor_width` using the value of
`FaceDetectionOptions.tensor_width`.
The syntax of `option_value:` is similar to the syntax of `input_stream:`. The
syntax is `option_value: "LHS:RHS"`. The LHS identifies a calculator option
field and the RHS identifies a graph option field. More specifically, the LHS
and RHS each consists of a series of protobuf field names identifying nested
protobuf messages and fields separated by '/'. This is known as the "ProtoPath"
syntax. Nested messages that are referenced in the LHS or RHS must already be
defined in the enclosing protobuf in order to be traversed using
`option_value:`.
## Cycles
<!-- TODO: add discussion of PreviousLoopbackCalculator -->

View File

@ -18,7 +18,7 @@ import android.content.ClipDescription;
import android.content.Context;
import android.net.Uri;
import android.os.Bundle;
import androidx.appcompat.widget.AppCompatEditText;
import android.support.v7.widget.AppCompatEditText;
import android.util.AttributeSet;
import android.util.Log;
import android.view.inputmethod.EditorInfo;

View File

@ -371,7 +371,7 @@ void* Tensor::MapAhwbToCpuRead() const {
if ((valid_ & kValidOpenGlBuffer) && ssbo_written_ == -1) {
// EGLSync is failed. Use another synchronization method.
// TODO: Use tflite::gpu::GlBufferSync and GlActiveSync.
glFinish();
gl_context_->Run([]() { glFinish(); });
} else if (valid_ & kValidAHardwareBuffer) {
CHECK(ahwb_written_) << "Ahwb-to-Cpu synchronization requires the "
"completion function to be set";

View File

@ -89,10 +89,6 @@ def mediapipe_aar(
calculators = calculators,
)
_mediapipe_proto(
name = name + "_proto",
)
native.genrule(
name = name + "_aar_manifest_generator",
outs = ["AndroidManifest.xml"],
@ -115,19 +111,10 @@ EOF
"//mediapipe/java/com/google/mediapipe/components:java_src",
"//mediapipe/java/com/google/mediapipe/framework:java_src",
"//mediapipe/java/com/google/mediapipe/glutil:java_src",
"com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
"com/google/mediapipe/formats/proto/ClassificationProto.java",
"com/google/mediapipe/formats/proto/DetectionProto.java",
"com/google/mediapipe/formats/proto/LandmarkProto.java",
"com/google/mediapipe/formats/proto/LocationDataProto.java",
"com/google/mediapipe/proto/CalculatorProto.java",
] +
] + mediapipe_java_proto_srcs() +
select({
"//conditions:default": [],
"enable_stats_logging": [
"com/google/mediapipe/proto/MediaPipeLoggingProto.java",
"com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
],
"enable_stats_logging": mediapipe_logging_java_proto_srcs(),
}),
manifest = "AndroidManifest.xml",
proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"],
@ -179,93 +166,6 @@ EOF
_aar_with_jni(name, name + "_android_lib")
def _mediapipe_proto(name):
"""Generates MediaPipe java proto libraries.
Args:
name: the name of the target.
"""
_proto_java_src_generator(
name = "mediapipe_log_extension_proto",
proto_src = "mediapipe/util/analytics/mediapipe_log_extension.proto",
java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java",
srcs = ["//mediapipe/util/analytics:protos_src"],
)
_proto_java_src_generator(
name = "mediapipe_logging_enums_proto",
proto_src = "mediapipe/util/analytics/mediapipe_logging_enums.proto",
java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
srcs = ["//mediapipe/util/analytics:protos_src"],
)
_proto_java_src_generator(
name = "calculator_proto",
proto_src = "mediapipe/framework/calculator.proto",
java_lite_out = "com/google/mediapipe/proto/CalculatorProto.java",
srcs = ["//mediapipe/framework:protos_src"],
)
_proto_java_src_generator(
name = "landmark_proto",
proto_src = "mediapipe/framework/formats/landmark.proto",
java_lite_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
srcs = ["//mediapipe/framework/formats:protos_src"],
)
_proto_java_src_generator(
name = "rasterization_proto",
proto_src = "mediapipe/framework/formats/annotation/rasterization.proto",
java_lite_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
srcs = ["//mediapipe/framework/formats/annotation:protos_src"],
)
_proto_java_src_generator(
name = "location_data_proto",
proto_src = "mediapipe/framework/formats/location_data.proto",
java_lite_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
srcs = [
"//mediapipe/framework/formats:protos_src",
"//mediapipe/framework/formats/annotation:protos_src",
],
)
_proto_java_src_generator(
name = "detection_proto",
proto_src = "mediapipe/framework/formats/detection.proto",
java_lite_out = "com/google/mediapipe/formats/proto/DetectionProto.java",
srcs = [
"//mediapipe/framework/formats:protos_src",
"//mediapipe/framework/formats/annotation:protos_src",
],
)
_proto_java_src_generator(
name = "classification_proto",
proto_src = "mediapipe/framework/formats/classification.proto",
java_lite_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
srcs = [
"//mediapipe/framework/formats:protos_src",
],
)
def _proto_java_src_generator(name, proto_src, java_lite_out, srcs = []):
native.genrule(
name = name + "_proto_java_src_generator",
srcs = srcs + [
"@com_google_protobuf//:lite_well_known_protos",
],
outs = [java_lite_out],
cmd = "$(location @com_google_protobuf//:protoc) " +
"--proto_path=. --proto_path=$(GENDIR) " +
"--proto_path=$$(pwd)/external/com_google_protobuf/src " +
"--java_out=lite:$(GENDIR) " + proto_src + " && " +
"mv $(GENDIR)/" + java_lite_out + " $$(dirname $(location " + java_lite_out + "))",
tools = [
"@com_google_protobuf//:protoc",
],
)
def _mediapipe_jni(name, gen_libmediapipe, calculators = []):
"""Generates MediaPipe jni library.
@ -345,3 +245,93 @@ cp -r lib jni
zip -r $$origdir/$(location :{}.aar) jni/*/*.so
""".format(android_library, name, name, name, name),
)
def mediapipe_java_proto_src_extractor(target, src_out, name = ""):
"""Extracts the generated MediaPipe java proto source code from the target.
Args:
target: The java proto lite target to be built and extracted.
src_out: The output java proto src code path.
name: The optional bazel target name.
Returns:
The output java proto src code path.
"""
if not name:
name = target.split(":")[-1] + "_proto_java_src_extractor"
src_jar = target.replace("_java_proto_lite", "_proto-lite-src.jar").replace(":", "/").replace("//", "")
native.genrule(
name = name + "_proto_java_src_extractor",
srcs = [target],
outs = [src_out],
cmd = "unzip $(GENDIR)/" + src_jar + " -d $(GENDIR) && mv $(GENDIR)/" +
src_out + " $$(dirname $(location " + src_out + "))",
)
return src_out
def mediapipe_java_proto_srcs(name = ""):
"""Extracts the generated MediaPipe framework java proto source code.
Args:
name: The optional bazel target name.
Returns:
The list of the extrated MediaPipe java proto source code.
"""
proto_src_list = []
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:calculator_java_proto_lite",
src_out = "com/google/mediapipe/proto/CalculatorProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:landmark_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite",
src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:location_data_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:detection_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/DetectionProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:classification_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
))
return proto_src_list
def mediapipe_logging_java_proto_srcs(name = ""):
"""Extracts the generated logging-related MediaPipe java proto source code.
Args:
name: The optional bazel target name.
Returns:
The list of the extrated MediaPipe logging-related java proto source code.
"""
proto_src_list = []
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util/analytics:mediapipe_log_extension_java_proto_lite",
src_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util/analytics:mediapipe_logging_enums_java_proto_lite",
src_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
))
return proto_src_list

View File

@ -0,0 +1,22 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//visibility:public"])
package_group(
name = "internal",
packages = [
"//mediapipe/model_maker/...",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,22 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//visibility:public"])
package_group(
name = "internal",
packages = [
"//mediapipe/model_maker/...",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,19 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
licenses(["notice"])

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,68 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
licenses(["notice"])
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
py_library(
name = "data_util",
srcs = ["data_util.py"],
srcs_version = "PY3",
)
py_test(
name = "data_util_test",
srcs = ["data_util_test.py"],
data = ["//mediapipe/model_maker/python/core/data/testdata"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":data_util"],
)
py_library(
name = "dataset",
srcs = ["dataset.py"],
srcs_version = "PY3",
)
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":dataset",
"//mediapipe/model_maker/python/core/utils:test_util",
],
)
py_library(
name = "classification_dataset",
srcs = ["classification_dataset.py"],
srcs_version = "PY3",
deps = [":dataset"],
)
py_test(
name = "classification_dataset_test",
srcs = ["classification_dataset_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":classification_dataset"],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,47 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common classification dataset library."""
from typing import Any, Tuple
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset as ds
class ClassificationDataset(ds.Dataset):
"""DataLoader for classification models."""
def __init__(self, dataset: tf.data.Dataset, size: int, index_to_label: Any):
super().__init__(dataset, size)
self.index_to_label = index_to_label
@property
def num_classes(self: ds._DatasetT) -> int:
return len(self.index_to_label)
def split(self: ds._DatasetT,
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
"""Splits dataset into two sub-datasets with the given fraction.
Primarily used for splitting the data set into training and testing sets.
Args:
fraction: float, demonstrates the fraction of the first returned
subdataset in the original data.
Returns:
The splitted two sub datasets.
"""
return self._split(fraction, self.index_to_label)

View File

@ -0,0 +1,68 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Dependency imports
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset
class ClassificationDataLoaderTest(tf.test.TestCase):
def test_split(self):
class MagicClassificationDataLoader(
classification_dataset.ClassificationDataset):
def __init__(self, dataset, size, index_to_label, value):
super(MagicClassificationDataLoader,
self).__init__(dataset, size, index_to_label)
self.value = value
def split(self, fraction):
return self._split(fraction, self.index_to_label, self.value)
# Some dummy inputs.
magic_value = 42
num_classes = 2
index_to_label = (False, True)
# Create data loader from sample data.
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = MagicClassificationDataLoader(ds, len(ds), index_to_label,
magic_value)
# Train/Test data split.
fraction = .25
train_data, test_data = data.split(fraction)
# `split` should return instances of child DataLoader.
self.assertIsInstance(train_data, MagicClassificationDataLoader)
self.assertIsInstance(test_data, MagicClassificationDataLoader)
# Make sure number of entries are right.
self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data))
self.assertLen(train_data, fraction * len(ds))
self.assertLen(test_data, len(ds) - len(train_data))
# Make sure attributes propagated correctly.
self.assertEqual(train_data.num_classes, num_classes)
self.assertEqual(test_data.index_to_label, index_to_label)
self.assertEqual(train_data.value, magic_value)
self.assertEqual(test_data.value, magic_value)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,35 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data utility library."""
import cv2
import numpy as np
import tensorflow as tf
def load_image(path: str) -> np.ndarray:
"""Loads an image as an RGB numpy array.
Args:
path: input image file absolute path.
Returns:
An RGB image in numpy.ndarray.
"""
tf.compat.v1.logging.info('Loading RGB image %s', path)
# TODO Replace the OpenCV image load and conversion library by
# MediaPipe image utility library once it is ready.
image = cv2.imread(path)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

View File

@ -0,0 +1,44 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# Dependency imports
from absl import flags
import tensorflow as tf
from mediapipe.model_maker.python.core.data import data_util
_WORKSPACE = "mediapipe"
_TEST_DATA_DIR = os.path.join(
_WORKSPACE, 'mediapipe/model_maker/python/core/data/testdata')
FLAGS = flags.FLAGS
class DataUtilTest(tf.test.TestCase):
def test_load_rgb_image(self):
image_path = os.path.join(FLAGS.test_srcdir, _TEST_DATA_DIR, 'test.jpg')
image_data = data_util.load_image(image_path)
self.assertEqual(image_data.shape, (5184, 3456, 3))
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,164 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common dataset for model training and evaluation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from typing import Callable, Optional, Tuple, TypeVar
# Dependency imports
import tensorflow as tf
_DatasetT = TypeVar('_DatasetT', bound='Dataset')
class Dataset(object):
"""A generic dataset class for loading model training and evaluation dataset.
For each ML task, such as image classification, text classification etc., a
subclass can be derived from this class to provide task-specific data loading
utilities.
"""
def __init__(self, tf_dataset: tf.data.Dataset, size: Optional[int] = None):
"""Initializes Dataset class.
To build dataset from raw data, consider using the task specific utilities,
e.g. from_folder().
Args:
tf_dataset: A tf.data.Dataset object that contains a potentially large set
of elements, where each element is a pair of (input_data, target). The
`input_data` means the raw input data, like an image, a text etc., while
the `target` means the ground truth of the raw input data, e.g. the
classification label of the image etc.
size: The size of the dataset. tf.data.Dataset donesn't support a function
to get the length directly since it's lazy-loaded and may be infinite.
"""
self._dataset = tf_dataset
self._size = size
@property
def size(self) -> Optional[int]:
"""Returns the size of the dataset.
Note that this function may return None becuase the exact size of the
dataset isn't a necessary parameter to create an instance of this class,
and tf.data.Dataset donesn't support a function to get the length directly
since it's lazy-loaded and may be infinite.
In most cases, however, when an instance of this class is created by helper
functions like 'from_folder', the size of the dataset will be preprocessed,
and this function can return an int representing the size of the dataset.
"""
return self._size
def gen_tf_dataset(self,
batch_size: int = 1,
is_training: bool = False,
shuffle: bool = False,
preprocess: Optional[Callable[..., bool]] = None,
drop_remainder: bool = False) -> tf.data.Dataset:
"""Generates a batched tf.data.Dataset for training/evaluation.
Args:
batch_size: An integer, the returned dataset will be batched by this size.
is_training: A boolean, when True, the returned dataset will be optionally
shuffled and repeated as an endless dataset.
shuffle: A boolean, when True, the returned dataset will be shuffled to
create randomness during model training.
preprocess: A function taking three arguments in order, feature, label and
boolean is_training.
drop_remainder: boolean, whether the finaly batch drops remainder.
Returns:
A TF dataset ready to be consumed by Keras model.
"""
dataset = self._dataset
if preprocess:
preprocess = functools.partial(preprocess, is_training=is_training)
dataset = dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
if is_training:
if shuffle:
# Shuffle size should be bigger than the batch_size. Otherwise it's only
# shuffling within the batch, which equals to not having shuffle.
buffer_size = 3 * batch_size
# But since we are doing shuffle before repeat, it doesn't make sense to
# shuffle more than total available entries.
# TODO: Investigate if shuffling before / after repeat
# dataset can get a better performance?
# Shuffle after repeat will give a more randomized dataset and mix the
# epoch boundary: https://www.tensorflow.org/guide/data
if self._size:
buffer_size = min(self._size, buffer_size)
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
# TODO: Consider converting dataset to distributed dataset
# here.
return dataset
def __len__(self):
"""Returns the number of element of the dataset."""
if self._size is not None:
return self._size
else:
return len(self._dataset)
def split(self: _DatasetT, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
"""Splits dataset into two sub-datasets with the given fraction.
Primarily used for splitting the data set into training and testing sets.
Args:
fraction: A float value defines the fraction of the first returned
subdataset in the original data.
Returns:
The splitted two sub datasets.
"""
return self._split(fraction)
def _split(self: _DatasetT, fraction: float,
*args) -> Tuple[_DatasetT, _DatasetT]:
"""Implementation for `split` method and returns sub-class instances.
Child DataLoader classes, if requires additional constructor arguments,
should implement their own `split` method by calling `_split` with all
arguments to the constructor.
Args:
fraction: A float value defines the fraction of the first returned
subdataset in the original data.
*args: additional arguments passed to the sub-class constructor.
Returns:
The splitted two sub datasets.
"""
assert (fraction > 0 and fraction < 1)
dataset = self._dataset
train_size = int(self._size * fraction)
trainset = self.__class__(dataset.take(train_size), train_size, *args)
test_size = self._size - train_size
testset = self.__class__(dataset.skip(train_size), test_size, *args)
return trainset, testset

View File

@ -0,0 +1,78 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset as ds
from mediapipe.model_maker.python.core.utils import test_util
class DatasetTest(tf.test.TestCase):
def test_split(self):
dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
[1, 0]])
data = ds.Dataset(dataset, 4)
train_data, test_data = data.split(0.5)
self.assertLen(train_data, 2)
self.assertIsInstance(train_data, ds.Dataset)
self.assertIsInstance(test_data, ds.Dataset)
for i, elem in enumerate(train_data.gen_tf_dataset()):
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
self.assertLen(test_data, 2)
for i, elem in enumerate(test_data.gen_tf_dataset()):
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
def test_len(self):
size = 4
dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
[1, 0]])
data = ds.Dataset(dataset, size)
self.assertLen(data, size)
def test_gen_tf_dataset(self):
input_dim = 8
data = test_util.create_dataset(
data_size=2, input_shape=[input_dim], num_classes=2)
dataset = data.gen_tf_dataset()
self.assertLen(dataset, 2)
for (feature, label) in dataset:
self.assertTrue((tf.shape(feature).numpy() == np.array([1, 8])).all())
self.assertTrue((tf.shape(label).numpy() == np.array([1])).all())
dataset2 = data.gen_tf_dataset(batch_size=2)
self.assertLen(dataset2, 1)
for (feature, label) in dataset2:
self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all())
self.assertTrue((tf.shape(label).numpy() == np.array([2])).all())
dataset3 = data.gen_tf_dataset(batch_size=2, is_training=True, shuffle=True)
self.assertEqual(dataset3.cardinality(), 1)
for (feature, label) in dataset3.take(10):
self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all())
self.assertTrue((tf.shape(label).numpy() == np.array([2])).all())
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,30 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//mediapipe/framework/tool:mediapipe_files.bzl",
"mediapipe_files",
)
package(
default_visibility = ["//mediapipe/model_maker/python/core/data:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
mediapipe_files(srcs = ["test.jpg"])
filegroup(
name = "testdata",
srcs = ["test.jpg"],
)

View File

@ -0,0 +1,100 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
licenses(["notice"])
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
py_library(
name = "test_util",
testonly = 1,
srcs = ["test_util.py"],
srcs_version = "PY3",
deps = [
":model_util",
"//mediapipe/model_maker/python/core/data:dataset",
],
)
py_library(
name = "image_preprocessing",
srcs = ["image_preprocessing.py"],
srcs_version = "PY3",
)
py_test(
name = "image_preprocessing_test",
srcs = ["image_preprocessing_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":image_preprocessing"],
)
py_library(
name = "model_util",
srcs = ["model_util.py"],
srcs_version = "PY3",
deps = [
":quantization",
"//mediapipe/model_maker/python/core/data:dataset",
],
)
py_test(
name = "model_util_test",
srcs = ["model_util_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":model_util",
":quantization",
":test_util",
],
)
py_library(
name = "loss_functions",
srcs = ["loss_functions.py"],
srcs_version = "PY3",
)
py_test(
name = "loss_functions_test",
srcs = ["loss_functions_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":loss_functions"],
)
py_library(
name = "quantization",
srcs = ["quantization.py"],
srcs_version = "PY3",
deps = ["//mediapipe/model_maker/python/core/data:dataset"],
)
py_test(
name = "quantization_test",
srcs = ["quantization_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":quantization",
":test_util",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,228 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ImageNet preprocessing."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
IMAGE_SIZE = 224
CROP_PADDING = 32
class Preprocessor(object):
"""Preprocessor for image classification."""
def __init__(self,
input_shape,
num_classes,
mean_rgb,
stddev_rgb,
use_augmentation=False):
self.input_shape = input_shape
self.num_classes = num_classes
self.mean_rgb = mean_rgb
self.stddev_rgb = stddev_rgb
self.use_augmentation = use_augmentation
def __call__(self, image, label, is_training=True):
if self.use_augmentation:
return self._preprocess_with_augmentation(image, label, is_training)
return self._preprocess_without_augmentation(image, label)
def _preprocess_with_augmentation(self, image, label, is_training):
"""Image preprocessing method with data augmentation."""
image_size = self.input_shape[0]
if is_training:
image = preprocess_for_train(image, image_size)
else:
image = preprocess_for_eval(image, image_size)
image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
label = tf.one_hot(label, depth=self.num_classes)
return image, label
# TODO: Changes to preprocess to support batch input.
def _preprocess_without_augmentation(self, image, label):
"""Image preprocessing method without data augmentation."""
image = tf.cast(image, tf.float32)
image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
image = tf.compat.v1.image.resize(image, self.input_shape)
label = tf.one_hot(label, depth=self.num_classes)
return image, label
def _distorted_bounding_box_crop(image,
bbox,
min_object_covered=0.1,
aspect_ratio_range=(0.75, 1.33),
area_range=(0.05, 1.0),
max_attempts=100):
"""Generates cropped_image using one of the bboxes randomly distorted.
See `tf.image.sample_distorted_bounding_box` for more documentation.
Args:
image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of
shape [height, width, channels].
bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` where
each coordinate is [0, 1) and the coordinates are arranged as `[ymin,
xmin, ymax, xmax]`. If num_boxes is 0 then use the whole image.
min_object_covered: An optional `float`. Defaults to `0.1`. The cropped area
of the image must contain at least this fraction of any bounding box
supplied.
aspect_ratio_range: An optional list of `float`s. The cropped area of the
image must have an aspect ratio = width / height within this range.
area_range: An optional list of `float`s. The cropped area of the image must
contain a fraction of the supplied image within in this range.
max_attempts: An optional `int`. Number of attempts at generating a cropped
region of the image of the specified constraints. After `max_attempts`
failures, return the entire image.
Returns:
A cropped image `Tensor`
"""
with tf.name_scope('distorted_bounding_box_crop'):
shape = tf.shape(image)
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
shape,
bounding_boxes=bbox,
min_object_covered=min_object_covered,
aspect_ratio_range=aspect_ratio_range,
area_range=area_range,
max_attempts=max_attempts,
use_image_if_no_bounding_boxes=True)
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
# Crop the image to the specified bounding box.
offset_y, offset_x, _ = tf.unstack(bbox_begin)
target_height, target_width, _ = tf.unstack(bbox_size)
image = tf.image.crop_to_bounding_box(image, offset_y, offset_x,
target_height, target_width)
return image
def _at_least_x_are_equal(a, b, x):
"""At least `x` of `a` and `b` `Tensors` are equal."""
match = tf.equal(a, b)
match = tf.cast(match, tf.int32)
return tf.greater_equal(tf.reduce_sum(match), x)
def _resize_image(image, image_size, method=None):
if method is not None:
tf.compat.v1.logging.info('Use customized resize method {}'.format(method))
return tf.compat.v1.image.resize([image], [image_size, image_size],
method)[0]
tf.compat.v1.logging.info('Use default resize_bicubic.')
return tf.compat.v1.image.resize_bicubic([image], [image_size, image_size])[0]
def _decode_and_random_crop(original_image, image_size, resize_method=None):
"""Makes a random crop of image_size."""
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
image = _distorted_bounding_box_crop(
original_image,
bbox,
min_object_covered=0.1,
aspect_ratio_range=(3. / 4, 4. / 3.),
area_range=(0.08, 1.0),
max_attempts=10)
original_shape = tf.shape(original_image)
bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)
image = tf.cond(bad,
lambda: _decode_and_center_crop(original_image, image_size),
lambda: _resize_image(image, image_size, resize_method))
return image
def _decode_and_center_crop(image, image_size, resize_method=None):
"""Crops to center of image with padding then scales image_size."""
shape = tf.shape(image)
image_height = shape[0]
image_width = shape[1]
padded_center_crop_size = tf.cast(
((image_size / (image_size + CROP_PADDING)) *
tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32)
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
image = tf.image.crop_to_bounding_box(image, offset_height, offset_width,
padded_center_crop_size,
padded_center_crop_size)
image = _resize_image(image, image_size, resize_method)
return image
def _flip(image):
"""Random horizontal image flip."""
image = tf.image.random_flip_left_right(image)
return image
def preprocess_for_train(
image: tf.Tensor,
image_size: int = IMAGE_SIZE,
resize_method: str = tf.image.ResizeMethod.BILINEAR) -> tf.Tensor:
"""Preprocesses the given image for evaluation.
Args:
image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of
shape [height, width, channels].
image_size: image size.
resize_method: resize method. If none, use bicubic.
Returns:
A preprocessed image `Tensor`.
"""
image = _decode_and_random_crop(image, image_size, resize_method)
image = _flip(image)
image = tf.reshape(image, [image_size, image_size, 3])
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
return image
def preprocess_for_eval(
image: tf.Tensor,
image_size: int = IMAGE_SIZE,
resize_method: str = tf.image.ResizeMethod.BILINEAR) -> tf.Tensor:
"""Preprocesses the given image for evaluation.
Args:
image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of
shape [height, width, channels].
image_size: image size.
resize_method: if None, use bicubic.
Returns:
A preprocessed image `Tensor`.
"""
image = _decode_and_center_crop(image, image_size, resize_method)
image = tf.reshape(image, [image_size, image_size, 3])
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
return image

View File

@ -0,0 +1,85 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import image_preprocessing
def _get_preprocessed_image(preprocessor, is_training=False):
image_placeholder = tf.compat.v1.placeholder(tf.uint8, [24, 24, 3])
label_placeholder = tf.compat.v1.placeholder(tf.int32, [1])
image_tensor, _ = preprocessor(image_placeholder, label_placeholder,
is_training)
with tf.compat.v1.Session() as sess:
input_image = np.arange(24 * 24 * 3, dtype=np.uint8).reshape([24, 24, 3])
image = sess.run(
image_tensor,
feed_dict={
image_placeholder: input_image,
label_placeholder: [0]
})
return image
class PreprocessorTest(tf.test.TestCase):
def test_preprocess_without_augmentation(self):
preprocessor = image_preprocessing.Preprocessor(input_shape=[2, 2],
num_classes=2,
mean_rgb=[0.0],
stddev_rgb=[255.0],
use_augmentation=False)
actual_image = np.array([[[0., 0.00392157, 0.00784314],
[0.14117648, 0.14509805, 0.14901961]],
[[0.37647063, 0.3803922, 0.38431376],
[0.5176471, 0.52156866, 0.5254902]]])
image = _get_preprocessed_image(preprocessor)
self.assertTrue(np.allclose(image, actual_image, atol=1e-05))
def test_preprocess_with_augmentation(self):
image_preprocessing.CROP_PADDING = 1
preprocessor = image_preprocessing.Preprocessor(input_shape=[2, 2],
num_classes=2,
mean_rgb=[0.0],
stddev_rgb=[255.0],
use_augmentation=True)
# Tests validation image.
actual_eval_image = np.array([[[0.17254902, 0.1764706, 0.18039216],
[0.26666668, 0.27058825, 0.27450982]],
[[0.42352945, 0.427451, 0.43137258],
[0.5176471, 0.52156866, 0.5254902]]])
image = _get_preprocessed_image(preprocessor, is_training=False)
self.assertTrue(np.allclose(image, actual_eval_image, atol=1e-05))
# Tests training image.
image1 = _get_preprocessed_image(preprocessor, is_training=True)
image2 = _get_preprocessed_image(preprocessor, is_training=True)
self.assertFalse(np.allclose(image1, image2, atol=1e-05))
self.assertEqual(image1.shape, (2, 2, 3))
self.assertEqual(image2.shape, (2, 2, 3))
if __name__ == '__main__':
tf.compat.v1.disable_eager_execution()
tf.test.main()

View File

@ -0,0 +1,105 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Loss function utility library."""
from typing import Optional, Sequence
import tensorflow as tf
class FocalLoss(tf.keras.losses.Loss):
"""Implementation of focal loss (https://arxiv.org/pdf/1708.02002.pdf).
This class computes the focal loss between labels and prediction. Focal loss
is a weighted loss function that modulates the standard cross-entropy loss
based on how well the neural network performs on a specific example of a
class. The labels should be provided in a `one_hot` vector representation.
There should be `#classes` floating point values per prediction.
The loss is reduced across all samples using 'sum_over_batch_size' reduction
(see https://www.tensorflow.org/api_docs/python/tf/keras/losses/Reduction).
Example usage:
>>> y_true = [[0, 1, 0], [0, 0, 1]]
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
>>> gamma = 2
>>> focal_loss = FocalLoss(gamma)
>>> focal_loss(y_true, y_pred).numpy()
0.9326
>>> # Calling with 'sample_weight'.
>>> focal_loss(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()
0.6528
Usage with the `compile()` API:
```python
model.compile(optimizer='sgd', loss=FocalLoss(gamma))
```
"""
def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
"""Constructor.
Args:
gamma: Focal loss gamma, as described in class docs.
class_weight: A weight to apply to the loss, one for each class. The
weight is applied for each input where the ground truth label matches.
"""
super(tf.keras.losses.Loss, self).__init__()
# Used for clipping min/max values of probability values in y_pred to avoid
# NaNs and Infs in computation.
self._epsilon = 1e-7
# This is a tunable "focusing parameter"; should be >= 0.
# When gamma = 0, the loss returned is the standard categorical
# cross-entropy loss.
self._gamma = gamma
self._class_weight = class_weight
# tf.keras.losses.Loss class implementation requires a Reduction specified
# in self.reduction. To use this reduction, we should use tensorflow's
# compute_weighted_loss function however it is only compatible with v1 of
# Tensorflow: https://www.tensorflow.org/api_docs/python/tf/compat/v1/losses/compute_weighted_loss?hl=en. pylint: disable=line-too-long
# So even though it is specified here, we don't use self.reduction in the
# loss function call.
self.reduction = tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
def __call__(self,
y_true: tf.Tensor,
y_pred: tf.Tensor,
sample_weight: Optional[tf.Tensor] = None) -> tf.Tensor:
if self._class_weight:
class_weight = tf.convert_to_tensor(self._class_weight, dtype=tf.float32)
label = tf.argmax(y_true, axis=1)
loss_weight = tf.gather(class_weight, label)
else:
loss_weight = tf.ones(tf.shape(y_true)[0])
y_true = tf.cast(y_true, y_pred.dtype)
y_pred = tf.clip_by_value(y_pred, self._epsilon, 1 - self._epsilon)
batch_size = tf.cast(tf.shape(y_pred)[0], y_pred.dtype)
if sample_weight is None:
sample_weight = tf.constant(1.0)
weight_shape = sample_weight.shape
weight_rank = weight_shape.ndims
y_pred_rank = y_pred.shape.ndims
if y_pred_rank - weight_rank == 1:
sample_weight = tf.expand_dims(sample_weight, [-1])
elif weight_rank != 0:
raise ValueError(f'Unexpected sample_weights, should be either a scalar'
f'or a vector of batch_size:{batch_size.numpy()}')
ce = -tf.math.log(y_pred)
modulating_factor = tf.math.pow(1 - y_pred, self._gamma)
losses = y_true * modulating_factor * ce * sample_weight
losses = losses * loss_weight[:, tf.newaxis]
# By default, this function uses "sum_over_batch_size" reduction for the
# loss per batch.
return tf.reduce_sum(losses) / batch_size

View File

@ -0,0 +1,103 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import loss_functions
class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
dict(testcase_name='no_sample_weight', sample_weight=None),
dict(
testcase_name='with_sample_weight',
sample_weight=tf.constant([0.2, 0.2, 0.3, 0.1, 0.2])))
def test_focal_loss_gamma_0_is_cross_entropy(
self, sample_weight: Optional[tf.Tensor]):
y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1,
0]])
y_pred = tf.constant([[0.7, 0.1, 0.2], [0.6, 0.3, 0.1], [0.1, 0.5, 0.4],
[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
tf_cce = tf.keras.losses.CategoricalCrossentropy(
from_logits=False,
reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
focal_loss = loss_functions.FocalLoss(gamma=0)
self.assertAllClose(
tf_cce(y_true, y_pred, sample_weight=sample_weight),
focal_loss(y_true, y_pred, sample_weight=sample_weight), 1e-4)
def test_focal_loss_with_sample_weight(self):
y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1,
0]])
y_pred = tf.constant([[0.7, 0.1, 0.2], [0.6, 0.3, 0.1], [0.1, 0.5, 0.4],
[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
focal_loss = loss_functions.FocalLoss(gamma=0)
sample_weight = tf.constant([0.2, 0.2, 0.3, 0.1, 0.2])
self.assertGreater(
focal_loss(y_true=y_true, y_pred=y_pred),
focal_loss(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight))
@parameterized.named_parameters(
dict(testcase_name='gt_0.1', y_pred=tf.constant([0.1, 0.9])),
dict(testcase_name='gt_0.3', y_pred=tf.constant([0.3, 0.7])),
dict(testcase_name='gt_0.5', y_pred=tf.constant([0.5, 0.5])),
dict(testcase_name='gt_0.7', y_pred=tf.constant([0.7, 0.3])),
dict(testcase_name='gt_0.9', y_pred=tf.constant([0.9, 0.1])),
)
def test_focal_loss_decreases_with_increasing_gamma(self, y_pred: tf.Tensor):
y_true = tf.constant([[1, 0]])
focal_loss_gamma_0 = loss_functions.FocalLoss(gamma=0)
loss_gamma_0 = focal_loss_gamma_0(y_true, y_pred)
focal_loss_gamma_0p5 = loss_functions.FocalLoss(gamma=0.5)
loss_gamma_0p5 = focal_loss_gamma_0p5(y_true, y_pred)
focal_loss_gamma_1 = loss_functions.FocalLoss(gamma=1)
loss_gamma_1 = focal_loss_gamma_1(y_true, y_pred)
focal_loss_gamma_2 = loss_functions.FocalLoss(gamma=2)
loss_gamma_2 = focal_loss_gamma_2(y_true, y_pred)
focal_loss_gamma_5 = loss_functions.FocalLoss(gamma=5)
loss_gamma_5 = focal_loss_gamma_5(y_true, y_pred)
self.assertGreater(loss_gamma_0, loss_gamma_0p5)
self.assertGreater(loss_gamma_0p5, loss_gamma_1)
self.assertGreater(loss_gamma_1, loss_gamma_2)
self.assertGreater(loss_gamma_2, loss_gamma_5)
@parameterized.named_parameters(
dict(testcase_name='index_0', true_class=0),
dict(testcase_name='index_1', true_class=1),
dict(testcase_name='index_2', true_class=2),
)
def test_focal_loss_class_weight_is_applied(self, true_class: int):
class_weight = [1.0, 3.0, 10.0]
y_pred = tf.constant([[1.0, 1.0, 1.0]]) / 3.0
y_true = tf.one_hot(true_class, depth=3)[tf.newaxis, :]
expected_loss = -math.log(1.0 / 3.0) * class_weight[true_class]
loss_fn = loss_functions.FocalLoss(gamma=0, class_weight=class_weight)
loss = loss_fn(y_true, y_pred)
self.assertNear(loss, expected_loss, 1e-4)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,241 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for keras models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union
# Dependency imports
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset
from mediapipe.model_maker.python.core.utils import quantization
DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
ESTIMITED_STEPS_PER_EPOCH = 1000
def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
batch_size: Optional[int] = None,
train_data: Optional[dataset.Dataset] = None) -> int:
"""Gets the estimated training steps per epoch.
1. If `steps_per_epoch` is set, returns `steps_per_epoch` directly.
2. Else if we can get the length of training data successfully, returns
`train_data_length // batch_size`.
Args:
steps_per_epoch: int, training steps per epoch.
batch_size: int, batch size.
train_data: training data.
Returns:
Estimated training steps per epoch.
Raises:
ValueError: if both steps_per_epoch and train_data are not set.
"""
if steps_per_epoch is not None:
# steps_per_epoch is set by users manually.
return steps_per_epoch
else:
if train_data is None:
raise ValueError('Input train_data cannot be None.')
# Gets the steps by the length of the training data.
return len(train_data) // batch_size
def export_tflite(
model: tf.keras.Model,
tflite_filepath: str,
quantization_config: Optional[quantization.QuantizationConfig] = None,
supported_ops: Tuple[tf.lite.OpsSet,
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,)):
"""Converts the model to tflite format and saves it.
Args:
model: model to be converted to tflite.
tflite_filepath: File path to save tflite model.
quantization_config: Configuration for post-training quantization.
supported_ops: A list of supported ops in the converted TFLite file.
"""
if tflite_filepath is None:
raise ValueError(
"TFLite filepath couldn't be None when exporting to tflite.")
with tempfile.TemporaryDirectory() as temp_dir:
save_path = os.path.join(temp_dir, 'saved_model')
model.save(save_path, include_optimizer=False, save_format='tf')
converter = tf.lite.TFLiteConverter.from_saved_model(save_path)
if quantization_config:
converter = quantization_config.set_converter_with_quantization(converter)
converter.target_spec.supported_ops = supported_ops
tflite_model = converter.convert()
with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
f.write(tflite_model)
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Applies a warmup schedule on a given learning rate decay schedule."""
def __init__(self,
initial_learning_rate: float,
decay_schedule_fn: Callable[[Any], Any],
warmup_steps: int,
name: Optional[str] = None):
"""Initializes a new instance of the `WarmUp` class.
Args:
initial_learning_rate: learning rate after the warmup.
decay_schedule_fn: A function maps step to learning rate. Will be applied
for values of step larger than 'warmup_steps'.
warmup_steps: Number of steps to do warmup for.
name: TF namescope under which to perform the learning rate calculation.
"""
super(WarmUp, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.warmup_steps = warmup_steps
self.decay_schedule_fn = decay_schedule_fn
self.name = name
def __call__(self, step: Union[int, tf.Tensor]) -> tf.Tensor:
with tf.name_scope(self.name or 'WarmUp') as name:
# Implements linear warmup. i.e., if global_step < warmup_steps, the
# learning rate will be `global_step/num_warmup_steps * init_lr`.
global_step_float = tf.cast(step, tf.float32)
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
warmup_percent_done = global_step_float / warmup_steps_float
warmup_learning_rate = self.initial_learning_rate * warmup_percent_done
return tf.cond(
global_step_float < warmup_steps_float,
lambda: warmup_learning_rate,
lambda: self.decay_schedule_fn(step),
name=name)
def get_config(self) -> Dict[Text, Any]:
return {
'initial_learning_rate': self.initial_learning_rate,
'decay_schedule_fn': self.decay_schedule_fn,
'warmup_steps': self.warmup_steps,
'name': self.name
}
class LiteRunner(object):
"""A runner to do inference with the TFLite model."""
def __init__(self, tflite_filepath: str):
"""Initializes Lite runner with tflite model file.
Args:
tflite_filepath: File path to the TFLite model.
"""
with tf.io.gfile.GFile(tflite_filepath, 'rb') as f:
tflite_model = f.read()
self.interpreter = tf.lite.Interpreter(model_content=tflite_model)
self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
def run(
self, input_tensors: Union[List[tf.Tensor], Dict[str, tf.Tensor]]
) -> Union[List[tf.Tensor], tf.Tensor]:
"""Runs inference with the TFLite model.
Args:
input_tensors: List / Dict of the input tensors of the TFLite model. The
order should be the same as the keras model if it's a list. It also
accepts tensor directly if the model has only 1 input.
Returns:
List of the output tensors for multi-output models, otherwise just
the output tensor. The order should be the same as the keras model.
"""
if not isinstance(input_tensors, list) and not isinstance(
input_tensors, dict):
input_tensors = [input_tensors]
interpreter = self.interpreter
# Reshape inputs
for i, input_detail in enumerate(self.input_details):
input_tensor = _get_input_tensor(
input_tensors=input_tensors,
input_details=self.input_details,
index=i)
interpreter.resize_tensor_input(
input_index=input_detail['index'], tensor_size=input_tensor.shape)
interpreter.allocate_tensors()
# Feed input to the interpreter
for i, input_detail in enumerate(self.input_details):
input_tensor = _get_input_tensor(
input_tensors=input_tensors,
input_details=self.input_details,
index=i)
if input_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT):
# Quantize the input
scale, zero_point = input_detail['quantization']
input_tensor = input_tensor / scale + zero_point
input_tensor = np.array(input_tensor, dtype=input_detail['dtype'])
interpreter.set_tensor(input_detail['index'], input_tensor)
interpreter.invoke()
output_tensors = []
for output_detail in self.output_details:
output_tensor = interpreter.get_tensor(output_detail['index'])
if output_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT):
# Dequantize the output
scale, zero_point = output_detail['quantization']
output_tensor = output_tensor.astype(np.float32)
output_tensor = (output_tensor - zero_point) * scale
output_tensors.append(output_tensor)
if len(output_tensors) == 1:
return output_tensors[0]
return output_tensors
def get_lite_runner(tflite_filepath: str) -> 'LiteRunner':
"""Returns a `LiteRunner` from file path to TFLite model."""
lite_runner = LiteRunner(tflite_filepath)
return lite_runner
def _get_input_tensor(input_tensors: Union[List[tf.Tensor], Dict[str,
tf.Tensor]],
input_details: Dict[str, Any], index: int) -> tf.Tensor:
"""Returns input tensor in `input_tensors` that maps `input_detail[i]`."""
if isinstance(input_tensors, dict):
# Gets the mapped input tensor.
input_detail = input_details
for input_tensor_name, input_tensor in input_tensors.items():
if input_tensor_name in input_detail['name']:
return input_tensor
raise ValueError('Input tensors don\'t contains a tensor that mapped the '
'input detail %s' % str(input_detail))
else:
return input_tensors[index]

View File

@ -0,0 +1,137 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.core.utils import test_util
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
dict(
testcase_name='input_only_steps_per_epoch',
steps_per_epoch=1000,
batch_size=None,
train_data=None,
expected_steps_per_epoch=1000),
dict(
testcase_name='input_steps_per_epoch_and_batch_size',
steps_per_epoch=1000,
batch_size=32,
train_data=None,
expected_steps_per_epoch=1000),
dict(
testcase_name='input_steps_per_epoch_batch_size_and_train_data',
steps_per_epoch=1000,
batch_size=32,
train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
[1, 0]]),
expected_steps_per_epoch=1000),
dict(
testcase_name='input_batch_size_and_train_data',
steps_per_epoch=None,
batch_size=2,
train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
[1, 0]]),
expected_steps_per_epoch=2))
def test_get_steps_per_epoch(self, steps_per_epoch, batch_size, train_data,
expected_steps_per_epoch):
estimated_steps_per_epoch = model_util.get_steps_per_epoch(
steps_per_epoch=steps_per_epoch,
batch_size=batch_size,
train_data=train_data)
self.assertEqual(estimated_steps_per_epoch, expected_steps_per_epoch)
def test_get_steps_per_epoch_raise_value_error(self):
with self.assertRaises(ValueError):
model_util.get_steps_per_epoch(
steps_per_epoch=None, batch_size=16, train_data=None)
def test_warmup(self):
init_lr = 0.1
warmup_steps = 1000
num_decay_steps = 100
learning_rate_fn = tf.keras.experimental.CosineDecay(
initial_learning_rate=init_lr, decay_steps=num_decay_steps)
warmup_object = model_util.WarmUp(
initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate_fn,
warmup_steps=1000,
name='test')
self.assertEqual(
warmup_object.get_config(), {
'initial_learning_rate': init_lr,
'decay_schedule_fn': learning_rate_fn,
'warmup_steps': warmup_steps,
'name': 'test'
})
def test_export_tflite(self):
input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
model_util.export_tflite(model, tflite_file)
self._test_tflite(model, tflite_file, input_dim)
@parameterized.named_parameters(
dict(
testcase_name='dynamic_quantize',
config=quantization.QuantizationConfig.for_dynamic(),
model_size=1288),
dict(
testcase_name='int8_quantize',
config=quantization.QuantizationConfig.for_int8(
representative_data=test_util.create_dataset(
data_size=10, input_shape=[16], num_classes=3)),
model_size=1832),
dict(
testcase_name='float16_quantize',
config=quantization.QuantizationConfig.for_float16(),
model_size=1468))
def test_export_tflite_quantized(self, config, model_size):
input_dim = 16
num_classes = 2
max_input_value = 5
model = test_util.build_model([input_dim], num_classes)
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
model_util.export_tflite(model, tflite_file, config)
self._test_tflite(
model, tflite_file, input_dim, max_input_value, atol=1e-00)
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
def _test_tflite(self,
keras_model: tf.keras.Model,
tflite_model_file: str,
input_dim: int,
max_input_value: int = 1000,
atol: float = 1e-04):
np.random.seed(0)
random_input = np.random.uniform(
low=0, high=max_input_value, size=(1, input_dim)).astype(np.float32)
self.assertTrue(
test_util.is_same_output(
tflite_model_file, keras_model, random_input, atol=atol))
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,213 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Libraries for post-training quantization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Any, Callable, List, Optional, Union
# Dependency imports
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset as ds
DEFAULT_QUANTIZATION_STEPS = 500
def _get_representative_dataset_generator(dataset: tf.data.Dataset,
num_steps: int) -> Callable[[], Any]:
"""Gets a representative dataset generator for post-training quantization.
The generator is to provide a small dataset to calibrate or estimate the
range, i.e, (min, max) of all floating-point arrays in the model for
quantization. Usually, this is a small subset of a few hundred samples
randomly chosen, in no particular order, from the training or evaluation
dataset. See tf.lite.RepresentativeDataset for more details.
Args:
dataset: Input dataset for extracting representative sub dataset.
num_steps: The number of quantization steps which also reflects the size of
the representative dataset.
Returns:
A representative dataset generator.
"""
def representative_dataset_gen():
"""Generates representative dataset for quantization."""
for data, _ in dataset.take(num_steps):
yield [data]
return representative_dataset_gen
class QuantizationConfig(object):
"""Configuration for post-training quantization.
Refer to
https://www.tensorflow.org/lite/performance/post_training_quantization
for different post-training quantization options.
"""
def __init__(
self,
optimizations: Optional[Union[tf.lite.Optimize,
List[tf.lite.Optimize]]] = None,
representative_data: Optional[ds.Dataset] = None,
quantization_steps: Optional[int] = None,
inference_input_type: Optional[tf.dtypes.DType] = None,
inference_output_type: Optional[tf.dtypes.DType] = None,
supported_ops: Optional[Union[tf.lite.OpsSet,
List[tf.lite.OpsSet]]] = None,
supported_types: Optional[Union[tf.dtypes.DType,
List[tf.dtypes.DType]]] = None,
experimental_new_quantizer: bool = False,
):
"""Constructs QuantizationConfig.
Args:
optimizations: A list of optimizations to apply when converting the model.
If not set, use `[Optimize.DEFAULT]` by default.
representative_data: A representative ds.Dataset for post-training
quantization.
quantization_steps: Number of post-training quantization calibration steps
to run (default to DEFAULT_QUANTIZATION_STEPS).
inference_input_type: Target data type of real-number input arrays. Allows
for a different type for input arrays. Defaults to None. If set, must be
be `{tf.float32, tf.uint8, tf.int8}`.
inference_output_type: Target data type of real-number output arrays.
Allows for a different type for output arrays. Defaults to None. If set,
must be `{tf.float32, tf.uint8, tf.int8}`.
supported_ops: Set of OpsSet options supported by the device. Used to Set
converter.target_spec.supported_ops.
supported_types: List of types for constant values on the target device.
Supported values are types exported by lite.constants. Frequently, an
optimization choice is driven by the most compact (i.e. smallest) type
in this list (default [constants.FLOAT]).
experimental_new_quantizer: Whether to enable experimental new quantizer.
Raises:
ValueError: if inference_input_type or inference_output_type are set but
not in {tf.float32, tf.uint8, tf.int8}.
"""
if inference_input_type is not None and inference_input_type not in {
tf.float32, tf.uint8, tf.int8
}:
raise ValueError('Unsupported inference_input_type %s' %
inference_input_type)
if inference_output_type is not None and inference_output_type not in {
tf.float32, tf.uint8, tf.int8
}:
raise ValueError('Unsupported inference_output_type %s' %
inference_output_type)
if optimizations is None:
optimizations = [tf.lite.Optimize.DEFAULT]
if not isinstance(optimizations, list):
optimizations = [optimizations]
self.optimizations = optimizations
self.representative_data = representative_data
if self.representative_data is not None and quantization_steps is None:
quantization_steps = DEFAULT_QUANTIZATION_STEPS
self.quantization_steps = quantization_steps
self.inference_input_type = inference_input_type
self.inference_output_type = inference_output_type
if supported_ops is not None and not isinstance(supported_ops, list):
supported_ops = [supported_ops]
self.supported_ops = supported_ops
if supported_types is not None and not isinstance(supported_types, list):
supported_types = [supported_types]
self.supported_types = supported_types
self.experimental_new_quantizer = experimental_new_quantizer
@classmethod
def for_dynamic(cls) -> 'QuantizationConfig':
"""Creates configuration for dynamic range quantization."""
return QuantizationConfig()
@classmethod
def for_int8(
cls,
representative_data: ds.Dataset,
quantization_steps: int = DEFAULT_QUANTIZATION_STEPS,
inference_input_type: tf.dtypes.DType = tf.uint8,
inference_output_type: tf.dtypes.DType = tf.uint8,
supported_ops: tf.lite.OpsSet = tf.lite.OpsSet.TFLITE_BUILTINS_INT8
) -> 'QuantizationConfig':
"""Creates configuration for full integer quantization.
Args:
representative_data: Representative data used for post-training
quantization.
quantization_steps: Number of post-training quantization calibration steps
to run.
inference_input_type: Target data type of real-number input arrays.
inference_output_type: Target data type of real-number output arrays.
supported_ops: Set of `tf.lite.OpsSet` options, where each option
represents a set of operators supported by the target device.
Returns:
QuantizationConfig.
"""
return QuantizationConfig(
representative_data=representative_data,
quantization_steps=quantization_steps,
inference_input_type=inference_input_type,
inference_output_type=inference_output_type,
supported_ops=supported_ops)
@classmethod
def for_float16(cls) -> 'QuantizationConfig':
"""Creates configuration for float16 quantization."""
return QuantizationConfig(supported_types=[tf.float16])
def set_converter_with_quantization(self, converter: tf.lite.TFLiteConverter,
**kwargs: Any) -> tf.lite.TFLiteConverter:
"""Sets input TFLite converter with quantization configurations.
Args:
converter: input tf.lite.TFLiteConverter.
**kwargs: arguments used by ds.Dataset.gen_tf_dataset.
Returns:
tf.lite.TFLiteConverter with quantization configurations.
"""
converter.optimizations = self.optimizations
if self.representative_data is not None:
tf_ds = self.representative_data.gen_tf_dataset(
batch_size=1, is_training=False, **kwargs)
converter.representative_dataset = tf.lite.RepresentativeDataset(
_get_representative_dataset_generator(tf_ds, self.quantization_steps))
if self.inference_input_type:
converter.inference_input_type = self.inference_input_type
if self.inference_output_type:
converter.inference_output_type = self.inference_output_type
if self.supported_ops:
converter.target_spec.supported_ops = self.supported_ops
if self.supported_types:
converter.target_spec.supported_types = self.supported_types
if self.experimental_new_quantizer is not None:
converter.experimental_new_quantizer = self.experimental_new_quantizer
return converter

View File

@ -0,0 +1,108 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.core.utils import test_util
class QuantizationTest(tf.test.TestCase, parameterized.TestCase):
def test_create_dynamic_quantization_config(self):
config = quantization.QuantizationConfig.for_dynamic()
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
self.assertIsNone(config.representative_data)
self.assertIsNone(config.inference_input_type)
self.assertIsNone(config.inference_output_type)
self.assertIsNone(config.supported_ops)
self.assertIsNone(config.supported_types)
self.assertFalse(config.experimental_new_quantizer)
def test_create_int8_quantization_config(self):
representative_data = test_util.create_dataset(
data_size=10, input_shape=[4], num_classes=3)
config = quantization.QuantizationConfig.for_int8(
representative_data=representative_data)
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
self.assertEqual(config.inference_input_type, tf.uint8)
self.assertEqual(config.inference_output_type, tf.uint8)
self.assertEqual(config.supported_ops,
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8])
self.assertFalse(config.experimental_new_quantizer)
def test_set_converter_with_quantization_from_int8_config(self):
representative_data = test_util.create_dataset(
data_size=10, input_shape=[4], num_classes=3)
config = quantization.QuantizationConfig.for_int8(
representative_data=representative_data)
model = test_util.build_model(input_shape=[4], num_classes=3)
saved_model_dir = self.get_temp_dir()
model.save(saved_model_dir)
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter = config.set_converter_with_quantization(converter=converter)
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
self.assertEqual(config.inference_input_type, tf.uint8)
self.assertEqual(config.inference_output_type, tf.uint8)
self.assertEqual(config.supported_ops,
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8])
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
self.assertEqual(interpreter.get_input_details()[0]['dtype'], tf.uint8)
self.assertEqual(interpreter.get_output_details()[0]['dtype'], tf.uint8)
def test_create_float16_quantization_config(self):
config = quantization.QuantizationConfig.for_float16()
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
self.assertIsNone(config.representative_data)
self.assertIsNone(config.inference_input_type)
self.assertIsNone(config.inference_output_type)
self.assertIsNone(config.supported_ops)
self.assertEqual(config.supported_types, [tf.float16])
self.assertFalse(config.experimental_new_quantizer)
def test_set_converter_with_quantization_from_float16_config(self):
config = quantization.QuantizationConfig.for_float16()
model = test_util.build_model(input_shape=[4], num_classes=3)
saved_model_dir = self.get_temp_dir()
model.save(saved_model_dir)
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter = config.set_converter_with_quantization(converter=converter)
self.assertEqual(config.supported_types, [tf.float16])
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
# The input and output are expected to be set to float32 by default.
self.assertEqual(interpreter.get_input_details()[0]['dtype'], tf.float32)
self.assertEqual(interpreter.get_output_details()[0]['dtype'], tf.float32)
@parameterized.named_parameters(
dict(
testcase_name='invalid_inference_input_type',
inference_input_type=tf.uint8,
inference_output_type=tf.int64),
dict(
testcase_name='invalid_inference_output_type',
inference_input_type=tf.int64,
inference_output_type=tf.float32))
def test_create_quantization_config_failure(self, inference_input_type,
inference_output_type):
with self.assertRaises(ValueError):
_ = quantization.QuantizationConfig(
inference_input_type=inference_input_type,
inference_output_type=inference_output_type)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,76 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test utilities for model maker."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import List, Union
# Dependency imports
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset as ds
from mediapipe.model_maker.python.core.utils import model_util
def create_dataset(data_size: int,
input_shape: List[int],
num_classes: int,
max_input_value: int = 1000) -> ds.Dataset:
"""Creates and returns a simple `Dataset` object for test."""
features = tf.random.uniform(
shape=[data_size] + input_shape,
minval=0,
maxval=max_input_value,
dtype=tf.float32)
labels = tf.random.uniform(
shape=[data_size], minval=0, maxval=num_classes, dtype=tf.int32)
tf_dataset = tf.data.Dataset.from_tensor_slices((features, labels))
dataset = ds.Dataset(tf_dataset, data_size)
return dataset
def build_model(input_shape: List[int], num_classes: int) -> tf.keras.Model:
"""Builds a simple Keras model for test."""
inputs = tf.keras.layers.Input(shape=input_shape)
if len(input_shape) == 3: # Image inputs.
outputs = tf.keras.layers.GlobalAveragePooling2D()(inputs)
outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(outputs)
elif len(input_shape) == 1: # Text inputs.
outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(inputs)
else:
raise ValueError("Model inputs should be 2D tensor or 4D tensor.")
model = tf.keras.Model(inputs=inputs, outputs=outputs)
return model
def is_same_output(tflite_file: str,
keras_model: tf.keras.Model,
input_tensors: Union[List[tf.Tensor], tf.Tensor],
atol: float = 1e-04) -> bool:
"""Returns if the output of TFLite model and keras model are identical."""
# Gets output from lite model.
lite_runner = model_util.get_lite_runner(tflite_file)
lite_output = lite_runner.run(input_tensors)
# Gets output from keras model.
keras_output = keras_model.predict_on_batch(input_tensors)
return np.allclose(lite_output, keras_output, atol=atol)

View File

@ -0,0 +1,4 @@
absl-py
numpy
opencv-contrib-python
tensorflow

View File

@ -507,8 +507,11 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
}
};
// REGISTER_MEDIAPIPE_GRAPH argument has to fit on one line to work properly.
// clang-format off
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::components::processors::ClassificationPostprocessingGraph); // NOLINT
// clang-format on
} // namespace processors
} // namespace components

View File

@ -41,8 +41,8 @@ cc_test(
)
cc_library(
name = "hand_gesture_recognizer_subgraph",
srcs = ["hand_gesture_recognizer_subgraph.cc"],
name = "hand_gesture_recognizer_graph",
srcs = ["hand_gesture_recognizer_graph.cc"],
deps = [
"//mediapipe/calculators/core:concatenate_vector_calculator",
"//mediapipe/calculators/tensor:tensor_converter_calculator",
@ -62,11 +62,11 @@ cc_library(
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
"//mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators:handedness_to_matrix_calculator",
"//mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators:landmarks_to_matrix_calculator",
"//mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto:hand_gesture_recognizer_subgraph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto:landmarks_to_matrix_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_subgraph",
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:handedness_to_matrix_calculator",
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator",
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/status",

View File

@ -12,11 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
package(default_visibility = [
"//mediapipe/app/xeno:__subpackages__",
"//mediapipe/tasks:internal",
])
mediapipe_proto_library(
name = "landmarks_to_matrix_calculator_proto",
srcs = ["landmarks_to_matrix_calculator.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)
cc_library(
name = "handedness_to_matrix_calculator",
srcs = ["handedness_to_matrix_calculator.cc"],
@ -25,7 +37,7 @@ cc_library(
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:ret_check",
"//mediapipe/tasks/cc/vision/hand_gesture_recognizer:handedness_util",
"//mediapipe/tasks/cc/vision/gesture_recognizer:handedness_util",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
@ -53,11 +65,11 @@ cc_library(
name = "landmarks_to_matrix_calculator",
srcs = ["landmarks_to_matrix_calculator.cc"],
deps = [
":landmarks_to_matrix_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:ret_check",
"//mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto:landmarks_to_matrix_calculator_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",

View File

@ -26,14 +26,16 @@ limitations under the License.
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h"
// TODO Update to use API2
namespace mediapipe {
namespace tasks {
namespace vision {
namespace api2 {
namespace {
using ::mediapipe::tasks::vision::gesture_recognizer::GetLeftHandScore;
constexpr char kHandednessTag[] = "HANDEDNESS";
constexpr char kHandednessMatrixTag[] = "HANDEDNESS_MATRIX";
@ -71,6 +73,8 @@ class HandednessToMatrixCalculator : public CalculatorBase {
return absl::OkStatus();
}
// TODO remove this after change to API2, because Setting offset
// to 0 is the default in API2
absl::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
return absl::OkStatus();
@ -95,6 +99,5 @@ absl::Status HandednessToMatrixCalculator::Process(CalculatorContext* cc) {
return absl::OkStatus();
}
} // namespace vision
} // namespace tasks
} // namespace api2
} // namespace mediapipe

View File

@ -28,8 +28,6 @@ limitations under the License.
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace {
@ -95,6 +93,4 @@ INSTANTIATE_TEST_CASE_P(
} // namespace
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -27,13 +27,11 @@ limitations under the License.
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h"
// TODO Update to use API2
namespace mediapipe {
namespace tasks {
namespace vision {
using proto::LandmarksToMatrixCalculatorOptions;
namespace api2 {
namespace {
@ -175,7 +173,7 @@ absl::Status ProcessLandmarks(LandmarkListT landmarks, CalculatorContext* cc) {
// input_stream: "IMAGE_SIZE:image_size"
// output_stream: "LANDMARKS_MATRIX:landmarks_matrix"
// options {
// [mediapipe.tasks.vision.proto.LandmarksToMatrixCalculatorOptions.ext] {
// [mediapipe.LandmarksToMatrixCalculatorOptions.ext] {
// object_normalization: true
// object_normalization_origin_offset: 0
// }
@ -221,6 +219,5 @@ absl::Status LandmarksToMatrixCalculator::Process(CalculatorContext* cc) {
return absl::OkStatus();
}
} // namespace vision
} // namespace tasks
} // namespace api2
} // namespace mediapipe

View File

@ -15,7 +15,7 @@ limitations under the License.
syntax = "proto2";
package mediapipe.tasks.vision.proto;
package mediapipe;
import "mediapipe/framework/calculator.proto";

View File

@ -28,8 +28,6 @@ limitations under the License.
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace {
@ -72,8 +70,7 @@ TEST_P(Landmarks2dToMatrixCalculatorTest, OutputsCorrectResult) {
input_stream: "IMAGE_SIZE:image_size"
output_stream: "LANDMARKS_MATRIX:landmarks_matrix"
options {
[mediapipe.tasks.vision.proto.LandmarksToMatrixCalculatorOptions
.ext] {
[mediapipe.LandmarksToMatrixCalculatorOptions.ext] {
object_normalization: $0
object_normalization_origin_offset: $1
}
@ -145,8 +142,7 @@ TEST_P(LandmarksWorld3dToMatrixCalculatorTest, OutputsCorrectResult) {
input_stream: "IMAGE_SIZE:image_size"
output_stream: "LANDMARKS_MATRIX:landmarks_matrix"
options {
[mediapipe.tasks.vision.proto.LandmarksToMatrixCalculatorOptions
.ext] {
[mediapipe.LandmarksToMatrixCalculatorOptions.ext] {
object_normalization: $0
object_normalization_origin_offset: $1
}
@ -202,6 +198,4 @@ INSTANTIATE_TEST_CASE_P(
} // namespace
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -34,14 +34,15 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace gesture_recognizer {
namespace {
@ -50,9 +51,8 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::vision::hand_gesture_recognizer::proto::
HandGestureRecognizerSubgraphOptions;
using ::mediapipe::tasks::vision::proto::LandmarksToMatrixCalculatorOptions;
using ::mediapipe::tasks::vision::gesture_recognizer::proto::
HandGestureRecognizerGraphOptions;
constexpr char kHandednessTag[] = "HANDEDNESS";
constexpr char kLandmarksTag[] = "LANDMARKS";
@ -70,18 +70,6 @@ constexpr char kIndexTag[] = "INDEX";
constexpr char kIterableTag[] = "ITERABLE";
constexpr char kBatchEndTag[] = "BATCH_END";
absl::Status SanityCheckOptions(
const HandGestureRecognizerSubgraphOptions& options) {
if (options.min_tracking_confidence() < 0 ||
options.min_tracking_confidence() > 1) {
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
"Invalid `min_tracking_confidence` option: "
"value must be in the range [0.0, 1.0]",
MediaPipeTasksStatus::kInvalidArgumentError);
}
return absl::OkStatus();
}
Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
Graph& graph) {
auto& node = graph.AddNode("TensorConverterCalculator");
@ -91,9 +79,10 @@ Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
} // namespace
// A "mediapipe.tasks.vision.SingleHandGestureRecognizerSubgraph" performs
// single hand gesture recognition. This graph is used as a building block for
// mediapipe.tasks.vision.HandGestureRecognizerGraph.
// A
// "mediapipe.tasks.vision.gesture_recognizer.SingleHandGestureRecognizerGraph"
// performs single hand gesture recognition. This graph is used as a building
// block for mediapipe.tasks.vision.GestureRecognizerGraph.
//
// Inputs:
// HANDEDNESS - ClassificationList
@ -113,14 +102,15 @@ Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
//
// Example:
// node {
// calculator: "mediapipe.tasks.vision.SingleHandGestureRecognizerSubgraph"
// calculator:
// "mediapipe.tasks.vision.gesture_recognizer.SingleHandGestureRecognizerGraph"
// input_stream: "HANDEDNESS:handedness"
// input_stream: "LANDMARKS:landmarks"
// input_stream: "WORLD_LANDMARKS:world_landmarks"
// input_stream: "IMAGE_SIZE:image_size"
// output_stream: "HAND_GESTURES:hand_gestures"
// options {
// [mediapipe.tasks.vision.hand_gesture_recognizer.proto.HandGestureRecognizerSubgraphOptions.ext]
// [mediapipe.tasks.vision.gesture_recognizer.proto.HandGestureRecognizerGraphOptions.ext]
// {
// base_options {
// model_asset {
@ -130,19 +120,19 @@ Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
// }
// }
// }
class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph {
class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
ASSIGN_OR_RETURN(
const auto* model_resources,
CreateModelResources<HandGestureRecognizerSubgraphOptions>(sc));
CreateModelResources<HandGestureRecognizerGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(
auto hand_gestures,
BuildHandGestureRecognizerGraph(
sc->Options<HandGestureRecognizerSubgraphOptions>(),
*model_resources, graph[Input<ClassificationList>(kHandednessTag)],
BuildGestureRecognizerGraph(
sc->Options<HandGestureRecognizerGraphOptions>(), *model_resources,
graph[Input<ClassificationList>(kHandednessTag)],
graph[Input<NormalizedLandmarkList>(kLandmarksTag)],
graph[Input<LandmarkList>(kWorldLandmarksTag)],
graph[Input<std::pair<int, int>>(kImageSizeTag)], graph));
@ -151,15 +141,13 @@ class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph {
}
private:
absl::StatusOr<Source<ClassificationResult>> BuildHandGestureRecognizerGraph(
const HandGestureRecognizerSubgraphOptions& graph_options,
absl::StatusOr<Source<ClassificationResult>> BuildGestureRecognizerGraph(
const HandGestureRecognizerGraphOptions& graph_options,
const core::ModelResources& model_resources,
Source<ClassificationList> handedness,
Source<NormalizedLandmarkList> hand_landmarks,
Source<LandmarkList> hand_world_landmarks,
Source<std::pair<int, int>> image_size, Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(graph_options));
// Converts the ClassificationList to a matrix.
auto& handedness_to_matrix = graph.AddNode("HandednessToMatrixCalculator");
handedness >> handedness_to_matrix.In(kHandednessTag);
@ -235,12 +223,15 @@ class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph {
}
};
// clang-format off
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::SingleHandGestureRecognizerSubgraph);
::mediapipe::tasks::vision::gesture_recognizer::SingleHandGestureRecognizerGraph); // NOLINT
// clang-format on
// A "mediapipe.tasks.vision.HandGestureRecognizerSubgraph" performs multi
// hand gesture recognition. This graph is used as a building block for
// mediapipe.tasks.vision.HandGestureRecognizerGraph.
// A
// "mediapipe.tasks.vision.gesture_recognizer.MultipleHandGestureRecognizerGraph"
// performs multi hand gesture recognition. This graph is used as a building
// block for mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph.
//
// Inputs:
// HANDEDNESS - std::vector<ClassificationList>
@ -263,7 +254,8 @@ REGISTER_MEDIAPIPE_GRAPH(
//
// Example:
// node {
// calculator: "mediapipe.tasks.vision.HandGestureRecognizerSubgraph"
// calculator:
// "mediapipe.tasks.vision.gesture_recognizer.MultipleHandGestureRecognizerGraph"
// input_stream: "HANDEDNESS:handedness"
// input_stream: "LANDMARKS:landmarks"
// input_stream: "WORLD_LANDMARKS:world_landmarks"
@ -271,7 +263,7 @@ REGISTER_MEDIAPIPE_GRAPH(
// input_stream: "HAND_TRACKING_IDS:hand_tracking_ids"
// output_stream: "HAND_GESTURES:hand_gestures"
// options {
// [mediapipe.tasks.vision.hand_gesture_recognizer.proto.HandGestureRecognizerSubgraph.ext]
// [mediapipe.tasks.vision.gesture_recognizer.proto.MultipleHandGestureRecognizerGraph.ext]
// {
// base_options {
// model_asset {
@ -281,15 +273,15 @@ REGISTER_MEDIAPIPE_GRAPH(
// }
// }
// }
class HandGestureRecognizerSubgraph : public core::ModelTaskGraph {
class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
Graph graph;
ASSIGN_OR_RETURN(
auto multi_hand_gestures,
BuildMultiHandGestureRecognizerSubraph(
sc->Options<HandGestureRecognizerSubgraphOptions>(),
BuildMultiGestureRecognizerSubraph(
sc->Options<HandGestureRecognizerGraphOptions>(),
graph[Input<std::vector<ClassificationList>>(kHandednessTag)],
graph[Input<std::vector<NormalizedLandmarkList>>(kLandmarksTag)],
graph[Input<std::vector<LandmarkList>>(kWorldLandmarksTag)],
@ -302,8 +294,8 @@ class HandGestureRecognizerSubgraph : public core::ModelTaskGraph {
private:
absl::StatusOr<Source<std::vector<ClassificationResult>>>
BuildMultiHandGestureRecognizerSubraph(
const HandGestureRecognizerSubgraphOptions& graph_options,
BuildMultiGestureRecognizerSubraph(
const HandGestureRecognizerGraphOptions& graph_options,
Source<std::vector<ClassificationList>> multi_handedness,
Source<std::vector<NormalizedLandmarkList>> multi_hand_landmarks,
Source<std::vector<LandmarkList>> multi_hand_world_landmarks,
@ -341,17 +333,18 @@ class HandGestureRecognizerSubgraph : public core::ModelTaskGraph {
hand_tracking_id >> get_world_landmarks_at_index.In(kIndexTag);
auto hand_world_landmarks = get_world_landmarks_at_index.Out(kItemTag);
auto& hand_gesture_recognizer_subgraph = graph.AddNode(
"mediapipe.tasks.vision.SingleHandGestureRecognizerSubgraph");
hand_gesture_recognizer_subgraph
.GetOptions<HandGestureRecognizerSubgraphOptions>()
auto& hand_gesture_recognizer_graph = graph.AddNode(
"mediapipe.tasks.vision.gesture_recognizer."
"SingleHandGestureRecognizerGraph");
hand_gesture_recognizer_graph
.GetOptions<HandGestureRecognizerGraphOptions>()
.CopyFrom(graph_options);
handedness >> hand_gesture_recognizer_subgraph.In(kHandednessTag);
hand_landmarks >> hand_gesture_recognizer_subgraph.In(kLandmarksTag);
handedness >> hand_gesture_recognizer_graph.In(kHandednessTag);
hand_landmarks >> hand_gesture_recognizer_graph.In(kLandmarksTag);
hand_world_landmarks >>
hand_gesture_recognizer_subgraph.In(kWorldLandmarksTag);
image_size_clone >> hand_gesture_recognizer_subgraph.In(kImageSizeTag);
auto hand_gestures = hand_gesture_recognizer_subgraph.Out(kHandGesturesTag);
hand_gesture_recognizer_graph.In(kWorldLandmarksTag);
image_size_clone >> hand_gesture_recognizer_graph.In(kImageSizeTag);
auto hand_gestures = hand_gesture_recognizer_graph.Out(kHandGesturesTag);
auto& end_loop_classification_results =
graph.AddNode("mediapipe.tasks.EndLoopClassificationResultCalculator");
@ -364,9 +357,12 @@ class HandGestureRecognizerSubgraph : public core::ModelTaskGraph {
}
};
// clang-format off
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::HandGestureRecognizerSubgraph);
::mediapipe::tasks::vision::gesture_recognizer::MultipleHandGestureRecognizerGraph); // NOLINT
// clang-format on
} // namespace gesture_recognizer
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h"
#include <algorithm>
@ -25,6 +25,7 @@ limitations under the License.
namespace mediapipe {
namespace tasks {
namespace vision {
namespace gesture_recognizer {
namespace {} // namespace
@ -58,6 +59,7 @@ absl::StatusOr<float> GetLeftHandScore(
}
}
} // namespace gesture_recognizer
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_
#define MEDIAPIPE_TASKS_CC_VISION_HAND_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_
#ifndef MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_
#define MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/classification.pb.h"
@ -22,6 +22,7 @@ limitations under the License.
namespace mediapipe {
namespace tasks {
namespace vision {
namespace gesture_recognizer {
bool IsLeftHand(const mediapipe::Classification& c);
@ -30,8 +31,9 @@ bool IsRightHand(const mediapipe::Classification& c);
absl::StatusOr<float> GetLeftHandScore(
const mediapipe::ClassificationList& classification_list);
} // namespace gesture_recognizer
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_
#endif // MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/port/gmock.h"
@ -23,6 +23,7 @@ limitations under the License.
namespace mediapipe {
namespace tasks {
namespace vision {
namespace gesture_recognizer {
namespace {
TEST(GetLeftHandScore, SingleLeftHandClassification) {
@ -72,6 +73,7 @@ TEST(GetLeftHandScore, LeftAndRightLowerCaseHandClassification) {
}
} // namespace
} // namespace gesture_recognizer
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -21,8 +21,18 @@ package(default_visibility = [
licenses(["notice"])
mediapipe_proto_library(
name = "hand_gesture_recognizer_subgraph_options_proto",
srcs = ["hand_gesture_recognizer_subgraph_options.proto"],
name = "gesture_embedder_graph_options_proto",
srcs = ["gesture_embedder_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)
mediapipe_proto_library(
name = "hand_gesture_recognizer_graph_options_proto",
srcs = ["hand_gesture_recognizer_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -30,12 +40,3 @@ mediapipe_proto_library(
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)
mediapipe_proto_library(
name = "landmarks_to_matrix_calculator_proto",
srcs = ["landmarks_to_matrix_calculator.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)

View File

@ -0,0 +1,30 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe.tasks.vision.gesture_recognizer.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
message GestureEmbedderGraphOptions {
extend mediapipe.CalculatorOptions {
optional GestureEmbedderGraphOptions ext = 478825422;
}
// Base options for configuring hand gesture recognition subgraph, such as
// specifying the TfLite model file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1;
}

View File

@ -15,15 +15,15 @@ limitations under the License.
// TODO Refactor naming and class structure of hand related Tasks.
syntax = "proto2";
package mediapipe.tasks.vision.hand_gesture_recognizer.proto;
package mediapipe.tasks.vision.gesture_recognizer.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
message HandGestureRecognizerSubgraphOptions {
message HandGestureRecognizerGraphOptions {
extend mediapipe.CalculatorOptions {
optional HandGestureRecognizerSubgraphOptions ext = 463370452;
optional HandGestureRecognizerGraphOptions ext = 463370452;
}
// Base options for configuring hand gesture recognition subgraph, such as
// specifying the TfLite model file with metadata, accelerator options, etc.

View File

@ -51,7 +51,7 @@ cc_library(
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",

View File

@ -40,12 +40,13 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace hand_detector {
namespace {
@ -53,18 +54,23 @@ using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::vision::hand_detector::proto::HandDetectorOptions;
using ::mediapipe::tasks::vision::hand_detector::proto::
HandDetectorGraphOptions;
constexpr char kImageTag[] = "IMAGE";
constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kNormRectsTag[] = "NORM_RECTS";
constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS";
constexpr char kHandRectsTag[] = "HAND_RECTS";
constexpr char kPalmRectsTag[] = "PALM_RECTS";
struct HandDetectionOuts {
Source<std::vector<Detection>> palm_detections;
Source<std::vector<NormalizedRect>> hand_rects;
Source<std::vector<NormalizedRect>> palm_rects;
Source<Image> image;
};
void ConfigureTensorsToDetectionsCalculator(
const HandDetectorGraphOptions& tasks_options,
mediapipe::TensorsToDetectionsCalculatorOptions* options) {
// TODO use metadata to configure these fields.
options->set_num_classes(1);
@ -77,7 +83,7 @@ void ConfigureTensorsToDetectionsCalculator(
options->set_sigmoid_score(true);
options->set_score_clipping_thresh(100.0);
options->set_reverse_output_order(true);
options->set_min_score_thresh(0.5);
options->set_min_score_thresh(tasks_options.min_detection_confidence());
options->set_x_scale(192.0);
options->set_y_scale(192.0);
options->set_w_scale(192.0);
@ -134,9 +140,9 @@ void ConfigureRectTransformationCalculator(
} // namespace
// A "mediapipe.tasks.vision.HandDetectorGraph" performs hand detection. The
// Hand Detection Graph is based on palm detection model, and scale the detected
// palm bounding box to enclose the detected whole hand.
// A "mediapipe.tasks.vision.hand_detector.HandDetectorGraph" performs hand
// detection. The Hand Detection Graph is based on palm detection model, and
// scale the detected palm bounding box to enclose the detected whole hand.
// Accepts CPU input images and outputs Landmark on CPU.
//
// Inputs:
@ -144,19 +150,27 @@ void ConfigureRectTransformationCalculator(
// Image to perform detection on.
//
// Outputs:
// DETECTIONS - std::vector<Detection>
// PALM_DETECTIONS - std::vector<Detection>
// Detected palms with maximum `num_hands` specified in options.
// NORM_RECTS - std::vector<NormalizedRect>
// HAND_RECTS - std::vector<NormalizedRect>
// Detected hand bounding boxes in normalized coordinates.
// PLAM_RECTS - std::vector<NormalizedRect>
// Detected palm bounding boxes in normalized coordinates.
// IMAGE - Image
// The input image that the hand detector runs on and has the pixel data
// stored on the target storage (CPU vs GPU).
//
// Example:
// node {
// calculator: "mediapipe.tasks.vision.HandDetectorGraph"
// calculator: "mediapipe.tasks.vision.hand_detector.HandDetectorGraph"
// input_stream: "IMAGE:image"
// output_stream: "DETECTIONS:palm_detections"
// output_stream: "NORM_RECTS:hand_rects_from_palm_detections"
// output_stream: "PALM_DETECTIONS:palm_detections"
// output_stream: "HAND_RECTS:hand_rects_from_palm_detections"
// output_stream: "PALM_RECTS:palm_rects"
// output_stream: "IMAGE:image_out"
// options {
// [mediapipe.tasks.hand_detector.proto.HandDetectorOptions.ext] {
// [mediapipe.tasks.vision.hand_detector.proto.HandDetectorGraphOptions.ext]
// {
// base_options {
// model_asset {
// file_name: "palm_detection.tflite"
@ -173,16 +187,20 @@ class HandDetectorGraph : public core::ModelTaskGraph {
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
ASSIGN_OR_RETURN(const auto* model_resources,
CreateModelResources<HandDetectorOptions>(sc));
CreateModelResources<HandDetectorGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(auto hand_detection_outs,
BuildHandDetectionSubgraph(
sc->Options<HandDetectorOptions>(), *model_resources,
ASSIGN_OR_RETURN(
auto hand_detection_outs,
BuildHandDetectionSubgraph(sc->Options<HandDetectorGraphOptions>(),
*model_resources,
graph[Input<Image>(kImageTag)], graph));
hand_detection_outs.palm_detections >>
graph[Output<std::vector<Detection>>(kDetectionsTag)];
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
hand_detection_outs.hand_rects >>
graph[Output<std::vector<NormalizedRect>>(kNormRectsTag)];
graph[Output<std::vector<NormalizedRect>>(kHandRectsTag)];
hand_detection_outs.palm_rects >>
graph[Output<std::vector<NormalizedRect>>(kPalmRectsTag)];
hand_detection_outs.image >> graph[Output<Image>(kImageTag)];
return graph.GetConfig();
}
@ -196,7 +214,7 @@ class HandDetectorGraph : public core::ModelTaskGraph {
// image_in: image stream to run hand detection on.
// graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<HandDetectionOuts> BuildHandDetectionSubgraph(
const HandDetectorOptions& subgraph_options,
const HandDetectorGraphOptions& subgraph_options,
const core::ModelResources& model_resources, Source<Image> image_in,
Graph& graph) {
// Add image preprocessing subgraph. The model expects aspect ratio
@ -235,6 +253,7 @@ class HandDetectorGraph : public core::ModelTaskGraph {
auto& tensors_to_detections =
graph.AddNode("TensorsToDetectionsCalculator");
ConfigureTensorsToDetectionsCalculator(
subgraph_options,
&tensors_to_detections
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>());
model_output_tensors >> tensors_to_detections.In("TENSORS");
@ -281,7 +300,8 @@ class HandDetectorGraph : public core::ModelTaskGraph {
.GetOptions<mediapipe::DetectionsToRectsCalculatorOptions>());
palm_detections >> detections_to_rects.In("DETECTIONS");
image_size >> detections_to_rects.In("IMAGE_SIZE");
auto palm_rects = detections_to_rects.Out("NORM_RECTS");
auto palm_rects =
detections_to_rects[Output<std::vector<NormalizedRect>>("NORM_RECTS")];
// Expands and shifts the rectangle that contains the palm so that it's
// likely to cover the entire hand.
@ -308,13 +328,18 @@ class HandDetectorGraph : public core::ModelTaskGraph {
clip_normalized_rect_vector_size[Output<std::vector<NormalizedRect>>(
"")];
return HandDetectionOuts{.palm_detections = palm_detections,
.hand_rects = clipped_hand_rects};
return HandDetectionOuts{
/* palm_detections= */ palm_detections,
/* hand_rects= */ clipped_hand_rects,
/* palm_rects= */ palm_rects,
/* image= */ preprocessing[Output<Image>(kImageTag)]};
}
};
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::HandDetectorGraph);
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::hand_detector::HandDetectorGraph);
} // namespace hand_detector
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -40,13 +40,14 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace hand_detector {
namespace {
using ::file::Defaults;
@ -60,7 +61,8 @@ using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::core::TaskRunner;
using ::mediapipe::tasks::core::proto::ExternalFile;
using ::mediapipe::tasks::vision::DecodeImageFromFile;
using ::mediapipe::tasks::vision::hand_detector::proto::HandDetectorOptions;
using ::mediapipe::tasks::vision::hand_detector::proto::
HandDetectorGraphOptions;
using ::mediapipe::tasks::vision::hand_detector::proto::HandDetectorResult;
using ::testing::EqualsProto;
using ::testing::TestParamInfo;
@ -80,9 +82,9 @@ constexpr char kTwoHandsResultFile[] = "hand_detector_result_two_hands.pbtxt";
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageName[] = "image";
constexpr char kPalmDetectionsTag[] = "DETECTIONS";
constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS";
constexpr char kPalmDetectionsName[] = "palm_detections";
constexpr char kHandNormRectsTag[] = "NORM_RECTS";
constexpr char kHandRectsTag[] = "HAND_RECTS";
constexpr char kHandNormRectsName[] = "hand_norm_rects";
constexpr float kPalmDetectionBboxMaxDiff = 0.01;
@ -104,22 +106,22 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
Graph graph;
auto& hand_detection =
graph.AddNode("mediapipe.tasks.vision.HandDetectorGraph");
graph.AddNode("mediapipe.tasks.vision.hand_detector.HandDetectorGraph");
auto options = std::make_unique<HandDetectorOptions>();
auto options = std::make_unique<HandDetectorGraphOptions>();
options->mutable_base_options()->mutable_model_asset()->set_file_name(
JoinPath("./", kTestDataDirectory, model_name));
options->set_min_detection_confidence(0.5);
options->set_num_hands(num_hands);
hand_detection.GetOptions<HandDetectorOptions>().Swap(options.get());
hand_detection.GetOptions<HandDetectorGraphOptions>().Swap(options.get());
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
hand_detection.In(kImageTag);
hand_detection.Out(kPalmDetectionsTag).SetName(kPalmDetectionsName) >>
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
hand_detection.Out(kHandNormRectsTag).SetName(kHandNormRectsName) >>
graph[Output<std::vector<NormalizedRect>>(kHandNormRectsTag)];
hand_detection.Out(kHandRectsTag).SetName(kHandNormRectsName) >>
graph[Output<std::vector<NormalizedRect>>(kHandRectsTag)];
return TaskRunner::Create(
graph.GetConfig(), std::make_unique<core::MediaPipeBuiltinOpResolver>());
@ -200,6 +202,7 @@ INSTANTIATE_TEST_SUITE_P(
});
} // namespace
} // namespace hand_detector
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -21,8 +21,8 @@ package(default_visibility = [
licenses(["notice"])
mediapipe_proto_library(
name = "hand_detector_options_proto",
srcs = ["hand_detector_options.proto"],
name = "hand_detector_graph_options_proto",
srcs = ["hand_detector_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",

View File

@ -21,24 +21,20 @@ import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.handdetector";
option java_outer_classname = "HandDetectorOptionsProto";
option java_outer_classname = "HandDetectorGraphOptionsProto";
message HandDetectorOptions {
message HandDetectorGraphOptions {
extend mediapipe.CalculatorOptions {
optional HandDetectorOptions ext = 464864288;
optional HandDetectorGraphOptions ext = 464864288;
}
// Base options for configuring Task library, such as specifying the TfLite
// model file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1;
// The locale to use for display names specified through the TFLite Model
// Metadata, if any. Defaults to English.
optional string display_names_locale = 2 [default = "en"];
// Minimum confidence value ([0.0, 1.0]) for confidence score to be considered
// successfully detecting a hand in the image.
optional float min_detection_confidence = 3 [default = 0.5];
optional float min_detection_confidence = 2 [default = 0.5];
// The maximum number of hands output by the detector.
optional int32 num_hands = 4;
optional int32 num_hands = 3;
}

View File

@ -19,10 +19,10 @@ package(default_visibility = [
licenses(["notice"])
cc_library(
name = "hand_landmarker_subgraph",
srcs = ["hand_landmarker_subgraph.cc"],
name = "hand_landmarks_detector_graph",
srcs = ["hand_landmarks_detector_graph.cc"],
deps = [
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_subgraph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"//mediapipe/calculators/core:split_vector_calculator",
@ -51,6 +51,7 @@ cc_library(
# TODO: move calculators in modules/hand_landmark/calculators to tasks dir.
"//mediapipe/modules/hand_landmark/calculators:hand_landmarks_to_rect_calculator",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/utils:gate",
"//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
@ -66,3 +67,41 @@ cc_library(
)
# TODO: Enable this test
cc_library(
name = "hand_landmarker_graph",
srcs = ["hand_landmarker_graph.cc"],
deps = [
":hand_landmarks_detector_graph",
"//mediapipe/calculators/core:begin_loop_calculator",
"//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto",
"//mediapipe/calculators/core:end_loop_calculator",
"//mediapipe/calculators/core:gate_calculator",
"//mediapipe/calculators/core:gate_calculator_cc_proto",
"//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/calculators/core:previous_loopback_calculator",
"//mediapipe/calculators/util:collection_has_min_size_calculator",
"//mediapipe/calculators/util:collection_has_min_size_calculator_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/utils:gate",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator",
"//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
],
alwayslink = 1,
)
# TODO: Enable this test

View File

@ -0,0 +1,286 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <type_traits>
#include <utility>
#include <vector>
#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h"
#include "mediapipe/calculators/core/gate_calculator.pb.h"
#include "mediapipe/calculators/util/collection_has_min_size_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/utils/gate.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace hand_landmarker {
namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::utils::DisallowIf;
using ::mediapipe::tasks::vision::hand_detector::proto::
HandDetectorGraphOptions;
using ::mediapipe::tasks::vision::hand_landmarker::proto::
HandLandmarkerGraphOptions;
using ::mediapipe::tasks::vision::hand_landmarker::proto::
HandLandmarksDetectorGraphOptions;
constexpr char kImageTag[] = "IMAGE";
constexpr char kLandmarksTag[] = "LANDMARKS";
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
constexpr char kHandRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME";
constexpr char kHandednessTag[] = "HANDEDNESS";
constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS";
constexpr char kPalmRectsTag[] = "PALM_RECTS";
constexpr char kPreviousLoopbackCalculatorName[] = "PreviousLoopbackCalculator";
struct HandLandmarkerOutputs {
Source<std::vector<NormalizedLandmarkList>> landmark_lists;
Source<std::vector<LandmarkList>> world_landmark_lists;
Source<std::vector<NormalizedRect>> hand_rects_next_frame;
Source<std::vector<ClassificationList>> handednesses;
Source<std::vector<NormalizedRect>> palm_rects;
Source<std::vector<Detection>> palm_detections;
Source<Image> image;
};
} // namespace
// A "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph" performs hand
// landmarks detection. The HandLandmarkerGraph consists of two subgraphs:
// HandDetectorGraph and MultipleHandLandmarksDetectorGraph.
// MultipleHandLandmarksDetectorGraph detects landmarks from bounding boxes
// produced by HandDetectorGraph. HandLandmarkerGraph tracks the landmarks over
// time, and skips the HandDetectorGraph. If the tracking is lost or the detectd
// hands are less than configured max number hands, HandDetectorGraph would be
// triggered to detect hands.
//
// Accepts CPU input images and outputs Landmarks on CPU.
//
// Inputs:
// IMAGE - Image
// Image to perform hand landmarks detection on.
//
// Outputs:
// LANDMARKS: - std::vector<NormalizedLandmarkList>
// Vector of detected hand landmarks.
// WORLD_LANDMARKS - std::vector<LandmarkList>
// Vector of detected hand landmarks in world coordinates.
// HAND_RECT_NEXT_FRAME - std::vector<NormalizedRect>
// Vector of the predicted rects enclosing the same hand RoI for landmark
// detection on the next frame.
// HANDEDNESS - std::vector<ClassificationList>
// Vector of classification of handedness.
// PALM_RECTS - std::vector<NormalizedRect>
// Detected palm bounding boxes in normalized coordinates.
// PALM_DETECTIONS - std::vector<Detection>
// Detected palms with maximum `num_hands` specified in options.
// IMAGE - Image
// The input image that the hand landmarker runs on and has the pixel data
// stored on the target storage (CPU vs GPU).
//
// Example:
// node {
// calculator: "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph"
// input_stream: "IMAGE:image_in"
// output_stream: "LANDMARKS:hand_landmarks"
// output_stream: "WORLD_LANDMARKS:world_hand_landmarks"
// output_stream: "HAND_RECT_NEXT_FRAME:hand_rect_next_frame"
// output_stream: "HANDEDNESS:handedness"
// output_stream: "PALM_RECTS:palm_rects"
// output_stream: "PALM_DETECTIONS:palm_detections"
// output_stream: "IMAGE:image_out"
// options {
// [mediapipe.tasks.hand_landmarker.proto.HandLandmarkerGraphOptions.ext] {
// base_options {
// model_asset {
// file_name: "hand_landmarker.task"
// }
// }
// hand_detector_graph_options {
// base_options {
// model_asset {
// file_name: "palm_detection.tflite"
// }
// }
// min_detection_confidence: 0.5
// num_hands: 2
// }
// hand_landmarks_detector_graph_options {
// base_options {
// model_asset {
// file_name: "hand_landmark_lite.tflite"
// }
// }
// min_detection_confidence: 0.5
// }
// }
// }
// }
class HandLandmarkerGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
Graph graph;
ASSIGN_OR_RETURN(
auto hand_landmarker_outputs,
BuildHandLandmarkerGraph(sc->Options<HandLandmarkerGraphOptions>(),
graph[Input<Image>(kImageTag)], graph));
hand_landmarker_outputs.landmark_lists >>
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
hand_landmarker_outputs.world_landmark_lists >>
graph[Output<std::vector<LandmarkList>>(kWorldLandmarksTag)];
hand_landmarker_outputs.hand_rects_next_frame >>
graph[Output<std::vector<NormalizedRect>>(kHandRectNextFrameTag)];
hand_landmarker_outputs.handednesses >>
graph[Output<std::vector<ClassificationList>>(kHandednessTag)];
hand_landmarker_outputs.palm_rects >>
graph[Output<std::vector<NormalizedRect>>(kPalmRectsTag)];
hand_landmarker_outputs.palm_detections >>
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
hand_landmarker_outputs.image >> graph[Output<Image>(kImageTag)];
// TODO remove when support is fixed.
// As mediapipe GraphBuilder currently doesn't support configuring
// InputStreamInfo, modifying the CalculatorGraphConfig proto directly.
CalculatorGraphConfig config = graph.GetConfig();
for (int i = 0; i < config.node_size(); ++i) {
if (config.node(i).calculator() == kPreviousLoopbackCalculatorName) {
auto* info = config.mutable_node(i)->add_input_stream_info();
info->set_tag_index("LOOP");
info->set_back_edge(true);
break;
}
}
return config;
}
private:
// Adds a mediapipe hand landmark detection graph into the provided
// builder::Graph instance.
//
// tasks_options: the mediapipe tasks module HandLandmarkerGraphOptions.
// image_in: (mediapipe::Image) stream to run hand landmark detection on.
// graph: the mediapipe graph instance to be updated.
absl::StatusOr<HandLandmarkerOutputs> BuildHandLandmarkerGraph(
const HandLandmarkerGraphOptions& tasks_options, Source<Image> image_in,
Graph& graph) {
const int max_num_hands =
tasks_options.hand_detector_graph_options().num_hands();
auto& previous_loopback = graph.AddNode(kPreviousLoopbackCalculatorName);
image_in >> previous_loopback.In("MAIN");
auto prev_hand_rects_from_landmarks =
previous_loopback[Output<std::vector<NormalizedRect>>("PREV_LOOP")];
auto& min_size_node =
graph.AddNode("NormalizedRectVectorHasMinSizeCalculator");
prev_hand_rects_from_landmarks >> min_size_node.In("ITERABLE");
min_size_node.GetOptions<CollectionHasMinSizeCalculatorOptions>()
.set_min_size(max_num_hands);
auto has_enough_hands = min_size_node.Out("").Cast<bool>();
auto image_for_hand_detector =
DisallowIf(image_in, has_enough_hands, graph);
auto& hand_detector =
graph.AddNode("mediapipe.tasks.vision.hand_detector.HandDetectorGraph");
hand_detector.GetOptions<HandDetectorGraphOptions>().CopyFrom(
tasks_options.hand_detector_graph_options());
image_for_hand_detector >> hand_detector.In("IMAGE");
auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS");
auto& hand_association = graph.AddNode("HandAssociationCalculator");
hand_association.GetOptions<HandAssociationCalculatorOptions>()
.set_min_similarity_threshold(tasks_options.min_tracking_confidence());
prev_hand_rects_from_landmarks >>
hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][0];
hand_rects_from_hand_detector >>
hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][1];
auto hand_rects = hand_association.Out("");
auto& clip_hand_rects =
graph.AddNode("ClipNormalizedRectVectorSizeCalculator");
clip_hand_rects.GetOptions<ClipVectorSizeCalculatorOptions>()
.set_max_vec_size(max_num_hands);
hand_rects >> clip_hand_rects.In("");
auto clipped_hand_rects = clip_hand_rects.Out("");
auto& hand_landmarks_detector_graph = graph.AddNode(
"mediapipe.tasks.vision.hand_landmarker."
"MultipleHandLandmarksDetectorGraph");
hand_landmarks_detector_graph
.GetOptions<HandLandmarksDetectorGraphOptions>()
.CopyFrom(tasks_options.hand_landmarks_detector_graph_options());
image_in >> hand_landmarks_detector_graph.In("IMAGE");
clipped_hand_rects >> hand_landmarks_detector_graph.In("HAND_RECT");
auto hand_rects_for_next_frame =
hand_landmarks_detector_graph[Output<std::vector<NormalizedRect>>(
kHandRectNextFrameTag)];
// Back edge.
hand_rects_for_next_frame >> previous_loopback.In("LOOP");
// TODO: Replace PassThroughCalculator with a calculator that
// converts the pixel data to be stored on the target storage (CPU vs GPU).
auto& pass_through = graph.AddNode("PassThroughCalculator");
image_in >> pass_through.In("");
return {{
/* landmark_lists= */ hand_landmarks_detector_graph
[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)],
/* world_landmark_lists= */
hand_landmarks_detector_graph[Output<std::vector<LandmarkList>>(
kWorldLandmarksTag)],
/* hand_rects_next_frame= */ hand_rects_for_next_frame,
hand_landmarks_detector_graph[Output<std::vector<ClassificationList>>(
kHandednessTag)],
/* palm_rects= */
hand_detector[Output<std::vector<NormalizedRect>>(kPalmRectsTag)],
/* palm_detections */
hand_detector[Output<std::vector<Detection>>(kPalmDetectionsTag)],
/* image */
pass_through[Output<Image>("")],
}};
}
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::hand_landmarker::HandLandmarkerGraph);
} // namespace hand_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,167 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <iostream>
#include <memory>
#include <string>
#include <utility>
#include "absl/flags/flag.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace hand_landmarker {
namespace {
using ::file::Defaults;
using ::file::GetTextProto;
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::core::TaskRunner;
using ::mediapipe::tasks::vision::hand_landmarker::proto::
HandLandmarkerGraphOptions;
using ::testing::EqualsProto;
using ::testing::proto::Approximately;
using ::testing::proto::Partially;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kPalmDetectionModel[] = "palm_detection_full.tflite";
constexpr char kHandLandmarkerFullModel[] = "hand_landmark_full.tflite";
constexpr char kLeftHandsImage[] = "left_hands.jpg";
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageName[] = "image_in";
constexpr char kLandmarksTag[] = "LANDMARKS";
constexpr char kLandmarksName[] = "landmarks";
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
constexpr char kWorldLandmarksName[] = "world_landmarks";
constexpr char kHandRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME";
constexpr char kHandRectNextFrameName[] = "hand_rect_next_frame";
constexpr char kHandednessTag[] = "HANDEDNESS";
constexpr char kHandednessName[] = "handedness";
// Expected hand landmarks positions, in text proto format.
constexpr char kExpectedLeftUpHandLandmarksFilename[] =
"expected_left_up_hand_landmarks.prototxt";
constexpr char kExpectedLeftDownHandLandmarksFilename[] =
"expected_left_down_hand_landmarks.prototxt";
constexpr float kFullModelFractionDiff = 0.03; // percentage
constexpr float kAbsMargin = 0.03;
constexpr int kMaxNumHands = 2;
constexpr float kMinTrackingConfidence = 0.5;
NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) {
NormalizedLandmarkList expected_landmark_list;
MP_EXPECT_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, filename),
&expected_landmark_list, Defaults()));
return expected_landmark_list;
}
// Helper function to create a Hand Landmarker TaskRunner.
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner() {
Graph graph;
auto& hand_landmarker_graph = graph.AddNode(
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph");
auto& options =
hand_landmarker_graph.GetOptions<HandLandmarkerGraphOptions>();
options.mutable_hand_detector_graph_options()
->mutable_base_options()
->mutable_model_asset()
->set_file_name(JoinPath("./", kTestDataDirectory, kPalmDetectionModel));
options.mutable_hand_detector_graph_options()->mutable_base_options();
options.mutable_hand_detector_graph_options()->set_num_hands(kMaxNumHands);
options.mutable_hand_landmarks_detector_graph_options()
->mutable_base_options()
->mutable_model_asset()
->set_file_name(
JoinPath("./", kTestDataDirectory, kHandLandmarkerFullModel));
options.set_min_tracking_confidence(kMinTrackingConfidence);
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
hand_landmarker_graph.In(kImageTag);
hand_landmarker_graph.Out(kLandmarksTag).SetName(kLandmarksName) >>
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
hand_landmarker_graph.Out(kWorldLandmarksTag).SetName(kWorldLandmarksName) >>
graph[Output<std::vector<LandmarkList>>(kWorldLandmarksTag)];
hand_landmarker_graph.Out(kHandednessTag).SetName(kHandednessName) >>
graph[Output<std::vector<ClassificationList>>(kHandednessTag)];
hand_landmarker_graph.Out(kHandRectNextFrameTag)
.SetName(kHandRectNextFrameName) >>
graph[Output<std::vector<NormalizedRect>>(kHandRectNextFrameTag)];
return TaskRunner::Create(
graph.GetConfig(), absl::make_unique<core::MediaPipeBuiltinOpResolver>());
}
class HandLandmarkerTest : public tflite_shims::testing::Test {};
TEST_F(HandLandmarkerTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kLeftHandsImage)));
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner());
auto output_packets =
task_runner->Process({{kImageName, MakePacket<Image>(std::move(image))}});
const auto& landmarks = (*output_packets)[kLandmarksName]
.Get<std::vector<NormalizedLandmarkList>>();
ASSERT_EQ(landmarks.size(), kMaxNumHands);
std::vector<NormalizedLandmarkList> expected_landmarks = {
GetExpectedLandmarkList(kExpectedLeftUpHandLandmarksFilename),
GetExpectedLandmarkList(kExpectedLeftDownHandLandmarksFilename)};
EXPECT_THAT(landmarks[0],
Approximately(Partially(EqualsProto(expected_landmarks[0])),
/*margin=*/kAbsMargin,
/*fraction=*/kFullModelFractionDiff));
EXPECT_THAT(landmarks[1],
Approximately(Partially(EqualsProto(expected_landmarks[1])),
/*margin=*/kAbsMargin,
/*fraction=*/kFullModelFractionDiff));
}
} // namespace
} // namespace hand_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -34,12 +34,13 @@ limitations under the License.
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/utils/gate.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
#include "mediapipe/util/label_map.pb.h"
@ -48,6 +49,7 @@ limitations under the License.
namespace mediapipe {
namespace tasks {
namespace vision {
namespace hand_landmarker {
namespace {
@ -55,9 +57,10 @@ using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::utils::AllowIf;
using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::vision::hand_landmarker::proto::
HandLandmarkerSubgraphOptions;
HandLandmarksDetectorGraphOptions;
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
constexpr char kImageTag[] = "IMAGE";
@ -82,7 +85,6 @@ struct SingleHandLandmarkerOutputs {
Source<bool> hand_presence;
Source<float> hand_presence_score;
Source<ClassificationList> handedness;
Source<std::pair<int, int>> image_size;
};
struct HandLandmarkerOutputs {
@ -92,10 +94,10 @@ struct HandLandmarkerOutputs {
Source<std::vector<bool>> presences;
Source<std::vector<float>> presence_scores;
Source<std::vector<ClassificationList>> handednesses;
Source<std::pair<int, int>> image_size;
};
absl::Status SanityCheckOptions(const HandLandmarkerSubgraphOptions& options) {
absl::Status SanityCheckOptions(
const HandLandmarksDetectorGraphOptions& options) {
if (options.min_detection_confidence() < 0 ||
options.min_detection_confidence() > 1) {
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
@ -182,8 +184,8 @@ void ConfigureHandRectTransformationCalculator(
} // namespace
// A "mediapipe.tasks.vision.SingleHandLandmarkerSubgraph" performs hand
// landmark detection.
// A "mediapipe.tasks.vision.hand_landmarker.SingleHandLandmarksDetectorGraph"
// performs hand landmarks detection.
// - Accepts CPU input images and outputs Landmark on CPU.
//
// Inputs:
@ -208,12 +210,11 @@ void ConfigureHandRectTransformationCalculator(
// Float value indicates the probability that the hand is present.
// HANDEDNESS - ClassificationList
// Classification of handedness.
// IMAGE_SIZE - std::vector<int, int>
// The size of input image.
//
// Example:
// node {
// calculator: "mediapipe.tasks.vision.SingleHandLandmarkerSubgraph"
// calculator:
// "mediapipe.tasks.vision.hand_landmarker.SingleHandLandmarksDetectorGraph"
// input_stream: "IMAGE:input_image"
// input_stream: "HAND_RECT:hand_rect"
// output_stream: "LANDMARKS:hand_landmarks"
@ -221,10 +222,8 @@ void ConfigureHandRectTransformationCalculator(
// output_stream: "HAND_RECT_NEXT_FRAME:hand_rect_next_frame"
// output_stream: "PRESENCE:hand_presence"
// output_stream: "PRESENCE_SCORE:hand_presence_score"
// output_stream: "HANDEDNESS:handedness"
// output_stream: "IMAGE_SIZE:image_size"
// options {
// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarkerSubgraphOptions.ext]
// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarksDetectorGraphOptions.ext]
// {
// base_options {
// model_asset {
@ -235,16 +234,17 @@ void ConfigureHandRectTransformationCalculator(
// }
// }
// }
class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph {
class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
ASSIGN_OR_RETURN(const auto* model_resources,
CreateModelResources<HandLandmarkerSubgraphOptions>(sc));
ASSIGN_OR_RETURN(
const auto* model_resources,
CreateModelResources<HandLandmarksDetectorGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(auto hand_landmark_detection_outs,
BuildSingleHandLandmarkerSubgraph(
sc->Options<HandLandmarkerSubgraphOptions>(),
BuildSingleHandLandmarksDetectorGraph(
sc->Options<HandLandmarksDetectorGraphOptions>(),
*model_resources, graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>(kHandRectTag)], graph));
hand_landmark_detection_outs.hand_landmarks >>
@ -259,8 +259,6 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph {
graph[Output<float>(kPresenceScoreTag)];
hand_landmark_detection_outs.handedness >>
graph[Output<ClassificationList>(kHandednessTag)];
hand_landmark_detection_outs.image_size >>
graph[Output<std::pair<int, int>>(kImageSizeTag)];
return graph.GetConfig();
}
@ -269,14 +267,16 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph {
// Adds a mediapipe hand landmark detection graph into the provided
// builder::Graph instance.
//
// subgraph_options: the mediapipe tasks module HandLandmarkerSubgraphOptions.
// model_resources: the ModelSources object initialized from a hand landmark
// subgraph_options: the mediapipe tasks module
// HandLandmarksDetectorGraphOptions. model_resources: the ModelSources object
// initialized from a hand landmark
// detection model file with model metadata.
// image_in: (mediapipe::Image) stream to run hand landmark detection on.
// rect: (NormalizedRect) stream to run on the RoI of image.
// graph: the mediapipe graph instance to be updated.
absl::StatusOr<SingleHandLandmarkerOutputs> BuildSingleHandLandmarkerSubgraph(
const HandLandmarkerSubgraphOptions& subgraph_options,
absl::StatusOr<SingleHandLandmarkerOutputs>
BuildSingleHandLandmarksDetectorGraph(
const HandLandmarksDetectorGraphOptions& subgraph_options,
const core::ModelResources& model_resources, Source<Image> image_in,
Source<NormalizedRect> hand_rect, Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options));
@ -332,18 +332,7 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph {
// score of hand presence.
auto& tensors_to_hand_presence = graph.AddNode("TensorsToFloatsCalculator");
hand_flag_tensors >> tensors_to_hand_presence.In("TENSORS");
// Converts the handedness tensor into a float that represents the
// classification score of handedness.
auto& tensors_to_handedness =
graph.AddNode("TensorsToClassificationCalculator");
ConfigureTensorsToHandednessCalculator(
&tensors_to_handedness.GetOptions<
mediapipe::TensorsToClassificationCalculatorOptions>());
handedness_tensors >> tensors_to_handedness.In("TENSORS");
auto hand_presence_score = tensors_to_hand_presence[Output<float>("FLOAT")];
auto handedness =
tensors_to_handedness[Output<ClassificationList>("CLASSIFICATIONS")];
// Applies a threshold to the confidence score to determine whether a
// hand is present.
@ -354,6 +343,18 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph {
hand_presence_score >> hand_presence_thresholding.In("FLOAT");
auto hand_presence = hand_presence_thresholding[Output<bool>("FLAG")];
// Converts the handedness tensor into a float that represents the
// classification score of handedness.
auto& tensors_to_handedness =
graph.AddNode("TensorsToClassificationCalculator");
ConfigureTensorsToHandednessCalculator(
&tensors_to_handedness.GetOptions<
mediapipe::TensorsToClassificationCalculatorOptions>());
handedness_tensors >> tensors_to_handedness.In("TENSORS");
auto handedness = AllowIf(
tensors_to_handedness[Output<ClassificationList>("CLASSIFICATIONS")],
hand_presence, graph);
// Adjusts landmarks (already normalized to [0.f, 1.f]) on the letterboxed
// hand image (after image transformation with the FIT scale mode) to the
// corresponding locations on the same image with the letterbox removed
@ -371,8 +372,9 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph {
landmark_letterbox_removal.Out("LANDMARKS") >>
landmark_projection.In("NORM_LANDMARKS");
hand_rect >> landmark_projection.In("NORM_RECT");
auto projected_landmarks =
landmark_projection[Output<NormalizedLandmarkList>("NORM_LANDMARKS")];
auto projected_landmarks = AllowIf(
landmark_projection[Output<NormalizedLandmarkList>("NORM_LANDMARKS")],
hand_presence, graph);
// Projects the world landmarks from the cropped hand image to the
// corresponding locations on the full image before cropping (input to the
@ -383,7 +385,8 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph {
world_landmark_projection.In("LANDMARKS");
hand_rect >> world_landmark_projection.In("NORM_RECT");
auto projected_world_landmarks =
world_landmark_projection[Output<LandmarkList>("LANDMARKS")];
AllowIf(world_landmark_projection[Output<LandmarkList>("LANDMARKS")],
hand_presence, graph);
// Converts the hand landmarks into a rectangle (normalized by image size)
// that encloses the hand.
@ -403,7 +406,8 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph {
hand_landmarks_to_rect.Out("NORM_RECT") >>
hand_rect_transformation.In("NORM_RECT");
auto hand_rect_next_frame =
hand_rect_transformation[Output<NormalizedRect>("")];
AllowIf(hand_rect_transformation[Output<NormalizedRect>("")],
hand_presence, graph);
return {{
/* hand_landmarks= */ projected_landmarks,
@ -412,16 +416,17 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph {
/* hand_presence= */ hand_presence,
/* hand_presence_score= */ hand_presence_score,
/* handedness= */ handedness,
/* image_size= */ image_size,
}};
}
};
// clang-format off
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::SingleHandLandmarkerSubgraph);
::mediapipe::tasks::vision::hand_landmarker::SingleHandLandmarksDetectorGraph); // NOLINT
// clang-format on
// A "mediapipe.tasks.vision.HandLandmarkerSubgraph" performs multi
// hand landmark detection.
// A "mediapipe.tasks.vision.hand_landmarker.MultipleHandLandmarksDetectorGraph"
// performs multi hand landmark detection.
// - Accepts CPU input image and a vector of hand rect RoIs to detect the
// multiple hands landmarks enclosed by the RoIs. Output vectors of
// hand landmarks related results, where each element in the vectors
@ -449,12 +454,11 @@ REGISTER_MEDIAPIPE_GRAPH(
// Vector of float value indicates the probability that the hand is present.
// HANDEDNESS - std::vector<ClassificationList>
// Vector of classification of handedness.
// IMAGE_SIZE - std::vector<int, int>
// The size of input image.
//
// Example:
// node {
// calculator: "mediapipe.tasks.vision.HandLandmarkerSubgraph"
// calculator:
// "mediapipe.tasks.vision.hand_landmarker.MultipleHandLandmarksDetectorGraph"
// input_stream: "IMAGE:input_image"
// input_stream: "HAND_RECT:hand_rect"
// output_stream: "LANDMARKS:hand_landmarks"
@ -463,9 +467,8 @@ REGISTER_MEDIAPIPE_GRAPH(
// output_stream: "PRESENCE:hand_presence"
// output_stream: "PRESENCE_SCORE:hand_presence_score"
// output_stream: "HANDEDNESS:handedness"
// output_stream: "IMAGE_SIZE:image_size"
// options {
// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarkerSubgraphOptions.ext]
// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarksDetectorGraphOptions.ext]
// {
// base_options {
// model_asset {
@ -476,15 +479,15 @@ REGISTER_MEDIAPIPE_GRAPH(
// }
// }
// }
class HandLandmarkerSubgraph : public core::ModelTaskGraph {
class MultipleHandLandmarksDetectorGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
Graph graph;
ASSIGN_OR_RETURN(
auto hand_landmark_detection_outputs,
BuildHandLandmarkerSubgraph(
sc->Options<HandLandmarkerSubgraphOptions>(),
BuildHandLandmarksDetectorGraph(
sc->Options<HandLandmarksDetectorGraphOptions>(),
graph[Input<Image>(kImageTag)],
graph[Input<std::vector<NormalizedRect>>(kHandRectTag)], graph));
hand_landmark_detection_outputs.landmark_lists >>
@ -499,21 +502,20 @@ class HandLandmarkerSubgraph : public core::ModelTaskGraph {
graph[Output<std::vector<float>>(kPresenceScoreTag)];
hand_landmark_detection_outputs.handednesses >>
graph[Output<std::vector<ClassificationList>>(kHandednessTag)];
hand_landmark_detection_outputs.image_size >>
graph[Output<std::pair<int, int>>(kImageSizeTag)];
return graph.GetConfig();
}
private:
absl::StatusOr<HandLandmarkerOutputs> BuildHandLandmarkerSubgraph(
const HandLandmarkerSubgraphOptions& subgraph_options,
absl::StatusOr<HandLandmarkerOutputs> BuildHandLandmarksDetectorGraph(
const HandLandmarksDetectorGraphOptions& subgraph_options,
Source<Image> image_in,
Source<std::vector<NormalizedRect>> multi_hand_rects, Graph& graph) {
auto& hand_landmark_subgraph =
graph.AddNode("mediapipe.tasks.vision.SingleHandLandmarkerSubgraph");
hand_landmark_subgraph.GetOptions<HandLandmarkerSubgraphOptions>().CopyFrom(
subgraph_options);
auto& hand_landmark_subgraph = graph.AddNode(
"mediapipe.tasks.vision.hand_landmarker."
"SingleHandLandmarksDetectorGraph");
hand_landmark_subgraph.GetOptions<HandLandmarksDetectorGraphOptions>()
.CopyFrom(subgraph_options);
auto& begin_loop_multi_hand_rects =
graph.AddNode("BeginLoopNormalizedRectCalculator");
@ -533,8 +535,6 @@ class HandLandmarkerSubgraph : public core::ModelTaskGraph {
hand_landmark_subgraph.Out("HAND_RECT_NEXT_FRAME");
auto landmarks = hand_landmark_subgraph.Out("LANDMARKS");
auto world_landmarks = hand_landmark_subgraph.Out("WORLD_LANDMARKS");
auto image_size =
hand_landmark_subgraph[Output<std::pair<int, int>>("IMAGE_SIZE")];
auto& end_loop_handedness =
graph.AddNode("EndLoopClassificationListCalculator");
@ -585,13 +585,16 @@ class HandLandmarkerSubgraph : public core::ModelTaskGraph {
/* presences= */ presences,
/* presence_scores= */ presence_scores,
/* handednesses= */ handednesses,
/* image_size= */ image_size,
}};
}
};
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::HandLandmarkerSubgraph);
// clang-format off
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::hand_landmarker::MultipleHandLandmarksDetectorGraph); // NOLINT
// clang-format on
} // namespace hand_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -39,12 +39,13 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace hand_landmarker {
namespace {
using ::file::Defaults;
@ -57,7 +58,7 @@ using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::core::TaskRunner;
using ::mediapipe::tasks::vision::DecodeImageFromFile;
using ::mediapipe::tasks::vision::hand_landmarker::proto::
HandLandmarkerSubgraphOptions;
HandLandmarksDetectorGraphOptions;
using ::testing::ElementsAreArray;
using ::testing::EqualsProto;
using ::testing::Pointwise;
@ -112,13 +113,14 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateSingleHandTaskRunner(
absl::string_view model_name) {
Graph graph;
auto& hand_landmark_detection =
graph.AddNode("mediapipe.tasks.vision.SingleHandLandmarkerSubgraph");
auto& hand_landmark_detection = graph.AddNode(
"mediapipe.tasks.vision.hand_landmarker."
"SingleHandLandmarksDetectorGraph");
auto options = std::make_unique<HandLandmarkerSubgraphOptions>();
auto options = std::make_unique<HandLandmarksDetectorGraphOptions>();
options->mutable_base_options()->mutable_model_asset()->set_file_name(
JoinPath("./", kTestDataDirectory, model_name));
hand_landmark_detection.GetOptions<HandLandmarkerSubgraphOptions>().Swap(
hand_landmark_detection.GetOptions<HandLandmarksDetectorGraphOptions>().Swap(
options.get());
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
@ -151,13 +153,14 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateMultiHandTaskRunner(
absl::string_view model_name) {
Graph graph;
auto& multi_hand_landmark_detection =
graph.AddNode("mediapipe.tasks.vision.HandLandmarkerSubgraph");
auto& multi_hand_landmark_detection = graph.AddNode(
"mediapipe.tasks.vision.hand_landmarker."
"MultipleHandLandmarksDetectorGraph");
auto options = std::make_unique<HandLandmarkerSubgraphOptions>();
auto options = std::make_unique<HandLandmarksDetectorGraphOptions>();
options->mutable_base_options()->mutable_model_asset()->set_file_name(
JoinPath("./", kTestDataDirectory, model_name));
multi_hand_landmark_detection.GetOptions<HandLandmarkerSubgraphOptions>()
multi_hand_landmark_detection.GetOptions<HandLandmarksDetectorGraphOptions>()
.Swap(options.get());
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
@ -462,6 +465,7 @@ INSTANTIATE_TEST_SUITE_P(
});
} // namespace
} // namespace hand_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -21,8 +21,8 @@ package(default_visibility = [
licenses(["notice"])
mediapipe_proto_library(
name = "hand_landmarker_subgraph_options_proto",
srcs = ["hand_landmarker_subgraph_options.proto"],
name = "hand_landmarks_detector_graph_options_proto",
srcs = ["hand_landmarks_detector_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
@ -31,13 +31,13 @@ mediapipe_proto_library(
)
mediapipe_proto_library(
name = "hand_landmarker_options_proto",
srcs = ["hand_landmarker_options.proto"],
name = "hand_landmarker_graph_options_proto",
srcs = ["hand_landmarker_graph_options.proto"],
deps = [
":hand_landmarker_subgraph_options_proto",
":hand_landmarks_detector_graph_options_proto",
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_options_proto",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_proto",
],
)

View File

@ -19,22 +19,26 @@ package mediapipe.tasks.vision.hand_landmarker.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.proto";
import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto";
import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto";
import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto";
message HandLandmarkerOptions {
message HandLandmarkerGraphOptions {
extend mediapipe.CalculatorOptions {
optional HandLandmarkerOptions ext = 462713202;
optional HandLandmarkerGraphOptions ext = 462713202;
}
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
// model file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1;
// The locale to use for display names specified through the TFLite Model
// Metadata, if any. Defaults to English.
optional string display_names_locale = 2 [default = "en"];
// Options for hand detector graph.
optional hand_detector.proto.HandDetectorGraphOptions
hand_detector_graph_options = 2;
optional hand_detector.proto.HandDetectorOptions hand_detector_options = 3;
// Options for hand landmarker subgraph.
optional HandLandmarksDetectorGraphOptions
hand_landmarks_detector_graph_options = 3;
optional HandLandmarkerSubgraphOptions hand_landmarker_subgraph_options = 4;
// Minimum confidence for hand landmarks tracking to be considered
// successfully.
optional float min_tracking_confidence = 4 [default = 0.5];
}

View File

@ -20,19 +20,15 @@ package mediapipe.tasks.vision.hand_landmarker.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
message HandLandmarkerSubgraphOptions {
message HandLandmarksDetectorGraphOptions {
extend mediapipe.CalculatorOptions {
optional HandLandmarkerSubgraphOptions ext = 474472470;
optional HandLandmarksDetectorGraphOptions ext = 474472470;
}
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
// model file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1;
// The locale to use for display names specified through the TFLite Model
// Metadata, if any. Defaults to English.
optional string display_names_locale = 2 [default = "en"];
// Minimum confidence value ([0.0, 1.0]) for hand presence score to be
// considered successfully detecting a hand in the image.
optional float min_detection_confidence = 3 [default = 0.5];
optional float min_detection_confidence = 2 [default = 0.5];
}

View File

@ -0,0 +1,21 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
filegroup(
name = "resource_files",
srcs = glob(["res/**"]),
visibility = ["//mediapipe/tasks/examples/android:__subpackages__"],
)

View File

@ -0,0 +1,37 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.examples.objectdetector">
<uses-sdk
android:minSdkVersion="28"
android:targetSdkVersion="30" />
<!-- For loading images from gallery -->
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
<!-- For using the camera -->
<uses-permission android:name="android.permission.CAMERA" />
<uses-feature android:name="android.hardware.camera" />
<!-- For logging solution events -->
<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
<application
android:allowBackup="true"
android:icon="@mipmap/ic_launcher"
android:label="MediaPipe Tasks Object Detector"
android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true"
android:theme="@style/AppTheme"
android:exported="false">
<activity android:name=".MainActivity"
android:screenOrientation="portrait"
android:exported="true">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
</manifest>

View File

@ -0,0 +1,48 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//visibility:private"])
android_binary(
name = "objectdetector",
srcs = glob(["**/*.java"]),
assets = [
"//mediapipe/tasks/testdata/vision:test_models",
],
assets_dir = "",
custom_package = "com.google.mediapipe.tasks.examples.objectdetector",
manifest = "AndroidManifest.xml",
manifest_values = {
"applicationId": "com.google.mediapipe.tasks.examples.objectdetector",
},
multidex = "native",
resource_files = ["//mediapipe/tasks/examples/android:resource_files"],
deps = [
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector",
"//third_party:androidx_appcompat",
"//third_party:androidx_constraint_layout",
"//third_party:opencv",
"@maven//:androidx_activity_activity",
"@maven//:androidx_concurrent_concurrent_futures",
"@maven//:androidx_exifinterface_exifinterface",
"@maven//:androidx_fragment_fragment",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,236 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.examples.objectdetector;
import android.content.Intent;
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.media.MediaMetadataRetriever;
import android.os.Bundle;
import android.provider.MediaStore;
import androidx.appcompat.app.AppCompatActivity;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.FrameLayout;
import androidx.activity.result.ActivityResultLauncher;
import androidx.activity.result.contract.ActivityResultContracts;
import androidx.exifinterface.media.ExifInterface;
// ContentResolver dependency
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.Image;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult;
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector;
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector.ObjectDetectorOptions;
import java.io.IOException;
import java.io.InputStream;
/** Main activity of MediaPipe Task Object Detector reference app. */
public class MainActivity extends AppCompatActivity {
private static final String TAG = "MainActivity";
private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
private ObjectDetector objectDetector;
private enum InputSource {
UNKNOWN,
IMAGE,
VIDEO,
CAMERA,
}
private InputSource inputSource = InputSource.UNKNOWN;
// Image mode demo component.
private ActivityResultLauncher<Intent> imageGetter;
// Video mode demo component.
private ActivityResultLauncher<Intent> videoGetter;
private ObjectDetectionResultImageView imageView;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
setupImageModeDemo();
setupVideoModeDemo();
// TODO: Adds live camera demo.
}
/** Sets up the image mode demo. */
private void setupImageModeDemo() {
imageView = new ObjectDetectionResultImageView(this);
// The Intent to access gallery and read images as bitmap.
imageGetter =
registerForActivityResult(
new ActivityResultContracts.StartActivityForResult(),
result -> {
Intent resultIntent = result.getData();
if (resultIntent != null) {
if (result.getResultCode() == RESULT_OK) {
Bitmap bitmap = null;
try {
bitmap =
downscaleBitmap(
MediaStore.Images.Media.getBitmap(
this.getContentResolver(), resultIntent.getData()));
} catch (IOException e) {
Log.e(TAG, "Bitmap reading error:" + e);
}
try {
InputStream imageData =
this.getContentResolver().openInputStream(resultIntent.getData());
bitmap = rotateBitmap(bitmap, imageData);
} catch (IOException e) {
Log.e(TAG, "Bitmap rotation error:" + e);
}
if (bitmap != null) {
Image image = new BitmapImageBuilder(bitmap).build();
ObjectDetectionResult detectionResult = objectDetector.detect(image);
imageView.setData(image, detectionResult);
runOnUiThread(() -> imageView.update());
}
}
}
});
Button loadImageButton = findViewById(R.id.button_load_picture);
loadImageButton.setOnClickListener(
v -> {
if (inputSource != InputSource.IMAGE) {
createObjectDetector(RunningMode.IMAGE);
this.inputSource = InputSource.IMAGE;
updateLayout();
}
// Reads images from gallery.
Intent pickImageIntent = new Intent(Intent.ACTION_PICK);
pickImageIntent.setDataAndType(MediaStore.Images.Media.INTERNAL_CONTENT_URI, "image/*");
imageGetter.launch(pickImageIntent);
});
}
/** Sets up the video mode demo. */
private void setupVideoModeDemo() {
imageView = new ObjectDetectionResultImageView(this);
// The Intent to access gallery and read a video file.
videoGetter =
registerForActivityResult(
new ActivityResultContracts.StartActivityForResult(),
result -> {
Intent resultIntent = result.getData();
if (resultIntent != null) {
if (result.getResultCode() == RESULT_OK) {
MediaMetadataRetriever metaRetriever = new MediaMetadataRetriever();
metaRetriever.setDataSource(this, resultIntent.getData());
long duration =
Long.parseLong(
metaRetriever.extractMetadata(
MediaMetadataRetriever.METADATA_KEY_DURATION));
int numFrames =
Integer.parseInt(
metaRetriever.extractMetadata(
MediaMetadataRetriever.METADATA_KEY_VIDEO_FRAME_COUNT));
long frameIntervalMs = duration / numFrames;
for (int i = 0; i < numFrames; ++i) {
Image image = new BitmapImageBuilder(metaRetriever.getFrameAtIndex(i)).build();
ObjectDetectionResult detectionResult =
objectDetector.detectForVideo(image, frameIntervalMs * i);
// Currently only annotates the detection result on the first video frame and
// display it to verify the correctness.
// TODO: Annotates the detection result on every frame, save the
// annotated frames as a video file, and play back the video afterwards.
if (i == 0) {
imageView.setData(image, detectionResult);
runOnUiThread(() -> imageView.update());
}
}
}
}
});
Button loadVideoButton = findViewById(R.id.button_load_video);
loadVideoButton.setOnClickListener(
v -> {
createObjectDetector(RunningMode.VIDEO);
updateLayout();
this.inputSource = InputSource.VIDEO;
// Reads a video from gallery.
Intent pickVideoIntent = new Intent(Intent.ACTION_PICK);
pickVideoIntent.setDataAndType(MediaStore.Video.Media.INTERNAL_CONTENT_URI, "video/*");
videoGetter.launch(pickVideoIntent);
});
}
private void createObjectDetector(RunningMode mode) {
if (objectDetector != null) {
objectDetector.close();
}
// Initializes a new MediaPipe ObjectDetector instance
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setScoreThreshold(0.5f)
.setMaxResults(5)
.setRunningMode(mode)
.build();
objectDetector = ObjectDetector.createFromOptions(this, options);
}
private void updateLayout() {
// Updates the preview layout.
FrameLayout frameLayout = findViewById(R.id.preview_display_layout);
frameLayout.removeAllViewsInLayout();
imageView.setImageDrawable(null);
frameLayout.addView(imageView);
imageView.setVisibility(View.VISIBLE);
}
private Bitmap downscaleBitmap(Bitmap originalBitmap) {
double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight();
int width = imageView.getWidth();
int height = imageView.getHeight();
if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) {
width = (int) (height * aspectRatio);
} else {
height = (int) (width / aspectRatio);
}
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
}
private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException {
int orientation =
new ExifInterface(imageData)
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
if (orientation == ExifInterface.ORIENTATION_NORMAL) {
return inputBitmap;
}
Matrix matrix = new Matrix();
switch (orientation) {
case ExifInterface.ORIENTATION_ROTATE_90:
matrix.postRotate(90);
break;
case ExifInterface.ORIENTATION_ROTATE_180:
matrix.postRotate(180);
break;
case ExifInterface.ORIENTATION_ROTATE_270:
matrix.postRotate(270);
break;
default:
matrix.postRotate(0);
}
return Bitmap.createBitmap(
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
}
}

View File

@ -0,0 +1,77 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.examples.objectdetector;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Matrix;
import android.graphics.Paint;
import androidx.appcompat.widget.AppCompatImageView;
import com.google.mediapipe.framework.image.BitmapExtractor;
import com.google.mediapipe.framework.image.Image;
import com.google.mediapipe.tasks.components.containers.Detection;
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult;
/** An ImageView implementation for displaying {@link ObjectDetectionResult}. */
public class ObjectDetectionResultImageView extends AppCompatImageView {
private static final String TAG = "ObjectDetectionResultImageView";
private static final int BBOX_COLOR = Color.GREEN;
private static final int BBOX_THICKNESS = 5; // Pixels
private Bitmap latest;
public ObjectDetectionResultImageView(Context context) {
super(context);
setScaleType(AppCompatImageView.ScaleType.FIT_CENTER);
}
/**
* Sets an {@link Image} and an {@link ObjectDetectionResult} to render.
*
* @param image an {@link Image} object for annotation.
* @param result an {@link ObjectDetectionResult} object that contains the detection result.
*/
public void setData(Image image, ObjectDetectionResult result) {
if (image == null || result == null) {
return;
}
latest = BitmapExtractor.extract(image);
Canvas canvas = new Canvas(latest);
canvas.drawBitmap(latest, new Matrix(), null);
for (int i = 0; i < result.detections().size(); ++i) {
drawDetectionOnCanvas(result.detections().get(i), canvas);
}
}
/** Updates the image view with the latest {@link ObjectDetectionResult}. */
public void update() {
postInvalidate();
if (latest != null) {
setImageBitmap(latest);
}
}
private void drawDetectionOnCanvas(Detection detection, Canvas canvas) {
// TODO: Draws the category and the score per bounding box.
// Draws bounding box.
Paint bboxPaint = new Paint();
bboxPaint.setColor(BBOX_COLOR);
bboxPaint.setStyle(Paint.Style.STROKE);
bboxPaint.setStrokeWidth(BBOX_THICKNESS);
canvas.drawRect(detection.boundingBox(), bboxPaint);
}
}

View File

@ -0,0 +1,34 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:aapt="http://schemas.android.com/aapt"
android:width="108dp"
android:height="108dp"
android:viewportHeight="108"
android:viewportWidth="108">
<path
android:fillType="evenOdd"
android:pathData="M32,64C32,64 38.39,52.99 44.13,50.95C51.37,48.37 70.14,49.57 70.14,49.57L108.26,87.69L108,109.01L75.97,107.97L32,64Z"
android:strokeColor="#00000000"
android:strokeWidth="1">
<aapt:attr name="android:fillColor">
<gradient
android:endX="78.5885"
android:endY="90.9159"
android:startX="48.7653"
android:startY="61.0927"
android:type="linear">
<item
android:color="#44000000"
android:offset="0.0" />
<item
android:color="#00000000"
android:offset="1.0" />
</gradient>
</aapt:attr>
</path>
<path
android:fillColor="#FFFFFF"
android:fillType="nonZero"
android:pathData="M66.94,46.02L66.94,46.02C72.44,50.07 76,56.61 76,64L32,64C32,56.61 35.56,50.11 40.98,46.06L36.18,41.19C35.45,40.45 35.45,39.3 36.18,38.56C36.91,37.81 38.05,37.81 38.78,38.56L44.25,44.05C47.18,42.57 50.48,41.71 54,41.71C57.48,41.71 60.78,42.57 63.68,44.05L69.11,38.56C69.84,37.81 70.98,37.81 71.71,38.56C72.44,39.3 72.44,40.45 71.71,41.19L66.94,46.02ZM62.94,56.92C64.08,56.92 65,56.01 65,54.88C65,53.76 64.08,52.85 62.94,52.85C61.8,52.85 60.88,53.76 60.88,54.88C60.88,56.01 61.8,56.92 62.94,56.92ZM45.06,56.92C46.2,56.92 47.13,56.01 47.13,54.88C47.13,53.76 46.2,52.85 45.06,52.85C43.92,52.85 43,53.76 43,54.88C43,56.01 43.92,56.92 45.06,56.92Z"
android:strokeColor="#00000000"
android:strokeWidth="1" />
</vector>

View File

@ -0,0 +1,74 @@
<?xml version="1.0" encoding="utf-8"?>
<vector
android:height="108dp"
android:width="108dp"
android:viewportHeight="108"
android:viewportWidth="108"
xmlns:android="http://schemas.android.com/apk/res/android">
<path android:fillColor="#26A69A"
android:pathData="M0,0h108v108h-108z"/>
<path android:fillColor="#00000000" android:pathData="M9,0L9,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,0L19,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M29,0L29,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M39,0L39,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M49,0L49,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M59,0L59,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M69,0L69,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M79,0L79,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M89,0L89,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M99,0L99,108"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,9L108,9"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,19L108,19"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,29L108,29"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,39L108,39"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,49L108,49"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,59L108,59"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,69L108,69"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,79L108,79"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,89L108,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M0,99L108,99"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,29L89,29"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,39L89,39"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,49L89,49"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,59L89,59"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,69L89,69"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M19,79L89,79"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M29,19L29,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M39,19L39,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M49,19L49,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M59,19L59,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M69,19L69,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
<path android:fillColor="#00000000" android:pathData="M79,19L79,89"
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
</vector>

View File

@ -0,0 +1,40 @@
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout
xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical">
<LinearLayout
android:id="@+id/buttons"
android:layout_width="match_parent"
android:layout_height="wrap_content"
style="?android:attr/buttonBarStyle" android:gravity="center"
android:orientation="horizontal">
<Button
android:id="@+id/button_load_picture"
android:layout_width="wrap_content"
style="?android:attr/buttonBarButtonStyle" android:layout_height="wrap_content"
android:text="@string/load_picture" />
<Button
android:id="@+id/button_load_video"
android:layout_width="wrap_content"
style="?android:attr/buttonBarButtonStyle" android:layout_height="wrap_content"
android:text="@string/load_video" />
<Button
android:id="@+id/button_start_camera"
android:layout_width="wrap_content"
style="?android:attr/buttonBarButtonStyle" android:layout_height="wrap_content"
android:text="@string/start_camera" />
</LinearLayout>
<FrameLayout
android:id="@+id/preview_display_layout"
android:layout_width="match_parent"
android:layout_height="match_parent">
<TextView
android:id="@+id/no_view"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:gravity="center"
android:text="@string/instruction" />
</FrameLayout>
</LinearLayout>

View File

@ -0,0 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
<background android:drawable="@drawable/ic_launcher_background"/>
<foreground android:drawable="@mipmap/ic_launcher_foreground"/>
</adaptive-icon>

View File

@ -0,0 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
<background android:drawable="@drawable/ic_launcher_background"/>
<foreground android:drawable="@mipmap/ic_launcher_foreground"/>
</adaptive-icon>

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 959 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 900 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<resources>
<color name="colorPrimary">#008577</color>
<color name="colorPrimaryDark">#00574B</color>
<color name="colorAccent">#D81B60</color>
</resources>

View File

@ -0,0 +1,6 @@
<resources>
<string name="load_picture" translatable="false">Load Picture</string>
<string name="load_video" translatable="false">Load Video</string>
<string name="start_camera" translatable="false">Start Camera</string>
<string name="instruction" translatable="false">Please press any button above to start</string>
</resources>

View File

@ -0,0 +1,11 @@
<resources>
<!-- Base application theme. -->
<style name="AppTheme" parent="Theme.AppCompat.Light.DarkActionBar">
<!-- Customize your theme here. -->
<item name="colorPrimary">@color/colorPrimary</item>
<item name="colorPrimaryDark">@color/colorPrimaryDark</item>
<item name="colorAccent">@color/colorAccent</item>
</style>
</resources>

View File

@ -550,6 +550,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/ssd_mobilenet_v1.tflite?generation=1661875947436302"],
)
http_file(
name = "com_google_mediapipe_test_jpg",
sha256 = "798a12a466933842528d8438f553320eebe5137f02650f12dd68706a2f94fb4f",
urls = ["https://storage.googleapis.com/mediapipe-assets/test.jpg?generation=1664672140191116"],
)
http_file(
name = "com_google_mediapipe_test_model_add_op_tflite",
sha256 = "298300ca8a9193b80ada1dca39d36f20bffeebde09e85385049b3bfe7be2272f",