Merge branch 'image-classification-python-impl' of https://github.com/kinaryml/mediapipe into image-classification-python-impl
|
@ -53,7 +53,7 @@ RUN pip3 install wheel
|
|||
RUN pip3 install future
|
||||
RUN pip3 install absl-py numpy opencv-contrib-python protobuf==3.20.1
|
||||
RUN pip3 install six==1.14.0
|
||||
RUN pip3 install tensorflow==2.2.0
|
||||
RUN pip3 install tensorflow
|
||||
RUN pip3 install tf_slim
|
||||
|
||||
RUN ln -s /usr/bin/python3 /usr/bin/python
|
||||
|
|
|
@ -143,6 +143,98 @@ Below is an example of how to create a subgraph named `TwoPassThroughSubgraph`.
|
|||
}
|
||||
```
|
||||
|
||||
## Graph Options
|
||||
|
||||
It is possible to specify a "graph options" protobuf for a MediaPipe graph
|
||||
similar to the [`Calculator Options`](calculators.md#calculator-options)
|
||||
protobuf specified for a MediaPipe calculator. These "graph options" can be
|
||||
specified where a graph is invoked, and used to populate calculator options and
|
||||
subgraph options within the graph.
|
||||
|
||||
In a CalculatorGraphConfig, graph options can be specified for a subgraph
|
||||
exactly like calculator options, as shown below:
|
||||
|
||||
```
|
||||
node {
|
||||
calculator: "FlowLimiterCalculator"
|
||||
input_stream: "image"
|
||||
output_stream: "throttled_image"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.FlowLimiterCalculatorOptions] {
|
||||
max_in_flight: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "FaceDetectionSubgraph"
|
||||
input_stream: "IMAGE:throttled_image"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.FaceDetectionOptions] {
|
||||
tensor_width: 192
|
||||
tensor_height: 192
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
In a CalculatorGraphConfig, graph options can be accepted and used to populate
|
||||
calculator options, as shown below:
|
||||
|
||||
```
|
||||
graph_options: {
|
||||
[type.googleapis.com/mediapipe.FaceDetectionOptions] {}
|
||||
}
|
||||
|
||||
node: {
|
||||
calculator: "ImageToTensorCalculator"
|
||||
input_stream: "IMAGE:multi_backend_image"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.ImageToTensorCalculatorOptions] {
|
||||
keep_aspect_ratio: true
|
||||
border_mode: BORDER_ZERO
|
||||
}
|
||||
}
|
||||
option_value: "output_tensor_width:options/tensor_width"
|
||||
option_value: "output_tensor_height:options/tensor_height"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "InferenceCalculator"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.InferenceCalculatorOptions] {}
|
||||
}
|
||||
option_value: "delegate:options/delegate"
|
||||
option_value: "model_path:options/model_path"
|
||||
}
|
||||
```
|
||||
|
||||
In this example, the `FaceDetectionSubgraph` accepts graph option protobuf
|
||||
`FaceDetectionOptions`. The `FaceDetectionOptions` is used to define some field
|
||||
values in the calculator options `ImageToTensorCalculatorOptions` and some field
|
||||
values in the subgraph options `InferenceCalculatorOptions`. The field values
|
||||
are defined using the `option_value:` syntax.
|
||||
|
||||
In the `CalculatorGraphConfig::Node` protobuf, the fields `node_options:` and
|
||||
`option_value:` together define the option values for a calculator such as
|
||||
`ImageToTensorCalculator`. The `node_options:` field defines a set of literal
|
||||
constant values using the text protobuf syntax. Each `option_value:` field
|
||||
defines the value for one protobuf field using information from the enclosing
|
||||
graph, specifically from field values of the graph options of the enclosing
|
||||
graph. In the example above, the `option_value:`
|
||||
`"output_tensor_width:options/tensor_width"` defines the field
|
||||
`ImageToTensorCalculatorOptions.output_tensor_width` using the value of
|
||||
`FaceDetectionOptions.tensor_width`.
|
||||
|
||||
The syntax of `option_value:` is similar to the syntax of `input_stream:`. The
|
||||
syntax is `option_value: "LHS:RHS"`. The LHS identifies a calculator option
|
||||
field and the RHS identifies a graph option field. More specifically, the LHS
|
||||
and RHS each consists of a series of protobuf field names identifying nested
|
||||
protobuf messages and fields separated by '/'. This is known as the "ProtoPath"
|
||||
syntax. Nested messages that are referenced in the LHS or RHS must already be
|
||||
defined in the enclosing protobuf in order to be traversed using
|
||||
`option_value:`.
|
||||
|
||||
## Cycles
|
||||
|
||||
<!-- TODO: add discussion of PreviousLoopbackCalculator -->
|
||||
|
|
|
@ -18,7 +18,7 @@ import android.content.ClipDescription;
|
|||
import android.content.Context;
|
||||
import android.net.Uri;
|
||||
import android.os.Bundle;
|
||||
import androidx.appcompat.widget.AppCompatEditText;
|
||||
import android.support.v7.widget.AppCompatEditText;
|
||||
import android.util.AttributeSet;
|
||||
import android.util.Log;
|
||||
import android.view.inputmethod.EditorInfo;
|
||||
|
|
|
@ -371,7 +371,7 @@ void* Tensor::MapAhwbToCpuRead() const {
|
|||
if ((valid_ & kValidOpenGlBuffer) && ssbo_written_ == -1) {
|
||||
// EGLSync is failed. Use another synchronization method.
|
||||
// TODO: Use tflite::gpu::GlBufferSync and GlActiveSync.
|
||||
glFinish();
|
||||
gl_context_->Run([]() { glFinish(); });
|
||||
} else if (valid_ & kValidAHardwareBuffer) {
|
||||
CHECK(ahwb_written_) << "Ahwb-to-Cpu synchronization requires the "
|
||||
"completion function to be set";
|
||||
|
|
|
@ -89,10 +89,6 @@ def mediapipe_aar(
|
|||
calculators = calculators,
|
||||
)
|
||||
|
||||
_mediapipe_proto(
|
||||
name = name + "_proto",
|
||||
)
|
||||
|
||||
native.genrule(
|
||||
name = name + "_aar_manifest_generator",
|
||||
outs = ["AndroidManifest.xml"],
|
||||
|
@ -115,19 +111,10 @@ EOF
|
|||
"//mediapipe/java/com/google/mediapipe/components:java_src",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:java_src",
|
||||
"//mediapipe/java/com/google/mediapipe/glutil:java_src",
|
||||
"com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
|
||||
"com/google/mediapipe/formats/proto/ClassificationProto.java",
|
||||
"com/google/mediapipe/formats/proto/DetectionProto.java",
|
||||
"com/google/mediapipe/formats/proto/LandmarkProto.java",
|
||||
"com/google/mediapipe/formats/proto/LocationDataProto.java",
|
||||
"com/google/mediapipe/proto/CalculatorProto.java",
|
||||
] +
|
||||
] + mediapipe_java_proto_srcs() +
|
||||
select({
|
||||
"//conditions:default": [],
|
||||
"enable_stats_logging": [
|
||||
"com/google/mediapipe/proto/MediaPipeLoggingProto.java",
|
||||
"com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
|
||||
],
|
||||
"enable_stats_logging": mediapipe_logging_java_proto_srcs(),
|
||||
}),
|
||||
manifest = "AndroidManifest.xml",
|
||||
proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"],
|
||||
|
@ -179,93 +166,6 @@ EOF
|
|||
|
||||
_aar_with_jni(name, name + "_android_lib")
|
||||
|
||||
def _mediapipe_proto(name):
|
||||
"""Generates MediaPipe java proto libraries.
|
||||
|
||||
Args:
|
||||
name: the name of the target.
|
||||
"""
|
||||
_proto_java_src_generator(
|
||||
name = "mediapipe_log_extension_proto",
|
||||
proto_src = "mediapipe/util/analytics/mediapipe_log_extension.proto",
|
||||
java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java",
|
||||
srcs = ["//mediapipe/util/analytics:protos_src"],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "mediapipe_logging_enums_proto",
|
||||
proto_src = "mediapipe/util/analytics/mediapipe_logging_enums.proto",
|
||||
java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
|
||||
srcs = ["//mediapipe/util/analytics:protos_src"],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "calculator_proto",
|
||||
proto_src = "mediapipe/framework/calculator.proto",
|
||||
java_lite_out = "com/google/mediapipe/proto/CalculatorProto.java",
|
||||
srcs = ["//mediapipe/framework:protos_src"],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "landmark_proto",
|
||||
proto_src = "mediapipe/framework/formats/landmark.proto",
|
||||
java_lite_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
|
||||
srcs = ["//mediapipe/framework/formats:protos_src"],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "rasterization_proto",
|
||||
proto_src = "mediapipe/framework/formats/annotation/rasterization.proto",
|
||||
java_lite_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
|
||||
srcs = ["//mediapipe/framework/formats/annotation:protos_src"],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "location_data_proto",
|
||||
proto_src = "mediapipe/framework/formats/location_data.proto",
|
||||
java_lite_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
|
||||
srcs = [
|
||||
"//mediapipe/framework/formats:protos_src",
|
||||
"//mediapipe/framework/formats/annotation:protos_src",
|
||||
],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "detection_proto",
|
||||
proto_src = "mediapipe/framework/formats/detection.proto",
|
||||
java_lite_out = "com/google/mediapipe/formats/proto/DetectionProto.java",
|
||||
srcs = [
|
||||
"//mediapipe/framework/formats:protos_src",
|
||||
"//mediapipe/framework/formats/annotation:protos_src",
|
||||
],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "classification_proto",
|
||||
proto_src = "mediapipe/framework/formats/classification.proto",
|
||||
java_lite_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
|
||||
srcs = [
|
||||
"//mediapipe/framework/formats:protos_src",
|
||||
],
|
||||
)
|
||||
|
||||
def _proto_java_src_generator(name, proto_src, java_lite_out, srcs = []):
|
||||
native.genrule(
|
||||
name = name + "_proto_java_src_generator",
|
||||
srcs = srcs + [
|
||||
"@com_google_protobuf//:lite_well_known_protos",
|
||||
],
|
||||
outs = [java_lite_out],
|
||||
cmd = "$(location @com_google_protobuf//:protoc) " +
|
||||
"--proto_path=. --proto_path=$(GENDIR) " +
|
||||
"--proto_path=$$(pwd)/external/com_google_protobuf/src " +
|
||||
"--java_out=lite:$(GENDIR) " + proto_src + " && " +
|
||||
"mv $(GENDIR)/" + java_lite_out + " $$(dirname $(location " + java_lite_out + "))",
|
||||
tools = [
|
||||
"@com_google_protobuf//:protoc",
|
||||
],
|
||||
)
|
||||
|
||||
def _mediapipe_jni(name, gen_libmediapipe, calculators = []):
|
||||
"""Generates MediaPipe jni library.
|
||||
|
||||
|
@ -345,3 +245,93 @@ cp -r lib jni
|
|||
zip -r $$origdir/$(location :{}.aar) jni/*/*.so
|
||||
""".format(android_library, name, name, name, name),
|
||||
)
|
||||
|
||||
def mediapipe_java_proto_src_extractor(target, src_out, name = ""):
|
||||
"""Extracts the generated MediaPipe java proto source code from the target.
|
||||
|
||||
Args:
|
||||
target: The java proto lite target to be built and extracted.
|
||||
src_out: The output java proto src code path.
|
||||
name: The optional bazel target name.
|
||||
|
||||
Returns:
|
||||
The output java proto src code path.
|
||||
"""
|
||||
|
||||
if not name:
|
||||
name = target.split(":")[-1] + "_proto_java_src_extractor"
|
||||
src_jar = target.replace("_java_proto_lite", "_proto-lite-src.jar").replace(":", "/").replace("//", "")
|
||||
native.genrule(
|
||||
name = name + "_proto_java_src_extractor",
|
||||
srcs = [target],
|
||||
outs = [src_out],
|
||||
cmd = "unzip $(GENDIR)/" + src_jar + " -d $(GENDIR) && mv $(GENDIR)/" +
|
||||
src_out + " $$(dirname $(location " + src_out + "))",
|
||||
)
|
||||
return src_out
|
||||
|
||||
def mediapipe_java_proto_srcs(name = ""):
|
||||
"""Extracts the generated MediaPipe framework java proto source code.
|
||||
|
||||
Args:
|
||||
name: The optional bazel target name.
|
||||
|
||||
Returns:
|
||||
The list of the extrated MediaPipe java proto source code.
|
||||
"""
|
||||
|
||||
proto_src_list = []
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework:calculator_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/CalculatorProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:landmark_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:location_data_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:detection_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/DetectionProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:classification_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
|
||||
))
|
||||
return proto_src_list
|
||||
|
||||
def mediapipe_logging_java_proto_srcs(name = ""):
|
||||
"""Extracts the generated logging-related MediaPipe java proto source code.
|
||||
|
||||
Args:
|
||||
name: The optional bazel target name.
|
||||
|
||||
Returns:
|
||||
The list of the extrated MediaPipe logging-related java proto source code.
|
||||
"""
|
||||
|
||||
proto_src_list = []
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/util/analytics:mediapipe_log_extension_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/util/analytics:mediapipe_logging_enums_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
|
||||
))
|
||||
return proto_src_list
|
||||
|
|
22
mediapipe/model_maker/BUILD
Normal file
|
@ -0,0 +1,22 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
"//mediapipe/model_maker/...",
|
||||
],
|
||||
)
|
13
mediapipe/model_maker/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
22
mediapipe/model_maker/python/BUILD
Normal file
|
@ -0,0 +1,22 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
"//mediapipe/model_maker/...",
|
||||
],
|
||||
)
|
13
mediapipe/model_maker/python/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
19
mediapipe/model_maker/python/core/BUILD
Normal file
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
13
mediapipe/model_maker/python/core/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
68
mediapipe/model_maker/python/core/data/BUILD
Normal file
|
@ -0,0 +1,68 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "data_util",
|
||||
srcs = ["data_util.py"],
|
||||
srcs_version = "PY3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "data_util_test",
|
||||
srcs = ["data_util_test.py"],
|
||||
data = ["//mediapipe/model_maker/python/core/data/testdata"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":data_util"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dataset",
|
||||
srcs = ["dataset.py"],
|
||||
srcs_version = "PY3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dataset_test",
|
||||
srcs = ["dataset_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":dataset",
|
||||
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "classification_dataset",
|
||||
srcs = ["classification_dataset.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [":dataset"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "classification_dataset_test",
|
||||
srcs = ["classification_dataset_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":classification_dataset"],
|
||||
)
|
13
mediapipe/model_maker/python/core/data/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Common classification dataset library."""
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||
|
||||
|
||||
class ClassificationDataset(ds.Dataset):
|
||||
"""DataLoader for classification models."""
|
||||
|
||||
def __init__(self, dataset: tf.data.Dataset, size: int, index_to_label: Any):
|
||||
super().__init__(dataset, size)
|
||||
self.index_to_label = index_to_label
|
||||
|
||||
@property
|
||||
def num_classes(self: ds._DatasetT) -> int:
|
||||
return len(self.index_to_label)
|
||||
|
||||
def split(self: ds._DatasetT,
|
||||
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
|
||||
"""Splits dataset into two sub-datasets with the given fraction.
|
||||
|
||||
Primarily used for splitting the data set into training and testing sets.
|
||||
|
||||
Args:
|
||||
fraction: float, demonstrates the fraction of the first returned
|
||||
subdataset in the original data.
|
||||
|
||||
Returns:
|
||||
The splitted two sub datasets.
|
||||
"""
|
||||
return self._split(fraction, self.index_to_label)
|
|
@ -0,0 +1,68 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||
|
||||
|
||||
class ClassificationDataLoaderTest(tf.test.TestCase):
|
||||
|
||||
def test_split(self):
|
||||
|
||||
class MagicClassificationDataLoader(
|
||||
classification_dataset.ClassificationDataset):
|
||||
|
||||
def __init__(self, dataset, size, index_to_label, value):
|
||||
super(MagicClassificationDataLoader,
|
||||
self).__init__(dataset, size, index_to_label)
|
||||
self.value = value
|
||||
|
||||
def split(self, fraction):
|
||||
return self._split(fraction, self.index_to_label, self.value)
|
||||
|
||||
# Some dummy inputs.
|
||||
magic_value = 42
|
||||
num_classes = 2
|
||||
index_to_label = (False, True)
|
||||
|
||||
# Create data loader from sample data.
|
||||
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||
data = MagicClassificationDataLoader(ds, len(ds), index_to_label,
|
||||
magic_value)
|
||||
|
||||
# Train/Test data split.
|
||||
fraction = .25
|
||||
train_data, test_data = data.split(fraction)
|
||||
|
||||
# `split` should return instances of child DataLoader.
|
||||
self.assertIsInstance(train_data, MagicClassificationDataLoader)
|
||||
self.assertIsInstance(test_data, MagicClassificationDataLoader)
|
||||
|
||||
# Make sure number of entries are right.
|
||||
self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data))
|
||||
self.assertLen(train_data, fraction * len(ds))
|
||||
self.assertLen(test_data, len(ds) - len(train_data))
|
||||
|
||||
# Make sure attributes propagated correctly.
|
||||
self.assertEqual(train_data.num_classes, num_classes)
|
||||
self.assertEqual(test_data.index_to_label, index_to_label)
|
||||
self.assertEqual(train_data.value, magic_value)
|
||||
self.assertEqual(test_data.value, magic_value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
35
mediapipe/model_maker/python/core/data/data_util.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Data utility library."""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def load_image(path: str) -> np.ndarray:
|
||||
"""Loads an image as an RGB numpy array.
|
||||
|
||||
Args:
|
||||
path: input image file absolute path.
|
||||
|
||||
Returns:
|
||||
An RGB image in numpy.ndarray.
|
||||
"""
|
||||
tf.compat.v1.logging.info('Loading RGB image %s', path)
|
||||
# TODO Replace the OpenCV image load and conversion library by
|
||||
# MediaPipe image utility library once it is ready.
|
||||
image = cv2.imread(path)
|
||||
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
44
mediapipe/model_maker/python/core/data/data_util_test.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
# Dependency imports
|
||||
|
||||
from absl import flags
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import data_util
|
||||
|
||||
_WORKSPACE = "mediapipe"
|
||||
_TEST_DATA_DIR = os.path.join(
|
||||
_WORKSPACE, 'mediapipe/model_maker/python/core/data/testdata')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class DataUtilTest(tf.test.TestCase):
|
||||
|
||||
def test_load_rgb_image(self):
|
||||
image_path = os.path.join(FLAGS.test_srcdir, _TEST_DATA_DIR, 'test.jpg')
|
||||
image_data = data_util.load_image(image_path)
|
||||
self.assertEqual(image_data.shape, (5184, 3456, 3))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
164
mediapipe/model_maker/python/core/data/dataset.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Common dataset for model training and evaluation."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
from typing import Callable, Optional, Tuple, TypeVar
|
||||
|
||||
# Dependency imports
|
||||
import tensorflow as tf
|
||||
|
||||
_DatasetT = TypeVar('_DatasetT', bound='Dataset')
|
||||
|
||||
|
||||
class Dataset(object):
|
||||
"""A generic dataset class for loading model training and evaluation dataset.
|
||||
|
||||
For each ML task, such as image classification, text classification etc., a
|
||||
subclass can be derived from this class to provide task-specific data loading
|
||||
utilities.
|
||||
"""
|
||||
|
||||
def __init__(self, tf_dataset: tf.data.Dataset, size: Optional[int] = None):
|
||||
"""Initializes Dataset class.
|
||||
|
||||
To build dataset from raw data, consider using the task specific utilities,
|
||||
e.g. from_folder().
|
||||
|
||||
Args:
|
||||
tf_dataset: A tf.data.Dataset object that contains a potentially large set
|
||||
of elements, where each element is a pair of (input_data, target). The
|
||||
`input_data` means the raw input data, like an image, a text etc., while
|
||||
the `target` means the ground truth of the raw input data, e.g. the
|
||||
classification label of the image etc.
|
||||
size: The size of the dataset. tf.data.Dataset donesn't support a function
|
||||
to get the length directly since it's lazy-loaded and may be infinite.
|
||||
"""
|
||||
self._dataset = tf_dataset
|
||||
self._size = size
|
||||
|
||||
@property
|
||||
def size(self) -> Optional[int]:
|
||||
"""Returns the size of the dataset.
|
||||
|
||||
Note that this function may return None becuase the exact size of the
|
||||
dataset isn't a necessary parameter to create an instance of this class,
|
||||
and tf.data.Dataset donesn't support a function to get the length directly
|
||||
since it's lazy-loaded and may be infinite.
|
||||
In most cases, however, when an instance of this class is created by helper
|
||||
functions like 'from_folder', the size of the dataset will be preprocessed,
|
||||
and this function can return an int representing the size of the dataset.
|
||||
"""
|
||||
return self._size
|
||||
|
||||
def gen_tf_dataset(self,
|
||||
batch_size: int = 1,
|
||||
is_training: bool = False,
|
||||
shuffle: bool = False,
|
||||
preprocess: Optional[Callable[..., bool]] = None,
|
||||
drop_remainder: bool = False) -> tf.data.Dataset:
|
||||
"""Generates a batched tf.data.Dataset for training/evaluation.
|
||||
|
||||
Args:
|
||||
batch_size: An integer, the returned dataset will be batched by this size.
|
||||
is_training: A boolean, when True, the returned dataset will be optionally
|
||||
shuffled and repeated as an endless dataset.
|
||||
shuffle: A boolean, when True, the returned dataset will be shuffled to
|
||||
create randomness during model training.
|
||||
preprocess: A function taking three arguments in order, feature, label and
|
||||
boolean is_training.
|
||||
drop_remainder: boolean, whether the finaly batch drops remainder.
|
||||
|
||||
Returns:
|
||||
A TF dataset ready to be consumed by Keras model.
|
||||
"""
|
||||
dataset = self._dataset
|
||||
|
||||
if preprocess:
|
||||
preprocess = functools.partial(preprocess, is_training=is_training)
|
||||
dataset = dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
|
||||
if is_training:
|
||||
if shuffle:
|
||||
# Shuffle size should be bigger than the batch_size. Otherwise it's only
|
||||
# shuffling within the batch, which equals to not having shuffle.
|
||||
buffer_size = 3 * batch_size
|
||||
# But since we are doing shuffle before repeat, it doesn't make sense to
|
||||
# shuffle more than total available entries.
|
||||
# TODO: Investigate if shuffling before / after repeat
|
||||
# dataset can get a better performance?
|
||||
# Shuffle after repeat will give a more randomized dataset and mix the
|
||||
# epoch boundary: https://www.tensorflow.org/guide/data
|
||||
if self._size:
|
||||
buffer_size = min(self._size, buffer_size)
|
||||
dataset = dataset.shuffle(buffer_size=buffer_size)
|
||||
|
||||
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
|
||||
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
||||
# TODO: Consider converting dataset to distributed dataset
|
||||
# here.
|
||||
return dataset
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the number of element of the dataset."""
|
||||
if self._size is not None:
|
||||
return self._size
|
||||
else:
|
||||
return len(self._dataset)
|
||||
|
||||
def split(self: _DatasetT, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
|
||||
"""Splits dataset into two sub-datasets with the given fraction.
|
||||
|
||||
Primarily used for splitting the data set into training and testing sets.
|
||||
|
||||
Args:
|
||||
fraction: A float value defines the fraction of the first returned
|
||||
subdataset in the original data.
|
||||
|
||||
Returns:
|
||||
The splitted two sub datasets.
|
||||
"""
|
||||
return self._split(fraction)
|
||||
|
||||
def _split(self: _DatasetT, fraction: float,
|
||||
*args) -> Tuple[_DatasetT, _DatasetT]:
|
||||
"""Implementation for `split` method and returns sub-class instances.
|
||||
|
||||
Child DataLoader classes, if requires additional constructor arguments,
|
||||
should implement their own `split` method by calling `_split` with all
|
||||
arguments to the constructor.
|
||||
|
||||
Args:
|
||||
fraction: A float value defines the fraction of the first returned
|
||||
subdataset in the original data.
|
||||
*args: additional arguments passed to the sub-class constructor.
|
||||
|
||||
Returns:
|
||||
The splitted two sub datasets.
|
||||
"""
|
||||
assert (fraction > 0 and fraction < 1)
|
||||
|
||||
dataset = self._dataset
|
||||
|
||||
train_size = int(self._size * fraction)
|
||||
trainset = self.__class__(dataset.take(train_size), train_size, *args)
|
||||
|
||||
test_size = self._size - train_size
|
||||
testset = self.__class__(dataset.skip(train_size), test_size, *args)
|
||||
|
||||
return trainset, testset
|
78
mediapipe/model_maker/python/core/data/dataset_test.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||
from mediapipe.model_maker.python.core.utils import test_util
|
||||
|
||||
|
||||
class DatasetTest(tf.test.TestCase):
|
||||
|
||||
def test_split(self):
|
||||
dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
|
||||
[1, 0]])
|
||||
data = ds.Dataset(dataset, 4)
|
||||
train_data, test_data = data.split(0.5)
|
||||
|
||||
self.assertLen(train_data, 2)
|
||||
self.assertIsInstance(train_data, ds.Dataset)
|
||||
self.assertIsInstance(test_data, ds.Dataset)
|
||||
for i, elem in enumerate(train_data.gen_tf_dataset()):
|
||||
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
|
||||
|
||||
self.assertLen(test_data, 2)
|
||||
for i, elem in enumerate(test_data.gen_tf_dataset()):
|
||||
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
|
||||
|
||||
def test_len(self):
|
||||
size = 4
|
||||
dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
|
||||
[1, 0]])
|
||||
data = ds.Dataset(dataset, size)
|
||||
self.assertLen(data, size)
|
||||
|
||||
def test_gen_tf_dataset(self):
|
||||
input_dim = 8
|
||||
data = test_util.create_dataset(
|
||||
data_size=2, input_shape=[input_dim], num_classes=2)
|
||||
|
||||
dataset = data.gen_tf_dataset()
|
||||
self.assertLen(dataset, 2)
|
||||
for (feature, label) in dataset:
|
||||
self.assertTrue((tf.shape(feature).numpy() == np.array([1, 8])).all())
|
||||
self.assertTrue((tf.shape(label).numpy() == np.array([1])).all())
|
||||
|
||||
dataset2 = data.gen_tf_dataset(batch_size=2)
|
||||
self.assertLen(dataset2, 1)
|
||||
for (feature, label) in dataset2:
|
||||
self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all())
|
||||
self.assertTrue((tf.shape(label).numpy() == np.array([2])).all())
|
||||
|
||||
dataset3 = data.gen_tf_dataset(batch_size=2, is_training=True, shuffle=True)
|
||||
self.assertEqual(dataset3.cardinality(), 1)
|
||||
for (feature, label) in dataset3.take(10):
|
||||
self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all())
|
||||
self.assertTrue((tf.shape(label).numpy() == np.array([2])).all())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
30
mediapipe/model_maker/python/core/data/testdata/BUILD
vendored
Normal file
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load(
|
||||
"//mediapipe/framework/tool:mediapipe_files.bzl",
|
||||
"mediapipe_files",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe/model_maker/python/core/data:__subpackages__"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
mediapipe_files(srcs = ["test.jpg"])
|
||||
|
||||
filegroup(
|
||||
name = "testdata",
|
||||
srcs = ["test.jpg"],
|
||||
)
|
100
mediapipe/model_maker/python/core/utils/BUILD
Normal file
|
@ -0,0 +1,100 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "test_util",
|
||||
testonly = 1,
|
||||
srcs = ["test_util.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":model_util",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "image_preprocessing",
|
||||
srcs = ["image_preprocessing.py"],
|
||||
srcs_version = "PY3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "image_preprocessing_test",
|
||||
srcs = ["image_preprocessing_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":image_preprocessing"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_util",
|
||||
srcs = ["model_util.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":quantization",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "model_util_test",
|
||||
srcs = ["model_util_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":model_util",
|
||||
":quantization",
|
||||
":test_util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "loss_functions",
|
||||
srcs = ["loss_functions.py"],
|
||||
srcs_version = "PY3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "loss_functions_test",
|
||||
srcs = ["loss_functions_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":loss_functions"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "quantization",
|
||||
srcs = ["quantization.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = ["//mediapipe/model_maker/python/core/data:dataset"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "quantization_test",
|
||||
srcs = ["quantization_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":quantization",
|
||||
":test_util",
|
||||
],
|
||||
)
|
13
mediapipe/model_maker/python/core/utils/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
228
mediapipe/model_maker/python/core/utils/image_preprocessing.py
Normal file
|
@ -0,0 +1,228 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""ImageNet preprocessing."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Dependency imports
|
||||
import tensorflow as tf
|
||||
|
||||
IMAGE_SIZE = 224
|
||||
CROP_PADDING = 32
|
||||
|
||||
|
||||
class Preprocessor(object):
|
||||
"""Preprocessor for image classification."""
|
||||
|
||||
def __init__(self,
|
||||
input_shape,
|
||||
num_classes,
|
||||
mean_rgb,
|
||||
stddev_rgb,
|
||||
use_augmentation=False):
|
||||
self.input_shape = input_shape
|
||||
self.num_classes = num_classes
|
||||
self.mean_rgb = mean_rgb
|
||||
self.stddev_rgb = stddev_rgb
|
||||
self.use_augmentation = use_augmentation
|
||||
|
||||
def __call__(self, image, label, is_training=True):
|
||||
if self.use_augmentation:
|
||||
return self._preprocess_with_augmentation(image, label, is_training)
|
||||
return self._preprocess_without_augmentation(image, label)
|
||||
|
||||
def _preprocess_with_augmentation(self, image, label, is_training):
|
||||
"""Image preprocessing method with data augmentation."""
|
||||
image_size = self.input_shape[0]
|
||||
if is_training:
|
||||
image = preprocess_for_train(image, image_size)
|
||||
else:
|
||||
image = preprocess_for_eval(image, image_size)
|
||||
|
||||
image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
|
||||
image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
|
||||
|
||||
label = tf.one_hot(label, depth=self.num_classes)
|
||||
return image, label
|
||||
|
||||
# TODO: Changes to preprocess to support batch input.
|
||||
def _preprocess_without_augmentation(self, image, label):
|
||||
"""Image preprocessing method without data augmentation."""
|
||||
image = tf.cast(image, tf.float32)
|
||||
|
||||
image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
|
||||
image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
|
||||
|
||||
image = tf.compat.v1.image.resize(image, self.input_shape)
|
||||
label = tf.one_hot(label, depth=self.num_classes)
|
||||
return image, label
|
||||
|
||||
|
||||
def _distorted_bounding_box_crop(image,
|
||||
bbox,
|
||||
min_object_covered=0.1,
|
||||
aspect_ratio_range=(0.75, 1.33),
|
||||
area_range=(0.05, 1.0),
|
||||
max_attempts=100):
|
||||
"""Generates cropped_image using one of the bboxes randomly distorted.
|
||||
|
||||
See `tf.image.sample_distorted_bounding_box` for more documentation.
|
||||
|
||||
Args:
|
||||
image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of
|
||||
shape [height, width, channels].
|
||||
bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` where
|
||||
each coordinate is [0, 1) and the coordinates are arranged as `[ymin,
|
||||
xmin, ymax, xmax]`. If num_boxes is 0 then use the whole image.
|
||||
min_object_covered: An optional `float`. Defaults to `0.1`. The cropped area
|
||||
of the image must contain at least this fraction of any bounding box
|
||||
supplied.
|
||||
aspect_ratio_range: An optional list of `float`s. The cropped area of the
|
||||
image must have an aspect ratio = width / height within this range.
|
||||
area_range: An optional list of `float`s. The cropped area of the image must
|
||||
contain a fraction of the supplied image within in this range.
|
||||
max_attempts: An optional `int`. Number of attempts at generating a cropped
|
||||
region of the image of the specified constraints. After `max_attempts`
|
||||
failures, return the entire image.
|
||||
|
||||
Returns:
|
||||
A cropped image `Tensor`
|
||||
"""
|
||||
with tf.name_scope('distorted_bounding_box_crop'):
|
||||
shape = tf.shape(image)
|
||||
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
|
||||
shape,
|
||||
bounding_boxes=bbox,
|
||||
min_object_covered=min_object_covered,
|
||||
aspect_ratio_range=aspect_ratio_range,
|
||||
area_range=area_range,
|
||||
max_attempts=max_attempts,
|
||||
use_image_if_no_bounding_boxes=True)
|
||||
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
|
||||
|
||||
# Crop the image to the specified bounding box.
|
||||
offset_y, offset_x, _ = tf.unstack(bbox_begin)
|
||||
target_height, target_width, _ = tf.unstack(bbox_size)
|
||||
image = tf.image.crop_to_bounding_box(image, offset_y, offset_x,
|
||||
target_height, target_width)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def _at_least_x_are_equal(a, b, x):
|
||||
"""At least `x` of `a` and `b` `Tensors` are equal."""
|
||||
match = tf.equal(a, b)
|
||||
match = tf.cast(match, tf.int32)
|
||||
return tf.greater_equal(tf.reduce_sum(match), x)
|
||||
|
||||
|
||||
def _resize_image(image, image_size, method=None):
|
||||
if method is not None:
|
||||
tf.compat.v1.logging.info('Use customized resize method {}'.format(method))
|
||||
return tf.compat.v1.image.resize([image], [image_size, image_size],
|
||||
method)[0]
|
||||
tf.compat.v1.logging.info('Use default resize_bicubic.')
|
||||
return tf.compat.v1.image.resize_bicubic([image], [image_size, image_size])[0]
|
||||
|
||||
|
||||
def _decode_and_random_crop(original_image, image_size, resize_method=None):
|
||||
"""Makes a random crop of image_size."""
|
||||
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
|
||||
image = _distorted_bounding_box_crop(
|
||||
original_image,
|
||||
bbox,
|
||||
min_object_covered=0.1,
|
||||
aspect_ratio_range=(3. / 4, 4. / 3.),
|
||||
area_range=(0.08, 1.0),
|
||||
max_attempts=10)
|
||||
original_shape = tf.shape(original_image)
|
||||
bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)
|
||||
|
||||
image = tf.cond(bad,
|
||||
lambda: _decode_and_center_crop(original_image, image_size),
|
||||
lambda: _resize_image(image, image_size, resize_method))
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def _decode_and_center_crop(image, image_size, resize_method=None):
|
||||
"""Crops to center of image with padding then scales image_size."""
|
||||
shape = tf.shape(image)
|
||||
image_height = shape[0]
|
||||
image_width = shape[1]
|
||||
|
||||
padded_center_crop_size = tf.cast(
|
||||
((image_size / (image_size + CROP_PADDING)) *
|
||||
tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32)
|
||||
|
||||
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
|
||||
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
|
||||
image = tf.image.crop_to_bounding_box(image, offset_height, offset_width,
|
||||
padded_center_crop_size,
|
||||
padded_center_crop_size)
|
||||
image = _resize_image(image, image_size, resize_method)
|
||||
return image
|
||||
|
||||
|
||||
def _flip(image):
|
||||
"""Random horizontal image flip."""
|
||||
image = tf.image.random_flip_left_right(image)
|
||||
return image
|
||||
|
||||
|
||||
def preprocess_for_train(
|
||||
image: tf.Tensor,
|
||||
image_size: int = IMAGE_SIZE,
|
||||
resize_method: str = tf.image.ResizeMethod.BILINEAR) -> tf.Tensor:
|
||||
"""Preprocesses the given image for evaluation.
|
||||
|
||||
Args:
|
||||
image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of
|
||||
shape [height, width, channels].
|
||||
image_size: image size.
|
||||
resize_method: resize method. If none, use bicubic.
|
||||
|
||||
Returns:
|
||||
A preprocessed image `Tensor`.
|
||||
"""
|
||||
image = _decode_and_random_crop(image, image_size, resize_method)
|
||||
image = _flip(image)
|
||||
image = tf.reshape(image, [image_size, image_size, 3])
|
||||
|
||||
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def preprocess_for_eval(
|
||||
image: tf.Tensor,
|
||||
image_size: int = IMAGE_SIZE,
|
||||
resize_method: str = tf.image.ResizeMethod.BILINEAR) -> tf.Tensor:
|
||||
"""Preprocesses the given image for evaluation.
|
||||
|
||||
Args:
|
||||
image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of
|
||||
shape [height, width, channels].
|
||||
image_size: image size.
|
||||
resize_method: if None, use bicubic.
|
||||
|
||||
Returns:
|
||||
A preprocessed image `Tensor`.
|
||||
"""
|
||||
image = _decode_and_center_crop(image, image_size, resize_method)
|
||||
image = tf.reshape(image, [image_size, image_size, 3])
|
||||
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
|
||||
return image
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Dependency imports
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import image_preprocessing
|
||||
|
||||
|
||||
def _get_preprocessed_image(preprocessor, is_training=False):
|
||||
image_placeholder = tf.compat.v1.placeholder(tf.uint8, [24, 24, 3])
|
||||
label_placeholder = tf.compat.v1.placeholder(tf.int32, [1])
|
||||
image_tensor, _ = preprocessor(image_placeholder, label_placeholder,
|
||||
is_training)
|
||||
|
||||
with tf.compat.v1.Session() as sess:
|
||||
input_image = np.arange(24 * 24 * 3, dtype=np.uint8).reshape([24, 24, 3])
|
||||
image = sess.run(
|
||||
image_tensor,
|
||||
feed_dict={
|
||||
image_placeholder: input_image,
|
||||
label_placeholder: [0]
|
||||
})
|
||||
return image
|
||||
|
||||
|
||||
class PreprocessorTest(tf.test.TestCase):
|
||||
|
||||
def test_preprocess_without_augmentation(self):
|
||||
preprocessor = image_preprocessing.Preprocessor(input_shape=[2, 2],
|
||||
num_classes=2,
|
||||
mean_rgb=[0.0],
|
||||
stddev_rgb=[255.0],
|
||||
use_augmentation=False)
|
||||
actual_image = np.array([[[0., 0.00392157, 0.00784314],
|
||||
[0.14117648, 0.14509805, 0.14901961]],
|
||||
[[0.37647063, 0.3803922, 0.38431376],
|
||||
[0.5176471, 0.52156866, 0.5254902]]])
|
||||
|
||||
image = _get_preprocessed_image(preprocessor)
|
||||
self.assertTrue(np.allclose(image, actual_image, atol=1e-05))
|
||||
|
||||
def test_preprocess_with_augmentation(self):
|
||||
image_preprocessing.CROP_PADDING = 1
|
||||
preprocessor = image_preprocessing.Preprocessor(input_shape=[2, 2],
|
||||
num_classes=2,
|
||||
mean_rgb=[0.0],
|
||||
stddev_rgb=[255.0],
|
||||
use_augmentation=True)
|
||||
# Tests validation image.
|
||||
actual_eval_image = np.array([[[0.17254902, 0.1764706, 0.18039216],
|
||||
[0.26666668, 0.27058825, 0.27450982]],
|
||||
[[0.42352945, 0.427451, 0.43137258],
|
||||
[0.5176471, 0.52156866, 0.5254902]]])
|
||||
|
||||
image = _get_preprocessed_image(preprocessor, is_training=False)
|
||||
self.assertTrue(np.allclose(image, actual_eval_image, atol=1e-05))
|
||||
|
||||
# Tests training image.
|
||||
image1 = _get_preprocessed_image(preprocessor, is_training=True)
|
||||
image2 = _get_preprocessed_image(preprocessor, is_training=True)
|
||||
self.assertFalse(np.allclose(image1, image2, atol=1e-05))
|
||||
self.assertEqual(image1.shape, (2, 2, 3))
|
||||
self.assertEqual(image2.shape, (2, 2, 3))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.compat.v1.disable_eager_execution()
|
||||
tf.test.main()
|
105
mediapipe/model_maker/python/core/utils/loss_functions.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Loss function utility library."""
|
||||
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class FocalLoss(tf.keras.losses.Loss):
|
||||
"""Implementation of focal loss (https://arxiv.org/pdf/1708.02002.pdf).
|
||||
|
||||
This class computes the focal loss between labels and prediction. Focal loss
|
||||
is a weighted loss function that modulates the standard cross-entropy loss
|
||||
based on how well the neural network performs on a specific example of a
|
||||
class. The labels should be provided in a `one_hot` vector representation.
|
||||
There should be `#classes` floating point values per prediction.
|
||||
The loss is reduced across all samples using 'sum_over_batch_size' reduction
|
||||
(see https://www.tensorflow.org/api_docs/python/tf/keras/losses/Reduction).
|
||||
|
||||
Example usage:
|
||||
>>> y_true = [[0, 1, 0], [0, 0, 1]]
|
||||
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
|
||||
>>> gamma = 2
|
||||
>>> focal_loss = FocalLoss(gamma)
|
||||
>>> focal_loss(y_true, y_pred).numpy()
|
||||
0.9326
|
||||
|
||||
>>> # Calling with 'sample_weight'.
|
||||
>>> focal_loss(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()
|
||||
0.6528
|
||||
|
||||
Usage with the `compile()` API:
|
||||
```python
|
||||
model.compile(optimizer='sgd', loss=FocalLoss(gamma))
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
gamma: Focal loss gamma, as described in class docs.
|
||||
class_weight: A weight to apply to the loss, one for each class. The
|
||||
weight is applied for each input where the ground truth label matches.
|
||||
"""
|
||||
super(tf.keras.losses.Loss, self).__init__()
|
||||
# Used for clipping min/max values of probability values in y_pred to avoid
|
||||
# NaNs and Infs in computation.
|
||||
self._epsilon = 1e-7
|
||||
# This is a tunable "focusing parameter"; should be >= 0.
|
||||
# When gamma = 0, the loss returned is the standard categorical
|
||||
# cross-entropy loss.
|
||||
self._gamma = gamma
|
||||
self._class_weight = class_weight
|
||||
# tf.keras.losses.Loss class implementation requires a Reduction specified
|
||||
# in self.reduction. To use this reduction, we should use tensorflow's
|
||||
# compute_weighted_loss function however it is only compatible with v1 of
|
||||
# Tensorflow: https://www.tensorflow.org/api_docs/python/tf/compat/v1/losses/compute_weighted_loss?hl=en. pylint: disable=line-too-long
|
||||
# So even though it is specified here, we don't use self.reduction in the
|
||||
# loss function call.
|
||||
self.reduction = tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
|
||||
|
||||
def __call__(self,
|
||||
y_true: tf.Tensor,
|
||||
y_pred: tf.Tensor,
|
||||
sample_weight: Optional[tf.Tensor] = None) -> tf.Tensor:
|
||||
if self._class_weight:
|
||||
class_weight = tf.convert_to_tensor(self._class_weight, dtype=tf.float32)
|
||||
label = tf.argmax(y_true, axis=1)
|
||||
loss_weight = tf.gather(class_weight, label)
|
||||
else:
|
||||
loss_weight = tf.ones(tf.shape(y_true)[0])
|
||||
y_true = tf.cast(y_true, y_pred.dtype)
|
||||
y_pred = tf.clip_by_value(y_pred, self._epsilon, 1 - self._epsilon)
|
||||
batch_size = tf.cast(tf.shape(y_pred)[0], y_pred.dtype)
|
||||
if sample_weight is None:
|
||||
sample_weight = tf.constant(1.0)
|
||||
weight_shape = sample_weight.shape
|
||||
weight_rank = weight_shape.ndims
|
||||
y_pred_rank = y_pred.shape.ndims
|
||||
if y_pred_rank - weight_rank == 1:
|
||||
sample_weight = tf.expand_dims(sample_weight, [-1])
|
||||
elif weight_rank != 0:
|
||||
raise ValueError(f'Unexpected sample_weights, should be either a scalar'
|
||||
f'or a vector of batch_size:{batch_size.numpy()}')
|
||||
ce = -tf.math.log(y_pred)
|
||||
modulating_factor = tf.math.pow(1 - y_pred, self._gamma)
|
||||
losses = y_true * modulating_factor * ce * sample_weight
|
||||
losses = losses * loss_weight[:, tf.newaxis]
|
||||
# By default, this function uses "sum_over_batch_size" reduction for the
|
||||
# loss per batch.
|
||||
return tf.reduce_sum(losses) / batch_size
|
103
mediapipe/model_maker/python/core/utils/loss_functions_test.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import loss_functions
|
||||
|
||||
|
||||
class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name='no_sample_weight', sample_weight=None),
|
||||
dict(
|
||||
testcase_name='with_sample_weight',
|
||||
sample_weight=tf.constant([0.2, 0.2, 0.3, 0.1, 0.2])))
|
||||
def test_focal_loss_gamma_0_is_cross_entropy(
|
||||
self, sample_weight: Optional[tf.Tensor]):
|
||||
y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1,
|
||||
0]])
|
||||
y_pred = tf.constant([[0.7, 0.1, 0.2], [0.6, 0.3, 0.1], [0.1, 0.5, 0.4],
|
||||
[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
|
||||
|
||||
tf_cce = tf.keras.losses.CategoricalCrossentropy(
|
||||
from_logits=False,
|
||||
reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
|
||||
focal_loss = loss_functions.FocalLoss(gamma=0)
|
||||
self.assertAllClose(
|
||||
tf_cce(y_true, y_pred, sample_weight=sample_weight),
|
||||
focal_loss(y_true, y_pred, sample_weight=sample_weight), 1e-4)
|
||||
|
||||
def test_focal_loss_with_sample_weight(self):
|
||||
y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1,
|
||||
0]])
|
||||
y_pred = tf.constant([[0.7, 0.1, 0.2], [0.6, 0.3, 0.1], [0.1, 0.5, 0.4],
|
||||
[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
|
||||
|
||||
focal_loss = loss_functions.FocalLoss(gamma=0)
|
||||
|
||||
sample_weight = tf.constant([0.2, 0.2, 0.3, 0.1, 0.2])
|
||||
|
||||
self.assertGreater(
|
||||
focal_loss(y_true=y_true, y_pred=y_pred),
|
||||
focal_loss(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name='gt_0.1', y_pred=tf.constant([0.1, 0.9])),
|
||||
dict(testcase_name='gt_0.3', y_pred=tf.constant([0.3, 0.7])),
|
||||
dict(testcase_name='gt_0.5', y_pred=tf.constant([0.5, 0.5])),
|
||||
dict(testcase_name='gt_0.7', y_pred=tf.constant([0.7, 0.3])),
|
||||
dict(testcase_name='gt_0.9', y_pred=tf.constant([0.9, 0.1])),
|
||||
)
|
||||
def test_focal_loss_decreases_with_increasing_gamma(self, y_pred: tf.Tensor):
|
||||
y_true = tf.constant([[1, 0]])
|
||||
|
||||
focal_loss_gamma_0 = loss_functions.FocalLoss(gamma=0)
|
||||
loss_gamma_0 = focal_loss_gamma_0(y_true, y_pred)
|
||||
focal_loss_gamma_0p5 = loss_functions.FocalLoss(gamma=0.5)
|
||||
loss_gamma_0p5 = focal_loss_gamma_0p5(y_true, y_pred)
|
||||
focal_loss_gamma_1 = loss_functions.FocalLoss(gamma=1)
|
||||
loss_gamma_1 = focal_loss_gamma_1(y_true, y_pred)
|
||||
focal_loss_gamma_2 = loss_functions.FocalLoss(gamma=2)
|
||||
loss_gamma_2 = focal_loss_gamma_2(y_true, y_pred)
|
||||
focal_loss_gamma_5 = loss_functions.FocalLoss(gamma=5)
|
||||
loss_gamma_5 = focal_loss_gamma_5(y_true, y_pred)
|
||||
|
||||
self.assertGreater(loss_gamma_0, loss_gamma_0p5)
|
||||
self.assertGreater(loss_gamma_0p5, loss_gamma_1)
|
||||
self.assertGreater(loss_gamma_1, loss_gamma_2)
|
||||
self.assertGreater(loss_gamma_2, loss_gamma_5)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name='index_0', true_class=0),
|
||||
dict(testcase_name='index_1', true_class=1),
|
||||
dict(testcase_name='index_2', true_class=2),
|
||||
)
|
||||
def test_focal_loss_class_weight_is_applied(self, true_class: int):
|
||||
class_weight = [1.0, 3.0, 10.0]
|
||||
y_pred = tf.constant([[1.0, 1.0, 1.0]]) / 3.0
|
||||
y_true = tf.one_hot(true_class, depth=3)[tf.newaxis, :]
|
||||
expected_loss = -math.log(1.0 / 3.0) * class_weight[true_class]
|
||||
|
||||
loss_fn = loss_functions.FocalLoss(gamma=0, class_weight=class_weight)
|
||||
loss = loss_fn(y_true, y_pred)
|
||||
self.assertNear(loss, expected_loss, 1e-4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
241
mediapipe/model_maker/python/core/utils/model_util.py
Normal file
|
@ -0,0 +1,241 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for keras models."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import dataset
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
|
||||
DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
|
||||
ESTIMITED_STEPS_PER_EPOCH = 1000
|
||||
|
||||
|
||||
def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
train_data: Optional[dataset.Dataset] = None) -> int:
|
||||
"""Gets the estimated training steps per epoch.
|
||||
|
||||
1. If `steps_per_epoch` is set, returns `steps_per_epoch` directly.
|
||||
2. Else if we can get the length of training data successfully, returns
|
||||
`train_data_length // batch_size`.
|
||||
|
||||
Args:
|
||||
steps_per_epoch: int, training steps per epoch.
|
||||
batch_size: int, batch size.
|
||||
train_data: training data.
|
||||
|
||||
Returns:
|
||||
Estimated training steps per epoch.
|
||||
|
||||
Raises:
|
||||
ValueError: if both steps_per_epoch and train_data are not set.
|
||||
"""
|
||||
if steps_per_epoch is not None:
|
||||
# steps_per_epoch is set by users manually.
|
||||
return steps_per_epoch
|
||||
else:
|
||||
if train_data is None:
|
||||
raise ValueError('Input train_data cannot be None.')
|
||||
# Gets the steps by the length of the training data.
|
||||
return len(train_data) // batch_size
|
||||
|
||||
|
||||
def export_tflite(
|
||||
model: tf.keras.Model,
|
||||
tflite_filepath: str,
|
||||
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||
supported_ops: Tuple[tf.lite.OpsSet,
|
||||
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,)):
|
||||
"""Converts the model to tflite format and saves it.
|
||||
|
||||
Args:
|
||||
model: model to be converted to tflite.
|
||||
tflite_filepath: File path to save tflite model.
|
||||
quantization_config: Configuration for post-training quantization.
|
||||
supported_ops: A list of supported ops in the converted TFLite file.
|
||||
"""
|
||||
if tflite_filepath is None:
|
||||
raise ValueError(
|
||||
"TFLite filepath couldn't be None when exporting to tflite.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
save_path = os.path.join(temp_dir, 'saved_model')
|
||||
model.save(save_path, include_optimizer=False, save_format='tf')
|
||||
converter = tf.lite.TFLiteConverter.from_saved_model(save_path)
|
||||
|
||||
if quantization_config:
|
||||
converter = quantization_config.set_converter_with_quantization(converter)
|
||||
|
||||
converter.target_spec.supported_ops = supported_ops
|
||||
tflite_model = converter.convert()
|
||||
|
||||
with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
|
||||
f.write(tflite_model)
|
||||
|
||||
|
||||
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
||||
"""Applies a warmup schedule on a given learning rate decay schedule."""
|
||||
|
||||
def __init__(self,
|
||||
initial_learning_rate: float,
|
||||
decay_schedule_fn: Callable[[Any], Any],
|
||||
warmup_steps: int,
|
||||
name: Optional[str] = None):
|
||||
"""Initializes a new instance of the `WarmUp` class.
|
||||
|
||||
Args:
|
||||
initial_learning_rate: learning rate after the warmup.
|
||||
decay_schedule_fn: A function maps step to learning rate. Will be applied
|
||||
for values of step larger than 'warmup_steps'.
|
||||
warmup_steps: Number of steps to do warmup for.
|
||||
name: TF namescope under which to perform the learning rate calculation.
|
||||
"""
|
||||
super(WarmUp, self).__init__()
|
||||
self.initial_learning_rate = initial_learning_rate
|
||||
self.warmup_steps = warmup_steps
|
||||
self.decay_schedule_fn = decay_schedule_fn
|
||||
self.name = name
|
||||
|
||||
def __call__(self, step: Union[int, tf.Tensor]) -> tf.Tensor:
|
||||
with tf.name_scope(self.name or 'WarmUp') as name:
|
||||
# Implements linear warmup. i.e., if global_step < warmup_steps, the
|
||||
# learning rate will be `global_step/num_warmup_steps * init_lr`.
|
||||
global_step_float = tf.cast(step, tf.float32)
|
||||
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
|
||||
warmup_percent_done = global_step_float / warmup_steps_float
|
||||
warmup_learning_rate = self.initial_learning_rate * warmup_percent_done
|
||||
return tf.cond(
|
||||
global_step_float < warmup_steps_float,
|
||||
lambda: warmup_learning_rate,
|
||||
lambda: self.decay_schedule_fn(step),
|
||||
name=name)
|
||||
|
||||
def get_config(self) -> Dict[Text, Any]:
|
||||
return {
|
||||
'initial_learning_rate': self.initial_learning_rate,
|
||||
'decay_schedule_fn': self.decay_schedule_fn,
|
||||
'warmup_steps': self.warmup_steps,
|
||||
'name': self.name
|
||||
}
|
||||
|
||||
|
||||
class LiteRunner(object):
|
||||
"""A runner to do inference with the TFLite model."""
|
||||
|
||||
def __init__(self, tflite_filepath: str):
|
||||
"""Initializes Lite runner with tflite model file.
|
||||
|
||||
Args:
|
||||
tflite_filepath: File path to the TFLite model.
|
||||
"""
|
||||
with tf.io.gfile.GFile(tflite_filepath, 'rb') as f:
|
||||
tflite_model = f.read()
|
||||
self.interpreter = tf.lite.Interpreter(model_content=tflite_model)
|
||||
self.interpreter.allocate_tensors()
|
||||
self.input_details = self.interpreter.get_input_details()
|
||||
self.output_details = self.interpreter.get_output_details()
|
||||
|
||||
def run(
|
||||
self, input_tensors: Union[List[tf.Tensor], Dict[str, tf.Tensor]]
|
||||
) -> Union[List[tf.Tensor], tf.Tensor]:
|
||||
"""Runs inference with the TFLite model.
|
||||
|
||||
Args:
|
||||
input_tensors: List / Dict of the input tensors of the TFLite model. The
|
||||
order should be the same as the keras model if it's a list. It also
|
||||
accepts tensor directly if the model has only 1 input.
|
||||
|
||||
Returns:
|
||||
List of the output tensors for multi-output models, otherwise just
|
||||
the output tensor. The order should be the same as the keras model.
|
||||
"""
|
||||
|
||||
if not isinstance(input_tensors, list) and not isinstance(
|
||||
input_tensors, dict):
|
||||
input_tensors = [input_tensors]
|
||||
|
||||
interpreter = self.interpreter
|
||||
|
||||
# Reshape inputs
|
||||
for i, input_detail in enumerate(self.input_details):
|
||||
input_tensor = _get_input_tensor(
|
||||
input_tensors=input_tensors,
|
||||
input_details=self.input_details,
|
||||
index=i)
|
||||
interpreter.resize_tensor_input(
|
||||
input_index=input_detail['index'], tensor_size=input_tensor.shape)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
# Feed input to the interpreter
|
||||
for i, input_detail in enumerate(self.input_details):
|
||||
input_tensor = _get_input_tensor(
|
||||
input_tensors=input_tensors,
|
||||
input_details=self.input_details,
|
||||
index=i)
|
||||
if input_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT):
|
||||
# Quantize the input
|
||||
scale, zero_point = input_detail['quantization']
|
||||
input_tensor = input_tensor / scale + zero_point
|
||||
input_tensor = np.array(input_tensor, dtype=input_detail['dtype'])
|
||||
interpreter.set_tensor(input_detail['index'], input_tensor)
|
||||
|
||||
interpreter.invoke()
|
||||
|
||||
output_tensors = []
|
||||
for output_detail in self.output_details:
|
||||
output_tensor = interpreter.get_tensor(output_detail['index'])
|
||||
if output_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT):
|
||||
# Dequantize the output
|
||||
scale, zero_point = output_detail['quantization']
|
||||
output_tensor = output_tensor.astype(np.float32)
|
||||
output_tensor = (output_tensor - zero_point) * scale
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
if len(output_tensors) == 1:
|
||||
return output_tensors[0]
|
||||
return output_tensors
|
||||
|
||||
|
||||
def get_lite_runner(tflite_filepath: str) -> 'LiteRunner':
|
||||
"""Returns a `LiteRunner` from file path to TFLite model."""
|
||||
lite_runner = LiteRunner(tflite_filepath)
|
||||
return lite_runner
|
||||
|
||||
|
||||
def _get_input_tensor(input_tensors: Union[List[tf.Tensor], Dict[str,
|
||||
tf.Tensor]],
|
||||
input_details: Dict[str, Any], index: int) -> tf.Tensor:
|
||||
"""Returns input tensor in `input_tensors` that maps `input_detail[i]`."""
|
||||
if isinstance(input_tensors, dict):
|
||||
# Gets the mapped input tensor.
|
||||
input_detail = input_details
|
||||
for input_tensor_name, input_tensor in input_tensors.items():
|
||||
if input_tensor_name in input_detail['name']:
|
||||
return input_tensor
|
||||
raise ValueError('Input tensors don\'t contains a tensor that mapped the '
|
||||
'input detail %s' % str(input_detail))
|
||||
else:
|
||||
return input_tensors[index]
|
137
mediapipe/model_maker/python/core/utils/model_util_test.py
Normal file
|
@ -0,0 +1,137 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
from mediapipe.model_maker.python.core.utils import test_util
|
||||
|
||||
|
||||
class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='input_only_steps_per_epoch',
|
||||
steps_per_epoch=1000,
|
||||
batch_size=None,
|
||||
train_data=None,
|
||||
expected_steps_per_epoch=1000),
|
||||
dict(
|
||||
testcase_name='input_steps_per_epoch_and_batch_size',
|
||||
steps_per_epoch=1000,
|
||||
batch_size=32,
|
||||
train_data=None,
|
||||
expected_steps_per_epoch=1000),
|
||||
dict(
|
||||
testcase_name='input_steps_per_epoch_batch_size_and_train_data',
|
||||
steps_per_epoch=1000,
|
||||
batch_size=32,
|
||||
train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
|
||||
[1, 0]]),
|
||||
expected_steps_per_epoch=1000),
|
||||
dict(
|
||||
testcase_name='input_batch_size_and_train_data',
|
||||
steps_per_epoch=None,
|
||||
batch_size=2,
|
||||
train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
|
||||
[1, 0]]),
|
||||
expected_steps_per_epoch=2))
|
||||
def test_get_steps_per_epoch(self, steps_per_epoch, batch_size, train_data,
|
||||
expected_steps_per_epoch):
|
||||
estimated_steps_per_epoch = model_util.get_steps_per_epoch(
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
batch_size=batch_size,
|
||||
train_data=train_data)
|
||||
self.assertEqual(estimated_steps_per_epoch, expected_steps_per_epoch)
|
||||
|
||||
def test_get_steps_per_epoch_raise_value_error(self):
|
||||
with self.assertRaises(ValueError):
|
||||
model_util.get_steps_per_epoch(
|
||||
steps_per_epoch=None, batch_size=16, train_data=None)
|
||||
|
||||
def test_warmup(self):
|
||||
init_lr = 0.1
|
||||
warmup_steps = 1000
|
||||
num_decay_steps = 100
|
||||
learning_rate_fn = tf.keras.experimental.CosineDecay(
|
||||
initial_learning_rate=init_lr, decay_steps=num_decay_steps)
|
||||
warmup_object = model_util.WarmUp(
|
||||
initial_learning_rate=init_lr,
|
||||
decay_schedule_fn=learning_rate_fn,
|
||||
warmup_steps=1000,
|
||||
name='test')
|
||||
self.assertEqual(
|
||||
warmup_object.get_config(), {
|
||||
'initial_learning_rate': init_lr,
|
||||
'decay_schedule_fn': learning_rate_fn,
|
||||
'warmup_steps': warmup_steps,
|
||||
'name': 'test'
|
||||
})
|
||||
|
||||
def test_export_tflite(self):
|
||||
input_dim = 4
|
||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
||||
model_util.export_tflite(model, tflite_file)
|
||||
self._test_tflite(model, tflite_file, input_dim)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='dynamic_quantize',
|
||||
config=quantization.QuantizationConfig.for_dynamic(),
|
||||
model_size=1288),
|
||||
dict(
|
||||
testcase_name='int8_quantize',
|
||||
config=quantization.QuantizationConfig.for_int8(
|
||||
representative_data=test_util.create_dataset(
|
||||
data_size=10, input_shape=[16], num_classes=3)),
|
||||
model_size=1832),
|
||||
dict(
|
||||
testcase_name='float16_quantize',
|
||||
config=quantization.QuantizationConfig.for_float16(),
|
||||
model_size=1468))
|
||||
def test_export_tflite_quantized(self, config, model_size):
|
||||
input_dim = 16
|
||||
num_classes = 2
|
||||
max_input_value = 5
|
||||
model = test_util.build_model([input_dim], num_classes)
|
||||
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
|
||||
|
||||
model_util.export_tflite(model, tflite_file, config)
|
||||
self._test_tflite(
|
||||
model, tflite_file, input_dim, max_input_value, atol=1e-00)
|
||||
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
|
||||
|
||||
def _test_tflite(self,
|
||||
keras_model: tf.keras.Model,
|
||||
tflite_model_file: str,
|
||||
input_dim: int,
|
||||
max_input_value: int = 1000,
|
||||
atol: float = 1e-04):
|
||||
np.random.seed(0)
|
||||
random_input = np.random.uniform(
|
||||
low=0, high=max_input_value, size=(1, input_dim)).astype(np.float32)
|
||||
|
||||
self.assertTrue(
|
||||
test_util.is_same_output(
|
||||
tflite_model_file, keras_model, random_input, atol=atol))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
213
mediapipe/model_maker/python/core/utils/quantization.py
Normal file
|
@ -0,0 +1,213 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Libraries for post-training quantization."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||
|
||||
DEFAULT_QUANTIZATION_STEPS = 500
|
||||
|
||||
|
||||
def _get_representative_dataset_generator(dataset: tf.data.Dataset,
|
||||
num_steps: int) -> Callable[[], Any]:
|
||||
"""Gets a representative dataset generator for post-training quantization.
|
||||
|
||||
The generator is to provide a small dataset to calibrate or estimate the
|
||||
range, i.e, (min, max) of all floating-point arrays in the model for
|
||||
quantization. Usually, this is a small subset of a few hundred samples
|
||||
randomly chosen, in no particular order, from the training or evaluation
|
||||
dataset. See tf.lite.RepresentativeDataset for more details.
|
||||
|
||||
Args:
|
||||
dataset: Input dataset for extracting representative sub dataset.
|
||||
num_steps: The number of quantization steps which also reflects the size of
|
||||
the representative dataset.
|
||||
|
||||
Returns:
|
||||
A representative dataset generator.
|
||||
"""
|
||||
|
||||
def representative_dataset_gen():
|
||||
"""Generates representative dataset for quantization."""
|
||||
for data, _ in dataset.take(num_steps):
|
||||
yield [data]
|
||||
|
||||
return representative_dataset_gen
|
||||
|
||||
|
||||
class QuantizationConfig(object):
|
||||
"""Configuration for post-training quantization.
|
||||
|
||||
Refer to
|
||||
https://www.tensorflow.org/lite/performance/post_training_quantization
|
||||
for different post-training quantization options.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizations: Optional[Union[tf.lite.Optimize,
|
||||
List[tf.lite.Optimize]]] = None,
|
||||
representative_data: Optional[ds.Dataset] = None,
|
||||
quantization_steps: Optional[int] = None,
|
||||
inference_input_type: Optional[tf.dtypes.DType] = None,
|
||||
inference_output_type: Optional[tf.dtypes.DType] = None,
|
||||
supported_ops: Optional[Union[tf.lite.OpsSet,
|
||||
List[tf.lite.OpsSet]]] = None,
|
||||
supported_types: Optional[Union[tf.dtypes.DType,
|
||||
List[tf.dtypes.DType]]] = None,
|
||||
experimental_new_quantizer: bool = False,
|
||||
):
|
||||
"""Constructs QuantizationConfig.
|
||||
|
||||
Args:
|
||||
optimizations: A list of optimizations to apply when converting the model.
|
||||
If not set, use `[Optimize.DEFAULT]` by default.
|
||||
representative_data: A representative ds.Dataset for post-training
|
||||
quantization.
|
||||
quantization_steps: Number of post-training quantization calibration steps
|
||||
to run (default to DEFAULT_QUANTIZATION_STEPS).
|
||||
inference_input_type: Target data type of real-number input arrays. Allows
|
||||
for a different type for input arrays. Defaults to None. If set, must be
|
||||
be `{tf.float32, tf.uint8, tf.int8}`.
|
||||
inference_output_type: Target data type of real-number output arrays.
|
||||
Allows for a different type for output arrays. Defaults to None. If set,
|
||||
must be `{tf.float32, tf.uint8, tf.int8}`.
|
||||
supported_ops: Set of OpsSet options supported by the device. Used to Set
|
||||
converter.target_spec.supported_ops.
|
||||
supported_types: List of types for constant values on the target device.
|
||||
Supported values are types exported by lite.constants. Frequently, an
|
||||
optimization choice is driven by the most compact (i.e. smallest) type
|
||||
in this list (default [constants.FLOAT]).
|
||||
experimental_new_quantizer: Whether to enable experimental new quantizer.
|
||||
|
||||
Raises:
|
||||
ValueError: if inference_input_type or inference_output_type are set but
|
||||
not in {tf.float32, tf.uint8, tf.int8}.
|
||||
"""
|
||||
if inference_input_type is not None and inference_input_type not in {
|
||||
tf.float32, tf.uint8, tf.int8
|
||||
}:
|
||||
raise ValueError('Unsupported inference_input_type %s' %
|
||||
inference_input_type)
|
||||
if inference_output_type is not None and inference_output_type not in {
|
||||
tf.float32, tf.uint8, tf.int8
|
||||
}:
|
||||
raise ValueError('Unsupported inference_output_type %s' %
|
||||
inference_output_type)
|
||||
|
||||
if optimizations is None:
|
||||
optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
if not isinstance(optimizations, list):
|
||||
optimizations = [optimizations]
|
||||
self.optimizations = optimizations
|
||||
|
||||
self.representative_data = representative_data
|
||||
if self.representative_data is not None and quantization_steps is None:
|
||||
quantization_steps = DEFAULT_QUANTIZATION_STEPS
|
||||
self.quantization_steps = quantization_steps
|
||||
|
||||
self.inference_input_type = inference_input_type
|
||||
self.inference_output_type = inference_output_type
|
||||
|
||||
if supported_ops is not None and not isinstance(supported_ops, list):
|
||||
supported_ops = [supported_ops]
|
||||
self.supported_ops = supported_ops
|
||||
|
||||
if supported_types is not None and not isinstance(supported_types, list):
|
||||
supported_types = [supported_types]
|
||||
self.supported_types = supported_types
|
||||
|
||||
self.experimental_new_quantizer = experimental_new_quantizer
|
||||
|
||||
@classmethod
|
||||
def for_dynamic(cls) -> 'QuantizationConfig':
|
||||
"""Creates configuration for dynamic range quantization."""
|
||||
return QuantizationConfig()
|
||||
|
||||
@classmethod
|
||||
def for_int8(
|
||||
cls,
|
||||
representative_data: ds.Dataset,
|
||||
quantization_steps: int = DEFAULT_QUANTIZATION_STEPS,
|
||||
inference_input_type: tf.dtypes.DType = tf.uint8,
|
||||
inference_output_type: tf.dtypes.DType = tf.uint8,
|
||||
supported_ops: tf.lite.OpsSet = tf.lite.OpsSet.TFLITE_BUILTINS_INT8
|
||||
) -> 'QuantizationConfig':
|
||||
"""Creates configuration for full integer quantization.
|
||||
|
||||
Args:
|
||||
representative_data: Representative data used for post-training
|
||||
quantization.
|
||||
quantization_steps: Number of post-training quantization calibration steps
|
||||
to run.
|
||||
inference_input_type: Target data type of real-number input arrays.
|
||||
inference_output_type: Target data type of real-number output arrays.
|
||||
supported_ops: Set of `tf.lite.OpsSet` options, where each option
|
||||
represents a set of operators supported by the target device.
|
||||
|
||||
Returns:
|
||||
QuantizationConfig.
|
||||
"""
|
||||
return QuantizationConfig(
|
||||
representative_data=representative_data,
|
||||
quantization_steps=quantization_steps,
|
||||
inference_input_type=inference_input_type,
|
||||
inference_output_type=inference_output_type,
|
||||
supported_ops=supported_ops)
|
||||
|
||||
@classmethod
|
||||
def for_float16(cls) -> 'QuantizationConfig':
|
||||
"""Creates configuration for float16 quantization."""
|
||||
return QuantizationConfig(supported_types=[tf.float16])
|
||||
|
||||
def set_converter_with_quantization(self, converter: tf.lite.TFLiteConverter,
|
||||
**kwargs: Any) -> tf.lite.TFLiteConverter:
|
||||
"""Sets input TFLite converter with quantization configurations.
|
||||
|
||||
Args:
|
||||
converter: input tf.lite.TFLiteConverter.
|
||||
**kwargs: arguments used by ds.Dataset.gen_tf_dataset.
|
||||
|
||||
Returns:
|
||||
tf.lite.TFLiteConverter with quantization configurations.
|
||||
"""
|
||||
converter.optimizations = self.optimizations
|
||||
|
||||
if self.representative_data is not None:
|
||||
tf_ds = self.representative_data.gen_tf_dataset(
|
||||
batch_size=1, is_training=False, **kwargs)
|
||||
converter.representative_dataset = tf.lite.RepresentativeDataset(
|
||||
_get_representative_dataset_generator(tf_ds, self.quantization_steps))
|
||||
|
||||
if self.inference_input_type:
|
||||
converter.inference_input_type = self.inference_input_type
|
||||
if self.inference_output_type:
|
||||
converter.inference_output_type = self.inference_output_type
|
||||
if self.supported_ops:
|
||||
converter.target_spec.supported_ops = self.supported_ops
|
||||
if self.supported_types:
|
||||
converter.target_spec.supported_types = self.supported_types
|
||||
|
||||
if self.experimental_new_quantizer is not None:
|
||||
converter.experimental_new_quantizer = self.experimental_new_quantizer
|
||||
return converter
|
108
mediapipe/model_maker/python/core/utils/quantization_test.py
Normal file
|
@ -0,0 +1,108 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
from mediapipe.model_maker.python.core.utils import test_util
|
||||
|
||||
|
||||
class QuantizationTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_create_dynamic_quantization_config(self):
|
||||
config = quantization.QuantizationConfig.for_dynamic()
|
||||
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
|
||||
self.assertIsNone(config.representative_data)
|
||||
self.assertIsNone(config.inference_input_type)
|
||||
self.assertIsNone(config.inference_output_type)
|
||||
self.assertIsNone(config.supported_ops)
|
||||
self.assertIsNone(config.supported_types)
|
||||
self.assertFalse(config.experimental_new_quantizer)
|
||||
|
||||
def test_create_int8_quantization_config(self):
|
||||
representative_data = test_util.create_dataset(
|
||||
data_size=10, input_shape=[4], num_classes=3)
|
||||
config = quantization.QuantizationConfig.for_int8(
|
||||
representative_data=representative_data)
|
||||
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
|
||||
self.assertEqual(config.inference_input_type, tf.uint8)
|
||||
self.assertEqual(config.inference_output_type, tf.uint8)
|
||||
self.assertEqual(config.supported_ops,
|
||||
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8])
|
||||
self.assertFalse(config.experimental_new_quantizer)
|
||||
|
||||
def test_set_converter_with_quantization_from_int8_config(self):
|
||||
representative_data = test_util.create_dataset(
|
||||
data_size=10, input_shape=[4], num_classes=3)
|
||||
config = quantization.QuantizationConfig.for_int8(
|
||||
representative_data=representative_data)
|
||||
model = test_util.build_model(input_shape=[4], num_classes=3)
|
||||
saved_model_dir = self.get_temp_dir()
|
||||
model.save(saved_model_dir)
|
||||
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
|
||||
converter = config.set_converter_with_quantization(converter=converter)
|
||||
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
|
||||
self.assertEqual(config.inference_input_type, tf.uint8)
|
||||
self.assertEqual(config.inference_output_type, tf.uint8)
|
||||
self.assertEqual(config.supported_ops,
|
||||
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8])
|
||||
tflite_model = converter.convert()
|
||||
interpreter = tf.lite.Interpreter(model_content=tflite_model)
|
||||
self.assertEqual(interpreter.get_input_details()[0]['dtype'], tf.uint8)
|
||||
self.assertEqual(interpreter.get_output_details()[0]['dtype'], tf.uint8)
|
||||
|
||||
def test_create_float16_quantization_config(self):
|
||||
config = quantization.QuantizationConfig.for_float16()
|
||||
self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT])
|
||||
self.assertIsNone(config.representative_data)
|
||||
self.assertIsNone(config.inference_input_type)
|
||||
self.assertIsNone(config.inference_output_type)
|
||||
self.assertIsNone(config.supported_ops)
|
||||
self.assertEqual(config.supported_types, [tf.float16])
|
||||
self.assertFalse(config.experimental_new_quantizer)
|
||||
|
||||
def test_set_converter_with_quantization_from_float16_config(self):
|
||||
config = quantization.QuantizationConfig.for_float16()
|
||||
model = test_util.build_model(input_shape=[4], num_classes=3)
|
||||
saved_model_dir = self.get_temp_dir()
|
||||
model.save(saved_model_dir)
|
||||
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
|
||||
converter = config.set_converter_with_quantization(converter=converter)
|
||||
self.assertEqual(config.supported_types, [tf.float16])
|
||||
tflite_model = converter.convert()
|
||||
interpreter = tf.lite.Interpreter(model_content=tflite_model)
|
||||
# The input and output are expected to be set to float32 by default.
|
||||
self.assertEqual(interpreter.get_input_details()[0]['dtype'], tf.float32)
|
||||
self.assertEqual(interpreter.get_output_details()[0]['dtype'], tf.float32)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='invalid_inference_input_type',
|
||||
inference_input_type=tf.uint8,
|
||||
inference_output_type=tf.int64),
|
||||
dict(
|
||||
testcase_name='invalid_inference_output_type',
|
||||
inference_input_type=tf.int64,
|
||||
inference_output_type=tf.float32))
|
||||
def test_create_quantization_config_failure(self, inference_input_type,
|
||||
inference_output_type):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = quantization.QuantizationConfig(
|
||||
inference_input_type=inference_input_type,
|
||||
inference_output_type=inference_output_type)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
76
mediapipe/model_maker/python/core/utils/test_util.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Test utilities for model maker."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
|
||||
|
||||
def create_dataset(data_size: int,
|
||||
input_shape: List[int],
|
||||
num_classes: int,
|
||||
max_input_value: int = 1000) -> ds.Dataset:
|
||||
"""Creates and returns a simple `Dataset` object for test."""
|
||||
features = tf.random.uniform(
|
||||
shape=[data_size] + input_shape,
|
||||
minval=0,
|
||||
maxval=max_input_value,
|
||||
dtype=tf.float32)
|
||||
|
||||
labels = tf.random.uniform(
|
||||
shape=[data_size], minval=0, maxval=num_classes, dtype=tf.int32)
|
||||
|
||||
tf_dataset = tf.data.Dataset.from_tensor_slices((features, labels))
|
||||
dataset = ds.Dataset(tf_dataset, data_size)
|
||||
return dataset
|
||||
|
||||
|
||||
def build_model(input_shape: List[int], num_classes: int) -> tf.keras.Model:
|
||||
"""Builds a simple Keras model for test."""
|
||||
inputs = tf.keras.layers.Input(shape=input_shape)
|
||||
if len(input_shape) == 3: # Image inputs.
|
||||
outputs = tf.keras.layers.GlobalAveragePooling2D()(inputs)
|
||||
outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(outputs)
|
||||
elif len(input_shape) == 1: # Text inputs.
|
||||
outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(inputs)
|
||||
else:
|
||||
raise ValueError("Model inputs should be 2D tensor or 4D tensor.")
|
||||
|
||||
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
return model
|
||||
|
||||
|
||||
def is_same_output(tflite_file: str,
|
||||
keras_model: tf.keras.Model,
|
||||
input_tensors: Union[List[tf.Tensor], tf.Tensor],
|
||||
atol: float = 1e-04) -> bool:
|
||||
"""Returns if the output of TFLite model and keras model are identical."""
|
||||
# Gets output from lite model.
|
||||
lite_runner = model_util.get_lite_runner(tflite_file)
|
||||
lite_output = lite_runner.run(input_tensors)
|
||||
|
||||
# Gets output from keras model.
|
||||
keras_output = keras_model.predict_on_batch(input_tensors)
|
||||
|
||||
return np.allclose(lite_output, keras_output, atol=atol)
|
4
mediapipe/model_maker/requirements.txt
Normal file
|
@ -0,0 +1,4 @@
|
|||
absl-py
|
||||
numpy
|
||||
opencv-contrib-python
|
||||
tensorflow
|
|
@ -507,8 +507,11 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
|
|||
}
|
||||
};
|
||||
|
||||
// REGISTER_MEDIAPIPE_GRAPH argument has to fit on one line to work properly.
|
||||
// clang-format off
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::components::processors::ClassificationPostprocessingGraph); // NOLINT
|
||||
// clang-format on
|
||||
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
|
|
|
@ -41,8 +41,8 @@ cc_test(
|
|||
)
|
||||
|
||||
cc_library(
|
||||
name = "hand_gesture_recognizer_subgraph",
|
||||
srcs = ["hand_gesture_recognizer_subgraph.cc"],
|
||||
name = "hand_gesture_recognizer_graph",
|
||||
srcs = ["hand_gesture_recognizer_graph.cc"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:concatenate_vector_calculator",
|
||||
"//mediapipe/calculators/tensor:tensor_converter_calculator",
|
||||
|
@ -62,11 +62,11 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators:handedness_to_matrix_calculator",
|
||||
"//mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators:landmarks_to_matrix_calculator",
|
||||
"//mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto:hand_gesture_recognizer_subgraph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto:landmarks_to_matrix_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_subgraph",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:handedness_to_matrix_calculator",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||
"@com_google_absl//absl/status",
|
|
@ -12,11 +12,23 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
|
||||
package(default_visibility = [
|
||||
"//mediapipe/app/xeno:__subpackages__",
|
||||
"//mediapipe/tasks:internal",
|
||||
])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "landmarks_to_matrix_calculator_proto",
|
||||
srcs = ["landmarks_to_matrix_calculator.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "handedness_to_matrix_calculator",
|
||||
srcs = ["handedness_to_matrix_calculator.cc"],
|
||||
|
@ -25,7 +37,7 @@ cc_library(
|
|||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/tasks/cc/vision/hand_gesture_recognizer:handedness_util",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer:handedness_util",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
|
@ -53,11 +65,11 @@ cc_library(
|
|||
name = "landmarks_to_matrix_calculator",
|
||||
srcs = ["landmarks_to_matrix_calculator.cc"],
|
||||
deps = [
|
||||
":landmarks_to_matrix_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto:landmarks_to_matrix_calculator_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
|
@ -26,14 +26,16 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h"
|
||||
|
||||
// TODO Update to use API2
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace api2 {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::vision::gesture_recognizer::GetLeftHandScore;
|
||||
|
||||
constexpr char kHandednessTag[] = "HANDEDNESS";
|
||||
constexpr char kHandednessMatrixTag[] = "HANDEDNESS_MATRIX";
|
||||
|
||||
|
@ -71,6 +73,8 @@ class HandednessToMatrixCalculator : public CalculatorBase {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// TODO remove this after change to API2, because Setting offset
|
||||
// to 0 is the default in API2
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
return absl::OkStatus();
|
||||
|
@ -95,6 +99,5 @@ absl::Status HandednessToMatrixCalculator::Process(CalculatorContext* cc) {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -28,8 +28,6 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -95,6 +93,4 @@ INSTANTIATE_TEST_CASE_P(
|
|||
|
||||
} // namespace
|
||||
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -27,13 +27,11 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h"
|
||||
|
||||
// TODO Update to use API2
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
|
||||
using proto::LandmarksToMatrixCalculatorOptions;
|
||||
namespace api2 {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -175,7 +173,7 @@ absl::Status ProcessLandmarks(LandmarkListT landmarks, CalculatorContext* cc) {
|
|||
// input_stream: "IMAGE_SIZE:image_size"
|
||||
// output_stream: "LANDMARKS_MATRIX:landmarks_matrix"
|
||||
// options {
|
||||
// [mediapipe.tasks.vision.proto.LandmarksToMatrixCalculatorOptions.ext] {
|
||||
// [mediapipe.LandmarksToMatrixCalculatorOptions.ext] {
|
||||
// object_normalization: true
|
||||
// object_normalization_origin_offset: 0
|
||||
// }
|
||||
|
@ -221,6 +219,5 @@ absl::Status LandmarksToMatrixCalculator::Process(CalculatorContext* cc) {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -15,7 +15,7 @@ limitations under the License.
|
|||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.vision.proto;
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
|
@ -28,8 +28,6 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -72,8 +70,7 @@ TEST_P(Landmarks2dToMatrixCalculatorTest, OutputsCorrectResult) {
|
|||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "LANDMARKS_MATRIX:landmarks_matrix"
|
||||
options {
|
||||
[mediapipe.tasks.vision.proto.LandmarksToMatrixCalculatorOptions
|
||||
.ext] {
|
||||
[mediapipe.LandmarksToMatrixCalculatorOptions.ext] {
|
||||
object_normalization: $0
|
||||
object_normalization_origin_offset: $1
|
||||
}
|
||||
|
@ -145,8 +142,7 @@ TEST_P(LandmarksWorld3dToMatrixCalculatorTest, OutputsCorrectResult) {
|
|||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "LANDMARKS_MATRIX:landmarks_matrix"
|
||||
options {
|
||||
[mediapipe.tasks.vision.proto.LandmarksToMatrixCalculatorOptions
|
||||
.ext] {
|
||||
[mediapipe.LandmarksToMatrixCalculatorOptions.ext] {
|
||||
object_normalization: $0
|
||||
object_normalization_origin_offset: $1
|
||||
}
|
||||
|
@ -202,6 +198,4 @@ INSTANTIATE_TEST_CASE_P(
|
|||
|
||||
} // namespace
|
||||
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -34,14 +34,15 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace gesture_recognizer {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -50,9 +51,8 @@ using ::mediapipe::api2::Output;
|
|||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::vision::hand_gesture_recognizer::proto::
|
||||
HandGestureRecognizerSubgraphOptions;
|
||||
using ::mediapipe::tasks::vision::proto::LandmarksToMatrixCalculatorOptions;
|
||||
using ::mediapipe::tasks::vision::gesture_recognizer::proto::
|
||||
HandGestureRecognizerGraphOptions;
|
||||
|
||||
constexpr char kHandednessTag[] = "HANDEDNESS";
|
||||
constexpr char kLandmarksTag[] = "LANDMARKS";
|
||||
|
@ -70,18 +70,6 @@ constexpr char kIndexTag[] = "INDEX";
|
|||
constexpr char kIterableTag[] = "ITERABLE";
|
||||
constexpr char kBatchEndTag[] = "BATCH_END";
|
||||
|
||||
absl::Status SanityCheckOptions(
|
||||
const HandGestureRecognizerSubgraphOptions& options) {
|
||||
if (options.min_tracking_confidence() < 0 ||
|
||||
options.min_tracking_confidence() > 1) {
|
||||
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
|
||||
"Invalid `min_tracking_confidence` option: "
|
||||
"value must be in the range [0.0, 1.0]",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
|
||||
Graph& graph) {
|
||||
auto& node = graph.AddNode("TensorConverterCalculator");
|
||||
|
@ -91,9 +79,10 @@ Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
|
|||
|
||||
} // namespace
|
||||
|
||||
// A "mediapipe.tasks.vision.SingleHandGestureRecognizerSubgraph" performs
|
||||
// single hand gesture recognition. This graph is used as a building block for
|
||||
// mediapipe.tasks.vision.HandGestureRecognizerGraph.
|
||||
// A
|
||||
// "mediapipe.tasks.vision.gesture_recognizer.SingleHandGestureRecognizerGraph"
|
||||
// performs single hand gesture recognition. This graph is used as a building
|
||||
// block for mediapipe.tasks.vision.GestureRecognizerGraph.
|
||||
//
|
||||
// Inputs:
|
||||
// HANDEDNESS - ClassificationList
|
||||
|
@ -113,14 +102,15 @@ Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
|
|||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "mediapipe.tasks.vision.SingleHandGestureRecognizerSubgraph"
|
||||
// calculator:
|
||||
// "mediapipe.tasks.vision.gesture_recognizer.SingleHandGestureRecognizerGraph"
|
||||
// input_stream: "HANDEDNESS:handedness"
|
||||
// input_stream: "LANDMARKS:landmarks"
|
||||
// input_stream: "WORLD_LANDMARKS:world_landmarks"
|
||||
// input_stream: "IMAGE_SIZE:image_size"
|
||||
// output_stream: "HAND_GESTURES:hand_gestures"
|
||||
// options {
|
||||
// [mediapipe.tasks.vision.hand_gesture_recognizer.proto.HandGestureRecognizerSubgraphOptions.ext]
|
||||
// [mediapipe.tasks.vision.gesture_recognizer.proto.HandGestureRecognizerGraphOptions.ext]
|
||||
// {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
|
@ -130,19 +120,19 @@ Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
|
|||
// }
|
||||
// }
|
||||
// }
|
||||
class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph {
|
||||
class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto* model_resources,
|
||||
CreateModelResources<HandGestureRecognizerSubgraphOptions>(sc));
|
||||
CreateModelResources<HandGestureRecognizerGraphOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto hand_gestures,
|
||||
BuildHandGestureRecognizerGraph(
|
||||
sc->Options<HandGestureRecognizerSubgraphOptions>(),
|
||||
*model_resources, graph[Input<ClassificationList>(kHandednessTag)],
|
||||
BuildGestureRecognizerGraph(
|
||||
sc->Options<HandGestureRecognizerGraphOptions>(), *model_resources,
|
||||
graph[Input<ClassificationList>(kHandednessTag)],
|
||||
graph[Input<NormalizedLandmarkList>(kLandmarksTag)],
|
||||
graph[Input<LandmarkList>(kWorldLandmarksTag)],
|
||||
graph[Input<std::pair<int, int>>(kImageSizeTag)], graph));
|
||||
|
@ -151,15 +141,13 @@ class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph {
|
|||
}
|
||||
|
||||
private:
|
||||
absl::StatusOr<Source<ClassificationResult>> BuildHandGestureRecognizerGraph(
|
||||
const HandGestureRecognizerSubgraphOptions& graph_options,
|
||||
absl::StatusOr<Source<ClassificationResult>> BuildGestureRecognizerGraph(
|
||||
const HandGestureRecognizerGraphOptions& graph_options,
|
||||
const core::ModelResources& model_resources,
|
||||
Source<ClassificationList> handedness,
|
||||
Source<NormalizedLandmarkList> hand_landmarks,
|
||||
Source<LandmarkList> hand_world_landmarks,
|
||||
Source<std::pair<int, int>> image_size, Graph& graph) {
|
||||
MP_RETURN_IF_ERROR(SanityCheckOptions(graph_options));
|
||||
|
||||
// Converts the ClassificationList to a matrix.
|
||||
auto& handedness_to_matrix = graph.AddNode("HandednessToMatrixCalculator");
|
||||
handedness >> handedness_to_matrix.In(kHandednessTag);
|
||||
|
@ -235,12 +223,15 @@ class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph {
|
|||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::SingleHandGestureRecognizerSubgraph);
|
||||
::mediapipe::tasks::vision::gesture_recognizer::SingleHandGestureRecognizerGraph); // NOLINT
|
||||
// clang-format on
|
||||
|
||||
// A "mediapipe.tasks.vision.HandGestureRecognizerSubgraph" performs multi
|
||||
// hand gesture recognition. This graph is used as a building block for
|
||||
// mediapipe.tasks.vision.HandGestureRecognizerGraph.
|
||||
// A
|
||||
// "mediapipe.tasks.vision.gesture_recognizer.MultipleHandGestureRecognizerGraph"
|
||||
// performs multi hand gesture recognition. This graph is used as a building
|
||||
// block for mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph.
|
||||
//
|
||||
// Inputs:
|
||||
// HANDEDNESS - std::vector<ClassificationList>
|
||||
|
@ -263,7 +254,8 @@ REGISTER_MEDIAPIPE_GRAPH(
|
|||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "mediapipe.tasks.vision.HandGestureRecognizerSubgraph"
|
||||
// calculator:
|
||||
// "mediapipe.tasks.vision.gesture_recognizer.MultipleHandGestureRecognizerGraph"
|
||||
// input_stream: "HANDEDNESS:handedness"
|
||||
// input_stream: "LANDMARKS:landmarks"
|
||||
// input_stream: "WORLD_LANDMARKS:world_landmarks"
|
||||
|
@ -271,7 +263,7 @@ REGISTER_MEDIAPIPE_GRAPH(
|
|||
// input_stream: "HAND_TRACKING_IDS:hand_tracking_ids"
|
||||
// output_stream: "HAND_GESTURES:hand_gestures"
|
||||
// options {
|
||||
// [mediapipe.tasks.vision.hand_gesture_recognizer.proto.HandGestureRecognizerSubgraph.ext]
|
||||
// [mediapipe.tasks.vision.gesture_recognizer.proto.MultipleHandGestureRecognizerGraph.ext]
|
||||
// {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
|
@ -281,15 +273,15 @@ REGISTER_MEDIAPIPE_GRAPH(
|
|||
// }
|
||||
// }
|
||||
// }
|
||||
class HandGestureRecognizerSubgraph : public core::ModelTaskGraph {
|
||||
class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto multi_hand_gestures,
|
||||
BuildMultiHandGestureRecognizerSubraph(
|
||||
sc->Options<HandGestureRecognizerSubgraphOptions>(),
|
||||
BuildMultiGestureRecognizerSubraph(
|
||||
sc->Options<HandGestureRecognizerGraphOptions>(),
|
||||
graph[Input<std::vector<ClassificationList>>(kHandednessTag)],
|
||||
graph[Input<std::vector<NormalizedLandmarkList>>(kLandmarksTag)],
|
||||
graph[Input<std::vector<LandmarkList>>(kWorldLandmarksTag)],
|
||||
|
@ -302,8 +294,8 @@ class HandGestureRecognizerSubgraph : public core::ModelTaskGraph {
|
|||
|
||||
private:
|
||||
absl::StatusOr<Source<std::vector<ClassificationResult>>>
|
||||
BuildMultiHandGestureRecognizerSubraph(
|
||||
const HandGestureRecognizerSubgraphOptions& graph_options,
|
||||
BuildMultiGestureRecognizerSubraph(
|
||||
const HandGestureRecognizerGraphOptions& graph_options,
|
||||
Source<std::vector<ClassificationList>> multi_handedness,
|
||||
Source<std::vector<NormalizedLandmarkList>> multi_hand_landmarks,
|
||||
Source<std::vector<LandmarkList>> multi_hand_world_landmarks,
|
||||
|
@ -341,17 +333,18 @@ class HandGestureRecognizerSubgraph : public core::ModelTaskGraph {
|
|||
hand_tracking_id >> get_world_landmarks_at_index.In(kIndexTag);
|
||||
auto hand_world_landmarks = get_world_landmarks_at_index.Out(kItemTag);
|
||||
|
||||
auto& hand_gesture_recognizer_subgraph = graph.AddNode(
|
||||
"mediapipe.tasks.vision.SingleHandGestureRecognizerSubgraph");
|
||||
hand_gesture_recognizer_subgraph
|
||||
.GetOptions<HandGestureRecognizerSubgraphOptions>()
|
||||
auto& hand_gesture_recognizer_graph = graph.AddNode(
|
||||
"mediapipe.tasks.vision.gesture_recognizer."
|
||||
"SingleHandGestureRecognizerGraph");
|
||||
hand_gesture_recognizer_graph
|
||||
.GetOptions<HandGestureRecognizerGraphOptions>()
|
||||
.CopyFrom(graph_options);
|
||||
handedness >> hand_gesture_recognizer_subgraph.In(kHandednessTag);
|
||||
hand_landmarks >> hand_gesture_recognizer_subgraph.In(kLandmarksTag);
|
||||
handedness >> hand_gesture_recognizer_graph.In(kHandednessTag);
|
||||
hand_landmarks >> hand_gesture_recognizer_graph.In(kLandmarksTag);
|
||||
hand_world_landmarks >>
|
||||
hand_gesture_recognizer_subgraph.In(kWorldLandmarksTag);
|
||||
image_size_clone >> hand_gesture_recognizer_subgraph.In(kImageSizeTag);
|
||||
auto hand_gestures = hand_gesture_recognizer_subgraph.Out(kHandGesturesTag);
|
||||
hand_gesture_recognizer_graph.In(kWorldLandmarksTag);
|
||||
image_size_clone >> hand_gesture_recognizer_graph.In(kImageSizeTag);
|
||||
auto hand_gestures = hand_gesture_recognizer_graph.Out(kHandGesturesTag);
|
||||
|
||||
auto& end_loop_classification_results =
|
||||
graph.AddNode("mediapipe.tasks.EndLoopClassificationResultCalculator");
|
||||
|
@ -364,9 +357,12 @@ class HandGestureRecognizerSubgraph : public core::ModelTaskGraph {
|
|||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::HandGestureRecognizerSubgraph);
|
||||
::mediapipe::tasks::vision::gesture_recognizer::MultipleHandGestureRecognizerGraph); // NOLINT
|
||||
// clang-format on
|
||||
|
||||
} // namespace gesture_recognizer
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace gesture_recognizer {
|
||||
|
||||
namespace {} // namespace
|
||||
|
||||
|
@ -58,6 +59,7 @@ absl::StatusOr<float> GetLeftHandScore(
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace gesture_recognizer
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_HAND_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace gesture_recognizer {
|
||||
|
||||
bool IsLeftHand(const mediapipe::Classification& c);
|
||||
|
||||
|
@ -30,8 +31,9 @@ bool IsRightHand(const mediapipe::Classification& c);
|
|||
absl::StatusOr<float> GetLeftHandScore(
|
||||
const mediapipe::ClassificationList& classification_list);
|
||||
|
||||
} // namespace gesture_recognizer
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h"
|
||||
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace gesture_recognizer {
|
||||
namespace {
|
||||
|
||||
TEST(GetLeftHandScore, SingleLeftHandClassification) {
|
||||
|
@ -72,6 +73,7 @@ TEST(GetLeftHandScore, LeftAndRightLowerCaseHandClassification) {
|
|||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gesture_recognizer
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -21,8 +21,18 @@ package(default_visibility = [
|
|||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "hand_gesture_recognizer_subgraph_options_proto",
|
||||
srcs = ["hand_gesture_recognizer_subgraph_options.proto"],
|
||||
name = "gesture_embedder_graph_options_proto",
|
||||
srcs = ["gesture_embedder_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "hand_gesture_recognizer_graph_options_proto",
|
||||
srcs = ["hand_gesture_recognizer_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -30,12 +40,3 @@ mediapipe_proto_library(
|
|||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "landmarks_to_matrix_calculator_proto",
|
||||
srcs = ["landmarks_to_matrix_calculator.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,30 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.vision.gesture_recognizer.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
message GestureEmbedderGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional GestureEmbedderGraphOptions ext = 478825422;
|
||||
}
|
||||
// Base options for configuring hand gesture recognition subgraph, such as
|
||||
// specifying the TfLite model file with metadata, accelerator options, etc.
|
||||
optional core.proto.BaseOptions base_options = 1;
|
||||
}
|
|
@ -15,15 +15,15 @@ limitations under the License.
|
|||
// TODO Refactor naming and class structure of hand related Tasks.
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.vision.hand_gesture_recognizer.proto;
|
||||
package mediapipe.tasks.vision.gesture_recognizer.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
message HandGestureRecognizerSubgraphOptions {
|
||||
message HandGestureRecognizerGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional HandGestureRecognizerSubgraphOptions ext = 463370452;
|
||||
optional HandGestureRecognizerGraphOptions ext = 463370452;
|
||||
}
|
||||
// Base options for configuring hand gesture recognition subgraph, such as
|
||||
// specifying the TfLite model file with metadata, accelerator options, etc.
|
|
@ -51,7 +51,7 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
|
|
|
@ -40,12 +40,13 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace hand_detector {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -53,18 +54,23 @@ using ::mediapipe::api2::Input;
|
|||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::vision::hand_detector::proto::HandDetectorOptions;
|
||||
using ::mediapipe::tasks::vision::hand_detector::proto::
|
||||
HandDetectorGraphOptions;
|
||||
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||
constexpr char kNormRectsTag[] = "NORM_RECTS";
|
||||
constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS";
|
||||
constexpr char kHandRectsTag[] = "HAND_RECTS";
|
||||
constexpr char kPalmRectsTag[] = "PALM_RECTS";
|
||||
|
||||
struct HandDetectionOuts {
|
||||
Source<std::vector<Detection>> palm_detections;
|
||||
Source<std::vector<NormalizedRect>> hand_rects;
|
||||
Source<std::vector<NormalizedRect>> palm_rects;
|
||||
Source<Image> image;
|
||||
};
|
||||
|
||||
void ConfigureTensorsToDetectionsCalculator(
|
||||
const HandDetectorGraphOptions& tasks_options,
|
||||
mediapipe::TensorsToDetectionsCalculatorOptions* options) {
|
||||
// TODO use metadata to configure these fields.
|
||||
options->set_num_classes(1);
|
||||
|
@ -77,7 +83,7 @@ void ConfigureTensorsToDetectionsCalculator(
|
|||
options->set_sigmoid_score(true);
|
||||
options->set_score_clipping_thresh(100.0);
|
||||
options->set_reverse_output_order(true);
|
||||
options->set_min_score_thresh(0.5);
|
||||
options->set_min_score_thresh(tasks_options.min_detection_confidence());
|
||||
options->set_x_scale(192.0);
|
||||
options->set_y_scale(192.0);
|
||||
options->set_w_scale(192.0);
|
||||
|
@ -134,9 +140,9 @@ void ConfigureRectTransformationCalculator(
|
|||
|
||||
} // namespace
|
||||
|
||||
// A "mediapipe.tasks.vision.HandDetectorGraph" performs hand detection. The
|
||||
// Hand Detection Graph is based on palm detection model, and scale the detected
|
||||
// palm bounding box to enclose the detected whole hand.
|
||||
// A "mediapipe.tasks.vision.hand_detector.HandDetectorGraph" performs hand
|
||||
// detection. The Hand Detection Graph is based on palm detection model, and
|
||||
// scale the detected palm bounding box to enclose the detected whole hand.
|
||||
// Accepts CPU input images and outputs Landmark on CPU.
|
||||
//
|
||||
// Inputs:
|
||||
|
@ -144,19 +150,27 @@ void ConfigureRectTransformationCalculator(
|
|||
// Image to perform detection on.
|
||||
//
|
||||
// Outputs:
|
||||
// DETECTIONS - std::vector<Detection>
|
||||
// PALM_DETECTIONS - std::vector<Detection>
|
||||
// Detected palms with maximum `num_hands` specified in options.
|
||||
// NORM_RECTS - std::vector<NormalizedRect>
|
||||
// HAND_RECTS - std::vector<NormalizedRect>
|
||||
// Detected hand bounding boxes in normalized coordinates.
|
||||
// PLAM_RECTS - std::vector<NormalizedRect>
|
||||
// Detected palm bounding boxes in normalized coordinates.
|
||||
// IMAGE - Image
|
||||
// The input image that the hand detector runs on and has the pixel data
|
||||
// stored on the target storage (CPU vs GPU).
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "mediapipe.tasks.vision.HandDetectorGraph"
|
||||
// calculator: "mediapipe.tasks.vision.hand_detector.HandDetectorGraph"
|
||||
// input_stream: "IMAGE:image"
|
||||
// output_stream: "DETECTIONS:palm_detections"
|
||||
// output_stream: "NORM_RECTS:hand_rects_from_palm_detections"
|
||||
// output_stream: "PALM_DETECTIONS:palm_detections"
|
||||
// output_stream: "HAND_RECTS:hand_rects_from_palm_detections"
|
||||
// output_stream: "PALM_RECTS:palm_rects"
|
||||
// output_stream: "IMAGE:image_out"
|
||||
// options {
|
||||
// [mediapipe.tasks.hand_detector.proto.HandDetectorOptions.ext] {
|
||||
// [mediapipe.tasks.vision.hand_detector.proto.HandDetectorGraphOptions.ext]
|
||||
// {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
// file_name: "palm_detection.tflite"
|
||||
|
@ -173,16 +187,20 @@ class HandDetectorGraph : public core::ModelTaskGraph {
|
|||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
ASSIGN_OR_RETURN(const auto* model_resources,
|
||||
CreateModelResources<HandDetectorOptions>(sc));
|
||||
CreateModelResources<HandDetectorGraphOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(auto hand_detection_outs,
|
||||
BuildHandDetectionSubgraph(
|
||||
sc->Options<HandDetectorOptions>(), *model_resources,
|
||||
ASSIGN_OR_RETURN(
|
||||
auto hand_detection_outs,
|
||||
BuildHandDetectionSubgraph(sc->Options<HandDetectorGraphOptions>(),
|
||||
*model_resources,
|
||||
graph[Input<Image>(kImageTag)], graph));
|
||||
hand_detection_outs.palm_detections >>
|
||||
graph[Output<std::vector<Detection>>(kDetectionsTag)];
|
||||
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
|
||||
hand_detection_outs.hand_rects >>
|
||||
graph[Output<std::vector<NormalizedRect>>(kNormRectsTag)];
|
||||
graph[Output<std::vector<NormalizedRect>>(kHandRectsTag)];
|
||||
hand_detection_outs.palm_rects >>
|
||||
graph[Output<std::vector<NormalizedRect>>(kPalmRectsTag)];
|
||||
hand_detection_outs.image >> graph[Output<Image>(kImageTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
|
@ -196,7 +214,7 @@ class HandDetectorGraph : public core::ModelTaskGraph {
|
|||
// image_in: image stream to run hand detection on.
|
||||
// graph: the mediapipe builder::Graph instance to be updated.
|
||||
absl::StatusOr<HandDetectionOuts> BuildHandDetectionSubgraph(
|
||||
const HandDetectorOptions& subgraph_options,
|
||||
const HandDetectorGraphOptions& subgraph_options,
|
||||
const core::ModelResources& model_resources, Source<Image> image_in,
|
||||
Graph& graph) {
|
||||
// Add image preprocessing subgraph. The model expects aspect ratio
|
||||
|
@ -235,6 +253,7 @@ class HandDetectorGraph : public core::ModelTaskGraph {
|
|||
auto& tensors_to_detections =
|
||||
graph.AddNode("TensorsToDetectionsCalculator");
|
||||
ConfigureTensorsToDetectionsCalculator(
|
||||
subgraph_options,
|
||||
&tensors_to_detections
|
||||
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>());
|
||||
model_output_tensors >> tensors_to_detections.In("TENSORS");
|
||||
|
@ -281,7 +300,8 @@ class HandDetectorGraph : public core::ModelTaskGraph {
|
|||
.GetOptions<mediapipe::DetectionsToRectsCalculatorOptions>());
|
||||
palm_detections >> detections_to_rects.In("DETECTIONS");
|
||||
image_size >> detections_to_rects.In("IMAGE_SIZE");
|
||||
auto palm_rects = detections_to_rects.Out("NORM_RECTS");
|
||||
auto palm_rects =
|
||||
detections_to_rects[Output<std::vector<NormalizedRect>>("NORM_RECTS")];
|
||||
|
||||
// Expands and shifts the rectangle that contains the palm so that it's
|
||||
// likely to cover the entire hand.
|
||||
|
@ -308,13 +328,18 @@ class HandDetectorGraph : public core::ModelTaskGraph {
|
|||
clip_normalized_rect_vector_size[Output<std::vector<NormalizedRect>>(
|
||||
"")];
|
||||
|
||||
return HandDetectionOuts{.palm_detections = palm_detections,
|
||||
.hand_rects = clipped_hand_rects};
|
||||
return HandDetectionOuts{
|
||||
/* palm_detections= */ palm_detections,
|
||||
/* hand_rects= */ clipped_hand_rects,
|
||||
/* palm_rects= */ palm_rects,
|
||||
/* image= */ preprocessing[Output<Image>(kImageTag)]};
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::HandDetectorGraph);
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::hand_detector::HandDetectorGraph);
|
||||
|
||||
} // namespace hand_detector
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -40,13 +40,14 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace hand_detector {
|
||||
namespace {
|
||||
|
||||
using ::file::Defaults;
|
||||
|
@ -60,7 +61,8 @@ using ::mediapipe::tasks::core::ModelResources;
|
|||
using ::mediapipe::tasks::core::TaskRunner;
|
||||
using ::mediapipe::tasks::core::proto::ExternalFile;
|
||||
using ::mediapipe::tasks::vision::DecodeImageFromFile;
|
||||
using ::mediapipe::tasks::vision::hand_detector::proto::HandDetectorOptions;
|
||||
using ::mediapipe::tasks::vision::hand_detector::proto::
|
||||
HandDetectorGraphOptions;
|
||||
using ::mediapipe::tasks::vision::hand_detector::proto::HandDetectorResult;
|
||||
using ::testing::EqualsProto;
|
||||
using ::testing::TestParamInfo;
|
||||
|
@ -80,9 +82,9 @@ constexpr char kTwoHandsResultFile[] = "hand_detector_result_two_hands.pbtxt";
|
|||
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kImageName[] = "image";
|
||||
constexpr char kPalmDetectionsTag[] = "DETECTIONS";
|
||||
constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS";
|
||||
constexpr char kPalmDetectionsName[] = "palm_detections";
|
||||
constexpr char kHandNormRectsTag[] = "NORM_RECTS";
|
||||
constexpr char kHandRectsTag[] = "HAND_RECTS";
|
||||
constexpr char kHandNormRectsName[] = "hand_norm_rects";
|
||||
|
||||
constexpr float kPalmDetectionBboxMaxDiff = 0.01;
|
||||
|
@ -104,22 +106,22 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
|
|||
Graph graph;
|
||||
|
||||
auto& hand_detection =
|
||||
graph.AddNode("mediapipe.tasks.vision.HandDetectorGraph");
|
||||
graph.AddNode("mediapipe.tasks.vision.hand_detector.HandDetectorGraph");
|
||||
|
||||
auto options = std::make_unique<HandDetectorOptions>();
|
||||
auto options = std::make_unique<HandDetectorGraphOptions>();
|
||||
options->mutable_base_options()->mutable_model_asset()->set_file_name(
|
||||
JoinPath("./", kTestDataDirectory, model_name));
|
||||
options->set_min_detection_confidence(0.5);
|
||||
options->set_num_hands(num_hands);
|
||||
hand_detection.GetOptions<HandDetectorOptions>().Swap(options.get());
|
||||
hand_detection.GetOptions<HandDetectorGraphOptions>().Swap(options.get());
|
||||
|
||||
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
|
||||
hand_detection.In(kImageTag);
|
||||
|
||||
hand_detection.Out(kPalmDetectionsTag).SetName(kPalmDetectionsName) >>
|
||||
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
|
||||
hand_detection.Out(kHandNormRectsTag).SetName(kHandNormRectsName) >>
|
||||
graph[Output<std::vector<NormalizedRect>>(kHandNormRectsTag)];
|
||||
hand_detection.Out(kHandRectsTag).SetName(kHandNormRectsName) >>
|
||||
graph[Output<std::vector<NormalizedRect>>(kHandRectsTag)];
|
||||
|
||||
return TaskRunner::Create(
|
||||
graph.GetConfig(), std::make_unique<core::MediaPipeBuiltinOpResolver>());
|
||||
|
@ -200,6 +202,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
});
|
||||
|
||||
} // namespace
|
||||
} // namespace hand_detector
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -21,8 +21,8 @@ package(default_visibility = [
|
|||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "hand_detector_options_proto",
|
||||
srcs = ["hand_detector_options.proto"],
|
||||
name = "hand_detector_graph_options_proto",
|
||||
srcs = ["hand_detector_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
|
|
@ -21,24 +21,20 @@ import "mediapipe/framework/calculator.proto";
|
|||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.vision.handdetector";
|
||||
option java_outer_classname = "HandDetectorOptionsProto";
|
||||
option java_outer_classname = "HandDetectorGraphOptionsProto";
|
||||
|
||||
message HandDetectorOptions {
|
||||
message HandDetectorGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional HandDetectorOptions ext = 464864288;
|
||||
optional HandDetectorGraphOptions ext = 464864288;
|
||||
}
|
||||
// Base options for configuring Task library, such as specifying the TfLite
|
||||
// model file with metadata, accelerator options, etc.
|
||||
optional core.proto.BaseOptions base_options = 1;
|
||||
|
||||
// The locale to use for display names specified through the TFLite Model
|
||||
// Metadata, if any. Defaults to English.
|
||||
optional string display_names_locale = 2 [default = "en"];
|
||||
|
||||
// Minimum confidence value ([0.0, 1.0]) for confidence score to be considered
|
||||
// successfully detecting a hand in the image.
|
||||
optional float min_detection_confidence = 3 [default = 0.5];
|
||||
optional float min_detection_confidence = 2 [default = 0.5];
|
||||
|
||||
// The maximum number of hands output by the detector.
|
||||
optional int32 num_hands = 4;
|
||||
optional int32 num_hands = 3;
|
||||
}
|
|
@ -19,10 +19,10 @@ package(default_visibility = [
|
|||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "hand_landmarker_subgraph",
|
||||
srcs = ["hand_landmarker_subgraph.cc"],
|
||||
name = "hand_landmarks_detector_graph",
|
||||
srcs = ["hand_landmarks_detector_graph.cc"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_subgraph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"//mediapipe/calculators/core:split_vector_calculator",
|
||||
|
@ -51,6 +51,7 @@ cc_library(
|
|||
# TODO: move calculators in modules/hand_landmark/calculators to tasks dir.
|
||||
"//mediapipe/modules/hand_landmark/calculators:hand_landmarks_to_rect_calculator",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/utils:gate",
|
||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
|
@ -66,3 +67,41 @@ cc_library(
|
|||
)
|
||||
|
||||
# TODO: Enable this test
|
||||
|
||||
cc_library(
|
||||
name = "hand_landmarker_graph",
|
||||
srcs = ["hand_landmarker_graph.cc"],
|
||||
deps = [
|
||||
":hand_landmarks_detector_graph",
|
||||
"//mediapipe/calculators/core:begin_loop_calculator",
|
||||
"//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto",
|
||||
"//mediapipe/calculators/core:end_loop_calculator",
|
||||
"//mediapipe/calculators/core:gate_calculator",
|
||||
"//mediapipe/calculators/core:gate_calculator_cc_proto",
|
||||
"//mediapipe/calculators/core:pass_through_calculator",
|
||||
"//mediapipe/calculators/core:previous_loopback_calculator",
|
||||
"//mediapipe/calculators/util:collection_has_min_size_calculator",
|
||||
"//mediapipe/calculators/util:collection_has_min_size_calculator_cc_proto",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/utils:gate",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# TODO: Enable this test
|
||||
|
|
|
@ -0,0 +1,286 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h"
|
||||
#include "mediapipe/calculators/core/gate_calculator.pb.h"
|
||||
#include "mediapipe/calculators/util/collection_has_min_size_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/utils/gate.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace hand_landmarker {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::utils::DisallowIf;
|
||||
using ::mediapipe::tasks::vision::hand_detector::proto::
|
||||
HandDetectorGraphOptions;
|
||||
using ::mediapipe::tasks::vision::hand_landmarker::proto::
|
||||
HandLandmarkerGraphOptions;
|
||||
using ::mediapipe::tasks::vision::hand_landmarker::proto::
|
||||
HandLandmarksDetectorGraphOptions;
|
||||
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kLandmarksTag[] = "LANDMARKS";
|
||||
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
|
||||
constexpr char kHandRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME";
|
||||
constexpr char kHandednessTag[] = "HANDEDNESS";
|
||||
constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS";
|
||||
constexpr char kPalmRectsTag[] = "PALM_RECTS";
|
||||
constexpr char kPreviousLoopbackCalculatorName[] = "PreviousLoopbackCalculator";
|
||||
|
||||
struct HandLandmarkerOutputs {
|
||||
Source<std::vector<NormalizedLandmarkList>> landmark_lists;
|
||||
Source<std::vector<LandmarkList>> world_landmark_lists;
|
||||
Source<std::vector<NormalizedRect>> hand_rects_next_frame;
|
||||
Source<std::vector<ClassificationList>> handednesses;
|
||||
Source<std::vector<NormalizedRect>> palm_rects;
|
||||
Source<std::vector<Detection>> palm_detections;
|
||||
Source<Image> image;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// A "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph" performs hand
|
||||
// landmarks detection. The HandLandmarkerGraph consists of two subgraphs:
|
||||
// HandDetectorGraph and MultipleHandLandmarksDetectorGraph.
|
||||
// MultipleHandLandmarksDetectorGraph detects landmarks from bounding boxes
|
||||
// produced by HandDetectorGraph. HandLandmarkerGraph tracks the landmarks over
|
||||
// time, and skips the HandDetectorGraph. If the tracking is lost or the detectd
|
||||
// hands are less than configured max number hands, HandDetectorGraph would be
|
||||
// triggered to detect hands.
|
||||
//
|
||||
// Accepts CPU input images and outputs Landmarks on CPU.
|
||||
//
|
||||
// Inputs:
|
||||
// IMAGE - Image
|
||||
// Image to perform hand landmarks detection on.
|
||||
//
|
||||
// Outputs:
|
||||
// LANDMARKS: - std::vector<NormalizedLandmarkList>
|
||||
// Vector of detected hand landmarks.
|
||||
// WORLD_LANDMARKS - std::vector<LandmarkList>
|
||||
// Vector of detected hand landmarks in world coordinates.
|
||||
// HAND_RECT_NEXT_FRAME - std::vector<NormalizedRect>
|
||||
// Vector of the predicted rects enclosing the same hand RoI for landmark
|
||||
// detection on the next frame.
|
||||
// HANDEDNESS - std::vector<ClassificationList>
|
||||
// Vector of classification of handedness.
|
||||
// PALM_RECTS - std::vector<NormalizedRect>
|
||||
// Detected palm bounding boxes in normalized coordinates.
|
||||
// PALM_DETECTIONS - std::vector<Detection>
|
||||
// Detected palms with maximum `num_hands` specified in options.
|
||||
// IMAGE - Image
|
||||
// The input image that the hand landmarker runs on and has the pixel data
|
||||
// stored on the target storage (CPU vs GPU).
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph"
|
||||
// input_stream: "IMAGE:image_in"
|
||||
// output_stream: "LANDMARKS:hand_landmarks"
|
||||
// output_stream: "WORLD_LANDMARKS:world_hand_landmarks"
|
||||
// output_stream: "HAND_RECT_NEXT_FRAME:hand_rect_next_frame"
|
||||
// output_stream: "HANDEDNESS:handedness"
|
||||
// output_stream: "PALM_RECTS:palm_rects"
|
||||
// output_stream: "PALM_DETECTIONS:palm_detections"
|
||||
// output_stream: "IMAGE:image_out"
|
||||
// options {
|
||||
// [mediapipe.tasks.hand_landmarker.proto.HandLandmarkerGraphOptions.ext] {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
// file_name: "hand_landmarker.task"
|
||||
// }
|
||||
// }
|
||||
// hand_detector_graph_options {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
// file_name: "palm_detection.tflite"
|
||||
// }
|
||||
// }
|
||||
// min_detection_confidence: 0.5
|
||||
// num_hands: 2
|
||||
// }
|
||||
// hand_landmarks_detector_graph_options {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
// file_name: "hand_landmark_lite.tflite"
|
||||
// }
|
||||
// }
|
||||
// min_detection_confidence: 0.5
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class HandLandmarkerGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto hand_landmarker_outputs,
|
||||
BuildHandLandmarkerGraph(sc->Options<HandLandmarkerGraphOptions>(),
|
||||
graph[Input<Image>(kImageTag)], graph));
|
||||
hand_landmarker_outputs.landmark_lists >>
|
||||
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
|
||||
hand_landmarker_outputs.world_landmark_lists >>
|
||||
graph[Output<std::vector<LandmarkList>>(kWorldLandmarksTag)];
|
||||
hand_landmarker_outputs.hand_rects_next_frame >>
|
||||
graph[Output<std::vector<NormalizedRect>>(kHandRectNextFrameTag)];
|
||||
hand_landmarker_outputs.handednesses >>
|
||||
graph[Output<std::vector<ClassificationList>>(kHandednessTag)];
|
||||
hand_landmarker_outputs.palm_rects >>
|
||||
graph[Output<std::vector<NormalizedRect>>(kPalmRectsTag)];
|
||||
hand_landmarker_outputs.palm_detections >>
|
||||
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
|
||||
hand_landmarker_outputs.image >> graph[Output<Image>(kImageTag)];
|
||||
|
||||
// TODO remove when support is fixed.
|
||||
// As mediapipe GraphBuilder currently doesn't support configuring
|
||||
// InputStreamInfo, modifying the CalculatorGraphConfig proto directly.
|
||||
CalculatorGraphConfig config = graph.GetConfig();
|
||||
for (int i = 0; i < config.node_size(); ++i) {
|
||||
if (config.node(i).calculator() == kPreviousLoopbackCalculatorName) {
|
||||
auto* info = config.mutable_node(i)->add_input_stream_info();
|
||||
info->set_tag_index("LOOP");
|
||||
info->set_back_edge(true);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
private:
|
||||
// Adds a mediapipe hand landmark detection graph into the provided
|
||||
// builder::Graph instance.
|
||||
//
|
||||
// tasks_options: the mediapipe tasks module HandLandmarkerGraphOptions.
|
||||
// image_in: (mediapipe::Image) stream to run hand landmark detection on.
|
||||
// graph: the mediapipe graph instance to be updated.
|
||||
absl::StatusOr<HandLandmarkerOutputs> BuildHandLandmarkerGraph(
|
||||
const HandLandmarkerGraphOptions& tasks_options, Source<Image> image_in,
|
||||
Graph& graph) {
|
||||
const int max_num_hands =
|
||||
tasks_options.hand_detector_graph_options().num_hands();
|
||||
|
||||
auto& previous_loopback = graph.AddNode(kPreviousLoopbackCalculatorName);
|
||||
image_in >> previous_loopback.In("MAIN");
|
||||
auto prev_hand_rects_from_landmarks =
|
||||
previous_loopback[Output<std::vector<NormalizedRect>>("PREV_LOOP")];
|
||||
|
||||
auto& min_size_node =
|
||||
graph.AddNode("NormalizedRectVectorHasMinSizeCalculator");
|
||||
prev_hand_rects_from_landmarks >> min_size_node.In("ITERABLE");
|
||||
min_size_node.GetOptions<CollectionHasMinSizeCalculatorOptions>()
|
||||
.set_min_size(max_num_hands);
|
||||
auto has_enough_hands = min_size_node.Out("").Cast<bool>();
|
||||
|
||||
auto image_for_hand_detector =
|
||||
DisallowIf(image_in, has_enough_hands, graph);
|
||||
|
||||
auto& hand_detector =
|
||||
graph.AddNode("mediapipe.tasks.vision.hand_detector.HandDetectorGraph");
|
||||
hand_detector.GetOptions<HandDetectorGraphOptions>().CopyFrom(
|
||||
tasks_options.hand_detector_graph_options());
|
||||
image_for_hand_detector >> hand_detector.In("IMAGE");
|
||||
auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS");
|
||||
|
||||
auto& hand_association = graph.AddNode("HandAssociationCalculator");
|
||||
hand_association.GetOptions<HandAssociationCalculatorOptions>()
|
||||
.set_min_similarity_threshold(tasks_options.min_tracking_confidence());
|
||||
prev_hand_rects_from_landmarks >>
|
||||
hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][0];
|
||||
hand_rects_from_hand_detector >>
|
||||
hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][1];
|
||||
auto hand_rects = hand_association.Out("");
|
||||
|
||||
auto& clip_hand_rects =
|
||||
graph.AddNode("ClipNormalizedRectVectorSizeCalculator");
|
||||
clip_hand_rects.GetOptions<ClipVectorSizeCalculatorOptions>()
|
||||
.set_max_vec_size(max_num_hands);
|
||||
hand_rects >> clip_hand_rects.In("");
|
||||
auto clipped_hand_rects = clip_hand_rects.Out("");
|
||||
|
||||
auto& hand_landmarks_detector_graph = graph.AddNode(
|
||||
"mediapipe.tasks.vision.hand_landmarker."
|
||||
"MultipleHandLandmarksDetectorGraph");
|
||||
hand_landmarks_detector_graph
|
||||
.GetOptions<HandLandmarksDetectorGraphOptions>()
|
||||
.CopyFrom(tasks_options.hand_landmarks_detector_graph_options());
|
||||
image_in >> hand_landmarks_detector_graph.In("IMAGE");
|
||||
clipped_hand_rects >> hand_landmarks_detector_graph.In("HAND_RECT");
|
||||
|
||||
auto hand_rects_for_next_frame =
|
||||
hand_landmarks_detector_graph[Output<std::vector<NormalizedRect>>(
|
||||
kHandRectNextFrameTag)];
|
||||
// Back edge.
|
||||
hand_rects_for_next_frame >> previous_loopback.In("LOOP");
|
||||
|
||||
// TODO: Replace PassThroughCalculator with a calculator that
|
||||
// converts the pixel data to be stored on the target storage (CPU vs GPU).
|
||||
auto& pass_through = graph.AddNode("PassThroughCalculator");
|
||||
image_in >> pass_through.In("");
|
||||
|
||||
return {{
|
||||
/* landmark_lists= */ hand_landmarks_detector_graph
|
||||
[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)],
|
||||
/* world_landmark_lists= */
|
||||
hand_landmarks_detector_graph[Output<std::vector<LandmarkList>>(
|
||||
kWorldLandmarksTag)],
|
||||
/* hand_rects_next_frame= */ hand_rects_for_next_frame,
|
||||
hand_landmarks_detector_graph[Output<std::vector<ClassificationList>>(
|
||||
kHandednessTag)],
|
||||
/* palm_rects= */
|
||||
hand_detector[Output<std::vector<NormalizedRect>>(kPalmRectsTag)],
|
||||
/* palm_detections */
|
||||
hand_detector[Output<std::vector<Detection>>(kPalmDetectionsTag)],
|
||||
/* image */
|
||||
pass_through[Output<Image>("")],
|
||||
}};
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::hand_landmarker::HandLandmarkerGraph);
|
||||
|
||||
} // namespace hand_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,167 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace hand_landmarker {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::file::Defaults;
|
||||
using ::file::GetTextProto;
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::core::TaskRunner;
|
||||
using ::mediapipe::tasks::vision::hand_landmarker::proto::
|
||||
HandLandmarkerGraphOptions;
|
||||
using ::testing::EqualsProto;
|
||||
using ::testing::proto::Approximately;
|
||||
using ::testing::proto::Partially;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||
constexpr char kPalmDetectionModel[] = "palm_detection_full.tflite";
|
||||
constexpr char kHandLandmarkerFullModel[] = "hand_landmark_full.tflite";
|
||||
constexpr char kLeftHandsImage[] = "left_hands.jpg";
|
||||
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kImageName[] = "image_in";
|
||||
constexpr char kLandmarksTag[] = "LANDMARKS";
|
||||
constexpr char kLandmarksName[] = "landmarks";
|
||||
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
|
||||
constexpr char kWorldLandmarksName[] = "world_landmarks";
|
||||
constexpr char kHandRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME";
|
||||
constexpr char kHandRectNextFrameName[] = "hand_rect_next_frame";
|
||||
constexpr char kHandednessTag[] = "HANDEDNESS";
|
||||
constexpr char kHandednessName[] = "handedness";
|
||||
|
||||
// Expected hand landmarks positions, in text proto format.
|
||||
constexpr char kExpectedLeftUpHandLandmarksFilename[] =
|
||||
"expected_left_up_hand_landmarks.prototxt";
|
||||
constexpr char kExpectedLeftDownHandLandmarksFilename[] =
|
||||
"expected_left_down_hand_landmarks.prototxt";
|
||||
|
||||
constexpr float kFullModelFractionDiff = 0.03; // percentage
|
||||
constexpr float kAbsMargin = 0.03;
|
||||
constexpr int kMaxNumHands = 2;
|
||||
constexpr float kMinTrackingConfidence = 0.5;
|
||||
|
||||
NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) {
|
||||
NormalizedLandmarkList expected_landmark_list;
|
||||
MP_EXPECT_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, filename),
|
||||
&expected_landmark_list, Defaults()));
|
||||
return expected_landmark_list;
|
||||
}
|
||||
|
||||
// Helper function to create a Hand Landmarker TaskRunner.
|
||||
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner() {
|
||||
Graph graph;
|
||||
auto& hand_landmarker_graph = graph.AddNode(
|
||||
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph");
|
||||
auto& options =
|
||||
hand_landmarker_graph.GetOptions<HandLandmarkerGraphOptions>();
|
||||
options.mutable_hand_detector_graph_options()
|
||||
->mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->set_file_name(JoinPath("./", kTestDataDirectory, kPalmDetectionModel));
|
||||
options.mutable_hand_detector_graph_options()->mutable_base_options();
|
||||
options.mutable_hand_detector_graph_options()->set_num_hands(kMaxNumHands);
|
||||
options.mutable_hand_landmarks_detector_graph_options()
|
||||
->mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->set_file_name(
|
||||
JoinPath("./", kTestDataDirectory, kHandLandmarkerFullModel));
|
||||
options.set_min_tracking_confidence(kMinTrackingConfidence);
|
||||
|
||||
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
|
||||
hand_landmarker_graph.In(kImageTag);
|
||||
hand_landmarker_graph.Out(kLandmarksTag).SetName(kLandmarksName) >>
|
||||
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
|
||||
hand_landmarker_graph.Out(kWorldLandmarksTag).SetName(kWorldLandmarksName) >>
|
||||
graph[Output<std::vector<LandmarkList>>(kWorldLandmarksTag)];
|
||||
hand_landmarker_graph.Out(kHandednessTag).SetName(kHandednessName) >>
|
||||
graph[Output<std::vector<ClassificationList>>(kHandednessTag)];
|
||||
hand_landmarker_graph.Out(kHandRectNextFrameTag)
|
||||
.SetName(kHandRectNextFrameName) >>
|
||||
graph[Output<std::vector<NormalizedRect>>(kHandRectNextFrameTag)];
|
||||
return TaskRunner::Create(
|
||||
graph.GetConfig(), absl::make_unique<core::MediaPipeBuiltinOpResolver>());
|
||||
}
|
||||
|
||||
class HandLandmarkerTest : public tflite_shims::testing::Test {};
|
||||
|
||||
TEST_F(HandLandmarkerTest, Succeeds) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kLeftHandsImage)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner());
|
||||
auto output_packets =
|
||||
task_runner->Process({{kImageName, MakePacket<Image>(std::move(image))}});
|
||||
const auto& landmarks = (*output_packets)[kLandmarksName]
|
||||
.Get<std::vector<NormalizedLandmarkList>>();
|
||||
ASSERT_EQ(landmarks.size(), kMaxNumHands);
|
||||
std::vector<NormalizedLandmarkList> expected_landmarks = {
|
||||
GetExpectedLandmarkList(kExpectedLeftUpHandLandmarksFilename),
|
||||
GetExpectedLandmarkList(kExpectedLeftDownHandLandmarksFilename)};
|
||||
|
||||
EXPECT_THAT(landmarks[0],
|
||||
Approximately(Partially(EqualsProto(expected_landmarks[0])),
|
||||
/*margin=*/kAbsMargin,
|
||||
/*fraction=*/kFullModelFractionDiff));
|
||||
EXPECT_THAT(landmarks[1],
|
||||
Approximately(Partially(EqualsProto(expected_landmarks[1])),
|
||||
/*margin=*/kAbsMargin,
|
||||
/*fraction=*/kFullModelFractionDiff));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace hand_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -34,12 +34,13 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
|
||||
#include "mediapipe/tasks/cc/components/utils/gate.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||
#include "mediapipe/util/label_map.pb.h"
|
||||
|
@ -48,6 +49,7 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace hand_landmarker {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -55,9 +57,10 @@ using ::mediapipe::api2::Input;
|
|||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::utils::AllowIf;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
using ::mediapipe::tasks::vision::hand_landmarker::proto::
|
||||
HandLandmarkerSubgraphOptions;
|
||||
HandLandmarksDetectorGraphOptions;
|
||||
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
|
||||
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
|
@ -82,7 +85,6 @@ struct SingleHandLandmarkerOutputs {
|
|||
Source<bool> hand_presence;
|
||||
Source<float> hand_presence_score;
|
||||
Source<ClassificationList> handedness;
|
||||
Source<std::pair<int, int>> image_size;
|
||||
};
|
||||
|
||||
struct HandLandmarkerOutputs {
|
||||
|
@ -92,10 +94,10 @@ struct HandLandmarkerOutputs {
|
|||
Source<std::vector<bool>> presences;
|
||||
Source<std::vector<float>> presence_scores;
|
||||
Source<std::vector<ClassificationList>> handednesses;
|
||||
Source<std::pair<int, int>> image_size;
|
||||
};
|
||||
|
||||
absl::Status SanityCheckOptions(const HandLandmarkerSubgraphOptions& options) {
|
||||
absl::Status SanityCheckOptions(
|
||||
const HandLandmarksDetectorGraphOptions& options) {
|
||||
if (options.min_detection_confidence() < 0 ||
|
||||
options.min_detection_confidence() > 1) {
|
||||
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
|
||||
|
@ -182,8 +184,8 @@ void ConfigureHandRectTransformationCalculator(
|
|||
|
||||
} // namespace
|
||||
|
||||
// A "mediapipe.tasks.vision.SingleHandLandmarkerSubgraph" performs hand
|
||||
// landmark detection.
|
||||
// A "mediapipe.tasks.vision.hand_landmarker.SingleHandLandmarksDetectorGraph"
|
||||
// performs hand landmarks detection.
|
||||
// - Accepts CPU input images and outputs Landmark on CPU.
|
||||
//
|
||||
// Inputs:
|
||||
|
@ -208,12 +210,11 @@ void ConfigureHandRectTransformationCalculator(
|
|||
// Float value indicates the probability that the hand is present.
|
||||
// HANDEDNESS - ClassificationList
|
||||
// Classification of handedness.
|
||||
// IMAGE_SIZE - std::vector<int, int>
|
||||
// The size of input image.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "mediapipe.tasks.vision.SingleHandLandmarkerSubgraph"
|
||||
// calculator:
|
||||
// "mediapipe.tasks.vision.hand_landmarker.SingleHandLandmarksDetectorGraph"
|
||||
// input_stream: "IMAGE:input_image"
|
||||
// input_stream: "HAND_RECT:hand_rect"
|
||||
// output_stream: "LANDMARKS:hand_landmarks"
|
||||
|
@ -221,10 +222,8 @@ void ConfigureHandRectTransformationCalculator(
|
|||
// output_stream: "HAND_RECT_NEXT_FRAME:hand_rect_next_frame"
|
||||
// output_stream: "PRESENCE:hand_presence"
|
||||
// output_stream: "PRESENCE_SCORE:hand_presence_score"
|
||||
// output_stream: "HANDEDNESS:handedness"
|
||||
// output_stream: "IMAGE_SIZE:image_size"
|
||||
// options {
|
||||
// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarkerSubgraphOptions.ext]
|
||||
// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarksDetectorGraphOptions.ext]
|
||||
// {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
|
@ -235,16 +234,17 @@ void ConfigureHandRectTransformationCalculator(
|
|||
// }
|
||||
// }
|
||||
// }
|
||||
class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph {
|
||||
class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
ASSIGN_OR_RETURN(const auto* model_resources,
|
||||
CreateModelResources<HandLandmarkerSubgraphOptions>(sc));
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto* model_resources,
|
||||
CreateModelResources<HandLandmarksDetectorGraphOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(auto hand_landmark_detection_outs,
|
||||
BuildSingleHandLandmarkerSubgraph(
|
||||
sc->Options<HandLandmarkerSubgraphOptions>(),
|
||||
BuildSingleHandLandmarksDetectorGraph(
|
||||
sc->Options<HandLandmarksDetectorGraphOptions>(),
|
||||
*model_resources, graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>(kHandRectTag)], graph));
|
||||
hand_landmark_detection_outs.hand_landmarks >>
|
||||
|
@ -259,8 +259,6 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph {
|
|||
graph[Output<float>(kPresenceScoreTag)];
|
||||
hand_landmark_detection_outs.handedness >>
|
||||
graph[Output<ClassificationList>(kHandednessTag)];
|
||||
hand_landmark_detection_outs.image_size >>
|
||||
graph[Output<std::pair<int, int>>(kImageSizeTag)];
|
||||
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
@ -269,14 +267,16 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph {
|
|||
// Adds a mediapipe hand landmark detection graph into the provided
|
||||
// builder::Graph instance.
|
||||
//
|
||||
// subgraph_options: the mediapipe tasks module HandLandmarkerSubgraphOptions.
|
||||
// model_resources: the ModelSources object initialized from a hand landmark
|
||||
// subgraph_options: the mediapipe tasks module
|
||||
// HandLandmarksDetectorGraphOptions. model_resources: the ModelSources object
|
||||
// initialized from a hand landmark
|
||||
// detection model file with model metadata.
|
||||
// image_in: (mediapipe::Image) stream to run hand landmark detection on.
|
||||
// rect: (NormalizedRect) stream to run on the RoI of image.
|
||||
// graph: the mediapipe graph instance to be updated.
|
||||
absl::StatusOr<SingleHandLandmarkerOutputs> BuildSingleHandLandmarkerSubgraph(
|
||||
const HandLandmarkerSubgraphOptions& subgraph_options,
|
||||
absl::StatusOr<SingleHandLandmarkerOutputs>
|
||||
BuildSingleHandLandmarksDetectorGraph(
|
||||
const HandLandmarksDetectorGraphOptions& subgraph_options,
|
||||
const core::ModelResources& model_resources, Source<Image> image_in,
|
||||
Source<NormalizedRect> hand_rect, Graph& graph) {
|
||||
MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options));
|
||||
|
@ -332,18 +332,7 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph {
|
|||
// score of hand presence.
|
||||
auto& tensors_to_hand_presence = graph.AddNode("TensorsToFloatsCalculator");
|
||||
hand_flag_tensors >> tensors_to_hand_presence.In("TENSORS");
|
||||
|
||||
// Converts the handedness tensor into a float that represents the
|
||||
// classification score of handedness.
|
||||
auto& tensors_to_handedness =
|
||||
graph.AddNode("TensorsToClassificationCalculator");
|
||||
ConfigureTensorsToHandednessCalculator(
|
||||
&tensors_to_handedness.GetOptions<
|
||||
mediapipe::TensorsToClassificationCalculatorOptions>());
|
||||
handedness_tensors >> tensors_to_handedness.In("TENSORS");
|
||||
auto hand_presence_score = tensors_to_hand_presence[Output<float>("FLOAT")];
|
||||
auto handedness =
|
||||
tensors_to_handedness[Output<ClassificationList>("CLASSIFICATIONS")];
|
||||
|
||||
// Applies a threshold to the confidence score to determine whether a
|
||||
// hand is present.
|
||||
|
@ -354,6 +343,18 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph {
|
|||
hand_presence_score >> hand_presence_thresholding.In("FLOAT");
|
||||
auto hand_presence = hand_presence_thresholding[Output<bool>("FLAG")];
|
||||
|
||||
// Converts the handedness tensor into a float that represents the
|
||||
// classification score of handedness.
|
||||
auto& tensors_to_handedness =
|
||||
graph.AddNode("TensorsToClassificationCalculator");
|
||||
ConfigureTensorsToHandednessCalculator(
|
||||
&tensors_to_handedness.GetOptions<
|
||||
mediapipe::TensorsToClassificationCalculatorOptions>());
|
||||
handedness_tensors >> tensors_to_handedness.In("TENSORS");
|
||||
auto handedness = AllowIf(
|
||||
tensors_to_handedness[Output<ClassificationList>("CLASSIFICATIONS")],
|
||||
hand_presence, graph);
|
||||
|
||||
// Adjusts landmarks (already normalized to [0.f, 1.f]) on the letterboxed
|
||||
// hand image (after image transformation with the FIT scale mode) to the
|
||||
// corresponding locations on the same image with the letterbox removed
|
||||
|
@ -371,8 +372,9 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph {
|
|||
landmark_letterbox_removal.Out("LANDMARKS") >>
|
||||
landmark_projection.In("NORM_LANDMARKS");
|
||||
hand_rect >> landmark_projection.In("NORM_RECT");
|
||||
auto projected_landmarks =
|
||||
landmark_projection[Output<NormalizedLandmarkList>("NORM_LANDMARKS")];
|
||||
auto projected_landmarks = AllowIf(
|
||||
landmark_projection[Output<NormalizedLandmarkList>("NORM_LANDMARKS")],
|
||||
hand_presence, graph);
|
||||
|
||||
// Projects the world landmarks from the cropped hand image to the
|
||||
// corresponding locations on the full image before cropping (input to the
|
||||
|
@ -383,7 +385,8 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph {
|
|||
world_landmark_projection.In("LANDMARKS");
|
||||
hand_rect >> world_landmark_projection.In("NORM_RECT");
|
||||
auto projected_world_landmarks =
|
||||
world_landmark_projection[Output<LandmarkList>("LANDMARKS")];
|
||||
AllowIf(world_landmark_projection[Output<LandmarkList>("LANDMARKS")],
|
||||
hand_presence, graph);
|
||||
|
||||
// Converts the hand landmarks into a rectangle (normalized by image size)
|
||||
// that encloses the hand.
|
||||
|
@ -403,7 +406,8 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph {
|
|||
hand_landmarks_to_rect.Out("NORM_RECT") >>
|
||||
hand_rect_transformation.In("NORM_RECT");
|
||||
auto hand_rect_next_frame =
|
||||
hand_rect_transformation[Output<NormalizedRect>("")];
|
||||
AllowIf(hand_rect_transformation[Output<NormalizedRect>("")],
|
||||
hand_presence, graph);
|
||||
|
||||
return {{
|
||||
/* hand_landmarks= */ projected_landmarks,
|
||||
|
@ -412,16 +416,17 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph {
|
|||
/* hand_presence= */ hand_presence,
|
||||
/* hand_presence_score= */ hand_presence_score,
|
||||
/* handedness= */ handedness,
|
||||
/* image_size= */ image_size,
|
||||
}};
|
||||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::SingleHandLandmarkerSubgraph);
|
||||
::mediapipe::tasks::vision::hand_landmarker::SingleHandLandmarksDetectorGraph); // NOLINT
|
||||
// clang-format on
|
||||
|
||||
// A "mediapipe.tasks.vision.HandLandmarkerSubgraph" performs multi
|
||||
// hand landmark detection.
|
||||
// A "mediapipe.tasks.vision.hand_landmarker.MultipleHandLandmarksDetectorGraph"
|
||||
// performs multi hand landmark detection.
|
||||
// - Accepts CPU input image and a vector of hand rect RoIs to detect the
|
||||
// multiple hands landmarks enclosed by the RoIs. Output vectors of
|
||||
// hand landmarks related results, where each element in the vectors
|
||||
|
@ -449,12 +454,11 @@ REGISTER_MEDIAPIPE_GRAPH(
|
|||
// Vector of float value indicates the probability that the hand is present.
|
||||
// HANDEDNESS - std::vector<ClassificationList>
|
||||
// Vector of classification of handedness.
|
||||
// IMAGE_SIZE - std::vector<int, int>
|
||||
// The size of input image.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "mediapipe.tasks.vision.HandLandmarkerSubgraph"
|
||||
// calculator:
|
||||
// "mediapipe.tasks.vision.hand_landmarker.MultipleHandLandmarksDetectorGraph"
|
||||
// input_stream: "IMAGE:input_image"
|
||||
// input_stream: "HAND_RECT:hand_rect"
|
||||
// output_stream: "LANDMARKS:hand_landmarks"
|
||||
|
@ -463,9 +467,8 @@ REGISTER_MEDIAPIPE_GRAPH(
|
|||
// output_stream: "PRESENCE:hand_presence"
|
||||
// output_stream: "PRESENCE_SCORE:hand_presence_score"
|
||||
// output_stream: "HANDEDNESS:handedness"
|
||||
// output_stream: "IMAGE_SIZE:image_size"
|
||||
// options {
|
||||
// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarkerSubgraphOptions.ext]
|
||||
// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarksDetectorGraphOptions.ext]
|
||||
// {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
|
@ -476,15 +479,15 @@ REGISTER_MEDIAPIPE_GRAPH(
|
|||
// }
|
||||
// }
|
||||
// }
|
||||
class HandLandmarkerSubgraph : public core::ModelTaskGraph {
|
||||
class MultipleHandLandmarksDetectorGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto hand_landmark_detection_outputs,
|
||||
BuildHandLandmarkerSubgraph(
|
||||
sc->Options<HandLandmarkerSubgraphOptions>(),
|
||||
BuildHandLandmarksDetectorGraph(
|
||||
sc->Options<HandLandmarksDetectorGraphOptions>(),
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<std::vector<NormalizedRect>>(kHandRectTag)], graph));
|
||||
hand_landmark_detection_outputs.landmark_lists >>
|
||||
|
@ -499,21 +502,20 @@ class HandLandmarkerSubgraph : public core::ModelTaskGraph {
|
|||
graph[Output<std::vector<float>>(kPresenceScoreTag)];
|
||||
hand_landmark_detection_outputs.handednesses >>
|
||||
graph[Output<std::vector<ClassificationList>>(kHandednessTag)];
|
||||
hand_landmark_detection_outputs.image_size >>
|
||||
graph[Output<std::pair<int, int>>(kImageSizeTag)];
|
||||
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
private:
|
||||
absl::StatusOr<HandLandmarkerOutputs> BuildHandLandmarkerSubgraph(
|
||||
const HandLandmarkerSubgraphOptions& subgraph_options,
|
||||
absl::StatusOr<HandLandmarkerOutputs> BuildHandLandmarksDetectorGraph(
|
||||
const HandLandmarksDetectorGraphOptions& subgraph_options,
|
||||
Source<Image> image_in,
|
||||
Source<std::vector<NormalizedRect>> multi_hand_rects, Graph& graph) {
|
||||
auto& hand_landmark_subgraph =
|
||||
graph.AddNode("mediapipe.tasks.vision.SingleHandLandmarkerSubgraph");
|
||||
hand_landmark_subgraph.GetOptions<HandLandmarkerSubgraphOptions>().CopyFrom(
|
||||
subgraph_options);
|
||||
auto& hand_landmark_subgraph = graph.AddNode(
|
||||
"mediapipe.tasks.vision.hand_landmarker."
|
||||
"SingleHandLandmarksDetectorGraph");
|
||||
hand_landmark_subgraph.GetOptions<HandLandmarksDetectorGraphOptions>()
|
||||
.CopyFrom(subgraph_options);
|
||||
|
||||
auto& begin_loop_multi_hand_rects =
|
||||
graph.AddNode("BeginLoopNormalizedRectCalculator");
|
||||
|
@ -533,8 +535,6 @@ class HandLandmarkerSubgraph : public core::ModelTaskGraph {
|
|||
hand_landmark_subgraph.Out("HAND_RECT_NEXT_FRAME");
|
||||
auto landmarks = hand_landmark_subgraph.Out("LANDMARKS");
|
||||
auto world_landmarks = hand_landmark_subgraph.Out("WORLD_LANDMARKS");
|
||||
auto image_size =
|
||||
hand_landmark_subgraph[Output<std::pair<int, int>>("IMAGE_SIZE")];
|
||||
|
||||
auto& end_loop_handedness =
|
||||
graph.AddNode("EndLoopClassificationListCalculator");
|
||||
|
@ -585,13 +585,16 @@ class HandLandmarkerSubgraph : public core::ModelTaskGraph {
|
|||
/* presences= */ presences,
|
||||
/* presence_scores= */ presence_scores,
|
||||
/* handednesses= */ handednesses,
|
||||
/* image_size= */ image_size,
|
||||
}};
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::HandLandmarkerSubgraph);
|
||||
// clang-format off
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::hand_landmarker::MultipleHandLandmarksDetectorGraph); // NOLINT
|
||||
// clang-format on
|
||||
|
||||
} // namespace hand_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -39,12 +39,13 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace hand_landmarker {
|
||||
namespace {
|
||||
|
||||
using ::file::Defaults;
|
||||
|
@ -57,7 +58,7 @@ using ::mediapipe::file::JoinPath;
|
|||
using ::mediapipe::tasks::core::TaskRunner;
|
||||
using ::mediapipe::tasks::vision::DecodeImageFromFile;
|
||||
using ::mediapipe::tasks::vision::hand_landmarker::proto::
|
||||
HandLandmarkerSubgraphOptions;
|
||||
HandLandmarksDetectorGraphOptions;
|
||||
using ::testing::ElementsAreArray;
|
||||
using ::testing::EqualsProto;
|
||||
using ::testing::Pointwise;
|
||||
|
@ -112,13 +113,14 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateSingleHandTaskRunner(
|
|||
absl::string_view model_name) {
|
||||
Graph graph;
|
||||
|
||||
auto& hand_landmark_detection =
|
||||
graph.AddNode("mediapipe.tasks.vision.SingleHandLandmarkerSubgraph");
|
||||
auto& hand_landmark_detection = graph.AddNode(
|
||||
"mediapipe.tasks.vision.hand_landmarker."
|
||||
"SingleHandLandmarksDetectorGraph");
|
||||
|
||||
auto options = std::make_unique<HandLandmarkerSubgraphOptions>();
|
||||
auto options = std::make_unique<HandLandmarksDetectorGraphOptions>();
|
||||
options->mutable_base_options()->mutable_model_asset()->set_file_name(
|
||||
JoinPath("./", kTestDataDirectory, model_name));
|
||||
hand_landmark_detection.GetOptions<HandLandmarkerSubgraphOptions>().Swap(
|
||||
hand_landmark_detection.GetOptions<HandLandmarksDetectorGraphOptions>().Swap(
|
||||
options.get());
|
||||
|
||||
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
|
||||
|
@ -151,13 +153,14 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateMultiHandTaskRunner(
|
|||
absl::string_view model_name) {
|
||||
Graph graph;
|
||||
|
||||
auto& multi_hand_landmark_detection =
|
||||
graph.AddNode("mediapipe.tasks.vision.HandLandmarkerSubgraph");
|
||||
auto& multi_hand_landmark_detection = graph.AddNode(
|
||||
"mediapipe.tasks.vision.hand_landmarker."
|
||||
"MultipleHandLandmarksDetectorGraph");
|
||||
|
||||
auto options = std::make_unique<HandLandmarkerSubgraphOptions>();
|
||||
auto options = std::make_unique<HandLandmarksDetectorGraphOptions>();
|
||||
options->mutable_base_options()->mutable_model_asset()->set_file_name(
|
||||
JoinPath("./", kTestDataDirectory, model_name));
|
||||
multi_hand_landmark_detection.GetOptions<HandLandmarkerSubgraphOptions>()
|
||||
multi_hand_landmark_detection.GetOptions<HandLandmarksDetectorGraphOptions>()
|
||||
.Swap(options.get());
|
||||
|
||||
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
|
||||
|
@ -462,6 +465,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
});
|
||||
|
||||
} // namespace
|
||||
} // namespace hand_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -21,8 +21,8 @@ package(default_visibility = [
|
|||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "hand_landmarker_subgraph_options_proto",
|
||||
srcs = ["hand_landmarker_subgraph_options.proto"],
|
||||
name = "hand_landmarks_detector_graph_options_proto",
|
||||
srcs = ["hand_landmarks_detector_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -31,13 +31,13 @@ mediapipe_proto_library(
|
|||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "hand_landmarker_options_proto",
|
||||
srcs = ["hand_landmarker_options.proto"],
|
||||
name = "hand_landmarker_graph_options_proto",
|
||||
srcs = ["hand_landmarker_graph_options.proto"],
|
||||
deps = [
|
||||
":hand_landmarker_subgraph_options_proto",
|
||||
":hand_landmarks_detector_graph_options_proto",
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_options_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_proto",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -19,22 +19,26 @@ package mediapipe.tasks.vision.hand_landmarker.proto;
|
|||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.proto";
|
||||
import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto";
|
||||
import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto";
|
||||
import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto";
|
||||
|
||||
message HandLandmarkerOptions {
|
||||
message HandLandmarkerGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional HandLandmarkerOptions ext = 462713202;
|
||||
optional HandLandmarkerGraphOptions ext = 462713202;
|
||||
}
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||
// model file with metadata, accelerator options, etc.
|
||||
optional core.proto.BaseOptions base_options = 1;
|
||||
|
||||
// The locale to use for display names specified through the TFLite Model
|
||||
// Metadata, if any. Defaults to English.
|
||||
optional string display_names_locale = 2 [default = "en"];
|
||||
// Options for hand detector graph.
|
||||
optional hand_detector.proto.HandDetectorGraphOptions
|
||||
hand_detector_graph_options = 2;
|
||||
|
||||
optional hand_detector.proto.HandDetectorOptions hand_detector_options = 3;
|
||||
// Options for hand landmarker subgraph.
|
||||
optional HandLandmarksDetectorGraphOptions
|
||||
hand_landmarks_detector_graph_options = 3;
|
||||
|
||||
optional HandLandmarkerSubgraphOptions hand_landmarker_subgraph_options = 4;
|
||||
// Minimum confidence for hand landmarks tracking to be considered
|
||||
// successfully.
|
||||
optional float min_tracking_confidence = 4 [default = 0.5];
|
||||
}
|
|
@ -20,19 +20,15 @@ package mediapipe.tasks.vision.hand_landmarker.proto;
|
|||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
message HandLandmarkerSubgraphOptions {
|
||||
message HandLandmarksDetectorGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional HandLandmarkerSubgraphOptions ext = 474472470;
|
||||
optional HandLandmarksDetectorGraphOptions ext = 474472470;
|
||||
}
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||
// model file with metadata, accelerator options, etc.
|
||||
optional core.proto.BaseOptions base_options = 1;
|
||||
|
||||
// The locale to use for display names specified through the TFLite Model
|
||||
// Metadata, if any. Defaults to English.
|
||||
optional string display_names_locale = 2 [default = "en"];
|
||||
|
||||
// Minimum confidence value ([0.0, 1.0]) for hand presence score to be
|
||||
// considered successfully detecting a hand in the image.
|
||||
optional float min_detection_confidence = 3 [default = 0.5];
|
||||
optional float min_detection_confidence = 2 [default = 0.5];
|
||||
}
|
21
mediapipe/tasks/examples/android/BUILD
Normal file
|
@ -0,0 +1,21 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
filegroup(
|
||||
name = "resource_files",
|
||||
srcs = glob(["res/**"]),
|
||||
visibility = ["//mediapipe/tasks/examples/android:__subpackages__"],
|
||||
)
|
|
@ -0,0 +1,37 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.examples.objectdetector">
|
||||
|
||||
<uses-sdk
|
||||
android:minSdkVersion="28"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
<!-- For loading images from gallery -->
|
||||
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
|
||||
<!-- For using the camera -->
|
||||
<uses-permission android:name="android.permission.CAMERA" />
|
||||
<uses-feature android:name="android.hardware.camera" />
|
||||
<!-- For logging solution events -->
|
||||
<uses-permission android:name="android.permission.INTERNET" />
|
||||
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
|
||||
|
||||
<application
|
||||
android:allowBackup="true"
|
||||
android:icon="@mipmap/ic_launcher"
|
||||
android:label="MediaPipe Tasks Object Detector"
|
||||
android:roundIcon="@mipmap/ic_launcher_round"
|
||||
android:supportsRtl="true"
|
||||
android:theme="@style/AppTheme"
|
||||
android:exported="false">
|
||||
<activity android:name=".MainActivity"
|
||||
android:screenOrientation="portrait"
|
||||
android:exported="true">
|
||||
<intent-filter>
|
||||
<action android:name="android.intent.action.MAIN" />
|
||||
|
||||
<category android:name="android.intent.category.LAUNCHER" />
|
||||
</intent-filter>
|
||||
</activity>
|
||||
</application>
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
android_binary(
|
||||
name = "objectdetector",
|
||||
srcs = glob(["**/*.java"]),
|
||||
assets = [
|
||||
"//mediapipe/tasks/testdata/vision:test_models",
|
||||
],
|
||||
assets_dir = "",
|
||||
custom_package = "com.google.mediapipe.tasks.examples.objectdetector",
|
||||
manifest = "AndroidManifest.xml",
|
||||
manifest_values = {
|
||||
"applicationId": "com.google.mediapipe.tasks.examples.objectdetector",
|
||||
},
|
||||
multidex = "native",
|
||||
resource_files = ["//mediapipe/tasks/examples/android:resource_files"],
|
||||
deps = [
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector",
|
||||
"//third_party:androidx_appcompat",
|
||||
"//third_party:androidx_constraint_layout",
|
||||
"//third_party:opencv",
|
||||
"@maven//:androidx_activity_activity",
|
||||
"@maven//:androidx_concurrent_concurrent_futures",
|
||||
"@maven//:androidx_exifinterface_exifinterface",
|
||||
"@maven//:androidx_fragment_fragment",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,236 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.examples.objectdetector;
|
||||
|
||||
import android.content.Intent;
|
||||
import android.graphics.Bitmap;
|
||||
import android.graphics.Matrix;
|
||||
import android.media.MediaMetadataRetriever;
|
||||
import android.os.Bundle;
|
||||
import android.provider.MediaStore;
|
||||
import androidx.appcompat.app.AppCompatActivity;
|
||||
import android.util.Log;
|
||||
import android.view.View;
|
||||
import android.widget.Button;
|
||||
import android.widget.FrameLayout;
|
||||
import androidx.activity.result.ActivityResultLauncher;
|
||||
import androidx.activity.result.contract.ActivityResultContracts;
|
||||
import androidx.exifinterface.media.ExifInterface;
|
||||
// ContentResolver dependency
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.Image;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult;
|
||||
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector;
|
||||
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector.ObjectDetectorOptions;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
|
||||
/** Main activity of MediaPipe Task Object Detector reference app. */
|
||||
public class MainActivity extends AppCompatActivity {
|
||||
private static final String TAG = "MainActivity";
|
||||
private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
|
||||
|
||||
private ObjectDetector objectDetector;
|
||||
|
||||
private enum InputSource {
|
||||
UNKNOWN,
|
||||
IMAGE,
|
||||
VIDEO,
|
||||
CAMERA,
|
||||
}
|
||||
|
||||
private InputSource inputSource = InputSource.UNKNOWN;
|
||||
|
||||
// Image mode demo component.
|
||||
private ActivityResultLauncher<Intent> imageGetter;
|
||||
// Video mode demo component.
|
||||
private ActivityResultLauncher<Intent> videoGetter;
|
||||
private ObjectDetectionResultImageView imageView;
|
||||
|
||||
@Override
|
||||
protected void onCreate(Bundle savedInstanceState) {
|
||||
super.onCreate(savedInstanceState);
|
||||
setContentView(R.layout.activity_main);
|
||||
setupImageModeDemo();
|
||||
setupVideoModeDemo();
|
||||
// TODO: Adds live camera demo.
|
||||
}
|
||||
|
||||
/** Sets up the image mode demo. */
|
||||
private void setupImageModeDemo() {
|
||||
imageView = new ObjectDetectionResultImageView(this);
|
||||
// The Intent to access gallery and read images as bitmap.
|
||||
imageGetter =
|
||||
registerForActivityResult(
|
||||
new ActivityResultContracts.StartActivityForResult(),
|
||||
result -> {
|
||||
Intent resultIntent = result.getData();
|
||||
if (resultIntent != null) {
|
||||
if (result.getResultCode() == RESULT_OK) {
|
||||
Bitmap bitmap = null;
|
||||
try {
|
||||
bitmap =
|
||||
downscaleBitmap(
|
||||
MediaStore.Images.Media.getBitmap(
|
||||
this.getContentResolver(), resultIntent.getData()));
|
||||
} catch (IOException e) {
|
||||
Log.e(TAG, "Bitmap reading error:" + e);
|
||||
}
|
||||
try {
|
||||
InputStream imageData =
|
||||
this.getContentResolver().openInputStream(resultIntent.getData());
|
||||
bitmap = rotateBitmap(bitmap, imageData);
|
||||
} catch (IOException e) {
|
||||
Log.e(TAG, "Bitmap rotation error:" + e);
|
||||
}
|
||||
if (bitmap != null) {
|
||||
Image image = new BitmapImageBuilder(bitmap).build();
|
||||
ObjectDetectionResult detectionResult = objectDetector.detect(image);
|
||||
imageView.setData(image, detectionResult);
|
||||
runOnUiThread(() -> imageView.update());
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
Button loadImageButton = findViewById(R.id.button_load_picture);
|
||||
loadImageButton.setOnClickListener(
|
||||
v -> {
|
||||
if (inputSource != InputSource.IMAGE) {
|
||||
createObjectDetector(RunningMode.IMAGE);
|
||||
this.inputSource = InputSource.IMAGE;
|
||||
updateLayout();
|
||||
}
|
||||
// Reads images from gallery.
|
||||
Intent pickImageIntent = new Intent(Intent.ACTION_PICK);
|
||||
pickImageIntent.setDataAndType(MediaStore.Images.Media.INTERNAL_CONTENT_URI, "image/*");
|
||||
imageGetter.launch(pickImageIntent);
|
||||
});
|
||||
}
|
||||
|
||||
/** Sets up the video mode demo. */
|
||||
private void setupVideoModeDemo() {
|
||||
imageView = new ObjectDetectionResultImageView(this);
|
||||
// The Intent to access gallery and read a video file.
|
||||
videoGetter =
|
||||
registerForActivityResult(
|
||||
new ActivityResultContracts.StartActivityForResult(),
|
||||
result -> {
|
||||
Intent resultIntent = result.getData();
|
||||
if (resultIntent != null) {
|
||||
if (result.getResultCode() == RESULT_OK) {
|
||||
MediaMetadataRetriever metaRetriever = new MediaMetadataRetriever();
|
||||
metaRetriever.setDataSource(this, resultIntent.getData());
|
||||
long duration =
|
||||
Long.parseLong(
|
||||
metaRetriever.extractMetadata(
|
||||
MediaMetadataRetriever.METADATA_KEY_DURATION));
|
||||
int numFrames =
|
||||
Integer.parseInt(
|
||||
metaRetriever.extractMetadata(
|
||||
MediaMetadataRetriever.METADATA_KEY_VIDEO_FRAME_COUNT));
|
||||
long frameIntervalMs = duration / numFrames;
|
||||
for (int i = 0; i < numFrames; ++i) {
|
||||
Image image = new BitmapImageBuilder(metaRetriever.getFrameAtIndex(i)).build();
|
||||
ObjectDetectionResult detectionResult =
|
||||
objectDetector.detectForVideo(image, frameIntervalMs * i);
|
||||
// Currently only annotates the detection result on the first video frame and
|
||||
// display it to verify the correctness.
|
||||
// TODO: Annotates the detection result on every frame, save the
|
||||
// annotated frames as a video file, and play back the video afterwards.
|
||||
if (i == 0) {
|
||||
imageView.setData(image, detectionResult);
|
||||
runOnUiThread(() -> imageView.update());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
Button loadVideoButton = findViewById(R.id.button_load_video);
|
||||
loadVideoButton.setOnClickListener(
|
||||
v -> {
|
||||
createObjectDetector(RunningMode.VIDEO);
|
||||
updateLayout();
|
||||
this.inputSource = InputSource.VIDEO;
|
||||
|
||||
// Reads a video from gallery.
|
||||
Intent pickVideoIntent = new Intent(Intent.ACTION_PICK);
|
||||
pickVideoIntent.setDataAndType(MediaStore.Video.Media.INTERNAL_CONTENT_URI, "video/*");
|
||||
videoGetter.launch(pickVideoIntent);
|
||||
});
|
||||
}
|
||||
|
||||
private void createObjectDetector(RunningMode mode) {
|
||||
if (objectDetector != null) {
|
||||
objectDetector.close();
|
||||
}
|
||||
// Initializes a new MediaPipe ObjectDetector instance
|
||||
ObjectDetectorOptions options =
|
||||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setScoreThreshold(0.5f)
|
||||
.setMaxResults(5)
|
||||
.setRunningMode(mode)
|
||||
.build();
|
||||
objectDetector = ObjectDetector.createFromOptions(this, options);
|
||||
}
|
||||
|
||||
private void updateLayout() {
|
||||
// Updates the preview layout.
|
||||
FrameLayout frameLayout = findViewById(R.id.preview_display_layout);
|
||||
frameLayout.removeAllViewsInLayout();
|
||||
imageView.setImageDrawable(null);
|
||||
frameLayout.addView(imageView);
|
||||
imageView.setVisibility(View.VISIBLE);
|
||||
}
|
||||
|
||||
private Bitmap downscaleBitmap(Bitmap originalBitmap) {
|
||||
double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight();
|
||||
int width = imageView.getWidth();
|
||||
int height = imageView.getHeight();
|
||||
if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) {
|
||||
width = (int) (height * aspectRatio);
|
||||
} else {
|
||||
height = (int) (width / aspectRatio);
|
||||
}
|
||||
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
|
||||
}
|
||||
|
||||
private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException {
|
||||
int orientation =
|
||||
new ExifInterface(imageData)
|
||||
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
|
||||
if (orientation == ExifInterface.ORIENTATION_NORMAL) {
|
||||
return inputBitmap;
|
||||
}
|
||||
Matrix matrix = new Matrix();
|
||||
switch (orientation) {
|
||||
case ExifInterface.ORIENTATION_ROTATE_90:
|
||||
matrix.postRotate(90);
|
||||
break;
|
||||
case ExifInterface.ORIENTATION_ROTATE_180:
|
||||
matrix.postRotate(180);
|
||||
break;
|
||||
case ExifInterface.ORIENTATION_ROTATE_270:
|
||||
matrix.postRotate(270);
|
||||
break;
|
||||
default:
|
||||
matrix.postRotate(0);
|
||||
}
|
||||
return Bitmap.createBitmap(
|
||||
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.examples.objectdetector;
|
||||
|
||||
import android.content.Context;
|
||||
import android.graphics.Bitmap;
|
||||
import android.graphics.Canvas;
|
||||
import android.graphics.Color;
|
||||
import android.graphics.Matrix;
|
||||
import android.graphics.Paint;
|
||||
import androidx.appcompat.widget.AppCompatImageView;
|
||||
import com.google.mediapipe.framework.image.BitmapExtractor;
|
||||
import com.google.mediapipe.framework.image.Image;
|
||||
import com.google.mediapipe.tasks.components.containers.Detection;
|
||||
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult;
|
||||
|
||||
/** An ImageView implementation for displaying {@link ObjectDetectionResult}. */
|
||||
public class ObjectDetectionResultImageView extends AppCompatImageView {
|
||||
private static final String TAG = "ObjectDetectionResultImageView";
|
||||
|
||||
private static final int BBOX_COLOR = Color.GREEN;
|
||||
private static final int BBOX_THICKNESS = 5; // Pixels
|
||||
private Bitmap latest;
|
||||
|
||||
public ObjectDetectionResultImageView(Context context) {
|
||||
super(context);
|
||||
setScaleType(AppCompatImageView.ScaleType.FIT_CENTER);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets an {@link Image} and an {@link ObjectDetectionResult} to render.
|
||||
*
|
||||
* @param image an {@link Image} object for annotation.
|
||||
* @param result an {@link ObjectDetectionResult} object that contains the detection result.
|
||||
*/
|
||||
public void setData(Image image, ObjectDetectionResult result) {
|
||||
if (image == null || result == null) {
|
||||
return;
|
||||
}
|
||||
latest = BitmapExtractor.extract(image);
|
||||
Canvas canvas = new Canvas(latest);
|
||||
canvas.drawBitmap(latest, new Matrix(), null);
|
||||
for (int i = 0; i < result.detections().size(); ++i) {
|
||||
drawDetectionOnCanvas(result.detections().get(i), canvas);
|
||||
}
|
||||
}
|
||||
|
||||
/** Updates the image view with the latest {@link ObjectDetectionResult}. */
|
||||
public void update() {
|
||||
postInvalidate();
|
||||
if (latest != null) {
|
||||
setImageBitmap(latest);
|
||||
}
|
||||
}
|
||||
|
||||
private void drawDetectionOnCanvas(Detection detection, Canvas canvas) {
|
||||
// TODO: Draws the category and the score per bounding box.
|
||||
// Draws bounding box.
|
||||
Paint bboxPaint = new Paint();
|
||||
bboxPaint.setColor(BBOX_COLOR);
|
||||
bboxPaint.setStyle(Paint.Style.STROKE);
|
||||
bboxPaint.setStrokeWidth(BBOX_THICKNESS);
|
||||
canvas.drawRect(detection.boundingBox(), bboxPaint);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:aapt="http://schemas.android.com/aapt"
|
||||
android:width="108dp"
|
||||
android:height="108dp"
|
||||
android:viewportHeight="108"
|
||||
android:viewportWidth="108">
|
||||
<path
|
||||
android:fillType="evenOdd"
|
||||
android:pathData="M32,64C32,64 38.39,52.99 44.13,50.95C51.37,48.37 70.14,49.57 70.14,49.57L108.26,87.69L108,109.01L75.97,107.97L32,64Z"
|
||||
android:strokeColor="#00000000"
|
||||
android:strokeWidth="1">
|
||||
<aapt:attr name="android:fillColor">
|
||||
<gradient
|
||||
android:endX="78.5885"
|
||||
android:endY="90.9159"
|
||||
android:startX="48.7653"
|
||||
android:startY="61.0927"
|
||||
android:type="linear">
|
||||
<item
|
||||
android:color="#44000000"
|
||||
android:offset="0.0" />
|
||||
<item
|
||||
android:color="#00000000"
|
||||
android:offset="1.0" />
|
||||
</gradient>
|
||||
</aapt:attr>
|
||||
</path>
|
||||
<path
|
||||
android:fillColor="#FFFFFF"
|
||||
android:fillType="nonZero"
|
||||
android:pathData="M66.94,46.02L66.94,46.02C72.44,50.07 76,56.61 76,64L32,64C32,56.61 35.56,50.11 40.98,46.06L36.18,41.19C35.45,40.45 35.45,39.3 36.18,38.56C36.91,37.81 38.05,37.81 38.78,38.56L44.25,44.05C47.18,42.57 50.48,41.71 54,41.71C57.48,41.71 60.78,42.57 63.68,44.05L69.11,38.56C69.84,37.81 70.98,37.81 71.71,38.56C72.44,39.3 72.44,40.45 71.71,41.19L66.94,46.02ZM62.94,56.92C64.08,56.92 65,56.01 65,54.88C65,53.76 64.08,52.85 62.94,52.85C61.8,52.85 60.88,53.76 60.88,54.88C60.88,56.01 61.8,56.92 62.94,56.92ZM45.06,56.92C46.2,56.92 47.13,56.01 47.13,54.88C47.13,53.76 46.2,52.85 45.06,52.85C43.92,52.85 43,53.76 43,54.88C43,56.01 43.92,56.92 45.06,56.92Z"
|
||||
android:strokeColor="#00000000"
|
||||
android:strokeWidth="1" />
|
||||
</vector>
|
|
@ -0,0 +1,74 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<vector
|
||||
android:height="108dp"
|
||||
android:width="108dp"
|
||||
android:viewportHeight="108"
|
||||
android:viewportWidth="108"
|
||||
xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
<path android:fillColor="#26A69A"
|
||||
android:pathData="M0,0h108v108h-108z"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M9,0L9,108"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M19,0L19,108"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M29,0L29,108"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M39,0L39,108"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M49,0L49,108"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M59,0L59,108"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M69,0L69,108"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M79,0L79,108"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M89,0L89,108"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M99,0L99,108"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M0,9L108,9"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M0,19L108,19"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M0,29L108,29"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M0,39L108,39"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M0,49L108,49"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M0,59L108,59"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M0,69L108,69"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M0,79L108,79"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M0,89L108,89"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M0,99L108,99"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M19,29L89,29"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M19,39L89,39"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M19,49L89,49"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M19,59L89,59"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M19,69L89,69"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M19,79L89,79"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M29,19L29,89"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M39,19L39,89"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M49,19L49,89"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M59,19L59,89"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M69,19L69,89"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
<path android:fillColor="#00000000" android:pathData="M79,19L79,89"
|
||||
android:strokeColor="#33FFFFFF" android:strokeWidth="0.8"/>
|
||||
</vector>
|
|
@ -0,0 +1,40 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<LinearLayout
|
||||
xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent"
|
||||
android:orientation="vertical">
|
||||
<LinearLayout
|
||||
android:id="@+id/buttons"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content"
|
||||
style="?android:attr/buttonBarStyle" android:gravity="center"
|
||||
android:orientation="horizontal">
|
||||
<Button
|
||||
android:id="@+id/button_load_picture"
|
||||
android:layout_width="wrap_content"
|
||||
style="?android:attr/buttonBarButtonStyle" android:layout_height="wrap_content"
|
||||
android:text="@string/load_picture" />
|
||||
<Button
|
||||
android:id="@+id/button_load_video"
|
||||
android:layout_width="wrap_content"
|
||||
style="?android:attr/buttonBarButtonStyle" android:layout_height="wrap_content"
|
||||
android:text="@string/load_video" />
|
||||
<Button
|
||||
android:id="@+id/button_start_camera"
|
||||
android:layout_width="wrap_content"
|
||||
style="?android:attr/buttonBarButtonStyle" android:layout_height="wrap_content"
|
||||
android:text="@string/start_camera" />
|
||||
</LinearLayout>
|
||||
<FrameLayout
|
||||
android:id="@+id/preview_display_layout"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent">
|
||||
<TextView
|
||||
android:id="@+id/no_view"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content"
|
||||
android:gravity="center"
|
||||
android:text="@string/instruction" />
|
||||
</FrameLayout>
|
||||
</LinearLayout>
|
|
@ -0,0 +1,5 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
<background android:drawable="@drawable/ic_launcher_background"/>
|
||||
<foreground android:drawable="@mipmap/ic_launcher_foreground"/>
|
||||
</adaptive-icon>
|
|
@ -0,0 +1,5 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
<background android:drawable="@drawable/ic_launcher_background"/>
|
||||
<foreground android:drawable="@mipmap/ic_launcher_foreground"/>
|
||||
</adaptive-icon>
|
BIN
mediapipe/tasks/examples/android/res/mipmap-hdpi/ic_launcher.png
Normal file
After Width: | Height: | Size: 1.3 KiB |
After Width: | Height: | Size: 2.2 KiB |
After Width: | Height: | Size: 3.2 KiB |
BIN
mediapipe/tasks/examples/android/res/mipmap-mdpi/ic_launcher.png
Normal file
After Width: | Height: | Size: 959 B |
After Width: | Height: | Size: 900 B |
After Width: | Height: | Size: 1.9 KiB |
After Width: | Height: | Size: 1.9 KiB |
After Width: | Height: | Size: 1.8 KiB |
After Width: | Height: | Size: 4.5 KiB |
After Width: | Height: | Size: 3.5 KiB |
After Width: | Height: | Size: 5.5 KiB |
After Width: | Height: | Size: 7.6 KiB |
After Width: | Height: | Size: 4.9 KiB |
After Width: | Height: | Size: 8.1 KiB |
After Width: | Height: | Size: 11 KiB |
6
mediapipe/tasks/examples/android/res/values/colors.xml
Normal file
|
@ -0,0 +1,6 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<resources>
|
||||
<color name="colorPrimary">#008577</color>
|
||||
<color name="colorPrimaryDark">#00574B</color>
|
||||
<color name="colorAccent">#D81B60</color>
|
||||
</resources>
|
6
mediapipe/tasks/examples/android/res/values/strings.xml
Normal file
|
@ -0,0 +1,6 @@
|
|||
<resources>
|
||||
<string name="load_picture" translatable="false">Load Picture</string>
|
||||
<string name="load_video" translatable="false">Load Video</string>
|
||||
<string name="start_camera" translatable="false">Start Camera</string>
|
||||
<string name="instruction" translatable="false">Please press any button above to start</string>
|
||||
</resources>
|
11
mediapipe/tasks/examples/android/res/values/styles.xml
Normal file
|
@ -0,0 +1,11 @@
|
|||
<resources>
|
||||
|
||||
<!-- Base application theme. -->
|
||||
<style name="AppTheme" parent="Theme.AppCompat.Light.DarkActionBar">
|
||||
<!-- Customize your theme here. -->
|
||||
<item name="colorPrimary">@color/colorPrimary</item>
|
||||
<item name="colorPrimaryDark">@color/colorPrimaryDark</item>
|
||||
<item name="colorAccent">@color/colorAccent</item>
|
||||
</style>
|
||||
|
||||
</resources>
|
6
third_party/external_files.bzl
vendored
|
@ -550,6 +550,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/ssd_mobilenet_v1.tflite?generation=1661875947436302"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_test_jpg",
|
||||
sha256 = "798a12a466933842528d8438f553320eebe5137f02650f12dd68706a2f94fb4f",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/test.jpg?generation=1664672140191116"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_test_model_add_op_tflite",
|
||||
sha256 = "298300ca8a9193b80ada1dca39d36f20bffeebde09e85385049b3bfe7be2272f",
|
||||
|
|