diff --git a/Dockerfile b/Dockerfile index 462dacbd4..4d6c68e7e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -53,7 +53,7 @@ RUN pip3 install wheel RUN pip3 install future RUN pip3 install absl-py numpy opencv-contrib-python protobuf==3.20.1 RUN pip3 install six==1.14.0 -RUN pip3 install tensorflow==2.2.0 +RUN pip3 install tensorflow RUN pip3 install tf_slim RUN ln -s /usr/bin/python3 /usr/bin/python diff --git a/docs/framework_concepts/graphs.md b/docs/framework_concepts/graphs.md index f951b506d..b20a87467 100644 --- a/docs/framework_concepts/graphs.md +++ b/docs/framework_concepts/graphs.md @@ -143,6 +143,98 @@ Below is an example of how to create a subgraph named `TwoPassThroughSubgraph`. } ``` +## Graph Options + +It is possible to specify a "graph options" protobuf for a MediaPipe graph +similar to the [`Calculator Options`](calculators.md#calculator-options) +protobuf specified for a MediaPipe calculator. These "graph options" can be +specified where a graph is invoked, and used to populate calculator options and +subgraph options within the graph. + +In a CalculatorGraphConfig, graph options can be specified for a subgraph +exactly like calculator options, as shown below: + +``` +node { + calculator: "FlowLimiterCalculator" + input_stream: "image" + output_stream: "throttled_image" + node_options: { + [type.googleapis.com/mediapipe.FlowLimiterCalculatorOptions] { + max_in_flight: 1 + } + } +} + +node { + calculator: "FaceDetectionSubgraph" + input_stream: "IMAGE:throttled_image" + node_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] { + tensor_width: 192 + tensor_height: 192 + } + } +} +``` + +In a CalculatorGraphConfig, graph options can be accepted and used to populate +calculator options, as shown below: + +``` +graph_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] {} +} + +node: { + calculator: "ImageToTensorCalculator" + input_stream: "IMAGE:multi_backend_image" + node_options: { + [type.googleapis.com/mediapipe.ImageToTensorCalculatorOptions] { + keep_aspect_ratio: true + border_mode: BORDER_ZERO + } + } + option_value: "output_tensor_width:options/tensor_width" + option_value: "output_tensor_height:options/tensor_height" +} + +node { + calculator: "InferenceCalculator" + node_options: { + [type.googleapis.com/mediapipe.InferenceCalculatorOptions] {} + } + option_value: "delegate:options/delegate" + option_value: "model_path:options/model_path" +} +``` + +In this example, the `FaceDetectionSubgraph` accepts graph option protobuf +`FaceDetectionOptions`. The `FaceDetectionOptions` is used to define some field +values in the calculator options `ImageToTensorCalculatorOptions` and some field +values in the subgraph options `InferenceCalculatorOptions`. The field values +are defined using the `option_value:` syntax. + +In the `CalculatorGraphConfig::Node` protobuf, the fields `node_options:` and +`option_value:` together define the option values for a calculator such as +`ImageToTensorCalculator`. The `node_options:` field defines a set of literal +constant values using the text protobuf syntax. Each `option_value:` field +defines the value for one protobuf field using information from the enclosing +graph, specifically from field values of the graph options of the enclosing +graph. In the example above, the `option_value:` +`"output_tensor_width:options/tensor_width"` defines the field +`ImageToTensorCalculatorOptions.output_tensor_width` using the value of +`FaceDetectionOptions.tensor_width`. + +The syntax of `option_value:` is similar to the syntax of `input_stream:`. The +syntax is `option_value: "LHS:RHS"`. The LHS identifies a calculator option +field and the RHS identifies a graph option field. More specifically, the LHS +and RHS each consists of a series of protobuf field names identifying nested +protobuf messages and fields separated by '/'. This is known as the "ProtoPath" +syntax. Nested messages that are referenced in the LHS or RHS must already be +defined in the enclosing protobuf in order to be traversed using +`option_value:`. + ## Cycles diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java index 1b733ed82..10e6422ba 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java @@ -18,7 +18,7 @@ import android.content.ClipDescription; import android.content.Context; import android.net.Uri; import android.os.Bundle; -import androidx.appcompat.widget.AppCompatEditText; +import android.support.v7.widget.AppCompatEditText; import android.util.AttributeSet; import android.util.Log; import android.view.inputmethod.EditorInfo; diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index c839cf5a2..b11f6b55b 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -371,7 +371,7 @@ void* Tensor::MapAhwbToCpuRead() const { if ((valid_ & kValidOpenGlBuffer) && ssbo_written_ == -1) { // EGLSync is failed. Use another synchronization method. // TODO: Use tflite::gpu::GlBufferSync and GlActiveSync. - glFinish(); + gl_context_->Run([]() { glFinish(); }); } else if (valid_ & kValidAHardwareBuffer) { CHECK(ahwb_written_) << "Ahwb-to-Cpu synchronization requires the " "completion function to be set"; diff --git a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl index ed1686954..7f2cb146c 100644 --- a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl +++ b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl @@ -89,10 +89,6 @@ def mediapipe_aar( calculators = calculators, ) - _mediapipe_proto( - name = name + "_proto", - ) - native.genrule( name = name + "_aar_manifest_generator", outs = ["AndroidManifest.xml"], @@ -115,19 +111,10 @@ EOF "//mediapipe/java/com/google/mediapipe/components:java_src", "//mediapipe/java/com/google/mediapipe/framework:java_src", "//mediapipe/java/com/google/mediapipe/glutil:java_src", - "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java", - "com/google/mediapipe/formats/proto/ClassificationProto.java", - "com/google/mediapipe/formats/proto/DetectionProto.java", - "com/google/mediapipe/formats/proto/LandmarkProto.java", - "com/google/mediapipe/formats/proto/LocationDataProto.java", - "com/google/mediapipe/proto/CalculatorProto.java", - ] + + ] + mediapipe_java_proto_srcs() + select({ "//conditions:default": [], - "enable_stats_logging": [ - "com/google/mediapipe/proto/MediaPipeLoggingProto.java", - "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java", - ], + "enable_stats_logging": mediapipe_logging_java_proto_srcs(), }), manifest = "AndroidManifest.xml", proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"], @@ -179,93 +166,6 @@ EOF _aar_with_jni(name, name + "_android_lib") -def _mediapipe_proto(name): - """Generates MediaPipe java proto libraries. - - Args: - name: the name of the target. - """ - _proto_java_src_generator( - name = "mediapipe_log_extension_proto", - proto_src = "mediapipe/util/analytics/mediapipe_log_extension.proto", - java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java", - srcs = ["//mediapipe/util/analytics:protos_src"], - ) - - _proto_java_src_generator( - name = "mediapipe_logging_enums_proto", - proto_src = "mediapipe/util/analytics/mediapipe_logging_enums.proto", - java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java", - srcs = ["//mediapipe/util/analytics:protos_src"], - ) - - _proto_java_src_generator( - name = "calculator_proto", - proto_src = "mediapipe/framework/calculator.proto", - java_lite_out = "com/google/mediapipe/proto/CalculatorProto.java", - srcs = ["//mediapipe/framework:protos_src"], - ) - - _proto_java_src_generator( - name = "landmark_proto", - proto_src = "mediapipe/framework/formats/landmark.proto", - java_lite_out = "com/google/mediapipe/formats/proto/LandmarkProto.java", - srcs = ["//mediapipe/framework/formats:protos_src"], - ) - - _proto_java_src_generator( - name = "rasterization_proto", - proto_src = "mediapipe/framework/formats/annotation/rasterization.proto", - java_lite_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java", - srcs = ["//mediapipe/framework/formats/annotation:protos_src"], - ) - - _proto_java_src_generator( - name = "location_data_proto", - proto_src = "mediapipe/framework/formats/location_data.proto", - java_lite_out = "com/google/mediapipe/formats/proto/LocationDataProto.java", - srcs = [ - "//mediapipe/framework/formats:protos_src", - "//mediapipe/framework/formats/annotation:protos_src", - ], - ) - - _proto_java_src_generator( - name = "detection_proto", - proto_src = "mediapipe/framework/formats/detection.proto", - java_lite_out = "com/google/mediapipe/formats/proto/DetectionProto.java", - srcs = [ - "//mediapipe/framework/formats:protos_src", - "//mediapipe/framework/formats/annotation:protos_src", - ], - ) - - _proto_java_src_generator( - name = "classification_proto", - proto_src = "mediapipe/framework/formats/classification.proto", - java_lite_out = "com/google/mediapipe/formats/proto/ClassificationProto.java", - srcs = [ - "//mediapipe/framework/formats:protos_src", - ], - ) - -def _proto_java_src_generator(name, proto_src, java_lite_out, srcs = []): - native.genrule( - name = name + "_proto_java_src_generator", - srcs = srcs + [ - "@com_google_protobuf//:lite_well_known_protos", - ], - outs = [java_lite_out], - cmd = "$(location @com_google_protobuf//:protoc) " + - "--proto_path=. --proto_path=$(GENDIR) " + - "--proto_path=$$(pwd)/external/com_google_protobuf/src " + - "--java_out=lite:$(GENDIR) " + proto_src + " && " + - "mv $(GENDIR)/" + java_lite_out + " $$(dirname $(location " + java_lite_out + "))", - tools = [ - "@com_google_protobuf//:protoc", - ], - ) - def _mediapipe_jni(name, gen_libmediapipe, calculators = []): """Generates MediaPipe jni library. @@ -345,3 +245,93 @@ cp -r lib jni zip -r $$origdir/$(location :{}.aar) jni/*/*.so """.format(android_library, name, name, name, name), ) + +def mediapipe_java_proto_src_extractor(target, src_out, name = ""): + """Extracts the generated MediaPipe java proto source code from the target. + + Args: + target: The java proto lite target to be built and extracted. + src_out: The output java proto src code path. + name: The optional bazel target name. + + Returns: + The output java proto src code path. + """ + + if not name: + name = target.split(":")[-1] + "_proto_java_src_extractor" + src_jar = target.replace("_java_proto_lite", "_proto-lite-src.jar").replace(":", "/").replace("//", "") + native.genrule( + name = name + "_proto_java_src_extractor", + srcs = [target], + outs = [src_out], + cmd = "unzip $(GENDIR)/" + src_jar + " -d $(GENDIR) && mv $(GENDIR)/" + + src_out + " $$(dirname $(location " + src_out + "))", + ) + return src_out + +def mediapipe_java_proto_srcs(name = ""): + """Extracts the generated MediaPipe framework java proto source code. + + Args: + name: The optional bazel target name. + + Returns: + The list of the extrated MediaPipe java proto source code. + """ + + proto_src_list = [] + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework:calculator_java_proto_lite", + src_out = "com/google/mediapipe/proto/CalculatorProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework/formats:landmark_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite", + src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework/formats:location_data_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework/formats:detection_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/DetectionProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework/formats:classification_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java", + )) + return proto_src_list + +def mediapipe_logging_java_proto_srcs(name = ""): + """Extracts the generated logging-related MediaPipe java proto source code. + + Args: + name: The optional bazel target name. + + Returns: + The list of the extrated MediaPipe logging-related java proto source code. + """ + + proto_src_list = [] + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/util/analytics:mediapipe_log_extension_java_proto_lite", + src_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/util/analytics:mediapipe_logging_enums_java_proto_lite", + src_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java", + )) + return proto_src_list diff --git a/mediapipe/model_maker/BUILD b/mediapipe/model_maker/BUILD new file mode 100644 index 000000000..cb312072f --- /dev/null +++ b/mediapipe/model_maker/BUILD @@ -0,0 +1,22 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//visibility:public"]) + +package_group( + name = "internal", + packages = [ + "//mediapipe/model_maker/...", + ], +) diff --git a/mediapipe/model_maker/__init__.py b/mediapipe/model_maker/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/python/BUILD b/mediapipe/model_maker/python/BUILD new file mode 100644 index 000000000..cb312072f --- /dev/null +++ b/mediapipe/model_maker/python/BUILD @@ -0,0 +1,22 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//visibility:public"]) + +package_group( + name = "internal", + packages = [ + "//mediapipe/model_maker/...", + ], +) diff --git a/mediapipe/model_maker/python/__init__.py b/mediapipe/model_maker/python/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/python/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/python/core/BUILD b/mediapipe/model_maker/python/core/BUILD new file mode 100644 index 000000000..10aef8c33 --- /dev/null +++ b/mediapipe/model_maker/python/core/BUILD @@ -0,0 +1,19 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_visibility = ["//mediapipe:__subpackages__"], +) + +licenses(["notice"]) diff --git a/mediapipe/model_maker/python/core/__init__.py b/mediapipe/model_maker/python/core/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/python/core/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/python/core/data/BUILD b/mediapipe/model_maker/python/core/data/BUILD new file mode 100644 index 000000000..c4c659d56 --- /dev/null +++ b/mediapipe/model_maker/python/core/data/BUILD @@ -0,0 +1,68 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library and test compatibility macro. + +licenses(["notice"]) + +package( + default_visibility = ["//mediapipe:__subpackages__"], +) + +py_library( + name = "data_util", + srcs = ["data_util.py"], + srcs_version = "PY3", +) + +py_test( + name = "data_util_test", + srcs = ["data_util_test.py"], + data = ["//mediapipe/model_maker/python/core/data/testdata"], + python_version = "PY3", + srcs_version = "PY3", + deps = [":data_util"], +) + +py_library( + name = "dataset", + srcs = ["dataset.py"], + srcs_version = "PY3", +) + +py_test( + name = "dataset_test", + srcs = ["dataset_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":dataset", + "//mediapipe/model_maker/python/core/utils:test_util", + ], +) + +py_library( + name = "classification_dataset", + srcs = ["classification_dataset.py"], + srcs_version = "PY3", + deps = [":dataset"], +) + +py_test( + name = "classification_dataset_test", + srcs = ["classification_dataset_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [":classification_dataset"], +) diff --git a/mediapipe/model_maker/python/core/data/__init__.py b/mediapipe/model_maker/python/core/data/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/python/core/data/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/python/core/data/classification_dataset.py b/mediapipe/model_maker/python/core/data/classification_dataset.py new file mode 100644 index 000000000..9075e46eb --- /dev/null +++ b/mediapipe/model_maker/python/core/data/classification_dataset.py @@ -0,0 +1,47 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common classification dataset library.""" + +from typing import Any, Tuple + +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import dataset as ds + + +class ClassificationDataset(ds.Dataset): + """DataLoader for classification models.""" + + def __init__(self, dataset: tf.data.Dataset, size: int, index_to_label: Any): + super().__init__(dataset, size) + self.index_to_label = index_to_label + + @property + def num_classes(self: ds._DatasetT) -> int: + return len(self.index_to_label) + + def split(self: ds._DatasetT, + fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]: + """Splits dataset into two sub-datasets with the given fraction. + + Primarily used for splitting the data set into training and testing sets. + + Args: + fraction: float, demonstrates the fraction of the first returned + subdataset in the original data. + + Returns: + The splitted two sub datasets. + """ + return self._split(fraction, self.index_to_label) diff --git a/mediapipe/model_maker/python/core/data/classification_dataset_test.py b/mediapipe/model_maker/python/core/data/classification_dataset_test.py new file mode 100644 index 000000000..f8688ab14 --- /dev/null +++ b/mediapipe/model_maker/python/core/data/classification_dataset_test.py @@ -0,0 +1,68 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Dependency imports + +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import classification_dataset + + +class ClassificationDataLoaderTest(tf.test.TestCase): + + def test_split(self): + + class MagicClassificationDataLoader( + classification_dataset.ClassificationDataset): + + def __init__(self, dataset, size, index_to_label, value): + super(MagicClassificationDataLoader, + self).__init__(dataset, size, index_to_label) + self.value = value + + def split(self, fraction): + return self._split(fraction, self.index_to_label, self.value) + + # Some dummy inputs. + magic_value = 42 + num_classes = 2 + index_to_label = (False, True) + + # Create data loader from sample data. + ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) + data = MagicClassificationDataLoader(ds, len(ds), index_to_label, + magic_value) + + # Train/Test data split. + fraction = .25 + train_data, test_data = data.split(fraction) + + # `split` should return instances of child DataLoader. + self.assertIsInstance(train_data, MagicClassificationDataLoader) + self.assertIsInstance(test_data, MagicClassificationDataLoader) + + # Make sure number of entries are right. + self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data)) + self.assertLen(train_data, fraction * len(ds)) + self.assertLen(test_data, len(ds) - len(train_data)) + + # Make sure attributes propagated correctly. + self.assertEqual(train_data.num_classes, num_classes) + self.assertEqual(test_data.index_to_label, index_to_label) + self.assertEqual(train_data.value, magic_value) + self.assertEqual(test_data.value, magic_value) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/core/data/data_util.py b/mediapipe/model_maker/python/core/data/data_util.py new file mode 100644 index 000000000..8c6b9145f --- /dev/null +++ b/mediapipe/model_maker/python/core/data/data_util.py @@ -0,0 +1,35 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Data utility library.""" + +import cv2 +import numpy as np +import tensorflow as tf + + +def load_image(path: str) -> np.ndarray: + """Loads an image as an RGB numpy array. + + Args: + path: input image file absolute path. + + Returns: + An RGB image in numpy.ndarray. + """ + tf.compat.v1.logging.info('Loading RGB image %s', path) + # TODO Replace the OpenCV image load and conversion library by + # MediaPipe image utility library once it is ready. + image = cv2.imread(path) + return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) diff --git a/mediapipe/model_maker/python/core/data/data_util_test.py b/mediapipe/model_maker/python/core/data/data_util_test.py new file mode 100644 index 000000000..56ac832c3 --- /dev/null +++ b/mediapipe/model_maker/python/core/data/data_util_test.py @@ -0,0 +1,44 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +# Dependency imports + +from absl import flags +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import data_util + +_WORKSPACE = "mediapipe" +_TEST_DATA_DIR = os.path.join( + _WORKSPACE, 'mediapipe/model_maker/python/core/data/testdata') + +FLAGS = flags.FLAGS + + +class DataUtilTest(tf.test.TestCase): + + def test_load_rgb_image(self): + image_path = os.path.join(FLAGS.test_srcdir, _TEST_DATA_DIR, 'test.jpg') + image_data = data_util.load_image(image_path) + self.assertEqual(image_data.shape, (5184, 3456, 3)) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/core/data/dataset.py b/mediapipe/model_maker/python/core/data/dataset.py new file mode 100644 index 000000000..a92b05c0d --- /dev/null +++ b/mediapipe/model_maker/python/core/data/dataset.py @@ -0,0 +1,164 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common dataset for model training and evaluation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +from typing import Callable, Optional, Tuple, TypeVar + +# Dependency imports +import tensorflow as tf + +_DatasetT = TypeVar('_DatasetT', bound='Dataset') + + +class Dataset(object): + """A generic dataset class for loading model training and evaluation dataset. + + For each ML task, such as image classification, text classification etc., a + subclass can be derived from this class to provide task-specific data loading + utilities. + """ + + def __init__(self, tf_dataset: tf.data.Dataset, size: Optional[int] = None): + """Initializes Dataset class. + + To build dataset from raw data, consider using the task specific utilities, + e.g. from_folder(). + + Args: + tf_dataset: A tf.data.Dataset object that contains a potentially large set + of elements, where each element is a pair of (input_data, target). The + `input_data` means the raw input data, like an image, a text etc., while + the `target` means the ground truth of the raw input data, e.g. the + classification label of the image etc. + size: The size of the dataset. tf.data.Dataset donesn't support a function + to get the length directly since it's lazy-loaded and may be infinite. + """ + self._dataset = tf_dataset + self._size = size + + @property + def size(self) -> Optional[int]: + """Returns the size of the dataset. + + Note that this function may return None becuase the exact size of the + dataset isn't a necessary parameter to create an instance of this class, + and tf.data.Dataset donesn't support a function to get the length directly + since it's lazy-loaded and may be infinite. + In most cases, however, when an instance of this class is created by helper + functions like 'from_folder', the size of the dataset will be preprocessed, + and this function can return an int representing the size of the dataset. + """ + return self._size + + def gen_tf_dataset(self, + batch_size: int = 1, + is_training: bool = False, + shuffle: bool = False, + preprocess: Optional[Callable[..., bool]] = None, + drop_remainder: bool = False) -> tf.data.Dataset: + """Generates a batched tf.data.Dataset for training/evaluation. + + Args: + batch_size: An integer, the returned dataset will be batched by this size. + is_training: A boolean, when True, the returned dataset will be optionally + shuffled and repeated as an endless dataset. + shuffle: A boolean, when True, the returned dataset will be shuffled to + create randomness during model training. + preprocess: A function taking three arguments in order, feature, label and + boolean is_training. + drop_remainder: boolean, whether the finaly batch drops remainder. + + Returns: + A TF dataset ready to be consumed by Keras model. + """ + dataset = self._dataset + + if preprocess: + preprocess = functools.partial(preprocess, is_training=is_training) + dataset = dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE) + + if is_training: + if shuffle: + # Shuffle size should be bigger than the batch_size. Otherwise it's only + # shuffling within the batch, which equals to not having shuffle. + buffer_size = 3 * batch_size + # But since we are doing shuffle before repeat, it doesn't make sense to + # shuffle more than total available entries. + # TODO: Investigate if shuffling before / after repeat + # dataset can get a better performance? + # Shuffle after repeat will give a more randomized dataset and mix the + # epoch boundary: https://www.tensorflow.org/guide/data + if self._size: + buffer_size = min(self._size, buffer_size) + dataset = dataset.shuffle(buffer_size=buffer_size) + + dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) + dataset = dataset.prefetch(tf.data.AUTOTUNE) + # TODO: Consider converting dataset to distributed dataset + # here. + return dataset + + def __len__(self): + """Returns the number of element of the dataset.""" + if self._size is not None: + return self._size + else: + return len(self._dataset) + + def split(self: _DatasetT, fraction: float) -> Tuple[_DatasetT, _DatasetT]: + """Splits dataset into two sub-datasets with the given fraction. + + Primarily used for splitting the data set into training and testing sets. + + Args: + fraction: A float value defines the fraction of the first returned + subdataset in the original data. + + Returns: + The splitted two sub datasets. + """ + return self._split(fraction) + + def _split(self: _DatasetT, fraction: float, + *args) -> Tuple[_DatasetT, _DatasetT]: + """Implementation for `split` method and returns sub-class instances. + + Child DataLoader classes, if requires additional constructor arguments, + should implement their own `split` method by calling `_split` with all + arguments to the constructor. + + Args: + fraction: A float value defines the fraction of the first returned + subdataset in the original data. + *args: additional arguments passed to the sub-class constructor. + + Returns: + The splitted two sub datasets. + """ + assert (fraction > 0 and fraction < 1) + + dataset = self._dataset + + train_size = int(self._size * fraction) + trainset = self.__class__(dataset.take(train_size), train_size, *args) + + test_size = self._size - train_size + testset = self.__class__(dataset.skip(train_size), test_size, *args) + + return trainset, testset diff --git a/mediapipe/model_maker/python/core/data/dataset_test.py b/mediapipe/model_maker/python/core/data/dataset_test.py new file mode 100644 index 000000000..9adff127d --- /dev/null +++ b/mediapipe/model_maker/python/core/data/dataset_test.py @@ -0,0 +1,78 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +import numpy as np +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import dataset as ds +from mediapipe.model_maker.python.core.utils import test_util + + +class DatasetTest(tf.test.TestCase): + + def test_split(self): + dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], + [1, 0]]) + data = ds.Dataset(dataset, 4) + train_data, test_data = data.split(0.5) + + self.assertLen(train_data, 2) + self.assertIsInstance(train_data, ds.Dataset) + self.assertIsInstance(test_data, ds.Dataset) + for i, elem in enumerate(train_data.gen_tf_dataset()): + self.assertTrue((elem.numpy() == np.array([i, 1])).all()) + + self.assertLen(test_data, 2) + for i, elem in enumerate(test_data.gen_tf_dataset()): + self.assertTrue((elem.numpy() == np.array([i, 0])).all()) + + def test_len(self): + size = 4 + dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], + [1, 0]]) + data = ds.Dataset(dataset, size) + self.assertLen(data, size) + + def test_gen_tf_dataset(self): + input_dim = 8 + data = test_util.create_dataset( + data_size=2, input_shape=[input_dim], num_classes=2) + + dataset = data.gen_tf_dataset() + self.assertLen(dataset, 2) + for (feature, label) in dataset: + self.assertTrue((tf.shape(feature).numpy() == np.array([1, 8])).all()) + self.assertTrue((tf.shape(label).numpy() == np.array([1])).all()) + + dataset2 = data.gen_tf_dataset(batch_size=2) + self.assertLen(dataset2, 1) + for (feature, label) in dataset2: + self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all()) + self.assertTrue((tf.shape(label).numpy() == np.array([2])).all()) + + dataset3 = data.gen_tf_dataset(batch_size=2, is_training=True, shuffle=True) + self.assertEqual(dataset3.cardinality(), 1) + for (feature, label) in dataset3.take(10): + self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all()) + self.assertTrue((tf.shape(label).numpy() == np.array([2])).all()) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/core/data/testdata/BUILD b/mediapipe/model_maker/python/core/data/testdata/BUILD new file mode 100644 index 000000000..54e562d41 --- /dev/null +++ b/mediapipe/model_maker/python/core/data/testdata/BUILD @@ -0,0 +1,30 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//mediapipe/framework/tool:mediapipe_files.bzl", + "mediapipe_files", +) + +package( + default_visibility = ["//mediapipe/model_maker/python/core/data:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) + +mediapipe_files(srcs = ["test.jpg"]) + +filegroup( + name = "testdata", + srcs = ["test.jpg"], +) diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD new file mode 100644 index 000000000..e4b18b395 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -0,0 +1,100 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library and test compatibility macro. + +licenses(["notice"]) + +package( + default_visibility = ["//mediapipe:__subpackages__"], +) + +py_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.py"], + srcs_version = "PY3", + deps = [ + ":model_util", + "//mediapipe/model_maker/python/core/data:dataset", + ], +) + +py_library( + name = "image_preprocessing", + srcs = ["image_preprocessing.py"], + srcs_version = "PY3", +) + +py_test( + name = "image_preprocessing_test", + srcs = ["image_preprocessing_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [":image_preprocessing"], +) + +py_library( + name = "model_util", + srcs = ["model_util.py"], + srcs_version = "PY3", + deps = [ + ":quantization", + "//mediapipe/model_maker/python/core/data:dataset", + ], +) + +py_test( + name = "model_util_test", + srcs = ["model_util_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":model_util", + ":quantization", + ":test_util", + ], +) + +py_library( + name = "loss_functions", + srcs = ["loss_functions.py"], + srcs_version = "PY3", +) + +py_test( + name = "loss_functions_test", + srcs = ["loss_functions_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [":loss_functions"], +) + +py_library( + name = "quantization", + srcs = ["quantization.py"], + srcs_version = "PY3", + deps = ["//mediapipe/model_maker/python/core/data:dataset"], +) + +py_test( + name = "quantization_test", + srcs = ["quantization_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":quantization", + ":test_util", + ], +) diff --git a/mediapipe/model_maker/python/core/utils/__init__.py b/mediapipe/model_maker/python/core/utils/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/python/core/utils/image_preprocessing.py b/mediapipe/model_maker/python/core/utils/image_preprocessing.py new file mode 100644 index 000000000..62b34fb27 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/image_preprocessing.py @@ -0,0 +1,228 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ImageNet preprocessing.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports +import tensorflow as tf + +IMAGE_SIZE = 224 +CROP_PADDING = 32 + + +class Preprocessor(object): + """Preprocessor for image classification.""" + + def __init__(self, + input_shape, + num_classes, + mean_rgb, + stddev_rgb, + use_augmentation=False): + self.input_shape = input_shape + self.num_classes = num_classes + self.mean_rgb = mean_rgb + self.stddev_rgb = stddev_rgb + self.use_augmentation = use_augmentation + + def __call__(self, image, label, is_training=True): + if self.use_augmentation: + return self._preprocess_with_augmentation(image, label, is_training) + return self._preprocess_without_augmentation(image, label) + + def _preprocess_with_augmentation(self, image, label, is_training): + """Image preprocessing method with data augmentation.""" + image_size = self.input_shape[0] + if is_training: + image = preprocess_for_train(image, image_size) + else: + image = preprocess_for_eval(image, image_size) + + image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype) + image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype) + + label = tf.one_hot(label, depth=self.num_classes) + return image, label + + # TODO: Changes to preprocess to support batch input. + def _preprocess_without_augmentation(self, image, label): + """Image preprocessing method without data augmentation.""" + image = tf.cast(image, tf.float32) + + image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype) + image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype) + + image = tf.compat.v1.image.resize(image, self.input_shape) + label = tf.one_hot(label, depth=self.num_classes) + return image, label + + +def _distorted_bounding_box_crop(image, + bbox, + min_object_covered=0.1, + aspect_ratio_range=(0.75, 1.33), + area_range=(0.05, 1.0), + max_attempts=100): + """Generates cropped_image using one of the bboxes randomly distorted. + + See `tf.image.sample_distorted_bounding_box` for more documentation. + + Args: + image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of + shape [height, width, channels]. + bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` where + each coordinate is [0, 1) and the coordinates are arranged as `[ymin, + xmin, ymax, xmax]`. If num_boxes is 0 then use the whole image. + min_object_covered: An optional `float`. Defaults to `0.1`. The cropped area + of the image must contain at least this fraction of any bounding box + supplied. + aspect_ratio_range: An optional list of `float`s. The cropped area of the + image must have an aspect ratio = width / height within this range. + area_range: An optional list of `float`s. The cropped area of the image must + contain a fraction of the supplied image within in this range. + max_attempts: An optional `int`. Number of attempts at generating a cropped + region of the image of the specified constraints. After `max_attempts` + failures, return the entire image. + + Returns: + A cropped image `Tensor` + """ + with tf.name_scope('distorted_bounding_box_crop'): + shape = tf.shape(image) + sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( + shape, + bounding_boxes=bbox, + min_object_covered=min_object_covered, + aspect_ratio_range=aspect_ratio_range, + area_range=area_range, + max_attempts=max_attempts, + use_image_if_no_bounding_boxes=True) + bbox_begin, bbox_size, _ = sample_distorted_bounding_box + + # Crop the image to the specified bounding box. + offset_y, offset_x, _ = tf.unstack(bbox_begin) + target_height, target_width, _ = tf.unstack(bbox_size) + image = tf.image.crop_to_bounding_box(image, offset_y, offset_x, + target_height, target_width) + + return image + + +def _at_least_x_are_equal(a, b, x): + """At least `x` of `a` and `b` `Tensors` are equal.""" + match = tf.equal(a, b) + match = tf.cast(match, tf.int32) + return tf.greater_equal(tf.reduce_sum(match), x) + + +def _resize_image(image, image_size, method=None): + if method is not None: + tf.compat.v1.logging.info('Use customized resize method {}'.format(method)) + return tf.compat.v1.image.resize([image], [image_size, image_size], + method)[0] + tf.compat.v1.logging.info('Use default resize_bicubic.') + return tf.compat.v1.image.resize_bicubic([image], [image_size, image_size])[0] + + +def _decode_and_random_crop(original_image, image_size, resize_method=None): + """Makes a random crop of image_size.""" + bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) + image = _distorted_bounding_box_crop( + original_image, + bbox, + min_object_covered=0.1, + aspect_ratio_range=(3. / 4, 4. / 3.), + area_range=(0.08, 1.0), + max_attempts=10) + original_shape = tf.shape(original_image) + bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3) + + image = tf.cond(bad, + lambda: _decode_and_center_crop(original_image, image_size), + lambda: _resize_image(image, image_size, resize_method)) + + return image + + +def _decode_and_center_crop(image, image_size, resize_method=None): + """Crops to center of image with padding then scales image_size.""" + shape = tf.shape(image) + image_height = shape[0] + image_width = shape[1] + + padded_center_crop_size = tf.cast( + ((image_size / (image_size + CROP_PADDING)) * + tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32) + + offset_height = ((image_height - padded_center_crop_size) + 1) // 2 + offset_width = ((image_width - padded_center_crop_size) + 1) // 2 + image = tf.image.crop_to_bounding_box(image, offset_height, offset_width, + padded_center_crop_size, + padded_center_crop_size) + image = _resize_image(image, image_size, resize_method) + return image + + +def _flip(image): + """Random horizontal image flip.""" + image = tf.image.random_flip_left_right(image) + return image + + +def preprocess_for_train( + image: tf.Tensor, + image_size: int = IMAGE_SIZE, + resize_method: str = tf.image.ResizeMethod.BILINEAR) -> tf.Tensor: + """Preprocesses the given image for evaluation. + + Args: + image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of + shape [height, width, channels]. + image_size: image size. + resize_method: resize method. If none, use bicubic. + + Returns: + A preprocessed image `Tensor`. + """ + image = _decode_and_random_crop(image, image_size, resize_method) + image = _flip(image) + image = tf.reshape(image, [image_size, image_size, 3]) + + image = tf.image.convert_image_dtype(image, dtype=tf.float32) + + return image + + +def preprocess_for_eval( + image: tf.Tensor, + image_size: int = IMAGE_SIZE, + resize_method: str = tf.image.ResizeMethod.BILINEAR) -> tf.Tensor: + """Preprocesses the given image for evaluation. + + Args: + image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of + shape [height, width, channels]. + image_size: image size. + resize_method: if None, use bicubic. + + Returns: + A preprocessed image `Tensor`. + """ + image = _decode_and_center_crop(image, image_size, resize_method) + image = tf.reshape(image, [image_size, image_size, 3]) + image = tf.image.convert_image_dtype(image, dtype=tf.float32) + return image diff --git a/mediapipe/model_maker/python/core/utils/image_preprocessing_test.py b/mediapipe/model_maker/python/core/utils/image_preprocessing_test.py new file mode 100644 index 000000000..bc4b44569 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/image_preprocessing_test.py @@ -0,0 +1,85 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports +import numpy as np +import tensorflow as tf + +from mediapipe.model_maker.python.core.utils import image_preprocessing + + +def _get_preprocessed_image(preprocessor, is_training=False): + image_placeholder = tf.compat.v1.placeholder(tf.uint8, [24, 24, 3]) + label_placeholder = tf.compat.v1.placeholder(tf.int32, [1]) + image_tensor, _ = preprocessor(image_placeholder, label_placeholder, + is_training) + + with tf.compat.v1.Session() as sess: + input_image = np.arange(24 * 24 * 3, dtype=np.uint8).reshape([24, 24, 3]) + image = sess.run( + image_tensor, + feed_dict={ + image_placeholder: input_image, + label_placeholder: [0] + }) + return image + + +class PreprocessorTest(tf.test.TestCase): + + def test_preprocess_without_augmentation(self): + preprocessor = image_preprocessing.Preprocessor(input_shape=[2, 2], + num_classes=2, + mean_rgb=[0.0], + stddev_rgb=[255.0], + use_augmentation=False) + actual_image = np.array([[[0., 0.00392157, 0.00784314], + [0.14117648, 0.14509805, 0.14901961]], + [[0.37647063, 0.3803922, 0.38431376], + [0.5176471, 0.52156866, 0.5254902]]]) + + image = _get_preprocessed_image(preprocessor) + self.assertTrue(np.allclose(image, actual_image, atol=1e-05)) + + def test_preprocess_with_augmentation(self): + image_preprocessing.CROP_PADDING = 1 + preprocessor = image_preprocessing.Preprocessor(input_shape=[2, 2], + num_classes=2, + mean_rgb=[0.0], + stddev_rgb=[255.0], + use_augmentation=True) + # Tests validation image. + actual_eval_image = np.array([[[0.17254902, 0.1764706, 0.18039216], + [0.26666668, 0.27058825, 0.27450982]], + [[0.42352945, 0.427451, 0.43137258], + [0.5176471, 0.52156866, 0.5254902]]]) + + image = _get_preprocessed_image(preprocessor, is_training=False) + self.assertTrue(np.allclose(image, actual_eval_image, atol=1e-05)) + + # Tests training image. + image1 = _get_preprocessed_image(preprocessor, is_training=True) + image2 = _get_preprocessed_image(preprocessor, is_training=True) + self.assertFalse(np.allclose(image1, image2, atol=1e-05)) + self.assertEqual(image1.shape, (2, 2, 3)) + self.assertEqual(image2.shape, (2, 2, 3)) + + +if __name__ == '__main__': + tf.compat.v1.disable_eager_execution() + tf.test.main() diff --git a/mediapipe/model_maker/python/core/utils/loss_functions.py b/mediapipe/model_maker/python/core/utils/loss_functions.py new file mode 100644 index 000000000..17c738a14 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/loss_functions.py @@ -0,0 +1,105 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Loss function utility library.""" + +from typing import Optional, Sequence + +import tensorflow as tf + + +class FocalLoss(tf.keras.losses.Loss): + """Implementation of focal loss (https://arxiv.org/pdf/1708.02002.pdf). + + This class computes the focal loss between labels and prediction. Focal loss + is a weighted loss function that modulates the standard cross-entropy loss + based on how well the neural network performs on a specific example of a + class. The labels should be provided in a `one_hot` vector representation. + There should be `#classes` floating point values per prediction. + The loss is reduced across all samples using 'sum_over_batch_size' reduction + (see https://www.tensorflow.org/api_docs/python/tf/keras/losses/Reduction). + + Example usage: + >>> y_true = [[0, 1, 0], [0, 0, 1]] + >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] + >>> gamma = 2 + >>> focal_loss = FocalLoss(gamma) + >>> focal_loss(y_true, y_pred).numpy() + 0.9326 + + >>> # Calling with 'sample_weight'. + >>> focal_loss(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy() + 0.6528 + + Usage with the `compile()` API: + ```python + model.compile(optimizer='sgd', loss=FocalLoss(gamma)) + ``` + + """ + + def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None): + """Constructor. + + Args: + gamma: Focal loss gamma, as described in class docs. + class_weight: A weight to apply to the loss, one for each class. The + weight is applied for each input where the ground truth label matches. + """ + super(tf.keras.losses.Loss, self).__init__() + # Used for clipping min/max values of probability values in y_pred to avoid + # NaNs and Infs in computation. + self._epsilon = 1e-7 + # This is a tunable "focusing parameter"; should be >= 0. + # When gamma = 0, the loss returned is the standard categorical + # cross-entropy loss. + self._gamma = gamma + self._class_weight = class_weight + # tf.keras.losses.Loss class implementation requires a Reduction specified + # in self.reduction. To use this reduction, we should use tensorflow's + # compute_weighted_loss function however it is only compatible with v1 of + # Tensorflow: https://www.tensorflow.org/api_docs/python/tf/compat/v1/losses/compute_weighted_loss?hl=en. pylint: disable=line-too-long + # So even though it is specified here, we don't use self.reduction in the + # loss function call. + self.reduction = tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE + + def __call__(self, + y_true: tf.Tensor, + y_pred: tf.Tensor, + sample_weight: Optional[tf.Tensor] = None) -> tf.Tensor: + if self._class_weight: + class_weight = tf.convert_to_tensor(self._class_weight, dtype=tf.float32) + label = tf.argmax(y_true, axis=1) + loss_weight = tf.gather(class_weight, label) + else: + loss_weight = tf.ones(tf.shape(y_true)[0]) + y_true = tf.cast(y_true, y_pred.dtype) + y_pred = tf.clip_by_value(y_pred, self._epsilon, 1 - self._epsilon) + batch_size = tf.cast(tf.shape(y_pred)[0], y_pred.dtype) + if sample_weight is None: + sample_weight = tf.constant(1.0) + weight_shape = sample_weight.shape + weight_rank = weight_shape.ndims + y_pred_rank = y_pred.shape.ndims + if y_pred_rank - weight_rank == 1: + sample_weight = tf.expand_dims(sample_weight, [-1]) + elif weight_rank != 0: + raise ValueError(f'Unexpected sample_weights, should be either a scalar' + f'or a vector of batch_size:{batch_size.numpy()}') + ce = -tf.math.log(y_pred) + modulating_factor = tf.math.pow(1 - y_pred, self._gamma) + losses = y_true * modulating_factor * ce * sample_weight + losses = losses * loss_weight[:, tf.newaxis] + # By default, this function uses "sum_over_batch_size" reduction for the + # loss per batch. + return tf.reduce_sum(losses) / batch_size diff --git a/mediapipe/model_maker/python/core/utils/loss_functions_test.py b/mediapipe/model_maker/python/core/utils/loss_functions_test.py new file mode 100644 index 000000000..716c329ef --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/loss_functions_test.py @@ -0,0 +1,103 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +from absl.testing import parameterized +import tensorflow as tf + +from mediapipe.model_maker.python.core.utils import loss_functions + + +class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + dict(testcase_name='no_sample_weight', sample_weight=None), + dict( + testcase_name='with_sample_weight', + sample_weight=tf.constant([0.2, 0.2, 0.3, 0.1, 0.2]))) + def test_focal_loss_gamma_0_is_cross_entropy( + self, sample_weight: Optional[tf.Tensor]): + y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, + 0]]) + y_pred = tf.constant([[0.7, 0.1, 0.2], [0.6, 0.3, 0.1], [0.1, 0.5, 0.4], + [0.8, 0.1, 0.1], [0.4, 0.5, 0.1]]) + + tf_cce = tf.keras.losses.CategoricalCrossentropy( + from_logits=False, + reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE) + focal_loss = loss_functions.FocalLoss(gamma=0) + self.assertAllClose( + tf_cce(y_true, y_pred, sample_weight=sample_weight), + focal_loss(y_true, y_pred, sample_weight=sample_weight), 1e-4) + + def test_focal_loss_with_sample_weight(self): + y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, + 0]]) + y_pred = tf.constant([[0.7, 0.1, 0.2], [0.6, 0.3, 0.1], [0.1, 0.5, 0.4], + [0.8, 0.1, 0.1], [0.4, 0.5, 0.1]]) + + focal_loss = loss_functions.FocalLoss(gamma=0) + + sample_weight = tf.constant([0.2, 0.2, 0.3, 0.1, 0.2]) + + self.assertGreater( + focal_loss(y_true=y_true, y_pred=y_pred), + focal_loss(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight)) + + @parameterized.named_parameters( + dict(testcase_name='gt_0.1', y_pred=tf.constant([0.1, 0.9])), + dict(testcase_name='gt_0.3', y_pred=tf.constant([0.3, 0.7])), + dict(testcase_name='gt_0.5', y_pred=tf.constant([0.5, 0.5])), + dict(testcase_name='gt_0.7', y_pred=tf.constant([0.7, 0.3])), + dict(testcase_name='gt_0.9', y_pred=tf.constant([0.9, 0.1])), + ) + def test_focal_loss_decreases_with_increasing_gamma(self, y_pred: tf.Tensor): + y_true = tf.constant([[1, 0]]) + + focal_loss_gamma_0 = loss_functions.FocalLoss(gamma=0) + loss_gamma_0 = focal_loss_gamma_0(y_true, y_pred) + focal_loss_gamma_0p5 = loss_functions.FocalLoss(gamma=0.5) + loss_gamma_0p5 = focal_loss_gamma_0p5(y_true, y_pred) + focal_loss_gamma_1 = loss_functions.FocalLoss(gamma=1) + loss_gamma_1 = focal_loss_gamma_1(y_true, y_pred) + focal_loss_gamma_2 = loss_functions.FocalLoss(gamma=2) + loss_gamma_2 = focal_loss_gamma_2(y_true, y_pred) + focal_loss_gamma_5 = loss_functions.FocalLoss(gamma=5) + loss_gamma_5 = focal_loss_gamma_5(y_true, y_pred) + + self.assertGreater(loss_gamma_0, loss_gamma_0p5) + self.assertGreater(loss_gamma_0p5, loss_gamma_1) + self.assertGreater(loss_gamma_1, loss_gamma_2) + self.assertGreater(loss_gamma_2, loss_gamma_5) + + @parameterized.named_parameters( + dict(testcase_name='index_0', true_class=0), + dict(testcase_name='index_1', true_class=1), + dict(testcase_name='index_2', true_class=2), + ) + def test_focal_loss_class_weight_is_applied(self, true_class: int): + class_weight = [1.0, 3.0, 10.0] + y_pred = tf.constant([[1.0, 1.0, 1.0]]) / 3.0 + y_true = tf.one_hot(true_class, depth=3)[tf.newaxis, :] + expected_loss = -math.log(1.0 / 3.0) * class_weight[true_class] + + loss_fn = loss_functions.FocalLoss(gamma=0, class_weight=class_weight) + loss = loss_fn(y_true, y_pred) + self.assertNear(loss, expected_loss, 1e-4) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/core/utils/model_util.py b/mediapipe/model_maker/python/core/utils/model_util.py new file mode 100644 index 000000000..4914fea57 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/model_util.py @@ -0,0 +1,241 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for keras models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile +from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union + +# Dependency imports + +import numpy as np +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import dataset +from mediapipe.model_maker.python.core.utils import quantization + +DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0 +ESTIMITED_STEPS_PER_EPOCH = 1000 + + +def get_steps_per_epoch(steps_per_epoch: Optional[int] = None, + batch_size: Optional[int] = None, + train_data: Optional[dataset.Dataset] = None) -> int: + """Gets the estimated training steps per epoch. + + 1. If `steps_per_epoch` is set, returns `steps_per_epoch` directly. + 2. Else if we can get the length of training data successfully, returns + `train_data_length // batch_size`. + + Args: + steps_per_epoch: int, training steps per epoch. + batch_size: int, batch size. + train_data: training data. + + Returns: + Estimated training steps per epoch. + + Raises: + ValueError: if both steps_per_epoch and train_data are not set. + """ + if steps_per_epoch is not None: + # steps_per_epoch is set by users manually. + return steps_per_epoch + else: + if train_data is None: + raise ValueError('Input train_data cannot be None.') + # Gets the steps by the length of the training data. + return len(train_data) // batch_size + + +def export_tflite( + model: tf.keras.Model, + tflite_filepath: str, + quantization_config: Optional[quantization.QuantizationConfig] = None, + supported_ops: Tuple[tf.lite.OpsSet, + ...] = (tf.lite.OpsSet.TFLITE_BUILTINS,)): + """Converts the model to tflite format and saves it. + + Args: + model: model to be converted to tflite. + tflite_filepath: File path to save tflite model. + quantization_config: Configuration for post-training quantization. + supported_ops: A list of supported ops in the converted TFLite file. + """ + if tflite_filepath is None: + raise ValueError( + "TFLite filepath couldn't be None when exporting to tflite.") + + with tempfile.TemporaryDirectory() as temp_dir: + save_path = os.path.join(temp_dir, 'saved_model') + model.save(save_path, include_optimizer=False, save_format='tf') + converter = tf.lite.TFLiteConverter.from_saved_model(save_path) + + if quantization_config: + converter = quantization_config.set_converter_with_quantization(converter) + + converter.target_spec.supported_ops = supported_ops + tflite_model = converter.convert() + + with tf.io.gfile.GFile(tflite_filepath, 'wb') as f: + f.write(tflite_model) + + +class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): + """Applies a warmup schedule on a given learning rate decay schedule.""" + + def __init__(self, + initial_learning_rate: float, + decay_schedule_fn: Callable[[Any], Any], + warmup_steps: int, + name: Optional[str] = None): + """Initializes a new instance of the `WarmUp` class. + + Args: + initial_learning_rate: learning rate after the warmup. + decay_schedule_fn: A function maps step to learning rate. Will be applied + for values of step larger than 'warmup_steps'. + warmup_steps: Number of steps to do warmup for. + name: TF namescope under which to perform the learning rate calculation. + """ + super(WarmUp, self).__init__() + self.initial_learning_rate = initial_learning_rate + self.warmup_steps = warmup_steps + self.decay_schedule_fn = decay_schedule_fn + self.name = name + + def __call__(self, step: Union[int, tf.Tensor]) -> tf.Tensor: + with tf.name_scope(self.name or 'WarmUp') as name: + # Implements linear warmup. i.e., if global_step < warmup_steps, the + # learning rate will be `global_step/num_warmup_steps * init_lr`. + global_step_float = tf.cast(step, tf.float32) + warmup_steps_float = tf.cast(self.warmup_steps, tf.float32) + warmup_percent_done = global_step_float / warmup_steps_float + warmup_learning_rate = self.initial_learning_rate * warmup_percent_done + return tf.cond( + global_step_float < warmup_steps_float, + lambda: warmup_learning_rate, + lambda: self.decay_schedule_fn(step), + name=name) + + def get_config(self) -> Dict[Text, Any]: + return { + 'initial_learning_rate': self.initial_learning_rate, + 'decay_schedule_fn': self.decay_schedule_fn, + 'warmup_steps': self.warmup_steps, + 'name': self.name + } + + +class LiteRunner(object): + """A runner to do inference with the TFLite model.""" + + def __init__(self, tflite_filepath: str): + """Initializes Lite runner with tflite model file. + + Args: + tflite_filepath: File path to the TFLite model. + """ + with tf.io.gfile.GFile(tflite_filepath, 'rb') as f: + tflite_model = f.read() + self.interpreter = tf.lite.Interpreter(model_content=tflite_model) + self.interpreter.allocate_tensors() + self.input_details = self.interpreter.get_input_details() + self.output_details = self.interpreter.get_output_details() + + def run( + self, input_tensors: Union[List[tf.Tensor], Dict[str, tf.Tensor]] + ) -> Union[List[tf.Tensor], tf.Tensor]: + """Runs inference with the TFLite model. + + Args: + input_tensors: List / Dict of the input tensors of the TFLite model. The + order should be the same as the keras model if it's a list. It also + accepts tensor directly if the model has only 1 input. + + Returns: + List of the output tensors for multi-output models, otherwise just + the output tensor. The order should be the same as the keras model. + """ + + if not isinstance(input_tensors, list) and not isinstance( + input_tensors, dict): + input_tensors = [input_tensors] + + interpreter = self.interpreter + + # Reshape inputs + for i, input_detail in enumerate(self.input_details): + input_tensor = _get_input_tensor( + input_tensors=input_tensors, + input_details=self.input_details, + index=i) + interpreter.resize_tensor_input( + input_index=input_detail['index'], tensor_size=input_tensor.shape) + interpreter.allocate_tensors() + + # Feed input to the interpreter + for i, input_detail in enumerate(self.input_details): + input_tensor = _get_input_tensor( + input_tensors=input_tensors, + input_details=self.input_details, + index=i) + if input_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT): + # Quantize the input + scale, zero_point = input_detail['quantization'] + input_tensor = input_tensor / scale + zero_point + input_tensor = np.array(input_tensor, dtype=input_detail['dtype']) + interpreter.set_tensor(input_detail['index'], input_tensor) + + interpreter.invoke() + + output_tensors = [] + for output_detail in self.output_details: + output_tensor = interpreter.get_tensor(output_detail['index']) + if output_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT): + # Dequantize the output + scale, zero_point = output_detail['quantization'] + output_tensor = output_tensor.astype(np.float32) + output_tensor = (output_tensor - zero_point) * scale + output_tensors.append(output_tensor) + + if len(output_tensors) == 1: + return output_tensors[0] + return output_tensors + + +def get_lite_runner(tflite_filepath: str) -> 'LiteRunner': + """Returns a `LiteRunner` from file path to TFLite model.""" + lite_runner = LiteRunner(tflite_filepath) + return lite_runner + + +def _get_input_tensor(input_tensors: Union[List[tf.Tensor], Dict[str, + tf.Tensor]], + input_details: Dict[str, Any], index: int) -> tf.Tensor: + """Returns input tensor in `input_tensors` that maps `input_detail[i]`.""" + if isinstance(input_tensors, dict): + # Gets the mapped input tensor. + input_detail = input_details + for input_tensor_name, input_tensor in input_tensors.items(): + if input_tensor_name in input_detail['name']: + return input_tensor + raise ValueError('Input tensors don\'t contains a tensor that mapped the ' + 'input detail %s' % str(input_detail)) + else: + return input_tensors[index] diff --git a/mediapipe/model_maker/python/core/utils/model_util_test.py b/mediapipe/model_maker/python/core/utils/model_util_test.py new file mode 100644 index 000000000..9c3908841 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/model_util_test.py @@ -0,0 +1,137 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from absl.testing import parameterized +import numpy as np +import tensorflow as tf + +from mediapipe.model_maker.python.core.utils import model_util +from mediapipe.model_maker.python.core.utils import quantization +from mediapipe.model_maker.python.core.utils import test_util + + +class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name='input_only_steps_per_epoch', + steps_per_epoch=1000, + batch_size=None, + train_data=None, + expected_steps_per_epoch=1000), + dict( + testcase_name='input_steps_per_epoch_and_batch_size', + steps_per_epoch=1000, + batch_size=32, + train_data=None, + expected_steps_per_epoch=1000), + dict( + testcase_name='input_steps_per_epoch_batch_size_and_train_data', + steps_per_epoch=1000, + batch_size=32, + train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], + [1, 0]]), + expected_steps_per_epoch=1000), + dict( + testcase_name='input_batch_size_and_train_data', + steps_per_epoch=None, + batch_size=2, + train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], + [1, 0]]), + expected_steps_per_epoch=2)) + def test_get_steps_per_epoch(self, steps_per_epoch, batch_size, train_data, + expected_steps_per_epoch): + estimated_steps_per_epoch = model_util.get_steps_per_epoch( + steps_per_epoch=steps_per_epoch, + batch_size=batch_size, + train_data=train_data) + self.assertEqual(estimated_steps_per_epoch, expected_steps_per_epoch) + + def test_get_steps_per_epoch_raise_value_error(self): + with self.assertRaises(ValueError): + model_util.get_steps_per_epoch( + steps_per_epoch=None, batch_size=16, train_data=None) + + def test_warmup(self): + init_lr = 0.1 + warmup_steps = 1000 + num_decay_steps = 100 + learning_rate_fn = tf.keras.experimental.CosineDecay( + initial_learning_rate=init_lr, decay_steps=num_decay_steps) + warmup_object = model_util.WarmUp( + initial_learning_rate=init_lr, + decay_schedule_fn=learning_rate_fn, + warmup_steps=1000, + name='test') + self.assertEqual( + warmup_object.get_config(), { + 'initial_learning_rate': init_lr, + 'decay_schedule_fn': learning_rate_fn, + 'warmup_steps': warmup_steps, + 'name': 'test' + }) + + def test_export_tflite(self): + input_dim = 4 + model = test_util.build_model(input_shape=[input_dim], num_classes=2) + tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') + model_util.export_tflite(model, tflite_file) + self._test_tflite(model, tflite_file, input_dim) + + @parameterized.named_parameters( + dict( + testcase_name='dynamic_quantize', + config=quantization.QuantizationConfig.for_dynamic(), + model_size=1288), + dict( + testcase_name='int8_quantize', + config=quantization.QuantizationConfig.for_int8( + representative_data=test_util.create_dataset( + data_size=10, input_shape=[16], num_classes=3)), + model_size=1832), + dict( + testcase_name='float16_quantize', + config=quantization.QuantizationConfig.for_float16(), + model_size=1468)) + def test_export_tflite_quantized(self, config, model_size): + input_dim = 16 + num_classes = 2 + max_input_value = 5 + model = test_util.build_model([input_dim], num_classes) + tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite') + + model_util.export_tflite(model, tflite_file, config) + self._test_tflite( + model, tflite_file, input_dim, max_input_value, atol=1e-00) + self.assertNear(os.path.getsize(tflite_file), model_size, 300) + + def _test_tflite(self, + keras_model: tf.keras.Model, + tflite_model_file: str, + input_dim: int, + max_input_value: int = 1000, + atol: float = 1e-04): + np.random.seed(0) + random_input = np.random.uniform( + low=0, high=max_input_value, size=(1, input_dim)).astype(np.float32) + + self.assertTrue( + test_util.is_same_output( + tflite_model_file, keras_model, random_input, atol=atol)) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/core/utils/quantization.py b/mediapipe/model_maker/python/core/utils/quantization.py new file mode 100644 index 000000000..a1a38cc64 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/quantization.py @@ -0,0 +1,213 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Libraries for post-training quantization.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from typing import Any, Callable, List, Optional, Union + +# Dependency imports + +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import dataset as ds + +DEFAULT_QUANTIZATION_STEPS = 500 + + +def _get_representative_dataset_generator(dataset: tf.data.Dataset, + num_steps: int) -> Callable[[], Any]: + """Gets a representative dataset generator for post-training quantization. + + The generator is to provide a small dataset to calibrate or estimate the + range, i.e, (min, max) of all floating-point arrays in the model for + quantization. Usually, this is a small subset of a few hundred samples + randomly chosen, in no particular order, from the training or evaluation + dataset. See tf.lite.RepresentativeDataset for more details. + + Args: + dataset: Input dataset for extracting representative sub dataset. + num_steps: The number of quantization steps which also reflects the size of + the representative dataset. + + Returns: + A representative dataset generator. + """ + + def representative_dataset_gen(): + """Generates representative dataset for quantization.""" + for data, _ in dataset.take(num_steps): + yield [data] + + return representative_dataset_gen + + +class QuantizationConfig(object): + """Configuration for post-training quantization. + + Refer to + https://www.tensorflow.org/lite/performance/post_training_quantization + for different post-training quantization options. + """ + + def __init__( + self, + optimizations: Optional[Union[tf.lite.Optimize, + List[tf.lite.Optimize]]] = None, + representative_data: Optional[ds.Dataset] = None, + quantization_steps: Optional[int] = None, + inference_input_type: Optional[tf.dtypes.DType] = None, + inference_output_type: Optional[tf.dtypes.DType] = None, + supported_ops: Optional[Union[tf.lite.OpsSet, + List[tf.lite.OpsSet]]] = None, + supported_types: Optional[Union[tf.dtypes.DType, + List[tf.dtypes.DType]]] = None, + experimental_new_quantizer: bool = False, + ): + """Constructs QuantizationConfig. + + Args: + optimizations: A list of optimizations to apply when converting the model. + If not set, use `[Optimize.DEFAULT]` by default. + representative_data: A representative ds.Dataset for post-training + quantization. + quantization_steps: Number of post-training quantization calibration steps + to run (default to DEFAULT_QUANTIZATION_STEPS). + inference_input_type: Target data type of real-number input arrays. Allows + for a different type for input arrays. Defaults to None. If set, must be + be `{tf.float32, tf.uint8, tf.int8}`. + inference_output_type: Target data type of real-number output arrays. + Allows for a different type for output arrays. Defaults to None. If set, + must be `{tf.float32, tf.uint8, tf.int8}`. + supported_ops: Set of OpsSet options supported by the device. Used to Set + converter.target_spec.supported_ops. + supported_types: List of types for constant values on the target device. + Supported values are types exported by lite.constants. Frequently, an + optimization choice is driven by the most compact (i.e. smallest) type + in this list (default [constants.FLOAT]). + experimental_new_quantizer: Whether to enable experimental new quantizer. + + Raises: + ValueError: if inference_input_type or inference_output_type are set but + not in {tf.float32, tf.uint8, tf.int8}. + """ + if inference_input_type is not None and inference_input_type not in { + tf.float32, tf.uint8, tf.int8 + }: + raise ValueError('Unsupported inference_input_type %s' % + inference_input_type) + if inference_output_type is not None and inference_output_type not in { + tf.float32, tf.uint8, tf.int8 + }: + raise ValueError('Unsupported inference_output_type %s' % + inference_output_type) + + if optimizations is None: + optimizations = [tf.lite.Optimize.DEFAULT] + if not isinstance(optimizations, list): + optimizations = [optimizations] + self.optimizations = optimizations + + self.representative_data = representative_data + if self.representative_data is not None and quantization_steps is None: + quantization_steps = DEFAULT_QUANTIZATION_STEPS + self.quantization_steps = quantization_steps + + self.inference_input_type = inference_input_type + self.inference_output_type = inference_output_type + + if supported_ops is not None and not isinstance(supported_ops, list): + supported_ops = [supported_ops] + self.supported_ops = supported_ops + + if supported_types is not None and not isinstance(supported_types, list): + supported_types = [supported_types] + self.supported_types = supported_types + + self.experimental_new_quantizer = experimental_new_quantizer + + @classmethod + def for_dynamic(cls) -> 'QuantizationConfig': + """Creates configuration for dynamic range quantization.""" + return QuantizationConfig() + + @classmethod + def for_int8( + cls, + representative_data: ds.Dataset, + quantization_steps: int = DEFAULT_QUANTIZATION_STEPS, + inference_input_type: tf.dtypes.DType = tf.uint8, + inference_output_type: tf.dtypes.DType = tf.uint8, + supported_ops: tf.lite.OpsSet = tf.lite.OpsSet.TFLITE_BUILTINS_INT8 + ) -> 'QuantizationConfig': + """Creates configuration for full integer quantization. + + Args: + representative_data: Representative data used for post-training + quantization. + quantization_steps: Number of post-training quantization calibration steps + to run. + inference_input_type: Target data type of real-number input arrays. + inference_output_type: Target data type of real-number output arrays. + supported_ops: Set of `tf.lite.OpsSet` options, where each option + represents a set of operators supported by the target device. + + Returns: + QuantizationConfig. + """ + return QuantizationConfig( + representative_data=representative_data, + quantization_steps=quantization_steps, + inference_input_type=inference_input_type, + inference_output_type=inference_output_type, + supported_ops=supported_ops) + + @classmethod + def for_float16(cls) -> 'QuantizationConfig': + """Creates configuration for float16 quantization.""" + return QuantizationConfig(supported_types=[tf.float16]) + + def set_converter_with_quantization(self, converter: tf.lite.TFLiteConverter, + **kwargs: Any) -> tf.lite.TFLiteConverter: + """Sets input TFLite converter with quantization configurations. + + Args: + converter: input tf.lite.TFLiteConverter. + **kwargs: arguments used by ds.Dataset.gen_tf_dataset. + + Returns: + tf.lite.TFLiteConverter with quantization configurations. + """ + converter.optimizations = self.optimizations + + if self.representative_data is not None: + tf_ds = self.representative_data.gen_tf_dataset( + batch_size=1, is_training=False, **kwargs) + converter.representative_dataset = tf.lite.RepresentativeDataset( + _get_representative_dataset_generator(tf_ds, self.quantization_steps)) + + if self.inference_input_type: + converter.inference_input_type = self.inference_input_type + if self.inference_output_type: + converter.inference_output_type = self.inference_output_type + if self.supported_ops: + converter.target_spec.supported_ops = self.supported_ops + if self.supported_types: + converter.target_spec.supported_types = self.supported_types + + if self.experimental_new_quantizer is not None: + converter.experimental_new_quantizer = self.experimental_new_quantizer + return converter diff --git a/mediapipe/model_maker/python/core/utils/quantization_test.py b/mediapipe/model_maker/python/core/utils/quantization_test.py new file mode 100644 index 000000000..9d27d34ac --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/quantization_test.py @@ -0,0 +1,108 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +import tensorflow as tf + +from mediapipe.model_maker.python.core.utils import quantization +from mediapipe.model_maker.python.core.utils import test_util + + +class QuantizationTest(tf.test.TestCase, parameterized.TestCase): + + def test_create_dynamic_quantization_config(self): + config = quantization.QuantizationConfig.for_dynamic() + self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT]) + self.assertIsNone(config.representative_data) + self.assertIsNone(config.inference_input_type) + self.assertIsNone(config.inference_output_type) + self.assertIsNone(config.supported_ops) + self.assertIsNone(config.supported_types) + self.assertFalse(config.experimental_new_quantizer) + + def test_create_int8_quantization_config(self): + representative_data = test_util.create_dataset( + data_size=10, input_shape=[4], num_classes=3) + config = quantization.QuantizationConfig.for_int8( + representative_data=representative_data) + self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT]) + self.assertEqual(config.inference_input_type, tf.uint8) + self.assertEqual(config.inference_output_type, tf.uint8) + self.assertEqual(config.supported_ops, + [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]) + self.assertFalse(config.experimental_new_quantizer) + + def test_set_converter_with_quantization_from_int8_config(self): + representative_data = test_util.create_dataset( + data_size=10, input_shape=[4], num_classes=3) + config = quantization.QuantizationConfig.for_int8( + representative_data=representative_data) + model = test_util.build_model(input_shape=[4], num_classes=3) + saved_model_dir = self.get_temp_dir() + model.save(saved_model_dir) + converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) + converter = config.set_converter_with_quantization(converter=converter) + self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT]) + self.assertEqual(config.inference_input_type, tf.uint8) + self.assertEqual(config.inference_output_type, tf.uint8) + self.assertEqual(config.supported_ops, + [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]) + tflite_model = converter.convert() + interpreter = tf.lite.Interpreter(model_content=tflite_model) + self.assertEqual(interpreter.get_input_details()[0]['dtype'], tf.uint8) + self.assertEqual(interpreter.get_output_details()[0]['dtype'], tf.uint8) + + def test_create_float16_quantization_config(self): + config = quantization.QuantizationConfig.for_float16() + self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT]) + self.assertIsNone(config.representative_data) + self.assertIsNone(config.inference_input_type) + self.assertIsNone(config.inference_output_type) + self.assertIsNone(config.supported_ops) + self.assertEqual(config.supported_types, [tf.float16]) + self.assertFalse(config.experimental_new_quantizer) + + def test_set_converter_with_quantization_from_float16_config(self): + config = quantization.QuantizationConfig.for_float16() + model = test_util.build_model(input_shape=[4], num_classes=3) + saved_model_dir = self.get_temp_dir() + model.save(saved_model_dir) + converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) + converter = config.set_converter_with_quantization(converter=converter) + self.assertEqual(config.supported_types, [tf.float16]) + tflite_model = converter.convert() + interpreter = tf.lite.Interpreter(model_content=tflite_model) + # The input and output are expected to be set to float32 by default. + self.assertEqual(interpreter.get_input_details()[0]['dtype'], tf.float32) + self.assertEqual(interpreter.get_output_details()[0]['dtype'], tf.float32) + + @parameterized.named_parameters( + dict( + testcase_name='invalid_inference_input_type', + inference_input_type=tf.uint8, + inference_output_type=tf.int64), + dict( + testcase_name='invalid_inference_output_type', + inference_input_type=tf.int64, + inference_output_type=tf.float32)) + def test_create_quantization_config_failure(self, inference_input_type, + inference_output_type): + with self.assertRaises(ValueError): + _ = quantization.QuantizationConfig( + inference_input_type=inference_input_type, + inference_output_type=inference_output_type) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/core/utils/test_util.py b/mediapipe/model_maker/python/core/utils/test_util.py new file mode 100644 index 000000000..eb2952dd3 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/test_util.py @@ -0,0 +1,76 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test utilities for model maker.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from typing import List, Union + +# Dependency imports + +import numpy as np +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import dataset as ds +from mediapipe.model_maker.python.core.utils import model_util + + +def create_dataset(data_size: int, + input_shape: List[int], + num_classes: int, + max_input_value: int = 1000) -> ds.Dataset: + """Creates and returns a simple `Dataset` object for test.""" + features = tf.random.uniform( + shape=[data_size] + input_shape, + minval=0, + maxval=max_input_value, + dtype=tf.float32) + + labels = tf.random.uniform( + shape=[data_size], minval=0, maxval=num_classes, dtype=tf.int32) + + tf_dataset = tf.data.Dataset.from_tensor_slices((features, labels)) + dataset = ds.Dataset(tf_dataset, data_size) + return dataset + + +def build_model(input_shape: List[int], num_classes: int) -> tf.keras.Model: + """Builds a simple Keras model for test.""" + inputs = tf.keras.layers.Input(shape=input_shape) + if len(input_shape) == 3: # Image inputs. + outputs = tf.keras.layers.GlobalAveragePooling2D()(inputs) + outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(outputs) + elif len(input_shape) == 1: # Text inputs. + outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(inputs) + else: + raise ValueError("Model inputs should be 2D tensor or 4D tensor.") + + model = tf.keras.Model(inputs=inputs, outputs=outputs) + return model + + +def is_same_output(tflite_file: str, + keras_model: tf.keras.Model, + input_tensors: Union[List[tf.Tensor], tf.Tensor], + atol: float = 1e-04) -> bool: + """Returns if the output of TFLite model and keras model are identical.""" + # Gets output from lite model. + lite_runner = model_util.get_lite_runner(tflite_file) + lite_output = lite_runner.run(input_tensors) + + # Gets output from keras model. + keras_output = keras_model.predict_on_batch(input_tensors) + + return np.allclose(lite_output, keras_output, atol=atol) diff --git a/mediapipe/model_maker/requirements.txt b/mediapipe/model_maker/requirements.txt new file mode 100644 index 000000000..5e3832b09 --- /dev/null +++ b/mediapipe/model_maker/requirements.txt @@ -0,0 +1,4 @@ +absl-py +numpy +opencv-contrib-python +tensorflow diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index 649ff2c11..cd5933ee6 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -507,8 +507,11 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { } }; +// REGISTER_MEDIAPIPE_GRAPH argument has to fit on one line to work properly. +// clang-format off REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::components::processors::ClassificationPostprocessingGraph); // NOLINT + ::mediapipe::tasks::components::processors::ClassificationPostprocessingGraph); // NOLINT +// clang-format on } // namespace processors } // namespace components diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD similarity index 83% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD rename to mediapipe/tasks/cc/vision/gesture_recognizer/BUILD index 9e2d9bd17..c9319e946 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -41,8 +41,8 @@ cc_test( ) cc_library( - name = "hand_gesture_recognizer_subgraph", - srcs = ["hand_gesture_recognizer_subgraph.cc"], + name = "hand_gesture_recognizer_graph", + srcs = ["hand_gesture_recognizer_graph.cc"], deps = [ "//mediapipe/calculators/core:concatenate_vector_calculator", "//mediapipe/calculators/tensor:tensor_converter_calculator", @@ -62,11 +62,11 @@ cc_library( "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators:handedness_to_matrix_calculator", - "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators:landmarks_to_matrix_calculator", - "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto:hand_gesture_recognizer_subgraph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto:landmarks_to_matrix_calculator_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_subgraph", + "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:handedness_to_matrix_calculator", + "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator", + "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "//mediapipe/tasks/metadata:metadata_schema_cc", "@com_google_absl//absl/status", diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD similarity index 84% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/BUILD rename to mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD index 4863c8682..a6de4f950 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD @@ -12,11 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + package(default_visibility = [ "//mediapipe/app/xeno:__subpackages__", "//mediapipe/tasks:internal", ]) +mediapipe_proto_library( + name = "landmarks_to_matrix_calculator_proto", + srcs = ["landmarks_to_matrix_calculator.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) + cc_library( name = "handedness_to_matrix_calculator", srcs = ["handedness_to_matrix_calculator.cc"], @@ -25,7 +37,7 @@ cc_library( "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:ret_check", - "//mediapipe/tasks/cc/vision/hand_gesture_recognizer:handedness_util", + "//mediapipe/tasks/cc/vision/gesture_recognizer:handedness_util", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -53,11 +65,11 @@ cc_library( name = "landmarks_to_matrix_calculator", srcs = ["landmarks_to_matrix_calculator.cc"], deps = [ + ":landmarks_to_matrix_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:ret_check", - "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto:landmarks_to_matrix_calculator_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/handedness_to_matrix_calculator.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator.cc similarity index 90% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/handedness_to_matrix_calculator.cc rename to mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator.cc index 746293d21..b6c973a1b 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/handedness_to_matrix_calculator.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator.cc @@ -26,14 +26,16 @@ limitations under the License. #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h" +// TODO Update to use API2 namespace mediapipe { -namespace tasks { -namespace vision { +namespace api2 { namespace { +using ::mediapipe::tasks::vision::gesture_recognizer::GetLeftHandScore; + constexpr char kHandednessTag[] = "HANDEDNESS"; constexpr char kHandednessMatrixTag[] = "HANDEDNESS_MATRIX"; @@ -71,6 +73,8 @@ class HandednessToMatrixCalculator : public CalculatorBase { return absl::OkStatus(); } + // TODO remove this after change to API2, because Setting offset + // to 0 is the default in API2 absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); return absl::OkStatus(); @@ -95,6 +99,5 @@ absl::Status HandednessToMatrixCalculator::Process(CalculatorContext* cc) { return absl::OkStatus(); } -} // namespace vision -} // namespace tasks +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc similarity index 97% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc rename to mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc index c93c48ac5..17b16bf80 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc @@ -28,8 +28,6 @@ limitations under the License. #include "mediapipe/framework/port/status_matchers.h" namespace mediapipe { -namespace tasks { -namespace vision { namespace { @@ -95,6 +93,4 @@ INSTANTIATE_TEST_CASE_P( } // namespace -} // namespace vision -} // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc similarity index 96% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc rename to mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc index 990e99920..b70689eaf 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc @@ -27,13 +27,11 @@ limitations under the License. #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h" +// TODO Update to use API2 namespace mediapipe { -namespace tasks { -namespace vision { - -using proto::LandmarksToMatrixCalculatorOptions; +namespace api2 { namespace { @@ -175,7 +173,7 @@ absl::Status ProcessLandmarks(LandmarkListT landmarks, CalculatorContext* cc) { // input_stream: "IMAGE_SIZE:image_size" // output_stream: "LANDMARKS_MATRIX:landmarks_matrix" // options { -// [mediapipe.tasks.vision.proto.LandmarksToMatrixCalculatorOptions.ext] { +// [mediapipe.LandmarksToMatrixCalculatorOptions.ext] { // object_normalization: true // object_normalization_origin_offset: 0 // } @@ -221,6 +219,5 @@ absl::Status LandmarksToMatrixCalculator::Process(CalculatorContext* cc) { return absl::OkStatus(); } -} // namespace vision -} // namespace tasks +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.proto similarity index 97% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.proto rename to mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.proto index 6b004e203..10b034447 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.vision.proto; +package mediapipe; import "mediapipe/framework/calculator.proto"; diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc similarity index 96% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc rename to mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc index 05d238f66..8a68d8dae 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc @@ -28,8 +28,6 @@ limitations under the License. #include "mediapipe/framework/port/status_matchers.h" namespace mediapipe { -namespace tasks { -namespace vision { namespace { @@ -72,8 +70,7 @@ TEST_P(Landmarks2dToMatrixCalculatorTest, OutputsCorrectResult) { input_stream: "IMAGE_SIZE:image_size" output_stream: "LANDMARKS_MATRIX:landmarks_matrix" options { - [mediapipe.tasks.vision.proto.LandmarksToMatrixCalculatorOptions - .ext] { + [mediapipe.LandmarksToMatrixCalculatorOptions.ext] { object_normalization: $0 object_normalization_origin_offset: $1 } @@ -145,8 +142,7 @@ TEST_P(LandmarksWorld3dToMatrixCalculatorTest, OutputsCorrectResult) { input_stream: "IMAGE_SIZE:image_size" output_stream: "LANDMARKS_MATRIX:landmarks_matrix" options { - [mediapipe.tasks.vision.proto.LandmarksToMatrixCalculatorOptions - .ext] { + [mediapipe.LandmarksToMatrixCalculatorOptions.ext] { object_normalization: $0 object_normalization_origin_offset: $1 } @@ -202,6 +198,4 @@ INSTANTIATE_TEST_CASE_P( } // namespace -} // namespace vision -} // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc similarity index 80% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc rename to mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index 247d8453d..05bc607ae 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -34,14 +34,15 @@ limitations under the License. #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/utils.h" -#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.pb.h" -#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" namespace mediapipe { namespace tasks { namespace vision { +namespace gesture_recognizer { namespace { @@ -50,9 +51,8 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; -using ::mediapipe::tasks::vision::hand_gesture_recognizer::proto:: - HandGestureRecognizerSubgraphOptions; -using ::mediapipe::tasks::vision::proto::LandmarksToMatrixCalculatorOptions; +using ::mediapipe::tasks::vision::gesture_recognizer::proto:: + HandGestureRecognizerGraphOptions; constexpr char kHandednessTag[] = "HANDEDNESS"; constexpr char kLandmarksTag[] = "LANDMARKS"; @@ -70,18 +70,6 @@ constexpr char kIndexTag[] = "INDEX"; constexpr char kIterableTag[] = "ITERABLE"; constexpr char kBatchEndTag[] = "BATCH_END"; -absl::Status SanityCheckOptions( - const HandGestureRecognizerSubgraphOptions& options) { - if (options.min_tracking_confidence() < 0 || - options.min_tracking_confidence() > 1) { - return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, - "Invalid `min_tracking_confidence` option: " - "value must be in the range [0.0, 1.0]", - MediaPipeTasksStatus::kInvalidArgumentError); - } - return absl::OkStatus(); -} - Source> ConvertMatrixToTensor(Source matrix, Graph& graph) { auto& node = graph.AddNode("TensorConverterCalculator"); @@ -91,9 +79,10 @@ Source> ConvertMatrixToTensor(Source matrix, } // namespace -// A "mediapipe.tasks.vision.SingleHandGestureRecognizerSubgraph" performs -// single hand gesture recognition. This graph is used as a building block for -// mediapipe.tasks.vision.HandGestureRecognizerGraph. +// A +// "mediapipe.tasks.vision.gesture_recognizer.SingleHandGestureRecognizerGraph" +// performs single hand gesture recognition. This graph is used as a building +// block for mediapipe.tasks.vision.GestureRecognizerGraph. // // Inputs: // HANDEDNESS - ClassificationList @@ -113,14 +102,15 @@ Source> ConvertMatrixToTensor(Source matrix, // // Example: // node { -// calculator: "mediapipe.tasks.vision.SingleHandGestureRecognizerSubgraph" +// calculator: +// "mediapipe.tasks.vision.gesture_recognizer.SingleHandGestureRecognizerGraph" // input_stream: "HANDEDNESS:handedness" // input_stream: "LANDMARKS:landmarks" // input_stream: "WORLD_LANDMARKS:world_landmarks" // input_stream: "IMAGE_SIZE:image_size" // output_stream: "HAND_GESTURES:hand_gestures" // options { -// [mediapipe.tasks.vision.hand_gesture_recognizer.proto.HandGestureRecognizerSubgraphOptions.ext] +// [mediapipe.tasks.vision.gesture_recognizer.proto.HandGestureRecognizerGraphOptions.ext] // { // base_options { // model_asset { @@ -130,19 +120,19 @@ Source> ConvertMatrixToTensor(Source matrix, // } // } // } -class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph { +class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { ASSIGN_OR_RETURN( const auto* model_resources, - CreateModelResources(sc)); + CreateModelResources(sc)); Graph graph; ASSIGN_OR_RETURN( auto hand_gestures, - BuildHandGestureRecognizerGraph( - sc->Options(), - *model_resources, graph[Input(kHandednessTag)], + BuildGestureRecognizerGraph( + sc->Options(), *model_resources, + graph[Input(kHandednessTag)], graph[Input(kLandmarksTag)], graph[Input(kWorldLandmarksTag)], graph[Input>(kImageSizeTag)], graph)); @@ -151,15 +141,13 @@ class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph { } private: - absl::StatusOr> BuildHandGestureRecognizerGraph( - const HandGestureRecognizerSubgraphOptions& graph_options, + absl::StatusOr> BuildGestureRecognizerGraph( + const HandGestureRecognizerGraphOptions& graph_options, const core::ModelResources& model_resources, Source handedness, Source hand_landmarks, Source hand_world_landmarks, Source> image_size, Graph& graph) { - MP_RETURN_IF_ERROR(SanityCheckOptions(graph_options)); - // Converts the ClassificationList to a matrix. auto& handedness_to_matrix = graph.AddNode("HandednessToMatrixCalculator"); handedness >> handedness_to_matrix.In(kHandednessTag); @@ -235,12 +223,15 @@ class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph { } }; +// clang-format off REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::vision::SingleHandGestureRecognizerSubgraph); + ::mediapipe::tasks::vision::gesture_recognizer::SingleHandGestureRecognizerGraph); // NOLINT +// clang-format on -// A "mediapipe.tasks.vision.HandGestureRecognizerSubgraph" performs multi -// hand gesture recognition. This graph is used as a building block for -// mediapipe.tasks.vision.HandGestureRecognizerGraph. +// A +// "mediapipe.tasks.vision.gesture_recognizer.MultipleHandGestureRecognizerGraph" +// performs multi hand gesture recognition. This graph is used as a building +// block for mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph. // // Inputs: // HANDEDNESS - std::vector @@ -263,7 +254,8 @@ REGISTER_MEDIAPIPE_GRAPH( // // Example: // node { -// calculator: "mediapipe.tasks.vision.HandGestureRecognizerSubgraph" +// calculator: +// "mediapipe.tasks.vision.gesture_recognizer.MultipleHandGestureRecognizerGraph" // input_stream: "HANDEDNESS:handedness" // input_stream: "LANDMARKS:landmarks" // input_stream: "WORLD_LANDMARKS:world_landmarks" @@ -271,7 +263,7 @@ REGISTER_MEDIAPIPE_GRAPH( // input_stream: "HAND_TRACKING_IDS:hand_tracking_ids" // output_stream: "HAND_GESTURES:hand_gestures" // options { -// [mediapipe.tasks.vision.hand_gesture_recognizer.proto.HandGestureRecognizerSubgraph.ext] +// [mediapipe.tasks.vision.gesture_recognizer.proto.MultipleHandGestureRecognizerGraph.ext] // { // base_options { // model_asset { @@ -281,15 +273,15 @@ REGISTER_MEDIAPIPE_GRAPH( // } // } // } -class HandGestureRecognizerSubgraph : public core::ModelTaskGraph { +class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { Graph graph; ASSIGN_OR_RETURN( auto multi_hand_gestures, - BuildMultiHandGestureRecognizerSubraph( - sc->Options(), + BuildMultiGestureRecognizerSubraph( + sc->Options(), graph[Input>(kHandednessTag)], graph[Input>(kLandmarksTag)], graph[Input>(kWorldLandmarksTag)], @@ -302,8 +294,8 @@ class HandGestureRecognizerSubgraph : public core::ModelTaskGraph { private: absl::StatusOr>> - BuildMultiHandGestureRecognizerSubraph( - const HandGestureRecognizerSubgraphOptions& graph_options, + BuildMultiGestureRecognizerSubraph( + const HandGestureRecognizerGraphOptions& graph_options, Source> multi_handedness, Source> multi_hand_landmarks, Source> multi_hand_world_landmarks, @@ -341,17 +333,18 @@ class HandGestureRecognizerSubgraph : public core::ModelTaskGraph { hand_tracking_id >> get_world_landmarks_at_index.In(kIndexTag); auto hand_world_landmarks = get_world_landmarks_at_index.Out(kItemTag); - auto& hand_gesture_recognizer_subgraph = graph.AddNode( - "mediapipe.tasks.vision.SingleHandGestureRecognizerSubgraph"); - hand_gesture_recognizer_subgraph - .GetOptions() + auto& hand_gesture_recognizer_graph = graph.AddNode( + "mediapipe.tasks.vision.gesture_recognizer." + "SingleHandGestureRecognizerGraph"); + hand_gesture_recognizer_graph + .GetOptions() .CopyFrom(graph_options); - handedness >> hand_gesture_recognizer_subgraph.In(kHandednessTag); - hand_landmarks >> hand_gesture_recognizer_subgraph.In(kLandmarksTag); + handedness >> hand_gesture_recognizer_graph.In(kHandednessTag); + hand_landmarks >> hand_gesture_recognizer_graph.In(kLandmarksTag); hand_world_landmarks >> - hand_gesture_recognizer_subgraph.In(kWorldLandmarksTag); - image_size_clone >> hand_gesture_recognizer_subgraph.In(kImageSizeTag); - auto hand_gestures = hand_gesture_recognizer_subgraph.Out(kHandGesturesTag); + hand_gesture_recognizer_graph.In(kWorldLandmarksTag); + image_size_clone >> hand_gesture_recognizer_graph.In(kImageSizeTag); + auto hand_gestures = hand_gesture_recognizer_graph.Out(kHandGesturesTag); auto& end_loop_classification_results = graph.AddNode("mediapipe.tasks.EndLoopClassificationResultCalculator"); @@ -364,9 +357,12 @@ class HandGestureRecognizerSubgraph : public core::ModelTaskGraph { } }; +// clang-format off REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::vision::HandGestureRecognizerSubgraph); + ::mediapipe::tasks::vision::gesture_recognizer::MultipleHandGestureRecognizerGraph); // NOLINT +// clang-format on +} // namespace gesture_recognizer } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.cc similarity index 93% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.cc rename to mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.cc index 00e19cdb5..60ccae92c 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h" #include @@ -25,6 +25,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace gesture_recognizer { namespace {} // namespace @@ -58,6 +59,7 @@ absl::StatusOr GetLeftHandScore( } } +} // namespace gesture_recognizer } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h similarity index 79% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h rename to mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h index 74e04b8cc..ae4137d0f 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_ -#define MEDIAPIPE_TASKS_CC_VISION_HAND_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_ +#ifndef MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_ #include "absl/status/statusor.h" #include "mediapipe/framework/formats/classification.pb.h" @@ -22,6 +22,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace gesture_recognizer { bool IsLeftHand(const mediapipe::Classification& c); @@ -30,8 +31,9 @@ bool IsRightHand(const mediapipe::Classification& c); absl::StatusOr GetLeftHandScore( const mediapipe::ClassificationList& classification_list); +} // namespace gesture_recognizer } // namespace vision } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_ +#endif // MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_ diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util_test.cc similarity index 94% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util_test.cc rename to mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util_test.cc index 51dfb5dea..40a201ae8 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util_test.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h" #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/port/gmock.h" @@ -23,6 +23,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace gesture_recognizer { namespace { TEST(GetLeftHandScore, SingleLeftHandClassification) { @@ -72,6 +73,7 @@ TEST(GetLeftHandScore, LeftAndRightLowerCaseHandClassification) { } } // namespace +} // namespace gesture_recognizer } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD similarity index 81% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD rename to mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD index 44ec611b2..7b5c65eab 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD @@ -21,8 +21,18 @@ package(default_visibility = [ licenses(["notice"]) mediapipe_proto_library( - name = "hand_gesture_recognizer_subgraph_options_proto", - srcs = ["hand_gesture_recognizer_subgraph_options.proto"], + name = "gesture_embedder_graph_options_proto", + srcs = ["gesture_embedder_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) + +mediapipe_proto_library( + name = "hand_gesture_recognizer_graph_options_proto", + srcs = ["hand_gesture_recognizer_graph_options.proto"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -30,12 +40,3 @@ mediapipe_proto_library( "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) - -mediapipe_proto_library( - name = "landmarks_to_matrix_calculator_proto", - srcs = ["landmarks_to_matrix_calculator.proto"], - deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - ], -) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto new file mode 100644 index 000000000..c12359eb3 --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto @@ -0,0 +1,30 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.vision.gesture_recognizer.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; + +message GestureEmbedderGraphOptions { + extend mediapipe.CalculatorOptions { + optional GestureEmbedderGraphOptions ext = 478825422; + } + // Base options for configuring hand gesture recognition subgraph, such as + // specifying the TfLite model file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; +} diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto similarity index 89% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto rename to mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto index d8ee95037..ac8cda15c 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto @@ -15,15 +15,15 @@ limitations under the License. // TODO Refactor naming and class structure of hand related Tasks. syntax = "proto2"; -package mediapipe.tasks.vision.hand_gesture_recognizer.proto; +package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; -message HandGestureRecognizerSubgraphOptions { +message HandGestureRecognizerGraphOptions { extend mediapipe.CalculatorOptions { - optional HandGestureRecognizerSubgraphOptions ext = 463370452; + optional HandGestureRecognizerGraphOptions ext = 463370452; } // Base options for configuring hand gesture recognition subgraph, such as // specifying the TfLite model file with metadata, accelerator options, etc. diff --git a/mediapipe/tasks/cc/vision/hand_detector/BUILD b/mediapipe/tasks/cc/vision/hand_detector/BUILD index c87cc50a6..433a30471 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/BUILD +++ b/mediapipe/tasks/cc/vision/hand_detector/BUILD @@ -51,7 +51,7 @@ cc_library( "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc index 7ead21bad..8573d718f 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -40,12 +40,13 @@ limitations under the License. #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/utils.h" -#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" namespace mediapipe { namespace tasks { namespace vision { +namespace hand_detector { namespace { @@ -53,18 +54,23 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::vision::hand_detector::proto::HandDetectorOptions; +using ::mediapipe::tasks::vision::hand_detector::proto:: + HandDetectorGraphOptions; constexpr char kImageTag[] = "IMAGE"; -constexpr char kDetectionsTag[] = "DETECTIONS"; -constexpr char kNormRectsTag[] = "NORM_RECTS"; +constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS"; +constexpr char kHandRectsTag[] = "HAND_RECTS"; +constexpr char kPalmRectsTag[] = "PALM_RECTS"; struct HandDetectionOuts { Source> palm_detections; Source> hand_rects; + Source> palm_rects; + Source image; }; void ConfigureTensorsToDetectionsCalculator( + const HandDetectorGraphOptions& tasks_options, mediapipe::TensorsToDetectionsCalculatorOptions* options) { // TODO use metadata to configure these fields. options->set_num_classes(1); @@ -77,7 +83,7 @@ void ConfigureTensorsToDetectionsCalculator( options->set_sigmoid_score(true); options->set_score_clipping_thresh(100.0); options->set_reverse_output_order(true); - options->set_min_score_thresh(0.5); + options->set_min_score_thresh(tasks_options.min_detection_confidence()); options->set_x_scale(192.0); options->set_y_scale(192.0); options->set_w_scale(192.0); @@ -134,9 +140,9 @@ void ConfigureRectTransformationCalculator( } // namespace -// A "mediapipe.tasks.vision.HandDetectorGraph" performs hand detection. The -// Hand Detection Graph is based on palm detection model, and scale the detected -// palm bounding box to enclose the detected whole hand. +// A "mediapipe.tasks.vision.hand_detector.HandDetectorGraph" performs hand +// detection. The Hand Detection Graph is based on palm detection model, and +// scale the detected palm bounding box to enclose the detected whole hand. // Accepts CPU input images and outputs Landmark on CPU. // // Inputs: @@ -144,19 +150,27 @@ void ConfigureRectTransformationCalculator( // Image to perform detection on. // // Outputs: -// DETECTIONS - std::vector +// PALM_DETECTIONS - std::vector // Detected palms with maximum `num_hands` specified in options. -// NORM_RECTS - std::vector +// HAND_RECTS - std::vector // Detected hand bounding boxes in normalized coordinates. +// PLAM_RECTS - std::vector +// Detected palm bounding boxes in normalized coordinates. +// IMAGE - Image +// The input image that the hand detector runs on and has the pixel data +// stored on the target storage (CPU vs GPU). // // Example: // node { -// calculator: "mediapipe.tasks.vision.HandDetectorGraph" +// calculator: "mediapipe.tasks.vision.hand_detector.HandDetectorGraph" // input_stream: "IMAGE:image" -// output_stream: "DETECTIONS:palm_detections" -// output_stream: "NORM_RECTS:hand_rects_from_palm_detections" +// output_stream: "PALM_DETECTIONS:palm_detections" +// output_stream: "HAND_RECTS:hand_rects_from_palm_detections" +// output_stream: "PALM_RECTS:palm_rects" +// output_stream: "IMAGE:image_out" // options { -// [mediapipe.tasks.hand_detector.proto.HandDetectorOptions.ext] { +// [mediapipe.tasks.vision.hand_detector.proto.HandDetectorGraphOptions.ext] +// { // base_options { // model_asset { // file_name: "palm_detection.tflite" @@ -173,16 +187,20 @@ class HandDetectorGraph : public core::ModelTaskGraph { absl::StatusOr GetConfig( SubgraphContext* sc) override { ASSIGN_OR_RETURN(const auto* model_resources, - CreateModelResources(sc)); + CreateModelResources(sc)); Graph graph; - ASSIGN_OR_RETURN(auto hand_detection_outs, - BuildHandDetectionSubgraph( - sc->Options(), *model_resources, - graph[Input(kImageTag)], graph)); + ASSIGN_OR_RETURN( + auto hand_detection_outs, + BuildHandDetectionSubgraph(sc->Options(), + *model_resources, + graph[Input(kImageTag)], graph)); hand_detection_outs.palm_detections >> - graph[Output>(kDetectionsTag)]; + graph[Output>(kPalmDetectionsTag)]; hand_detection_outs.hand_rects >> - graph[Output>(kNormRectsTag)]; + graph[Output>(kHandRectsTag)]; + hand_detection_outs.palm_rects >> + graph[Output>(kPalmRectsTag)]; + hand_detection_outs.image >> graph[Output(kImageTag)]; return graph.GetConfig(); } @@ -196,7 +214,7 @@ class HandDetectorGraph : public core::ModelTaskGraph { // image_in: image stream to run hand detection on. // graph: the mediapipe builder::Graph instance to be updated. absl::StatusOr BuildHandDetectionSubgraph( - const HandDetectorOptions& subgraph_options, + const HandDetectorGraphOptions& subgraph_options, const core::ModelResources& model_resources, Source image_in, Graph& graph) { // Add image preprocessing subgraph. The model expects aspect ratio @@ -235,6 +253,7 @@ class HandDetectorGraph : public core::ModelTaskGraph { auto& tensors_to_detections = graph.AddNode("TensorsToDetectionsCalculator"); ConfigureTensorsToDetectionsCalculator( + subgraph_options, &tensors_to_detections .GetOptions()); model_output_tensors >> tensors_to_detections.In("TENSORS"); @@ -281,7 +300,8 @@ class HandDetectorGraph : public core::ModelTaskGraph { .GetOptions()); palm_detections >> detections_to_rects.In("DETECTIONS"); image_size >> detections_to_rects.In("IMAGE_SIZE"); - auto palm_rects = detections_to_rects.Out("NORM_RECTS"); + auto palm_rects = + detections_to_rects[Output>("NORM_RECTS")]; // Expands and shifts the rectangle that contains the palm so that it's // likely to cover the entire hand. @@ -308,13 +328,18 @@ class HandDetectorGraph : public core::ModelTaskGraph { clip_normalized_rect_vector_size[Output>( "")]; - return HandDetectionOuts{.palm_detections = palm_detections, - .hand_rects = clipped_hand_rects}; + return HandDetectionOuts{ + /* palm_detections= */ palm_detections, + /* hand_rects= */ clipped_hand_rects, + /* palm_rects= */ palm_rects, + /* image= */ preprocessing[Output(kImageTag)]}; } }; -REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::HandDetectorGraph); +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::hand_detector::HandDetectorGraph); +} // namespace hand_detector } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc index 3fa97664e..11cfc3026 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc @@ -40,13 +40,14 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" -#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" namespace mediapipe { namespace tasks { namespace vision { +namespace hand_detector { namespace { using ::file::Defaults; @@ -60,7 +61,8 @@ using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::TaskRunner; using ::mediapipe::tasks::core::proto::ExternalFile; using ::mediapipe::tasks::vision::DecodeImageFromFile; -using ::mediapipe::tasks::vision::hand_detector::proto::HandDetectorOptions; +using ::mediapipe::tasks::vision::hand_detector::proto:: + HandDetectorGraphOptions; using ::mediapipe::tasks::vision::hand_detector::proto::HandDetectorResult; using ::testing::EqualsProto; using ::testing::TestParamInfo; @@ -80,9 +82,9 @@ constexpr char kTwoHandsResultFile[] = "hand_detector_result_two_hands.pbtxt"; constexpr char kImageTag[] = "IMAGE"; constexpr char kImageName[] = "image"; -constexpr char kPalmDetectionsTag[] = "DETECTIONS"; +constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS"; constexpr char kPalmDetectionsName[] = "palm_detections"; -constexpr char kHandNormRectsTag[] = "NORM_RECTS"; +constexpr char kHandRectsTag[] = "HAND_RECTS"; constexpr char kHandNormRectsName[] = "hand_norm_rects"; constexpr float kPalmDetectionBboxMaxDiff = 0.01; @@ -104,22 +106,22 @@ absl::StatusOr> CreateTaskRunner( Graph graph; auto& hand_detection = - graph.AddNode("mediapipe.tasks.vision.HandDetectorGraph"); + graph.AddNode("mediapipe.tasks.vision.hand_detector.HandDetectorGraph"); - auto options = std::make_unique(); + auto options = std::make_unique(); options->mutable_base_options()->mutable_model_asset()->set_file_name( JoinPath("./", kTestDataDirectory, model_name)); options->set_min_detection_confidence(0.5); options->set_num_hands(num_hands); - hand_detection.GetOptions().Swap(options.get()); + hand_detection.GetOptions().Swap(options.get()); graph[Input(kImageTag)].SetName(kImageName) >> hand_detection.In(kImageTag); hand_detection.Out(kPalmDetectionsTag).SetName(kPalmDetectionsName) >> graph[Output>(kPalmDetectionsTag)]; - hand_detection.Out(kHandNormRectsTag).SetName(kHandNormRectsName) >> - graph[Output>(kHandNormRectsTag)]; + hand_detection.Out(kHandRectsTag).SetName(kHandNormRectsName) >> + graph[Output>(kHandRectsTag)]; return TaskRunner::Create( graph.GetConfig(), std::make_unique()); @@ -200,6 +202,7 @@ INSTANTIATE_TEST_SUITE_P( }); } // namespace +} // namespace hand_detector } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_detector/proto/BUILD b/mediapipe/tasks/cc/vision/hand_detector/proto/BUILD index 2d22aab10..77f3b2649 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/proto/BUILD +++ b/mediapipe/tasks/cc/vision/hand_detector/proto/BUILD @@ -21,8 +21,8 @@ package(default_visibility = [ licenses(["notice"]) mediapipe_proto_library( - name = "hand_detector_options_proto", - srcs = ["hand_detector_options.proto"], + name = "hand_detector_graph_options_proto", + srcs = ["hand_detector_graph_options.proto"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", diff --git a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.proto b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto similarity index 76% rename from mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.proto rename to mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto index ae22c7991..be20583d0 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.proto +++ b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto @@ -21,24 +21,20 @@ import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.handdetector"; -option java_outer_classname = "HandDetectorOptionsProto"; +option java_outer_classname = "HandDetectorGraphOptionsProto"; -message HandDetectorOptions { +message HandDetectorGraphOptions { extend mediapipe.CalculatorOptions { - optional HandDetectorOptions ext = 464864288; + optional HandDetectorGraphOptions ext = 464864288; } // Base options for configuring Task library, such as specifying the TfLite // model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; - // The locale to use for display names specified through the TFLite Model - // Metadata, if any. Defaults to English. - optional string display_names_locale = 2 [default = "en"]; - // Minimum confidence value ([0.0, 1.0]) for confidence score to be considered // successfully detecting a hand in the image. - optional float min_detection_confidence = 3 [default = 0.5]; + optional float min_detection_confidence = 2 [default = 0.5]; // The maximum number of hands output by the detector. - optional int32 num_hands = 4; + optional int32 num_hands = 3; } diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 653976b96..a2bb458db 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -19,10 +19,10 @@ package(default_visibility = [ licenses(["notice"]) cc_library( - name = "hand_landmarker_subgraph", - srcs = ["hand_landmarker_subgraph.cc"], + name = "hand_landmarks_detector_graph", + srcs = ["hand_landmarks_detector_graph.cc"], deps = [ - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_subgraph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "//mediapipe/calculators/core:split_vector_calculator", @@ -51,6 +51,7 @@ cc_library( # TODO: move calculators in modules/hand_landmark/calculators to tasks dir. "//mediapipe/modules/hand_landmark/calculators:hand_landmarks_to_rect_calculator", "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/utils:gate", "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", @@ -66,3 +67,41 @@ cc_library( ) # TODO: Enable this test + +cc_library( + name = "hand_landmarker_graph", + srcs = ["hand_landmarker_graph.cc"], + deps = [ + ":hand_landmarks_detector_graph", + "//mediapipe/calculators/core:begin_loop_calculator", + "//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto", + "//mediapipe/calculators/core:end_loop_calculator", + "//mediapipe/calculators/core:gate_calculator", + "//mediapipe/calculators/core:gate_calculator_cc_proto", + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/calculators/core:previous_loopback_calculator", + "//mediapipe/calculators/util:collection_has_min_size_calculator", + "//mediapipe/calculators/util:collection_has_min_size_calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/utils:gate", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator", + "//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + ], + alwayslink = 1, +) + +# TODO: Enable this test diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc new file mode 100644 index 000000000..949c06520 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -0,0 +1,286 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h" +#include "mediapipe/calculators/core/gate_calculator.pb.h" +#include "mediapipe/calculators/util/collection_has_min_size_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/utils/gate.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace hand_landmarker { + +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::utils::DisallowIf; +using ::mediapipe::tasks::vision::hand_detector::proto:: + HandDetectorGraphOptions; +using ::mediapipe::tasks::vision::hand_landmarker::proto:: + HandLandmarkerGraphOptions; +using ::mediapipe::tasks::vision::hand_landmarker::proto:: + HandLandmarksDetectorGraphOptions; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kLandmarksTag[] = "LANDMARKS"; +constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; +constexpr char kHandRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME"; +constexpr char kHandednessTag[] = "HANDEDNESS"; +constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS"; +constexpr char kPalmRectsTag[] = "PALM_RECTS"; +constexpr char kPreviousLoopbackCalculatorName[] = "PreviousLoopbackCalculator"; + +struct HandLandmarkerOutputs { + Source> landmark_lists; + Source> world_landmark_lists; + Source> hand_rects_next_frame; + Source> handednesses; + Source> palm_rects; + Source> palm_detections; + Source image; +}; + +} // namespace + +// A "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph" performs hand +// landmarks detection. The HandLandmarkerGraph consists of two subgraphs: +// HandDetectorGraph and MultipleHandLandmarksDetectorGraph. +// MultipleHandLandmarksDetectorGraph detects landmarks from bounding boxes +// produced by HandDetectorGraph. HandLandmarkerGraph tracks the landmarks over +// time, and skips the HandDetectorGraph. If the tracking is lost or the detectd +// hands are less than configured max number hands, HandDetectorGraph would be +// triggered to detect hands. +// +// Accepts CPU input images and outputs Landmarks on CPU. +// +// Inputs: +// IMAGE - Image +// Image to perform hand landmarks detection on. +// +// Outputs: +// LANDMARKS: - std::vector +// Vector of detected hand landmarks. +// WORLD_LANDMARKS - std::vector +// Vector of detected hand landmarks in world coordinates. +// HAND_RECT_NEXT_FRAME - std::vector +// Vector of the predicted rects enclosing the same hand RoI for landmark +// detection on the next frame. +// HANDEDNESS - std::vector +// Vector of classification of handedness. +// PALM_RECTS - std::vector +// Detected palm bounding boxes in normalized coordinates. +// PALM_DETECTIONS - std::vector +// Detected palms with maximum `num_hands` specified in options. +// IMAGE - Image +// The input image that the hand landmarker runs on and has the pixel data +// stored on the target storage (CPU vs GPU). +// +// Example: +// node { +// calculator: "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph" +// input_stream: "IMAGE:image_in" +// output_stream: "LANDMARKS:hand_landmarks" +// output_stream: "WORLD_LANDMARKS:world_hand_landmarks" +// output_stream: "HAND_RECT_NEXT_FRAME:hand_rect_next_frame" +// output_stream: "HANDEDNESS:handedness" +// output_stream: "PALM_RECTS:palm_rects" +// output_stream: "PALM_DETECTIONS:palm_detections" +// output_stream: "IMAGE:image_out" +// options { +// [mediapipe.tasks.hand_landmarker.proto.HandLandmarkerGraphOptions.ext] { +// base_options { +// model_asset { +// file_name: "hand_landmarker.task" +// } +// } +// hand_detector_graph_options { +// base_options { +// model_asset { +// file_name: "palm_detection.tflite" +// } +// } +// min_detection_confidence: 0.5 +// num_hands: 2 +// } +// hand_landmarks_detector_graph_options { +// base_options { +// model_asset { +// file_name: "hand_landmark_lite.tflite" +// } +// } +// min_detection_confidence: 0.5 +// } +// } +// } +// } +class HandLandmarkerGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + Graph graph; + ASSIGN_OR_RETURN( + auto hand_landmarker_outputs, + BuildHandLandmarkerGraph(sc->Options(), + graph[Input(kImageTag)], graph)); + hand_landmarker_outputs.landmark_lists >> + graph[Output>(kLandmarksTag)]; + hand_landmarker_outputs.world_landmark_lists >> + graph[Output>(kWorldLandmarksTag)]; + hand_landmarker_outputs.hand_rects_next_frame >> + graph[Output>(kHandRectNextFrameTag)]; + hand_landmarker_outputs.handednesses >> + graph[Output>(kHandednessTag)]; + hand_landmarker_outputs.palm_rects >> + graph[Output>(kPalmRectsTag)]; + hand_landmarker_outputs.palm_detections >> + graph[Output>(kPalmDetectionsTag)]; + hand_landmarker_outputs.image >> graph[Output(kImageTag)]; + + // TODO remove when support is fixed. + // As mediapipe GraphBuilder currently doesn't support configuring + // InputStreamInfo, modifying the CalculatorGraphConfig proto directly. + CalculatorGraphConfig config = graph.GetConfig(); + for (int i = 0; i < config.node_size(); ++i) { + if (config.node(i).calculator() == kPreviousLoopbackCalculatorName) { + auto* info = config.mutable_node(i)->add_input_stream_info(); + info->set_tag_index("LOOP"); + info->set_back_edge(true); + break; + } + } + return config; + } + + private: + // Adds a mediapipe hand landmark detection graph into the provided + // builder::Graph instance. + // + // tasks_options: the mediapipe tasks module HandLandmarkerGraphOptions. + // image_in: (mediapipe::Image) stream to run hand landmark detection on. + // graph: the mediapipe graph instance to be updated. + absl::StatusOr BuildHandLandmarkerGraph( + const HandLandmarkerGraphOptions& tasks_options, Source image_in, + Graph& graph) { + const int max_num_hands = + tasks_options.hand_detector_graph_options().num_hands(); + + auto& previous_loopback = graph.AddNode(kPreviousLoopbackCalculatorName); + image_in >> previous_loopback.In("MAIN"); + auto prev_hand_rects_from_landmarks = + previous_loopback[Output>("PREV_LOOP")]; + + auto& min_size_node = + graph.AddNode("NormalizedRectVectorHasMinSizeCalculator"); + prev_hand_rects_from_landmarks >> min_size_node.In("ITERABLE"); + min_size_node.GetOptions() + .set_min_size(max_num_hands); + auto has_enough_hands = min_size_node.Out("").Cast(); + + auto image_for_hand_detector = + DisallowIf(image_in, has_enough_hands, graph); + + auto& hand_detector = + graph.AddNode("mediapipe.tasks.vision.hand_detector.HandDetectorGraph"); + hand_detector.GetOptions().CopyFrom( + tasks_options.hand_detector_graph_options()); + image_for_hand_detector >> hand_detector.In("IMAGE"); + auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS"); + + auto& hand_association = graph.AddNode("HandAssociationCalculator"); + hand_association.GetOptions() + .set_min_similarity_threshold(tasks_options.min_tracking_confidence()); + prev_hand_rects_from_landmarks >> + hand_association[Input>::Multiple("")][0]; + hand_rects_from_hand_detector >> + hand_association[Input>::Multiple("")][1]; + auto hand_rects = hand_association.Out(""); + + auto& clip_hand_rects = + graph.AddNode("ClipNormalizedRectVectorSizeCalculator"); + clip_hand_rects.GetOptions() + .set_max_vec_size(max_num_hands); + hand_rects >> clip_hand_rects.In(""); + auto clipped_hand_rects = clip_hand_rects.Out(""); + + auto& hand_landmarks_detector_graph = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker." + "MultipleHandLandmarksDetectorGraph"); + hand_landmarks_detector_graph + .GetOptions() + .CopyFrom(tasks_options.hand_landmarks_detector_graph_options()); + image_in >> hand_landmarks_detector_graph.In("IMAGE"); + clipped_hand_rects >> hand_landmarks_detector_graph.In("HAND_RECT"); + + auto hand_rects_for_next_frame = + hand_landmarks_detector_graph[Output>( + kHandRectNextFrameTag)]; + // Back edge. + hand_rects_for_next_frame >> previous_loopback.In("LOOP"); + + // TODO: Replace PassThroughCalculator with a calculator that + // converts the pixel data to be stored on the target storage (CPU vs GPU). + auto& pass_through = graph.AddNode("PassThroughCalculator"); + image_in >> pass_through.In(""); + + return {{ + /* landmark_lists= */ hand_landmarks_detector_graph + [Output>(kLandmarksTag)], + /* world_landmark_lists= */ + hand_landmarks_detector_graph[Output>( + kWorldLandmarksTag)], + /* hand_rects_next_frame= */ hand_rects_for_next_frame, + hand_landmarks_detector_graph[Output>( + kHandednessTag)], + /* palm_rects= */ + hand_detector[Output>(kPalmRectsTag)], + /* palm_detections */ + hand_detector[Output>(kPalmDetectionsTag)], + /* image */ + pass_through[Output("")], + }}; + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::hand_landmarker::HandLandmarkerGraph); + +} // namespace hand_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc new file mode 100644 index 000000000..bce5613ff --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc @@ -0,0 +1,167 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace hand_landmarker { + +namespace { + +using ::file::Defaults; +using ::file::GetTextProto; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::core::TaskRunner; +using ::mediapipe::tasks::vision::hand_landmarker::proto:: + HandLandmarkerGraphOptions; +using ::testing::EqualsProto; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kPalmDetectionModel[] = "palm_detection_full.tflite"; +constexpr char kHandLandmarkerFullModel[] = "hand_landmark_full.tflite"; +constexpr char kLeftHandsImage[] = "left_hands.jpg"; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageName[] = "image_in"; +constexpr char kLandmarksTag[] = "LANDMARKS"; +constexpr char kLandmarksName[] = "landmarks"; +constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; +constexpr char kWorldLandmarksName[] = "world_landmarks"; +constexpr char kHandRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME"; +constexpr char kHandRectNextFrameName[] = "hand_rect_next_frame"; +constexpr char kHandednessTag[] = "HANDEDNESS"; +constexpr char kHandednessName[] = "handedness"; + +// Expected hand landmarks positions, in text proto format. +constexpr char kExpectedLeftUpHandLandmarksFilename[] = + "expected_left_up_hand_landmarks.prototxt"; +constexpr char kExpectedLeftDownHandLandmarksFilename[] = + "expected_left_down_hand_landmarks.prototxt"; + +constexpr float kFullModelFractionDiff = 0.03; // percentage +constexpr float kAbsMargin = 0.03; +constexpr int kMaxNumHands = 2; +constexpr float kMinTrackingConfidence = 0.5; + +NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) { + NormalizedLandmarkList expected_landmark_list; + MP_EXPECT_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, filename), + &expected_landmark_list, Defaults())); + return expected_landmark_list; +} + +// Helper function to create a Hand Landmarker TaskRunner. +absl::StatusOr> CreateTaskRunner() { + Graph graph; + auto& hand_landmarker_graph = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph"); + auto& options = + hand_landmarker_graph.GetOptions(); + options.mutable_hand_detector_graph_options() + ->mutable_base_options() + ->mutable_model_asset() + ->set_file_name(JoinPath("./", kTestDataDirectory, kPalmDetectionModel)); + options.mutable_hand_detector_graph_options()->mutable_base_options(); + options.mutable_hand_detector_graph_options()->set_num_hands(kMaxNumHands); + options.mutable_hand_landmarks_detector_graph_options() + ->mutable_base_options() + ->mutable_model_asset() + ->set_file_name( + JoinPath("./", kTestDataDirectory, kHandLandmarkerFullModel)); + options.set_min_tracking_confidence(kMinTrackingConfidence); + + graph[Input(kImageTag)].SetName(kImageName) >> + hand_landmarker_graph.In(kImageTag); + hand_landmarker_graph.Out(kLandmarksTag).SetName(kLandmarksName) >> + graph[Output>(kLandmarksTag)]; + hand_landmarker_graph.Out(kWorldLandmarksTag).SetName(kWorldLandmarksName) >> + graph[Output>(kWorldLandmarksTag)]; + hand_landmarker_graph.Out(kHandednessTag).SetName(kHandednessName) >> + graph[Output>(kHandednessTag)]; + hand_landmarker_graph.Out(kHandRectNextFrameTag) + .SetName(kHandRectNextFrameName) >> + graph[Output>(kHandRectNextFrameTag)]; + return TaskRunner::Create( + graph.GetConfig(), absl::make_unique()); +} + +class HandLandmarkerTest : public tflite_shims::testing::Test {}; + +TEST_F(HandLandmarkerTest, Succeeds) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kLeftHandsImage))); + MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner()); + auto output_packets = + task_runner->Process({{kImageName, MakePacket(std::move(image))}}); + const auto& landmarks = (*output_packets)[kLandmarksName] + .Get>(); + ASSERT_EQ(landmarks.size(), kMaxNumHands); + std::vector expected_landmarks = { + GetExpectedLandmarkList(kExpectedLeftUpHandLandmarksFilename), + GetExpectedLandmarkList(kExpectedLeftDownHandLandmarksFilename)}; + + EXPECT_THAT(landmarks[0], + Approximately(Partially(EqualsProto(expected_landmarks[0])), + /*margin=*/kAbsMargin, + /*fraction=*/kFullModelFractionDiff)); + EXPECT_THAT(landmarks[1], + Approximately(Partially(EqualsProto(expected_landmarks[1])), + /*margin=*/kAbsMargin, + /*fraction=*/kFullModelFractionDiff)); +} + +} // namespace + +} // namespace hand_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc similarity index 89% rename from mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph.cc rename to mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc index fff4ae0d4..23521790d 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc @@ -34,12 +34,13 @@ limitations under the License. #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/utils/gate.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" -#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/util/label_map.pb.h" @@ -48,6 +49,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace hand_landmarker { namespace { @@ -55,9 +57,10 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::utils::AllowIf; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::vision::hand_landmarker::proto:: - HandLandmarkerSubgraphOptions; + HandLandmarksDetectorGraphOptions; using LabelItems = mediapipe::proto_ns::Map; constexpr char kImageTag[] = "IMAGE"; @@ -82,7 +85,6 @@ struct SingleHandLandmarkerOutputs { Source hand_presence; Source hand_presence_score; Source handedness; - Source> image_size; }; struct HandLandmarkerOutputs { @@ -92,10 +94,10 @@ struct HandLandmarkerOutputs { Source> presences; Source> presence_scores; Source> handednesses; - Source> image_size; }; -absl::Status SanityCheckOptions(const HandLandmarkerSubgraphOptions& options) { +absl::Status SanityCheckOptions( + const HandLandmarksDetectorGraphOptions& options) { if (options.min_detection_confidence() < 0 || options.min_detection_confidence() > 1) { return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, @@ -182,8 +184,8 @@ void ConfigureHandRectTransformationCalculator( } // namespace -// A "mediapipe.tasks.vision.SingleHandLandmarkerSubgraph" performs hand -// landmark detection. +// A "mediapipe.tasks.vision.hand_landmarker.SingleHandLandmarksDetectorGraph" +// performs hand landmarks detection. // - Accepts CPU input images and outputs Landmark on CPU. // // Inputs: @@ -208,12 +210,11 @@ void ConfigureHandRectTransformationCalculator( // Float value indicates the probability that the hand is present. // HANDEDNESS - ClassificationList // Classification of handedness. -// IMAGE_SIZE - std::vector -// The size of input image. // // Example: // node { -// calculator: "mediapipe.tasks.vision.SingleHandLandmarkerSubgraph" +// calculator: +// "mediapipe.tasks.vision.hand_landmarker.SingleHandLandmarksDetectorGraph" // input_stream: "IMAGE:input_image" // input_stream: "HAND_RECT:hand_rect" // output_stream: "LANDMARKS:hand_landmarks" @@ -221,10 +222,8 @@ void ConfigureHandRectTransformationCalculator( // output_stream: "HAND_RECT_NEXT_FRAME:hand_rect_next_frame" // output_stream: "PRESENCE:hand_presence" // output_stream: "PRESENCE_SCORE:hand_presence_score" -// output_stream: "HANDEDNESS:handedness" -// output_stream: "IMAGE_SIZE:image_size" // options { -// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarkerSubgraphOptions.ext] +// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarksDetectorGraphOptions.ext] // { // base_options { // model_asset { @@ -235,16 +234,17 @@ void ConfigureHandRectTransformationCalculator( // } // } // } -class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { +class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { - ASSIGN_OR_RETURN(const auto* model_resources, - CreateModelResources(sc)); + ASSIGN_OR_RETURN( + const auto* model_resources, + CreateModelResources(sc)); Graph graph; ASSIGN_OR_RETURN(auto hand_landmark_detection_outs, - BuildSingleHandLandmarkerSubgraph( - sc->Options(), + BuildSingleHandLandmarksDetectorGraph( + sc->Options(), *model_resources, graph[Input(kImageTag)], graph[Input(kHandRectTag)], graph)); hand_landmark_detection_outs.hand_landmarks >> @@ -259,8 +259,6 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { graph[Output(kPresenceScoreTag)]; hand_landmark_detection_outs.handedness >> graph[Output(kHandednessTag)]; - hand_landmark_detection_outs.image_size >> - graph[Output>(kImageSizeTag)]; return graph.GetConfig(); } @@ -269,14 +267,16 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { // Adds a mediapipe hand landmark detection graph into the provided // builder::Graph instance. // - // subgraph_options: the mediapipe tasks module HandLandmarkerSubgraphOptions. - // model_resources: the ModelSources object initialized from a hand landmark + // subgraph_options: the mediapipe tasks module + // HandLandmarksDetectorGraphOptions. model_resources: the ModelSources object + // initialized from a hand landmark // detection model file with model metadata. // image_in: (mediapipe::Image) stream to run hand landmark detection on. // rect: (NormalizedRect) stream to run on the RoI of image. // graph: the mediapipe graph instance to be updated. - absl::StatusOr BuildSingleHandLandmarkerSubgraph( - const HandLandmarkerSubgraphOptions& subgraph_options, + absl::StatusOr + BuildSingleHandLandmarksDetectorGraph( + const HandLandmarksDetectorGraphOptions& subgraph_options, const core::ModelResources& model_resources, Source image_in, Source hand_rect, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options)); @@ -332,18 +332,7 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { // score of hand presence. auto& tensors_to_hand_presence = graph.AddNode("TensorsToFloatsCalculator"); hand_flag_tensors >> tensors_to_hand_presence.In("TENSORS"); - - // Converts the handedness tensor into a float that represents the - // classification score of handedness. - auto& tensors_to_handedness = - graph.AddNode("TensorsToClassificationCalculator"); - ConfigureTensorsToHandednessCalculator( - &tensors_to_handedness.GetOptions< - mediapipe::TensorsToClassificationCalculatorOptions>()); - handedness_tensors >> tensors_to_handedness.In("TENSORS"); auto hand_presence_score = tensors_to_hand_presence[Output("FLOAT")]; - auto handedness = - tensors_to_handedness[Output("CLASSIFICATIONS")]; // Applies a threshold to the confidence score to determine whether a // hand is present. @@ -354,6 +343,18 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { hand_presence_score >> hand_presence_thresholding.In("FLOAT"); auto hand_presence = hand_presence_thresholding[Output("FLAG")]; + // Converts the handedness tensor into a float that represents the + // classification score of handedness. + auto& tensors_to_handedness = + graph.AddNode("TensorsToClassificationCalculator"); + ConfigureTensorsToHandednessCalculator( + &tensors_to_handedness.GetOptions< + mediapipe::TensorsToClassificationCalculatorOptions>()); + handedness_tensors >> tensors_to_handedness.In("TENSORS"); + auto handedness = AllowIf( + tensors_to_handedness[Output("CLASSIFICATIONS")], + hand_presence, graph); + // Adjusts landmarks (already normalized to [0.f, 1.f]) on the letterboxed // hand image (after image transformation with the FIT scale mode) to the // corresponding locations on the same image with the letterbox removed @@ -371,8 +372,9 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { landmark_letterbox_removal.Out("LANDMARKS") >> landmark_projection.In("NORM_LANDMARKS"); hand_rect >> landmark_projection.In("NORM_RECT"); - auto projected_landmarks = - landmark_projection[Output("NORM_LANDMARKS")]; + auto projected_landmarks = AllowIf( + landmark_projection[Output("NORM_LANDMARKS")], + hand_presence, graph); // Projects the world landmarks from the cropped hand image to the // corresponding locations on the full image before cropping (input to the @@ -383,7 +385,8 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { world_landmark_projection.In("LANDMARKS"); hand_rect >> world_landmark_projection.In("NORM_RECT"); auto projected_world_landmarks = - world_landmark_projection[Output("LANDMARKS")]; + AllowIf(world_landmark_projection[Output("LANDMARKS")], + hand_presence, graph); // Converts the hand landmarks into a rectangle (normalized by image size) // that encloses the hand. @@ -403,7 +406,8 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { hand_landmarks_to_rect.Out("NORM_RECT") >> hand_rect_transformation.In("NORM_RECT"); auto hand_rect_next_frame = - hand_rect_transformation[Output("")]; + AllowIf(hand_rect_transformation[Output("")], + hand_presence, graph); return {{ /* hand_landmarks= */ projected_landmarks, @@ -412,16 +416,17 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { /* hand_presence= */ hand_presence, /* hand_presence_score= */ hand_presence_score, /* handedness= */ handedness, - /* image_size= */ image_size, }}; } }; +// clang-format off REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::vision::SingleHandLandmarkerSubgraph); + ::mediapipe::tasks::vision::hand_landmarker::SingleHandLandmarksDetectorGraph); // NOLINT +// clang-format on -// A "mediapipe.tasks.vision.HandLandmarkerSubgraph" performs multi -// hand landmark detection. +// A "mediapipe.tasks.vision.hand_landmarker.MultipleHandLandmarksDetectorGraph" +// performs multi hand landmark detection. // - Accepts CPU input image and a vector of hand rect RoIs to detect the // multiple hands landmarks enclosed by the RoIs. Output vectors of // hand landmarks related results, where each element in the vectors @@ -449,12 +454,11 @@ REGISTER_MEDIAPIPE_GRAPH( // Vector of float value indicates the probability that the hand is present. // HANDEDNESS - std::vector // Vector of classification of handedness. -// IMAGE_SIZE - std::vector -// The size of input image. // // Example: // node { -// calculator: "mediapipe.tasks.vision.HandLandmarkerSubgraph" +// calculator: +// "mediapipe.tasks.vision.hand_landmarker.MultipleHandLandmarksDetectorGraph" // input_stream: "IMAGE:input_image" // input_stream: "HAND_RECT:hand_rect" // output_stream: "LANDMARKS:hand_landmarks" @@ -463,9 +467,8 @@ REGISTER_MEDIAPIPE_GRAPH( // output_stream: "PRESENCE:hand_presence" // output_stream: "PRESENCE_SCORE:hand_presence_score" // output_stream: "HANDEDNESS:handedness" -// output_stream: "IMAGE_SIZE:image_size" // options { -// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarkerSubgraphOptions.ext] +// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarksDetectorGraphOptions.ext] // { // base_options { // model_asset { @@ -476,15 +479,15 @@ REGISTER_MEDIAPIPE_GRAPH( // } // } // } -class HandLandmarkerSubgraph : public core::ModelTaskGraph { +class MultipleHandLandmarksDetectorGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { Graph graph; ASSIGN_OR_RETURN( auto hand_landmark_detection_outputs, - BuildHandLandmarkerSubgraph( - sc->Options(), + BuildHandLandmarksDetectorGraph( + sc->Options(), graph[Input(kImageTag)], graph[Input>(kHandRectTag)], graph)); hand_landmark_detection_outputs.landmark_lists >> @@ -499,21 +502,20 @@ class HandLandmarkerSubgraph : public core::ModelTaskGraph { graph[Output>(kPresenceScoreTag)]; hand_landmark_detection_outputs.handednesses >> graph[Output>(kHandednessTag)]; - hand_landmark_detection_outputs.image_size >> - graph[Output>(kImageSizeTag)]; return graph.GetConfig(); } private: - absl::StatusOr BuildHandLandmarkerSubgraph( - const HandLandmarkerSubgraphOptions& subgraph_options, + absl::StatusOr BuildHandLandmarksDetectorGraph( + const HandLandmarksDetectorGraphOptions& subgraph_options, Source image_in, Source> multi_hand_rects, Graph& graph) { - auto& hand_landmark_subgraph = - graph.AddNode("mediapipe.tasks.vision.SingleHandLandmarkerSubgraph"); - hand_landmark_subgraph.GetOptions().CopyFrom( - subgraph_options); + auto& hand_landmark_subgraph = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker." + "SingleHandLandmarksDetectorGraph"); + hand_landmark_subgraph.GetOptions() + .CopyFrom(subgraph_options); auto& begin_loop_multi_hand_rects = graph.AddNode("BeginLoopNormalizedRectCalculator"); @@ -533,8 +535,6 @@ class HandLandmarkerSubgraph : public core::ModelTaskGraph { hand_landmark_subgraph.Out("HAND_RECT_NEXT_FRAME"); auto landmarks = hand_landmark_subgraph.Out("LANDMARKS"); auto world_landmarks = hand_landmark_subgraph.Out("WORLD_LANDMARKS"); - auto image_size = - hand_landmark_subgraph[Output>("IMAGE_SIZE")]; auto& end_loop_handedness = graph.AddNode("EndLoopClassificationListCalculator"); @@ -585,13 +585,16 @@ class HandLandmarkerSubgraph : public core::ModelTaskGraph { /* presences= */ presences, /* presence_scores= */ presence_scores, /* handednesses= */ handednesses, - /* image_size= */ image_size, }}; } }; -REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::HandLandmarkerSubgraph); +// clang-format off +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::hand_landmarker::MultipleHandLandmarksDetectorGraph); // NOLINT +// clang-format on +} // namespace hand_landmarker } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc similarity index 96% rename from mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph_test.cc rename to mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc index 1c2bc6da7..d1e928ce7 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc @@ -39,12 +39,13 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" -#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" namespace mediapipe { namespace tasks { namespace vision { +namespace hand_landmarker { namespace { using ::file::Defaults; @@ -57,7 +58,7 @@ using ::mediapipe::file::JoinPath; using ::mediapipe::tasks::core::TaskRunner; using ::mediapipe::tasks::vision::DecodeImageFromFile; using ::mediapipe::tasks::vision::hand_landmarker::proto:: - HandLandmarkerSubgraphOptions; + HandLandmarksDetectorGraphOptions; using ::testing::ElementsAreArray; using ::testing::EqualsProto; using ::testing::Pointwise; @@ -112,13 +113,14 @@ absl::StatusOr> CreateSingleHandTaskRunner( absl::string_view model_name) { Graph graph; - auto& hand_landmark_detection = - graph.AddNode("mediapipe.tasks.vision.SingleHandLandmarkerSubgraph"); + auto& hand_landmark_detection = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker." + "SingleHandLandmarksDetectorGraph"); - auto options = std::make_unique(); + auto options = std::make_unique(); options->mutable_base_options()->mutable_model_asset()->set_file_name( JoinPath("./", kTestDataDirectory, model_name)); - hand_landmark_detection.GetOptions().Swap( + hand_landmark_detection.GetOptions().Swap( options.get()); graph[Input(kImageTag)].SetName(kImageName) >> @@ -151,13 +153,14 @@ absl::StatusOr> CreateMultiHandTaskRunner( absl::string_view model_name) { Graph graph; - auto& multi_hand_landmark_detection = - graph.AddNode("mediapipe.tasks.vision.HandLandmarkerSubgraph"); + auto& multi_hand_landmark_detection = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker." + "MultipleHandLandmarksDetectorGraph"); - auto options = std::make_unique(); + auto options = std::make_unique(); options->mutable_base_options()->mutable_model_asset()->set_file_name( JoinPath("./", kTestDataDirectory, model_name)); - multi_hand_landmark_detection.GetOptions() + multi_hand_landmark_detection.GetOptions() .Swap(options.get()); graph[Input(kImageTag)].SetName(kImageName) >> @@ -462,6 +465,7 @@ INSTANTIATE_TEST_SUITE_P( }); } // namespace +} // namespace hand_landmarker } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD index 8cc984c47..945b12f3e 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD @@ -21,8 +21,8 @@ package(default_visibility = [ licenses(["notice"]) mediapipe_proto_library( - name = "hand_landmarker_subgraph_options_proto", - srcs = ["hand_landmarker_subgraph_options.proto"], + name = "hand_landmarks_detector_graph_options_proto", + srcs = ["hand_landmarks_detector_graph_options.proto"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -31,13 +31,13 @@ mediapipe_proto_library( ) mediapipe_proto_library( - name = "hand_landmarker_options_proto", - srcs = ["hand_landmarker_options.proto"], + name = "hand_landmarker_graph_options_proto", + srcs = ["hand_landmarker_graph_options.proto"], deps = [ - ":hand_landmarker_subgraph_options_proto", + ":hand_landmarks_detector_graph_options_proto", "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_options_proto", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto similarity index 67% rename from mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_options.proto rename to mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto index b3d82eda4..7f3536b09 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto @@ -19,22 +19,26 @@ package mediapipe.tasks.vision.hand_landmarker.proto; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; -import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.proto"; -import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto"; +import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto"; +import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto"; -message HandLandmarkerOptions { +message HandLandmarkerGraphOptions { extend mediapipe.CalculatorOptions { - optional HandLandmarkerOptions ext = 462713202; + optional HandLandmarkerGraphOptions ext = 462713202; } // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; - // The locale to use for display names specified through the TFLite Model - // Metadata, if any. Defaults to English. - optional string display_names_locale = 2 [default = "en"]; + // Options for hand detector graph. + optional hand_detector.proto.HandDetectorGraphOptions + hand_detector_graph_options = 2; - optional hand_detector.proto.HandDetectorOptions hand_detector_options = 3; + // Options for hand landmarker subgraph. + optional HandLandmarksDetectorGraphOptions + hand_landmarks_detector_graph_options = 3; - optional HandLandmarkerSubgraphOptions hand_landmarker_subgraph_options = 4; + // Minimum confidence for hand landmarks tracking to be considered + // successfully. + optional float min_tracking_confidence = 4 [default = 0.5]; } diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto similarity index 77% rename from mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto rename to mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto index 9e93384d6..8c0fc66f2 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto @@ -20,19 +20,15 @@ package mediapipe.tasks.vision.hand_landmarker.proto; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; -message HandLandmarkerSubgraphOptions { +message HandLandmarksDetectorGraphOptions { extend mediapipe.CalculatorOptions { - optional HandLandmarkerSubgraphOptions ext = 474472470; + optional HandLandmarksDetectorGraphOptions ext = 474472470; } // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; - // The locale to use for display names specified through the TFLite Model - // Metadata, if any. Defaults to English. - optional string display_names_locale = 2 [default = "en"]; - // Minimum confidence value ([0.0, 1.0]) for hand presence score to be // considered successfully detecting a hand in the image. - optional float min_detection_confidence = 3 [default = 0.5]; + optional float min_detection_confidence = 2 [default = 0.5]; } diff --git a/mediapipe/tasks/examples/android/BUILD b/mediapipe/tasks/examples/android/BUILD new file mode 100644 index 000000000..c07af2d2c --- /dev/null +++ b/mediapipe/tasks/examples/android/BUILD @@ -0,0 +1,21 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +filegroup( + name = "resource_files", + srcs = glob(["res/**"]), + visibility = ["//mediapipe/tasks/examples/android:__subpackages__"], +) diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/AndroidManifest.xml b/mediapipe/tasks/examples/android/objectdetector/src/main/AndroidManifest.xml new file mode 100644 index 000000000..5c53dc269 --- /dev/null +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/AndroidManifest.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD b/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD new file mode 100644 index 000000000..65b98d647 --- /dev/null +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD @@ -0,0 +1,48 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//visibility:private"]) + +android_binary( + name = "objectdetector", + srcs = glob(["**/*.java"]), + assets = [ + "//mediapipe/tasks/testdata/vision:test_models", + ], + assets_dir = "", + custom_package = "com.google.mediapipe.tasks.examples.objectdetector", + manifest = "AndroidManifest.xml", + manifest_values = { + "applicationId": "com.google.mediapipe.tasks.examples.objectdetector", + }, + multidex = "native", + resource_files = ["//mediapipe/tasks/examples/android:resource_files"], + deps = [ + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector", + "//third_party:androidx_appcompat", + "//third_party:androidx_constraint_layout", + "//third_party:opencv", + "@maven//:androidx_activity_activity", + "@maven//:androidx_concurrent_concurrent_futures", + "@maven//:androidx_exifinterface_exifinterface", + "@maven//:androidx_fragment_fragment", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java new file mode 100644 index 000000000..7f7ec1389 --- /dev/null +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java @@ -0,0 +1,236 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.examples.objectdetector; + +import android.content.Intent; +import android.graphics.Bitmap; +import android.graphics.Matrix; +import android.media.MediaMetadataRetriever; +import android.os.Bundle; +import android.provider.MediaStore; +import androidx.appcompat.app.AppCompatActivity; +import android.util.Log; +import android.view.View; +import android.widget.Button; +import android.widget.FrameLayout; +import androidx.activity.result.ActivityResultLauncher; +import androidx.activity.result.contract.ActivityResultContracts; +import androidx.exifinterface.media.ExifInterface; +// ContentResolver dependency +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult; +import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector; +import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector.ObjectDetectorOptions; +import java.io.IOException; +import java.io.InputStream; + +/** Main activity of MediaPipe Task Object Detector reference app. */ +public class MainActivity extends AppCompatActivity { + private static final String TAG = "MainActivity"; + private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"; + + private ObjectDetector objectDetector; + + private enum InputSource { + UNKNOWN, + IMAGE, + VIDEO, + CAMERA, + } + + private InputSource inputSource = InputSource.UNKNOWN; + + // Image mode demo component. + private ActivityResultLauncher imageGetter; + // Video mode demo component. + private ActivityResultLauncher videoGetter; + private ObjectDetectionResultImageView imageView; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + setupImageModeDemo(); + setupVideoModeDemo(); + // TODO: Adds live camera demo. + } + + /** Sets up the image mode demo. */ + private void setupImageModeDemo() { + imageView = new ObjectDetectionResultImageView(this); + // The Intent to access gallery and read images as bitmap. + imageGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + Bitmap bitmap = null; + try { + bitmap = + downscaleBitmap( + MediaStore.Images.Media.getBitmap( + this.getContentResolver(), resultIntent.getData())); + } catch (IOException e) { + Log.e(TAG, "Bitmap reading error:" + e); + } + try { + InputStream imageData = + this.getContentResolver().openInputStream(resultIntent.getData()); + bitmap = rotateBitmap(bitmap, imageData); + } catch (IOException e) { + Log.e(TAG, "Bitmap rotation error:" + e); + } + if (bitmap != null) { + Image image = new BitmapImageBuilder(bitmap).build(); + ObjectDetectionResult detectionResult = objectDetector.detect(image); + imageView.setData(image, detectionResult); + runOnUiThread(() -> imageView.update()); + } + } + } + }); + Button loadImageButton = findViewById(R.id.button_load_picture); + loadImageButton.setOnClickListener( + v -> { + if (inputSource != InputSource.IMAGE) { + createObjectDetector(RunningMode.IMAGE); + this.inputSource = InputSource.IMAGE; + updateLayout(); + } + // Reads images from gallery. + Intent pickImageIntent = new Intent(Intent.ACTION_PICK); + pickImageIntent.setDataAndType(MediaStore.Images.Media.INTERNAL_CONTENT_URI, "image/*"); + imageGetter.launch(pickImageIntent); + }); + } + + /** Sets up the video mode demo. */ + private void setupVideoModeDemo() { + imageView = new ObjectDetectionResultImageView(this); + // The Intent to access gallery and read a video file. + videoGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + MediaMetadataRetriever metaRetriever = new MediaMetadataRetriever(); + metaRetriever.setDataSource(this, resultIntent.getData()); + long duration = + Long.parseLong( + metaRetriever.extractMetadata( + MediaMetadataRetriever.METADATA_KEY_DURATION)); + int numFrames = + Integer.parseInt( + metaRetriever.extractMetadata( + MediaMetadataRetriever.METADATA_KEY_VIDEO_FRAME_COUNT)); + long frameIntervalMs = duration / numFrames; + for (int i = 0; i < numFrames; ++i) { + Image image = new BitmapImageBuilder(metaRetriever.getFrameAtIndex(i)).build(); + ObjectDetectionResult detectionResult = + objectDetector.detectForVideo(image, frameIntervalMs * i); + // Currently only annotates the detection result on the first video frame and + // display it to verify the correctness. + // TODO: Annotates the detection result on every frame, save the + // annotated frames as a video file, and play back the video afterwards. + if (i == 0) { + imageView.setData(image, detectionResult); + runOnUiThread(() -> imageView.update()); + } + } + } + } + }); + Button loadVideoButton = findViewById(R.id.button_load_video); + loadVideoButton.setOnClickListener( + v -> { + createObjectDetector(RunningMode.VIDEO); + updateLayout(); + this.inputSource = InputSource.VIDEO; + + // Reads a video from gallery. + Intent pickVideoIntent = new Intent(Intent.ACTION_PICK); + pickVideoIntent.setDataAndType(MediaStore.Video.Media.INTERNAL_CONTENT_URI, "video/*"); + videoGetter.launch(pickVideoIntent); + }); + } + + private void createObjectDetector(RunningMode mode) { + if (objectDetector != null) { + objectDetector.close(); + } + // Initializes a new MediaPipe ObjectDetector instance + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setScoreThreshold(0.5f) + .setMaxResults(5) + .setRunningMode(mode) + .build(); + objectDetector = ObjectDetector.createFromOptions(this, options); + } + + private void updateLayout() { + // Updates the preview layout. + FrameLayout frameLayout = findViewById(R.id.preview_display_layout); + frameLayout.removeAllViewsInLayout(); + imageView.setImageDrawable(null); + frameLayout.addView(imageView); + imageView.setVisibility(View.VISIBLE); + } + + private Bitmap downscaleBitmap(Bitmap originalBitmap) { + double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight(); + int width = imageView.getWidth(); + int height = imageView.getHeight(); + if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) { + width = (int) (height * aspectRatio); + } else { + height = (int) (width / aspectRatio); + } + return Bitmap.createScaledBitmap(originalBitmap, width, height, false); + } + + private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException { + int orientation = + new ExifInterface(imageData) + .getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL); + if (orientation == ExifInterface.ORIENTATION_NORMAL) { + return inputBitmap; + } + Matrix matrix = new Matrix(); + switch (orientation) { + case ExifInterface.ORIENTATION_ROTATE_90: + matrix.postRotate(90); + break; + case ExifInterface.ORIENTATION_ROTATE_180: + matrix.postRotate(180); + break; + case ExifInterface.ORIENTATION_ROTATE_270: + matrix.postRotate(270); + break; + default: + matrix.postRotate(0); + } + return Bitmap.createBitmap( + inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true); + } +} diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java new file mode 100644 index 000000000..94a4a90dc --- /dev/null +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java @@ -0,0 +1,77 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.examples.objectdetector; + +import android.content.Context; +import android.graphics.Bitmap; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Matrix; +import android.graphics.Paint; +import androidx.appcompat.widget.AppCompatImageView; +import com.google.mediapipe.framework.image.BitmapExtractor; +import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.tasks.components.containers.Detection; +import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult; + +/** An ImageView implementation for displaying {@link ObjectDetectionResult}. */ +public class ObjectDetectionResultImageView extends AppCompatImageView { + private static final String TAG = "ObjectDetectionResultImageView"; + + private static final int BBOX_COLOR = Color.GREEN; + private static final int BBOX_THICKNESS = 5; // Pixels + private Bitmap latest; + + public ObjectDetectionResultImageView(Context context) { + super(context); + setScaleType(AppCompatImageView.ScaleType.FIT_CENTER); + } + + /** + * Sets an {@link Image} and an {@link ObjectDetectionResult} to render. + * + * @param image an {@link Image} object for annotation. + * @param result an {@link ObjectDetectionResult} object that contains the detection result. + */ + public void setData(Image image, ObjectDetectionResult result) { + if (image == null || result == null) { + return; + } + latest = BitmapExtractor.extract(image); + Canvas canvas = new Canvas(latest); + canvas.drawBitmap(latest, new Matrix(), null); + for (int i = 0; i < result.detections().size(); ++i) { + drawDetectionOnCanvas(result.detections().get(i), canvas); + } + } + + /** Updates the image view with the latest {@link ObjectDetectionResult}. */ + public void update() { + postInvalidate(); + if (latest != null) { + setImageBitmap(latest); + } + } + + private void drawDetectionOnCanvas(Detection detection, Canvas canvas) { + // TODO: Draws the category and the score per bounding box. + // Draws bounding box. + Paint bboxPaint = new Paint(); + bboxPaint.setColor(BBOX_COLOR); + bboxPaint.setStyle(Paint.Style.STROKE); + bboxPaint.setStrokeWidth(BBOX_THICKNESS); + canvas.drawRect(detection.boundingBox(), bboxPaint); + } +} diff --git a/mediapipe/tasks/examples/android/res/drawable-v24/ic_launcher_foreground.xml b/mediapipe/tasks/examples/android/res/drawable-v24/ic_launcher_foreground.xml new file mode 100644 index 000000000..c7bd21dbd --- /dev/null +++ b/mediapipe/tasks/examples/android/res/drawable-v24/ic_launcher_foreground.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + diff --git a/mediapipe/tasks/examples/android/res/drawable/ic_launcher_background.xml b/mediapipe/tasks/examples/android/res/drawable/ic_launcher_background.xml new file mode 100644 index 000000000..01f0af0ad --- /dev/null +++ b/mediapipe/tasks/examples/android/res/drawable/ic_launcher_background.xml @@ -0,0 +1,74 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/examples/android/res/layout/activity_main.xml b/mediapipe/tasks/examples/android/res/layout/activity_main.xml new file mode 100644 index 000000000..834e9a3e6 --- /dev/null +++ b/mediapipe/tasks/examples/android/res/layout/activity_main.xml @@ -0,0 +1,40 @@ + + + +