Merge branch 'google:master' into text-embedder-python
|
@ -155,6 +155,7 @@ http_archive(
|
||||||
name = "com_google_audio_tools",
|
name = "com_google_audio_tools",
|
||||||
strip_prefix = "multichannel-audio-tools-master",
|
strip_prefix = "multichannel-audio-tools-master",
|
||||||
urls = ["https://github.com/google/multichannel-audio-tools/archive/master.zip"],
|
urls = ["https://github.com/google/multichannel-audio-tools/archive/master.zip"],
|
||||||
|
repo_mapping = {"@com_github_glog_glog" : "@com_github_glog_glog_no_gflags"},
|
||||||
)
|
)
|
||||||
|
|
||||||
http_archive(
|
http_archive(
|
||||||
|
|
18
docs/BUILD
|
@ -12,3 +12,21 @@ py_binary(
|
||||||
"//third_party/py/tensorflow_docs/api_generator:public_api",
|
"//third_party/py/tensorflow_docs/api_generator:public_api",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_binary(
|
||||||
|
name = "build_java_api_docs",
|
||||||
|
srcs = ["build_java_api_docs.py"],
|
||||||
|
data = [
|
||||||
|
"//third_party/java/doclava/current:doclava.jar",
|
||||||
|
"//third_party/java/jsilver:jsilver_jar",
|
||||||
|
],
|
||||||
|
env = {
|
||||||
|
"DOCLAVA_JAR": "$(location //third_party/java/doclava/current:doclava.jar)",
|
||||||
|
"JSILVER_JAR": "$(location //third_party/java/jsilver:jsilver_jar)",
|
||||||
|
},
|
||||||
|
deps = [
|
||||||
|
"//third_party/py/absl:app",
|
||||||
|
"//third_party/py/absl/flags",
|
||||||
|
"//third_party/py/tensorflow_docs/api_generator/gen_java",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
58
docs/build_java_api_docs.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Generate Java reference docs for MediaPipe."""
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
|
||||||
|
from tensorflow_docs.api_generator import gen_java
|
||||||
|
|
||||||
|
_JAVA_ROOT = flags.DEFINE_string('java_src', None,
|
||||||
|
'Override the Java source path.',
|
||||||
|
required=False)
|
||||||
|
|
||||||
|
_OUT_DIR = flags.DEFINE_string('output_dir', '/tmp/mp_java/',
|
||||||
|
'Write docs here.')
|
||||||
|
|
||||||
|
_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/java',
|
||||||
|
'Path prefix in the _toc.yaml')
|
||||||
|
|
||||||
|
_ = flags.DEFINE_string('code_url_prefix', None,
|
||||||
|
'[UNUSED] The url prefix for links to code.')
|
||||||
|
|
||||||
|
_ = flags.DEFINE_bool(
|
||||||
|
'search_hints', True,
|
||||||
|
'[UNUSED] Include metadata search hints in the generated files')
|
||||||
|
|
||||||
|
|
||||||
|
def main(_) -> None:
|
||||||
|
if not (java_root := _JAVA_ROOT.value):
|
||||||
|
# Default to using a relative path to find the Java source.
|
||||||
|
mp_root = pathlib.Path(__file__)
|
||||||
|
while (mp_root := mp_root.parent).name != 'mediapipe':
|
||||||
|
# Find the nearest `mediapipe` dir.
|
||||||
|
pass
|
||||||
|
java_root = mp_root / 'tasks/java'
|
||||||
|
|
||||||
|
gen_java.gen_java_docs(
|
||||||
|
package='com.google.mediapipe',
|
||||||
|
source_path=pathlib.Path(java_root),
|
||||||
|
output_dir=pathlib.Path(_OUT_DIR.value),
|
||||||
|
site_path=pathlib.Path(_SITE_PATH.value))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(main)
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,14 +1,13 @@
|
||||||
"""Copyright 2019 - 2020 The MediaPipe Authors.
|
# Copyright 2019 - 2022 The MediaPipe Authors.
|
||||||
|
#
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
Unless required by applicable law or agreed to in writing, software
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
|
||||||
|
|
|
@ -277,6 +277,7 @@ cc_test(
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/util:time_series_test_util",
|
"//mediapipe/util:time_series_test_util",
|
||||||
|
"@com_google_audio_tools//audio/dsp/mfcc",
|
||||||
"@eigen_archive//:eigen3",
|
"@eigen_archive//:eigen3",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -352,6 +353,8 @@ cc_test(
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/framework/tool:validate_type",
|
"//mediapipe/framework/tool:validate_type",
|
||||||
"//mediapipe/util:time_series_test_util",
|
"//mediapipe/util:time_series_test_util",
|
||||||
|
"@com_google_audio_tools//audio/dsp:resampler",
|
||||||
|
"@com_google_audio_tools//audio/dsp:resampler_q",
|
||||||
"@com_google_audio_tools//audio/dsp:signal_vector_util",
|
"@com_google_audio_tools//audio/dsp:signal_vector_util",
|
||||||
"@eigen_archive//:eigen3",
|
"@eigen_archive//:eigen3",
|
||||||
],
|
],
|
||||||
|
|
|
@ -130,7 +130,7 @@ class BypassCalculator : public Node {
|
||||||
pass_out.insert(entry.second);
|
pass_out.insert(entry.second);
|
||||||
auto& packet = cc->Inputs().Get(entry.first).Value();
|
auto& packet = cc->Inputs().Get(entry.first).Value();
|
||||||
if (packet.Timestamp() == cc->InputTimestamp()) {
|
if (packet.Timestamp() == cc->InputTimestamp()) {
|
||||||
cc->Outputs().Get(entry.first).AddPacket(packet);
|
cc->Outputs().Get(entry.second).AddPacket(packet);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Timestamp bound = cc->InputTimestamp().NextAllowedInStream();
|
Timestamp bound = cc->InputTimestamp().NextAllowedInStream();
|
||||||
|
|
|
@ -42,10 +42,10 @@ constexpr char kTestGraphConfig1[] = R"pb(
|
||||||
node {
|
node {
|
||||||
calculator: "BypassCalculator"
|
calculator: "BypassCalculator"
|
||||||
input_stream: "PASS:appearances"
|
input_stream: "PASS:appearances"
|
||||||
input_stream: "TRUNCATE:0:video_frame"
|
input_stream: "IGNORE:0:video_frame"
|
||||||
input_stream: "TRUNCATE:1:feature_config"
|
input_stream: "IGNORE:1:feature_config"
|
||||||
output_stream: "PASS:passthrough_appearances"
|
output_stream: "PASS:passthrough_appearances"
|
||||||
output_stream: "TRUNCATE:passthrough_federated_gaze_output"
|
output_stream: "IGNORE:passthrough_federated_gaze_output"
|
||||||
node_options: {
|
node_options: {
|
||||||
[type.googleapis.com/mediapipe.BypassCalculatorOptions] {
|
[type.googleapis.com/mediapipe.BypassCalculatorOptions] {
|
||||||
pass_input_stream: "PASS"
|
pass_input_stream: "PASS"
|
||||||
|
|
|
@ -156,6 +156,7 @@ cc_test(
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_audio_tools//audio/dsp:window_functions",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -750,19 +750,12 @@ objc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
proto_library(
|
mediapipe_proto_library(
|
||||||
name = "scale_mode_proto",
|
name = "scale_mode_proto",
|
||||||
srcs = ["scale_mode.proto"],
|
srcs = ["scale_mode.proto"],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_cc_proto_library(
|
|
||||||
name = "scale_mode_cc_proto",
|
|
||||||
srcs = ["scale_mode.proto"],
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
deps = [":scale_mode_proto"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gl_quad_renderer",
|
name = "gl_quad_renderer",
|
||||||
srcs = ["gl_quad_renderer.cc"],
|
srcs = ["gl_quad_renderer.cc"],
|
||||||
|
|
51
mediapipe/model_maker/models/gesture_recognizer/BUILD
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
load(
|
||||||
|
"//mediapipe/framework/tool:mediapipe_files.bzl",
|
||||||
|
"mediapipe_files",
|
||||||
|
)
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__"],
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_files(
|
||||||
|
srcs = [
|
||||||
|
"canned_gesture_classifier.tflite",
|
||||||
|
"gesture_embedder.tflite",
|
||||||
|
"gesture_embedder/keras_metadata.pb",
|
||||||
|
"gesture_embedder/saved_model.pb",
|
||||||
|
"gesture_embedder/variables/variables.data-00000-of-00001",
|
||||||
|
"gesture_embedder/variables/variables.index",
|
||||||
|
"hand_landmark_full.tflite",
|
||||||
|
"palm_detection_full.tflite",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "models",
|
||||||
|
srcs = [
|
||||||
|
"canned_gesture_classifier.tflite",
|
||||||
|
"gesture_embedder.tflite",
|
||||||
|
"gesture_embedder/keras_metadata.pb",
|
||||||
|
"gesture_embedder/saved_model.pb",
|
||||||
|
"gesture_embedder/variables/variables.data-00000-of-00001",
|
||||||
|
"gesture_embedder/variables/variables.index",
|
||||||
|
"hand_landmark_full.tflite",
|
||||||
|
"palm_detection_full.tflite",
|
||||||
|
],
|
||||||
|
)
|
35
mediapipe/model_maker/python/text/core/BUILD
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict library and test compatibility macro.
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//mediapipe:__subpackages__"],
|
||||||
|
)
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "bert_model_options",
|
||||||
|
srcs = ["bert_model_options.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "bert_model_spec",
|
||||||
|
srcs = ["bert_model_spec.py"],
|
||||||
|
deps = [
|
||||||
|
":bert_model_options",
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
],
|
||||||
|
)
|
13
mediapipe/model_maker/python/text/core/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
33
mediapipe/model_maker/python/text/core/bert_model_options.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Configurable model options for a BERT model."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class BertModelOptions:
|
||||||
|
"""Configurable model options for a BERT model.
|
||||||
|
|
||||||
|
See https://arxiv.org/abs/1810.04805 (BERT: Pre-training of Deep Bidirectional
|
||||||
|
Transformers for Language Understanding) for more details.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
seq_len: Length of the sequence to feed into the model.
|
||||||
|
do_fine_tuning: If true, then the BERT model is not frozen for training.
|
||||||
|
dropout_rate: The rate for dropout.
|
||||||
|
"""
|
||||||
|
seq_len: int = 128
|
||||||
|
do_fine_tuning: bool = True
|
||||||
|
dropout_rate: float = 0.1
|
58
mediapipe/model_maker/python/text/core/bert_model_spec.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Specification for a BERT model."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.text.core import bert_model_options
|
||||||
|
|
||||||
|
_DEFAULT_TFLITE_INPUT_NAME = {
|
||||||
|
'ids': 'serving_default_input_word_ids:0',
|
||||||
|
'mask': 'serving_default_input_mask:0',
|
||||||
|
'segment_ids': 'serving_default_input_type_ids:0'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class BertModelSpec:
|
||||||
|
"""Specification for a BERT model.
|
||||||
|
|
||||||
|
See https://arxiv.org/abs/1810.04805 (BERT: Pre-training of Deep Bidirectional
|
||||||
|
Transformers for Language Understanding) for more details.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
hparams: Hyperparameters used for training.
|
||||||
|
model_options: Configurable options for a BERT model.
|
||||||
|
do_lower_case: boolean, whether to lower case the input text. Should be
|
||||||
|
True / False for uncased / cased models respectively, where the models
|
||||||
|
are specified by the `uri`.
|
||||||
|
tflite_input_name: Dict, input names for the TFLite model.
|
||||||
|
uri: URI for the BERT module.
|
||||||
|
name: The name of the object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
hparams: hp.BaseHParams = hp.BaseHParams(
|
||||||
|
epochs=3,
|
||||||
|
batch_size=32,
|
||||||
|
learning_rate=3e-5,
|
||||||
|
distribution_strategy='mirrored')
|
||||||
|
model_options: bert_model_options.BertModelOptions = (
|
||||||
|
bert_model_options.BertModelOptions())
|
||||||
|
do_lower_case: bool = True
|
||||||
|
tflite_input_name: Dict[str, str] = dataclasses.field(
|
||||||
|
default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME)
|
||||||
|
uri: str = 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1'
|
||||||
|
name: str = 'Bert'
|
146
mediapipe/model_maker/python/text/text_classifier/BUILD
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict library and test compatibility macro.
|
||||||
|
# Placeholder for internal Python strict test compatibility macro.
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//mediapipe:__subpackages__"],
|
||||||
|
)
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "model_options",
|
||||||
|
srcs = ["model_options.py"],
|
||||||
|
deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "model_spec",
|
||||||
|
srcs = ["model_spec.py"],
|
||||||
|
deps = [
|
||||||
|
":model_options",
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
"//mediapipe/model_maker/python/text/core:bert_model_spec",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "model_spec_test",
|
||||||
|
srcs = ["model_spec_test.py"],
|
||||||
|
deps = [
|
||||||
|
":model_options",
|
||||||
|
":model_spec",
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "dataset",
|
||||||
|
srcs = ["dataset.py"],
|
||||||
|
deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "dataset_test",
|
||||||
|
srcs = ["dataset_test.py"],
|
||||||
|
deps = [":dataset"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "preprocessor",
|
||||||
|
srcs = ["preprocessor.py"],
|
||||||
|
deps = [":dataset"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "preprocessor_test",
|
||||||
|
srcs = ["preprocessor_test.py"],
|
||||||
|
tags = ["requires-net:external"],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
":model_spec",
|
||||||
|
":preprocessor",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "text_classifier_options",
|
||||||
|
srcs = ["text_classifier_options.py"],
|
||||||
|
deps = [
|
||||||
|
":model_options",
|
||||||
|
":model_spec",
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "text_classifier",
|
||||||
|
srcs = ["text_classifier.py"],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
":model_options",
|
||||||
|
":model_spec",
|
||||||
|
":preprocessor",
|
||||||
|
":text_classifier_options",
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
"//mediapipe/model_maker/python/core/data:dataset",
|
||||||
|
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:text_classifier",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "text_classifier_test",
|
||||||
|
size = "large",
|
||||||
|
srcs = ["text_classifier_test.py"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/model_maker/python/text/text_classifier/testdata",
|
||||||
|
],
|
||||||
|
tags = ["requires-net:external"],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
":model_options",
|
||||||
|
":model_spec",
|
||||||
|
":text_classifier",
|
||||||
|
":text_classifier_options",
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "text_classifier_demo_lib",
|
||||||
|
srcs = ["text_classifier_demo.py"],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
":model_spec",
|
||||||
|
":text_classifier",
|
||||||
|
":text_classifier_options",
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_binary(
|
||||||
|
name = "text_classifier_demo",
|
||||||
|
srcs = ["text_classifier_demo.py"],
|
||||||
|
deps = [
|
||||||
|
":text_classifier_demo_lib",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
88
mediapipe/model_maker/python/text/text_classifier/dataset.py
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Text classifier dataset library."""
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import dataclasses
|
||||||
|
import random
|
||||||
|
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class CSVParameters:
|
||||||
|
"""Parameters used when reading a CSV file.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
text_column: Column name for the input text.
|
||||||
|
label_column: Column name for the labels.
|
||||||
|
fieldnames: Sequence of keys for the CSV columns. If None, the first row of
|
||||||
|
the CSV file is used as the keys.
|
||||||
|
delimiter: Character that separates fields.
|
||||||
|
quotechar: Character used to quote fields that contain special characters
|
||||||
|
like the `delimiter`.
|
||||||
|
"""
|
||||||
|
text_column: str
|
||||||
|
label_column: str
|
||||||
|
fieldnames: Optional[Sequence[str]] = None
|
||||||
|
delimiter: str = ","
|
||||||
|
quotechar: str = '"'
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(classification_dataset.ClassificationDataset):
|
||||||
|
"""Dataset library for text classifier."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_csv(cls,
|
||||||
|
filename: str,
|
||||||
|
csv_params: CSVParameters,
|
||||||
|
shuffle: bool = True) -> "Dataset":
|
||||||
|
"""Loads text with labels from a CSV file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Name of the CSV file.
|
||||||
|
csv_params: Parameters used for reading the CSV file.
|
||||||
|
shuffle: If True, randomly shuffle the data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset containing (text, label) pairs and other related info.
|
||||||
|
"""
|
||||||
|
with tf.io.gfile.GFile(filename, "r") as f:
|
||||||
|
reader = csv.DictReader(
|
||||||
|
f,
|
||||||
|
fieldnames=csv_params.fieldnames,
|
||||||
|
delimiter=csv_params.delimiter,
|
||||||
|
quotechar=csv_params.quotechar)
|
||||||
|
|
||||||
|
lines = list(reader)
|
||||||
|
if shuffle:
|
||||||
|
random.shuffle(lines)
|
||||||
|
|
||||||
|
label_names = sorted(set([line[csv_params.label_column] for line in lines]))
|
||||||
|
index_by_label = {label: index for index, label in enumerate(label_names)}
|
||||||
|
|
||||||
|
texts = [line[csv_params.text_column] for line in lines]
|
||||||
|
text_ds = tf.data.Dataset.from_tensor_slices(tf.cast(texts, tf.string))
|
||||||
|
label_indices = [
|
||||||
|
index_by_label[line[csv_params.label_column]] for line in lines
|
||||||
|
]
|
||||||
|
label_index_ds = tf.data.Dataset.from_tensor_slices(
|
||||||
|
tf.cast(label_indices, tf.int64))
|
||||||
|
text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds))
|
||||||
|
|
||||||
|
return Dataset(
|
||||||
|
dataset=text_label_ds, size=len(texts), label_names=label_names)
|
|
@ -0,0 +1,75 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import dataset
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def _get_csv_file(self):
|
||||||
|
labels_and_text = (('neutral', 'indifferent'), ('pos', 'extremely great'),
|
||||||
|
('neg', 'totally awful'), ('pos', 'super good'),
|
||||||
|
('neg', 'really bad'))
|
||||||
|
csv_file = os.path.join(self.get_temp_dir(), 'data.csv')
|
||||||
|
if os.path.exists(csv_file):
|
||||||
|
return csv_file
|
||||||
|
fieldnames = ['text', 'label']
|
||||||
|
with open(csv_file, 'w') as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
for label, text in labels_and_text:
|
||||||
|
writer.writerow({'text': text, 'label': label})
|
||||||
|
return csv_file
|
||||||
|
|
||||||
|
def test_from_csv(self):
|
||||||
|
csv_file = self._get_csv_file()
|
||||||
|
csv_params = dataset.CSVParameters(text_column='text', label_column='label')
|
||||||
|
data = dataset.Dataset.from_csv(filename=csv_file, csv_params=csv_params)
|
||||||
|
self.assertLen(data, 5)
|
||||||
|
self.assertEqual(data.num_classes, 3)
|
||||||
|
self.assertEqual(data.label_names, ['neg', 'neutral', 'pos'])
|
||||||
|
data_values = set([(text.numpy()[0], label.numpy()[0])
|
||||||
|
for text, label in data.gen_tf_dataset()])
|
||||||
|
expected_data_values = set([(b'indifferent', 1), (b'extremely great', 2),
|
||||||
|
(b'totally awful', 0), (b'super good', 2),
|
||||||
|
(b'really bad', 0)])
|
||||||
|
self.assertEqual(data_values, expected_data_values)
|
||||||
|
|
||||||
|
def test_split(self):
|
||||||
|
ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd'])
|
||||||
|
data = dataset.Dataset(ds, 4, ['pos', 'neg'])
|
||||||
|
train_data, test_data = data.split(0.5)
|
||||||
|
expected_train_data = [b'good', b'bad']
|
||||||
|
expected_test_data = [b'neutral', b'odd']
|
||||||
|
|
||||||
|
self.assertLen(train_data, 2)
|
||||||
|
train_data_values = [elem.numpy() for elem in train_data._dataset]
|
||||||
|
self.assertEqual(train_data_values, expected_train_data)
|
||||||
|
self.assertEqual(train_data.num_classes, 2)
|
||||||
|
self.assertEqual(train_data.label_names, ['pos', 'neg'])
|
||||||
|
|
||||||
|
self.assertLen(test_data, 2)
|
||||||
|
test_data_values = [elem.numpy() for elem in test_data._dataset]
|
||||||
|
self.assertEqual(test_data_values, expected_test_data)
|
||||||
|
self.assertEqual(test_data.num_classes, 2)
|
||||||
|
self.assertEqual(test_data.label_names, ['pos', 'neg'])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -0,0 +1,45 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Configurable model options for text classifier models."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.text.core import bert_model_options
|
||||||
|
|
||||||
|
# BERT text classifier options inherited from BertModelOptions.
|
||||||
|
BertClassifierOptions = bert_model_options.BertModelOptions
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class AverageWordEmbeddingClassifierOptions:
|
||||||
|
"""Configurable model options for an Average Word Embedding classifier.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
seq_len: Length of the sequence to feed into the model.
|
||||||
|
wordvec_dim: Dimension of the word embedding.
|
||||||
|
do_lower_case: Whether to convert all uppercase characters to lowercase
|
||||||
|
during preprocessing.
|
||||||
|
vocab_size: Number of words to generate the vocabulary from data.
|
||||||
|
dropout_rate: The rate for dropout.
|
||||||
|
"""
|
||||||
|
seq_len: int = 256
|
||||||
|
wordvec_dim: int = 16
|
||||||
|
do_lower_case: bool = True
|
||||||
|
vocab_size: int = 10000
|
||||||
|
dropout_rate: float = 0.2
|
||||||
|
|
||||||
|
|
||||||
|
TextClassifierModelOptions = Union[AverageWordEmbeddingClassifierOptions,
|
||||||
|
BertClassifierOptions]
|
|
@ -0,0 +1,70 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Specifications for text classifier models."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import enum
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.text.core import bert_model_spec
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||||
|
|
||||||
|
# BERT-based text classifier spec inherited from BertModelSpec
|
||||||
|
BertClassifierSpec = bert_model_spec.BertModelSpec
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class AverageWordEmbeddingClassifierSpec:
|
||||||
|
"""Specification for an average word embedding classifier model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
hparams: Configurable hyperparameters for training.
|
||||||
|
model_options: Configurable options for the average word embedding model.
|
||||||
|
name: The name of the object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# `learning_rate` is unused for the average word embedding model
|
||||||
|
hparams: hp.BaseHParams = hp.BaseHParams(
|
||||||
|
epochs=10, batch_size=32, learning_rate=0)
|
||||||
|
model_options: mo.AverageWordEmbeddingClassifierOptions = (
|
||||||
|
mo.AverageWordEmbeddingClassifierOptions())
|
||||||
|
name: str = 'AverageWordEmbedding'
|
||||||
|
|
||||||
|
|
||||||
|
average_word_embedding_classifier_spec = functools.partial(
|
||||||
|
AverageWordEmbeddingClassifierSpec)
|
||||||
|
|
||||||
|
mobilebert_classifier_spec = functools.partial(
|
||||||
|
BertClassifierSpec,
|
||||||
|
hparams=hp.BaseHParams(
|
||||||
|
epochs=3,
|
||||||
|
batch_size=48,
|
||||||
|
learning_rate=3e-5,
|
||||||
|
distribution_strategy='off'),
|
||||||
|
name='MobileBert',
|
||||||
|
uri='https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1',
|
||||||
|
tflite_input_name={
|
||||||
|
'ids': 'serving_default_input_1:0',
|
||||||
|
'mask': 'serving_default_input_3:0',
|
||||||
|
'segment_ids': 'serving_default_input_2:0'
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@enum.unique
|
||||||
|
class SupportedModels(enum.Enum):
|
||||||
|
"""Predefined text classifier model specs supported by Model Maker."""
|
||||||
|
AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
|
||||||
|
MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec
|
|
@ -0,0 +1,118 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for model_spec."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSpecTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_predefined_bert_spec(self):
|
||||||
|
model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||||
|
self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec)
|
||||||
|
self.assertEqual(model_spec_obj.name, 'MobileBert')
|
||||||
|
self.assertEqual(
|
||||||
|
model_spec_obj.uri, 'https://tfhub.dev/tensorflow/'
|
||||||
|
'mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1')
|
||||||
|
self.assertTrue(model_spec_obj.do_lower_case)
|
||||||
|
self.assertEqual(
|
||||||
|
model_spec_obj.tflite_input_name, {
|
||||||
|
'ids': 'serving_default_input_1:0',
|
||||||
|
'mask': 'serving_default_input_3:0',
|
||||||
|
'segment_ids': 'serving_default_input_2:0'
|
||||||
|
})
|
||||||
|
self.assertEqual(
|
||||||
|
model_spec_obj.model_options,
|
||||||
|
classifier_model_options.BertClassifierOptions(
|
||||||
|
seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
|
||||||
|
self.assertEqual(
|
||||||
|
model_spec_obj.hparams,
|
||||||
|
hp.BaseHParams(
|
||||||
|
epochs=3,
|
||||||
|
batch_size=48,
|
||||||
|
learning_rate=3e-5,
|
||||||
|
distribution_strategy='off'))
|
||||||
|
|
||||||
|
def test_predefined_average_word_embedding_spec(self):
|
||||||
|
model_spec_obj = (
|
||||||
|
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER.value())
|
||||||
|
self.assertIsInstance(model_spec_obj, ms.AverageWordEmbeddingClassifierSpec)
|
||||||
|
self.assertEqual(model_spec_obj.name, 'AverageWordEmbedding')
|
||||||
|
self.assertEqual(
|
||||||
|
model_spec_obj.model_options,
|
||||||
|
classifier_model_options.AverageWordEmbeddingClassifierOptions(
|
||||||
|
seq_len=256,
|
||||||
|
wordvec_dim=16,
|
||||||
|
do_lower_case=True,
|
||||||
|
vocab_size=10000,
|
||||||
|
dropout_rate=0.2))
|
||||||
|
self.assertEqual(
|
||||||
|
model_spec_obj.hparams,
|
||||||
|
hp.BaseHParams(
|
||||||
|
epochs=10,
|
||||||
|
batch_size=32,
|
||||||
|
learning_rate=0,
|
||||||
|
steps_per_epoch=None,
|
||||||
|
shuffle=False,
|
||||||
|
distribution_strategy='off',
|
||||||
|
num_gpus=-1,
|
||||||
|
tpu=''))
|
||||||
|
|
||||||
|
def test_custom_bert_spec(self):
|
||||||
|
custom_bert_classifier_options = (
|
||||||
|
classifier_model_options.BertClassifierOptions(
|
||||||
|
seq_len=512, do_fine_tuning=False, dropout_rate=0.3))
|
||||||
|
model_spec_obj = (
|
||||||
|
ms.SupportedModels.MOBILEBERT_CLASSIFIER.value(
|
||||||
|
model_options=custom_bert_classifier_options))
|
||||||
|
self.assertEqual(model_spec_obj.model_options,
|
||||||
|
custom_bert_classifier_options)
|
||||||
|
|
||||||
|
def test_custom_average_word_embedding_spec(self):
|
||||||
|
custom_hparams = hp.BaseHParams(
|
||||||
|
learning_rate=0.4,
|
||||||
|
batch_size=64,
|
||||||
|
epochs=10,
|
||||||
|
steps_per_epoch=10,
|
||||||
|
shuffle=True,
|
||||||
|
export_dir='foo/bar',
|
||||||
|
distribution_strategy='mirrored',
|
||||||
|
num_gpus=3,
|
||||||
|
tpu='tpu/address')
|
||||||
|
custom_average_word_embedding_model_options = (
|
||||||
|
classifier_model_options.AverageWordEmbeddingClassifierOptions(
|
||||||
|
seq_len=512,
|
||||||
|
wordvec_dim=32,
|
||||||
|
do_lower_case=False,
|
||||||
|
vocab_size=5000,
|
||||||
|
dropout_rate=0.5))
|
||||||
|
model_spec_obj = (
|
||||||
|
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER.value(
|
||||||
|
model_options=custom_average_word_embedding_model_options,
|
||||||
|
hparams=custom_hparams))
|
||||||
|
self.assertEqual(model_spec_obj.model_options,
|
||||||
|
custom_average_word_embedding_model_options)
|
||||||
|
self.assertEqual(model_spec_obj.hparams, custom_hparams)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Load compressed models from tensorflow_hub
|
||||||
|
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
|
||||||
|
tf.test.main()
|
|
@ -0,0 +1,285 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Preprocessors for text classification."""
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import tempfile
|
||||||
|
from typing import Mapping, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_hub
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
||||||
|
from official.nlp.data import classifier_data_lib
|
||||||
|
from official.nlp.tools import tokenization
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_text_and_label(text: tf.Tensor, label: tf.Tensor) -> None:
|
||||||
|
"""Validates the shape and type of `text` and `label`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Stores text data. Should have shape [1] and dtype tf.string.
|
||||||
|
label: Stores the label for the corresponding `text`. Should have shape [1]
|
||||||
|
and dtype tf.int64.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If either tensor has the wrong shape or type.
|
||||||
|
"""
|
||||||
|
if text.shape != [1]:
|
||||||
|
raise ValueError(f"`text` should have shape [1], got {text.shape}")
|
||||||
|
if text.dtype != tf.string:
|
||||||
|
raise ValueError(f"Expected dtype string for `text`, got {text.dtype}")
|
||||||
|
if label.shape != [1]:
|
||||||
|
raise ValueError(f"`label` should have shape [1], got {text.shape}")
|
||||||
|
if label.dtype != tf.int64:
|
||||||
|
raise ValueError(f"Expected dtype int64 for `label`, got {label.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_record(
|
||||||
|
record: tf.Tensor, name_to_features: Mapping[str, tf.io.FixedLenFeature]
|
||||||
|
) -> Tuple[Mapping[str, tf.Tensor], tf.Tensor]:
|
||||||
|
"""Decodes a record into input for a BERT model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
record: Stores serialized example.
|
||||||
|
name_to_features: Maps record keys to feature types.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BERT model input features and label for the record.
|
||||||
|
"""
|
||||||
|
example = tf.io.parse_single_example(record, name_to_features)
|
||||||
|
|
||||||
|
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
||||||
|
for name in list(example.keys()):
|
||||||
|
example[name] = tf.cast(example[name], tf.int32)
|
||||||
|
|
||||||
|
bert_features = {
|
||||||
|
"input_word_ids": example["input_ids"],
|
||||||
|
"input_mask": example["input_mask"],
|
||||||
|
"input_type_ids": example["segment_ids"]
|
||||||
|
}
|
||||||
|
return bert_features, example["label_ids"]
|
||||||
|
|
||||||
|
|
||||||
|
def _single_file_dataset(
|
||||||
|
input_file: str, name_to_features: Mapping[str, tf.io.FixedLenFeature]
|
||||||
|
) -> tf.data.TFRecordDataset:
|
||||||
|
"""Creates a single-file dataset to be passed for BERT custom training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_file: Filepath for the dataset.
|
||||||
|
name_to_features: Maps record keys to feature types.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset containing BERT model input features and labels.
|
||||||
|
"""
|
||||||
|
d = tf.data.TFRecordDataset(input_file)
|
||||||
|
d = d.map(
|
||||||
|
lambda record: _decode_record(record, name_to_features),
|
||||||
|
num_parallel_calls=tf.data.AUTOTUNE)
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
class AverageWordEmbeddingClassifierPreprocessor:
|
||||||
|
"""Preprocessor for an Average Word Embedding model.
|
||||||
|
|
||||||
|
Takes (text, label) data and applies regex tokenization and padding to the
|
||||||
|
text to generate (token IDs, label) data.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
seq_len: Length of the input sequence to the model.
|
||||||
|
do_lower_case: Whether text inputs should be converted to lower-case.
|
||||||
|
vocab: Vocabulary of tokens used by the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
PAD: str = "<PAD>" # Index: 0
|
||||||
|
START: str = "<START>" # Index: 1
|
||||||
|
UNKNOWN: str = "<UNKNOWN>" # Index: 2
|
||||||
|
|
||||||
|
def __init__(self, seq_len: int, do_lower_case: bool, texts: Sequence[str],
|
||||||
|
vocab_size: int):
|
||||||
|
self._seq_len = seq_len
|
||||||
|
self._do_lower_case = do_lower_case
|
||||||
|
self._vocab = self._gen_vocab(texts, vocab_size)
|
||||||
|
|
||||||
|
def _gen_vocab(self, texts: Sequence[str],
|
||||||
|
vocab_size: int) -> Mapping[str, int]:
|
||||||
|
"""Generates vocabulary list in `texts` with size `vocab_size`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: All texts (across training and validation data) that will be
|
||||||
|
preprocessed by the model.
|
||||||
|
vocab_size: Size of the vocab.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The vocab mapping tokens to IDs.
|
||||||
|
"""
|
||||||
|
vocab_counter = collections.Counter()
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
tokens = self._regex_tokenize(text)
|
||||||
|
for token in tokens:
|
||||||
|
vocab_counter[token] += 1
|
||||||
|
|
||||||
|
vocab_freq = vocab_counter.most_common(vocab_size)
|
||||||
|
vocab_list = [self.PAD, self.START, self.UNKNOWN
|
||||||
|
] + [word for word, _ in vocab_freq]
|
||||||
|
return collections.OrderedDict(((v, i) for i, v in enumerate(vocab_list)))
|
||||||
|
|
||||||
|
def get_vocab(self) -> Mapping[str, int]:
|
||||||
|
"""Returns the vocab of the AverageWordEmbeddingClassifierPreprocessor."""
|
||||||
|
return self._vocab
|
||||||
|
|
||||||
|
# TODO: Align with MediaPipe's RegexTokenizer.
|
||||||
|
def _regex_tokenize(self, text: str) -> Sequence[str]:
|
||||||
|
"""Splits `text` by words but does not split on single quotes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to be tokenized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tokens.
|
||||||
|
"""
|
||||||
|
text = tf.compat.as_text(text)
|
||||||
|
if self._do_lower_case:
|
||||||
|
text = text.lower()
|
||||||
|
tokens = re.compile(r"[^\w\']+").split(text.strip())
|
||||||
|
# Filters out any empty strings in `tokens`.
|
||||||
|
return list(filter(None, tokens))
|
||||||
|
|
||||||
|
def _tokenize_and_pad(self, text: str) -> Sequence[int]:
|
||||||
|
"""Tokenizes `text` and pads the tokens to `seq_len`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to be tokenized and padded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of token IDs padded to have length `seq_len`.
|
||||||
|
"""
|
||||||
|
tokens = self._regex_tokenize(text)
|
||||||
|
|
||||||
|
# Gets ids for START, PAD and UNKNOWN tokens.
|
||||||
|
start_id = self._vocab[self.START]
|
||||||
|
pad_id = self._vocab[self.PAD]
|
||||||
|
unknown_id = self._vocab[self.UNKNOWN]
|
||||||
|
|
||||||
|
token_ids = [self._vocab.get(token, unknown_id) for token in tokens]
|
||||||
|
token_ids = [start_id] + token_ids
|
||||||
|
|
||||||
|
if len(token_ids) < self._seq_len:
|
||||||
|
pad_length = self._seq_len - len(token_ids)
|
||||||
|
token_ids = token_ids + pad_length * [pad_id]
|
||||||
|
else:
|
||||||
|
token_ids = token_ids[:self._seq_len]
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset:
|
||||||
|
"""Preprocesses data into input for an Average Word Embedding model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Stores (text, label) data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset containing (token IDs, label) data.
|
||||||
|
"""
|
||||||
|
token_ids_list = []
|
||||||
|
labels_list = []
|
||||||
|
for text, label in dataset.gen_tf_dataset():
|
||||||
|
_validate_text_and_label(text, label)
|
||||||
|
token_ids = self._tokenize_and_pad(text.numpy()[0].decode("utf-8"))
|
||||||
|
token_ids_list.append(token_ids)
|
||||||
|
labels_list.append(label.numpy()[0])
|
||||||
|
|
||||||
|
token_ids_ds = tf.data.Dataset.from_tensor_slices(token_ids_list)
|
||||||
|
labels_ds = tf.data.Dataset.from_tensor_slices(labels_list)
|
||||||
|
preprocessed_ds = tf.data.Dataset.zip((token_ids_ds, labels_ds))
|
||||||
|
return text_classifier_ds.Dataset(
|
||||||
|
dataset=preprocessed_ds,
|
||||||
|
size=dataset.size,
|
||||||
|
label_names=dataset.label_names)
|
||||||
|
|
||||||
|
|
||||||
|
class BertClassifierPreprocessor:
|
||||||
|
"""Preprocessor for a BERT-based classifier.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
seq_len: Length of the input sequence to the model.
|
||||||
|
vocab_file: File containing the BERT vocab.
|
||||||
|
tokenizer: BERT tokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, seq_len: int, do_lower_case: bool, uri: str):
|
||||||
|
self._seq_len = seq_len
|
||||||
|
# Vocab filepath is tied to the BERT module's URI.
|
||||||
|
self._vocab_file = os.path.join(
|
||||||
|
tensorflow_hub.resolve(uri), "assets", "vocab.txt")
|
||||||
|
self._tokenizer = tokenization.FullTokenizer(self._vocab_file,
|
||||||
|
do_lower_case)
|
||||||
|
|
||||||
|
def _get_name_to_features(self):
|
||||||
|
"""Gets the dictionary mapping record keys to feature types."""
|
||||||
|
return {
|
||||||
|
"input_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
|
||||||
|
"input_mask": tf.io.FixedLenFeature([self._seq_len], tf.int64),
|
||||||
|
"segment_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
|
||||||
|
"label_ids": tf.io.FixedLenFeature([], tf.int64),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_vocab_file(self) -> str:
|
||||||
|
"""Returns the vocab file of the BertClassifierPreprocessor."""
|
||||||
|
return self._vocab_file
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset:
|
||||||
|
"""Preprocesses data into input for a BERT-based classifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Stores (text, label) data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset containing (bert_features, label) data.
|
||||||
|
"""
|
||||||
|
examples = []
|
||||||
|
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
|
||||||
|
_validate_text_and_label(text, label)
|
||||||
|
examples.append(
|
||||||
|
classifier_data_lib.InputExample(
|
||||||
|
guid=str(index),
|
||||||
|
text_a=text.numpy()[0].decode("utf-8"),
|
||||||
|
text_b=None,
|
||||||
|
# InputExample expects the label name rather than the int ID
|
||||||
|
label=dataset.label_names[label.numpy()[0]]))
|
||||||
|
|
||||||
|
tfrecord_file = os.path.join(tempfile.mkdtemp(), "bert_features.tfrecord")
|
||||||
|
classifier_data_lib.file_based_convert_examples_to_features(
|
||||||
|
examples=examples,
|
||||||
|
label_list=dataset.label_names,
|
||||||
|
max_seq_length=self._seq_len,
|
||||||
|
tokenizer=self._tokenizer,
|
||||||
|
output_file=tfrecord_file)
|
||||||
|
preprocessed_ds = _single_file_dataset(tfrecord_file,
|
||||||
|
self._get_name_to_features())
|
||||||
|
return text_classifier_ds.Dataset(
|
||||||
|
dataset=preprocessed_ds,
|
||||||
|
size=dataset.size,
|
||||||
|
label_names=dataset.label_names)
|
||||||
|
|
||||||
|
|
||||||
|
TextClassifierPreprocessor = (
|
||||||
|
Union[BertClassifierPreprocessor,
|
||||||
|
AverageWordEmbeddingClassifierPreprocessor])
|
|
@ -0,0 +1,96 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.testing as npt
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessorTest(tf.test.TestCase):
|
||||||
|
CSV_PARAMS_ = text_classifier_ds.CSVParameters(
|
||||||
|
text_column='text', label_column='label')
|
||||||
|
|
||||||
|
def _get_csv_file(self):
|
||||||
|
labels_and_text = (('pos', 'super super super super good'),
|
||||||
|
(('neg', 'really bad')))
|
||||||
|
csv_file = os.path.join(self.get_temp_dir(), 'data.csv')
|
||||||
|
if os.path.exists(csv_file):
|
||||||
|
return csv_file
|
||||||
|
fieldnames = ['text', 'label']
|
||||||
|
with open(csv_file, 'w') as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
for label, text in labels_and_text:
|
||||||
|
writer.writerow({'text': text, 'label': label})
|
||||||
|
return csv_file
|
||||||
|
|
||||||
|
def test_average_word_embedding_preprocessor(self):
|
||||||
|
csv_file = self._get_csv_file()
|
||||||
|
dataset = text_classifier_ds.Dataset.from_csv(
|
||||||
|
filename=csv_file, csv_params=self.CSV_PARAMS_)
|
||||||
|
average_word_embedding_preprocessor = (
|
||||||
|
preprocessor.AverageWordEmbeddingClassifierPreprocessor(
|
||||||
|
seq_len=5,
|
||||||
|
do_lower_case=True,
|
||||||
|
texts=['super super super super good', 'really bad'],
|
||||||
|
vocab_size=7))
|
||||||
|
preprocessed_dataset = (
|
||||||
|
average_word_embedding_preprocessor.preprocess(dataset))
|
||||||
|
labels = []
|
||||||
|
features_list = []
|
||||||
|
for features, label in preprocessed_dataset.gen_tf_dataset():
|
||||||
|
self.assertEqual(label.shape, [1])
|
||||||
|
labels.append(label.numpy()[0])
|
||||||
|
self.assertEqual(features.shape, [1, 5])
|
||||||
|
features_list.append(features.numpy()[0])
|
||||||
|
self.assertEqual(labels, [1, 0])
|
||||||
|
npt.assert_array_equal(
|
||||||
|
np.stack(features_list), np.array([[1, 3, 3, 3, 3], [1, 5, 6, 0, 0]]))
|
||||||
|
|
||||||
|
def test_bert_preprocessor(self):
|
||||||
|
csv_file = self._get_csv_file()
|
||||||
|
dataset = text_classifier_ds.Dataset.from_csv(
|
||||||
|
filename=csv_file, csv_params=self.CSV_PARAMS_)
|
||||||
|
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||||
|
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||||
|
seq_len=5, do_lower_case=bert_spec.do_lower_case, uri=bert_spec.uri)
|
||||||
|
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
|
||||||
|
labels = []
|
||||||
|
input_masks = []
|
||||||
|
for features, label in preprocessed_dataset.gen_tf_dataset():
|
||||||
|
self.assertEqual(label.shape, [1])
|
||||||
|
labels.append(label.numpy()[0])
|
||||||
|
self.assertSameElements(
|
||||||
|
features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids'])
|
||||||
|
for feature in features.values():
|
||||||
|
self.assertEqual(feature.shape, [1, 5])
|
||||||
|
input_masks.append(features['input_mask'].numpy()[0])
|
||||||
|
npt.assert_array_equal(features['input_type_ids'].numpy()[0],
|
||||||
|
[0, 0, 0, 0, 0])
|
||||||
|
npt.assert_array_equal(
|
||||||
|
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]]))
|
||||||
|
self.assertEqual(labels, [1, 0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Load compressed models from tensorflow_hub
|
||||||
|
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
|
||||||
|
tf.test.main()
|
23
mediapipe/model_maker/python/text/text_classifier/testdata/BUILD
vendored
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//mediapipe/model_maker/python/text/text_classifier:__subpackages__"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "testdata",
|
||||||
|
srcs = ["average_word_embedding_metadata.json"],
|
||||||
|
)
|
63
mediapipe/model_maker/python/text/text_classifier/testdata/average_word_embedding_metadata.json
vendored
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
{
|
||||||
|
"name": "TextClassifier",
|
||||||
|
"description": "Classify the input text into a set of known categories.",
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"input_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "input_text",
|
||||||
|
"description": "Embedding vectors representing the input text to be processed.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"process_units": [
|
||||||
|
{
|
||||||
|
"options_type": "RegexTokenizerOptions",
|
||||||
|
"options": {
|
||||||
|
"delim_regex_pattern": "[^\\w\\']+",
|
||||||
|
"vocab_file": [
|
||||||
|
{
|
||||||
|
"name": "vocab.txt",
|
||||||
|
"description": "Vocabulary file to convert natural language words to embedding vectors.",
|
||||||
|
"type": "VOCABULARY"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stats": {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"output_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "score",
|
||||||
|
"description": "Score of the labels respectively.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stats": {
|
||||||
|
"max": [
|
||||||
|
1.0
|
||||||
|
],
|
||||||
|
"min": [
|
||||||
|
0.0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"associated_files": [
|
||||||
|
{
|
||||||
|
"name": "labels.txt",
|
||||||
|
"description": "Labels for categories that the model can recognize.",
|
||||||
|
"type": "TENSOR_AXIS_LABELS"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"min_parser_version": "1.2.1"
|
||||||
|
}
|
|
@ -0,0 +1,437 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""API for text classification."""
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from typing import Any, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_hub as hub
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||||
|
from mediapipe.model_maker.python.core.tasks import classifier
|
||||||
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import text_classifier as text_classifier_writer
|
||||||
|
from official.nlp import optimization
|
||||||
|
|
||||||
|
|
||||||
|
def _validate(options: text_classifier_options.TextClassifierOptions):
|
||||||
|
"""Validates that `model_options` and `supported_model` are compatible.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
options: Options for creating and training a text classifier.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if there is a mismatch between `model_options` and
|
||||||
|
`supported_model`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if options.model_options is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if (isinstance(options.model_options,
|
||||||
|
mo.AverageWordEmbeddingClassifierOptions) and
|
||||||
|
(options.supported_model !=
|
||||||
|
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
|
||||||
|
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
|
||||||
|
f" got {options.supported_model}")
|
||||||
|
if (isinstance(options.model_options, mo.BertClassifierOptions) and
|
||||||
|
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
|
||||||
|
|
||||||
|
|
||||||
|
class TextClassifier(classifier.Classifier):
|
||||||
|
"""API for creating and training a text classification model."""
|
||||||
|
|
||||||
|
def __init__(self, model_spec: Any, hparams: hp.BaseHParams,
|
||||||
|
label_names: Sequence[str]):
|
||||||
|
super().__init__(
|
||||||
|
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
|
||||||
|
self._model_spec = model_spec
|
||||||
|
self._hparams = hparams
|
||||||
|
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
||||||
|
self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
|
||||||
|
options: text_classifier_options.TextClassifierOptions
|
||||||
|
) -> "TextClassifier":
|
||||||
|
"""Factory function that creates and trains a text classifier.
|
||||||
|
|
||||||
|
Note that `train_data` and `validation_data` are expected to share the same
|
||||||
|
`label_names` since they should be split from the same dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data.
|
||||||
|
options: Options for creating and training the text classifier.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A text classifier.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if `train_data` and `validation_data` do not have the
|
||||||
|
same label_names or `options` contains an unknown `supported_model`
|
||||||
|
"""
|
||||||
|
if train_data.label_names != validation_data.label_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"Training data label names {train_data.label_names} not equal to "
|
||||||
|
f"validation data label names {validation_data.label_names}")
|
||||||
|
|
||||||
|
_validate(options)
|
||||||
|
if options.model_options is None:
|
||||||
|
options.model_options = options.supported_model.value().model_options
|
||||||
|
|
||||||
|
if options.hparams is None:
|
||||||
|
options.hparams = options.supported_model.value().hparams
|
||||||
|
|
||||||
|
if options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
|
||||||
|
text_classifier = (
|
||||||
|
_BertClassifier.create_bert_classifier(train_data, validation_data,
|
||||||
|
options,
|
||||||
|
train_data.label_names))
|
||||||
|
elif (options.supported_model ==
|
||||||
|
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
||||||
|
text_classifier = (
|
||||||
|
_AverageWordEmbeddingClassifier
|
||||||
|
.create_average_word_embedding_classifier(train_data, validation_data,
|
||||||
|
options,
|
||||||
|
train_data.label_names))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model {options.supported_model}")
|
||||||
|
|
||||||
|
return text_classifier
|
||||||
|
|
||||||
|
def evaluate(self, data: ds.Dataset, batch_size: int = 32) -> Any:
|
||||||
|
"""Overrides Classifier.evaluate().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Evaluation dataset. Must be a TextClassifier Dataset.
|
||||||
|
batch_size: Number of samples per evaluation step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The loss value and accuracy.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if `data` is not a TextClassifier Dataset.
|
||||||
|
"""
|
||||||
|
# This override is needed because TextClassifier preprocesses its data
|
||||||
|
# outside of the `gen_tf_dataset()` method. The preprocess call also
|
||||||
|
# requires a TextClassifier Dataset instead of a core Dataset.
|
||||||
|
if not isinstance(data, text_ds.Dataset):
|
||||||
|
raise ValueError("Need a TextClassifier Dataset.")
|
||||||
|
|
||||||
|
processed_data = self._text_preprocessor.preprocess(data)
|
||||||
|
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
|
||||||
|
return self._model.evaluate(dataset)
|
||||||
|
|
||||||
|
def export_model(
|
||||||
|
self,
|
||||||
|
model_name: str = "model.tflite",
|
||||||
|
quantization_config: Optional[quantization.QuantizationConfig] = None):
|
||||||
|
"""Converts and saves the model to a TFLite file with metadata included.
|
||||||
|
|
||||||
|
Note that only the TFLite file is needed for deployment. This function also
|
||||||
|
saves a metadata.json file to the same directory as the TFLite file which
|
||||||
|
can be used to interpret the metadata content in the TFLite file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: File name to save TFLite model with metadata. The full export
|
||||||
|
path is {self._hparams.export_dir}/{model_name}.
|
||||||
|
quantization_config: The configuration for model quantization.
|
||||||
|
"""
|
||||||
|
if not tf.io.gfile.exists(self._hparams.export_dir):
|
||||||
|
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||||
|
tflite_file = os.path.join(self._hparams.export_dir, model_name)
|
||||||
|
metadata_file = os.path.join(self._hparams.export_dir, "metadata.json")
|
||||||
|
|
||||||
|
tflite_model = model_util.convert_to_tflite(
|
||||||
|
model=self._model, quantization_config=quantization_config)
|
||||||
|
vocab_filepath = os.path.join(tempfile.mkdtemp(), "vocab.txt")
|
||||||
|
self._save_vocab(vocab_filepath)
|
||||||
|
|
||||||
|
writer = self._get_metadata_writer(tflite_model, vocab_filepath)
|
||||||
|
tflite_model_with_metadata, metadata_json = writer.populate()
|
||||||
|
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
|
||||||
|
with open(metadata_file, "w") as f:
|
||||||
|
f.write(metadata_json)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _save_vocab(self, vocab_filepath: str):
|
||||||
|
"""Saves the preprocessor's vocab to `vocab_filepath`."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str):
|
||||||
|
"""Gets the metadata writer for the text classifier TFLite model."""
|
||||||
|
|
||||||
|
|
||||||
|
class _AverageWordEmbeddingClassifier(TextClassifier):
|
||||||
|
"""APIs to help create and train an Average Word Embedding text classifier."""
|
||||||
|
|
||||||
|
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
||||||
|
|
||||||
|
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
|
||||||
|
model_options: mo.AverageWordEmbeddingClassifierOptions,
|
||||||
|
hparams: hp.BaseHParams, label_names: Sequence[str]):
|
||||||
|
super().__init__(model_spec, hparams, label_names)
|
||||||
|
self._model_options = model_options
|
||||||
|
self._loss_function = "sparse_categorical_crossentropy"
|
||||||
|
self._metric_function = "accuracy"
|
||||||
|
self._text_preprocessor: (
|
||||||
|
preprocessor.AverageWordEmbeddingClassifierPreprocessor) = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_average_word_embedding_classifier(
|
||||||
|
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
|
||||||
|
options: text_classifier_options.TextClassifierOptions,
|
||||||
|
label_names: Sequence[str]) -> "_AverageWordEmbeddingClassifier":
|
||||||
|
"""Creates, trains, and returns an Average Word Embedding classifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data.
|
||||||
|
options: Options for creating and training the text classifier.
|
||||||
|
label_names: Label names used in the data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An Average Word Embedding classifier.
|
||||||
|
"""
|
||||||
|
average_word_embedding_classifier = _AverageWordEmbeddingClassifier(
|
||||||
|
model_spec=options.supported_model.value(),
|
||||||
|
model_options=options.model_options,
|
||||||
|
hparams=options.hparams,
|
||||||
|
label_names=train_data.label_names)
|
||||||
|
average_word_embedding_classifier._create_and_train_model(
|
||||||
|
train_data, validation_data)
|
||||||
|
return average_word_embedding_classifier
|
||||||
|
|
||||||
|
def _create_and_train_model(self, train_data: text_ds.Dataset,
|
||||||
|
validation_data: text_ds.Dataset):
|
||||||
|
"""Creates the Average Word Embedding classifier keras model and trains it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data.
|
||||||
|
"""
|
||||||
|
(processed_train_data, processed_validation_data) = (
|
||||||
|
self._load_and_run_preprocessor(train_data, validation_data))
|
||||||
|
self._create_model()
|
||||||
|
self._optimizer = "rmsprop"
|
||||||
|
self._train_model(processed_train_data, processed_validation_data)
|
||||||
|
|
||||||
|
def _load_and_run_preprocessor(
|
||||||
|
self, train_data: text_ds.Dataset, validation_data: text_ds.Dataset
|
||||||
|
) -> Tuple[text_ds.Dataset, text_ds.Dataset]:
|
||||||
|
"""Runs an AverageWordEmbeddingClassifierPreprocessor on the data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Preprocessed training data and preprocessed validation data.
|
||||||
|
"""
|
||||||
|
train_texts = [text.numpy()[0] for text, _ in train_data.gen_tf_dataset()]
|
||||||
|
validation_texts = [
|
||||||
|
text.numpy()[0] for text, _ in validation_data.gen_tf_dataset()
|
||||||
|
]
|
||||||
|
self._text_preprocessor = (
|
||||||
|
preprocessor.AverageWordEmbeddingClassifierPreprocessor(
|
||||||
|
seq_len=self._model_options.seq_len,
|
||||||
|
do_lower_case=self._model_options.do_lower_case,
|
||||||
|
texts=train_texts + validation_texts,
|
||||||
|
vocab_size=self._model_options.vocab_size))
|
||||||
|
return self._text_preprocessor.preprocess(
|
||||||
|
train_data), self._text_preprocessor.preprocess(validation_data)
|
||||||
|
|
||||||
|
def _create_model(self):
|
||||||
|
"""Creates an Average Word Embedding model."""
|
||||||
|
self._model = tf.keras.Sequential([
|
||||||
|
tf.keras.layers.InputLayer(
|
||||||
|
input_shape=[self._model_options.seq_len], dtype=tf.int32),
|
||||||
|
tf.keras.layers.Embedding(
|
||||||
|
len(self._text_preprocessor.get_vocab()),
|
||||||
|
self._model_options.wordvec_dim,
|
||||||
|
input_length=self._model_options.seq_len),
|
||||||
|
tf.keras.layers.GlobalAveragePooling1D(),
|
||||||
|
tf.keras.layers.Dense(
|
||||||
|
self._model_options.wordvec_dim, activation=tf.nn.relu),
|
||||||
|
tf.keras.layers.Dropout(self._model_options.dropout_rate),
|
||||||
|
tf.keras.layers.Dense(self._num_classes, activation="softmax")
|
||||||
|
])
|
||||||
|
|
||||||
|
def _save_vocab(self, vocab_filepath: str):
|
||||||
|
with tf.io.gfile.GFile(vocab_filepath, "w") as f:
|
||||||
|
for token, index in self._text_preprocessor.get_vocab().items():
|
||||||
|
f.write(f"{token} {index}\n")
|
||||||
|
|
||||||
|
def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str):
|
||||||
|
return text_classifier_writer.MetadataWriter.create_for_regex_model(
|
||||||
|
model_buffer=tflite_model,
|
||||||
|
regex_tokenizer=metadata_writer.RegexTokenizer(
|
||||||
|
# TODO: Align with MediaPipe's RegexTokenizer.
|
||||||
|
delim_regex_pattern=self._DELIM_REGEX_PATTERN,
|
||||||
|
vocab_file_path=vocab_filepath),
|
||||||
|
labels=metadata_writer.Labels().add(list(self._label_names)))
|
||||||
|
|
||||||
|
|
||||||
|
class _BertClassifier(TextClassifier):
|
||||||
|
"""APIs to help create and train a BERT-based text classifier."""
|
||||||
|
|
||||||
|
_INITIALIZER_RANGE = 0.02
|
||||||
|
|
||||||
|
def __init__(self, model_spec: ms.BertClassifierSpec,
|
||||||
|
model_options: mo.BertClassifierOptions, hparams: hp.BaseHParams,
|
||||||
|
label_names: Sequence[str]):
|
||||||
|
super().__init__(model_spec, hparams, label_names)
|
||||||
|
self._model_options = model_options
|
||||||
|
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||||
|
self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy(
|
||||||
|
"test_accuracy", dtype=tf.float32)
|
||||||
|
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_bert_classifier(
|
||||||
|
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
|
||||||
|
options: text_classifier_options.TextClassifierOptions,
|
||||||
|
label_names: Sequence[str]) -> "_BertClassifier":
|
||||||
|
"""Creates, trains, and returns a BERT-based classifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data.
|
||||||
|
options: Options for creating and training the text classifier.
|
||||||
|
label_names: Label names used in the data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A BERT-based classifier.
|
||||||
|
"""
|
||||||
|
bert_classifier = _BertClassifier(
|
||||||
|
model_spec=options.supported_model.value(),
|
||||||
|
model_options=options.model_options,
|
||||||
|
hparams=options.hparams,
|
||||||
|
label_names=train_data.label_names)
|
||||||
|
bert_classifier._create_and_train_model(train_data, validation_data)
|
||||||
|
return bert_classifier
|
||||||
|
|
||||||
|
def _create_and_train_model(self, train_data: text_ds.Dataset,
|
||||||
|
validation_data: text_ds.Dataset):
|
||||||
|
"""Creates the BERT-based classifier keras model and trains it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data.
|
||||||
|
"""
|
||||||
|
(processed_train_data, processed_validation_data) = (
|
||||||
|
self._load_and_run_preprocessor(train_data, validation_data))
|
||||||
|
self._create_model()
|
||||||
|
self._create_optimizer(processed_train_data)
|
||||||
|
self._train_model(processed_train_data, processed_validation_data)
|
||||||
|
|
||||||
|
def _load_and_run_preprocessor(
|
||||||
|
self, train_data: text_ds.Dataset, validation_data: text_ds.Dataset
|
||||||
|
) -> Tuple[text_ds.Dataset, text_ds.Dataset]:
|
||||||
|
"""Loads a BertClassifierPreprocessor and runs it on the data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Preprocessed training data and preprocessed validation data.
|
||||||
|
"""
|
||||||
|
self._text_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||||
|
seq_len=self._model_options.seq_len,
|
||||||
|
do_lower_case=self._model_spec.do_lower_case,
|
||||||
|
uri=self._model_spec.uri)
|
||||||
|
return (self._text_preprocessor.preprocess(train_data),
|
||||||
|
self._text_preprocessor.preprocess(validation_data))
|
||||||
|
|
||||||
|
def _create_model(self):
|
||||||
|
"""Creates a BERT-based classifier model.
|
||||||
|
|
||||||
|
The model architecture consists of stacking a dense classification layer and
|
||||||
|
dropout layer on top of the BERT encoder outputs.
|
||||||
|
"""
|
||||||
|
encoder_inputs = dict(
|
||||||
|
input_word_ids=tf.keras.layers.Input(
|
||||||
|
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
||||||
|
input_mask=tf.keras.layers.Input(
|
||||||
|
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
||||||
|
input_type_ids=tf.keras.layers.Input(
|
||||||
|
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
||||||
|
)
|
||||||
|
encoder = hub.KerasLayer(
|
||||||
|
self._model_spec.uri, trainable=self._model_options.do_fine_tuning)
|
||||||
|
encoder_outputs = encoder(encoder_inputs)
|
||||||
|
pooled_output = encoder_outputs["pooled_output"]
|
||||||
|
|
||||||
|
output = tf.keras.layers.Dropout(rate=self._model_options.dropout_rate)(
|
||||||
|
pooled_output)
|
||||||
|
initializer = tf.keras.initializers.TruncatedNormal(
|
||||||
|
stddev=self._INITIALIZER_RANGE)
|
||||||
|
output = tf.keras.layers.Dense(
|
||||||
|
self._num_classes,
|
||||||
|
kernel_initializer=initializer,
|
||||||
|
name="output",
|
||||||
|
activation="softmax",
|
||||||
|
dtype=tf.float32)(
|
||||||
|
output)
|
||||||
|
self._model = tf.keras.Model(inputs=encoder_inputs, outputs=output)
|
||||||
|
|
||||||
|
def _create_optimizer(self, train_data: text_ds.Dataset):
|
||||||
|
"""Loads an optimizer with a learning rate schedule.
|
||||||
|
|
||||||
|
The decay steps in the learning rate schedule depend on the
|
||||||
|
`steps_per_epoch` which may depend on the size of the training data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
"""
|
||||||
|
self._hparams.steps_per_epoch = model_util.get_steps_per_epoch(
|
||||||
|
steps_per_epoch=self._hparams.steps_per_epoch,
|
||||||
|
batch_size=self._hparams.batch_size,
|
||||||
|
train_data=train_data)
|
||||||
|
total_steps = self._hparams.steps_per_epoch * self._hparams.epochs
|
||||||
|
warmup_steps = int(total_steps * 0.1)
|
||||||
|
initial_lr = self._hparams.learning_rate
|
||||||
|
self._optimizer = optimization.create_optimizer(initial_lr, total_steps,
|
||||||
|
warmup_steps)
|
||||||
|
|
||||||
|
def _save_vocab(self, vocab_filepath: str):
|
||||||
|
tf.io.gfile.copy(
|
||||||
|
self._text_preprocessor.get_vocab_file(),
|
||||||
|
vocab_filepath,
|
||||||
|
overwrite=True)
|
||||||
|
|
||||||
|
def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str):
|
||||||
|
return text_classifier_writer.MetadataWriter.create_for_bert_model(
|
||||||
|
model_buffer=tflite_model,
|
||||||
|
tokenizer=metadata_writer.BertTokenizer(vocab_filepath),
|
||||||
|
labels=metadata_writer.Labels().add(list(self._label_names)),
|
||||||
|
ids_name=self._model_spec.tflite_input_name["ids"],
|
||||||
|
mask_name=self._model_spec.tflite_input_name["mask"],
|
||||||
|
segment_name=self._model_spec.tflite_input_name["segment_ids"])
|
|
@ -0,0 +1,108 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Demo for making a text classifier model by MediaPipe Model Maker."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
# Dependency imports
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
from absl import logging
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import text_classifier
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
|
||||||
|
def define_flags():
|
||||||
|
flags.DEFINE_string('export_dir', None,
|
||||||
|
'The directory to save exported files.')
|
||||||
|
flags.DEFINE_enum('supported_model', 'average_word_embedding',
|
||||||
|
['average_word_embedding', 'bert'],
|
||||||
|
'The text classifier to run.')
|
||||||
|
flags.mark_flag_as_required('export_dir')
|
||||||
|
|
||||||
|
|
||||||
|
def download_demo_data():
|
||||||
|
"""Downloads demo data, and returns directory path."""
|
||||||
|
data_path = tf.keras.utils.get_file(
|
||||||
|
fname='SST-2.zip',
|
||||||
|
origin='https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',
|
||||||
|
extract=True)
|
||||||
|
return os.path.join(os.path.dirname(data_path), 'SST-2') # folder name
|
||||||
|
|
||||||
|
|
||||||
|
def run(data_dir,
|
||||||
|
export_dir=tempfile.mkdtemp(),
|
||||||
|
supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
||||||
|
"""Runs demo."""
|
||||||
|
|
||||||
|
# Gets training data and validation data.
|
||||||
|
csv_params = text_ds.CSVParameters(
|
||||||
|
text_column='sentence', label_column='label', delimiter='\t')
|
||||||
|
train_data = text_ds.Dataset.from_csv(
|
||||||
|
filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
|
||||||
|
csv_params=csv_params)
|
||||||
|
validation_data = text_ds.Dataset.from_csv(
|
||||||
|
filename=os.path.join(os.path.join(data_dir, 'dev.tsv')),
|
||||||
|
csv_params=csv_params)
|
||||||
|
|
||||||
|
quantization_config = None
|
||||||
|
if supported_model == ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER:
|
||||||
|
hparams = hp.BaseHParams(
|
||||||
|
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir)
|
||||||
|
# Warning: This takes extremely long to run on CPU
|
||||||
|
elif supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
|
||||||
|
quantization_config = quantization.QuantizationConfig.for_dynamic()
|
||||||
|
hparams = hp.BaseHParams(
|
||||||
|
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir)
|
||||||
|
|
||||||
|
# Fine-tunes the model.
|
||||||
|
options = text_classifier_options.TextClassifierOptions(
|
||||||
|
supported_model=supported_model, hparams=hparams)
|
||||||
|
model = text_classifier.TextClassifier.create(train_data, validation_data,
|
||||||
|
options)
|
||||||
|
|
||||||
|
# Gets evaluation results.
|
||||||
|
_, acc = model.evaluate(validation_data)
|
||||||
|
print('Eval accuracy: %f' % acc)
|
||||||
|
|
||||||
|
model.export_model(quantization_config=quantization_config)
|
||||||
|
model.export_labels(export_dir=options.hparams.export_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def main(_):
|
||||||
|
logging.set_verbosity(logging.INFO)
|
||||||
|
data_dir = download_demo_data()
|
||||||
|
export_dir = os.path.expanduser(FLAGS.export_dir)
|
||||||
|
|
||||||
|
if FLAGS.supported_model == 'average_word_embedding':
|
||||||
|
supported_model = ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
|
||||||
|
elif FLAGS.supported_model == 'bert':
|
||||||
|
supported_model = ms.SupportedModels.MOBILEBERT_CLASSIFIER
|
||||||
|
|
||||||
|
run(data_dir, export_dir, supported_model)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
define_flags()
|
||||||
|
app.run(main)
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""User-facing customization options to create and train a text classifier."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class TextClassifierOptions:
|
||||||
|
"""User-facing options for creating the text classifier.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
supported_model: A preconfigured model spec.
|
||||||
|
hparams: Training hyperparameters the user can set to override the ones in
|
||||||
|
`supported_model`.
|
||||||
|
model_options: Model options the user can set to override the ones in
|
||||||
|
`supported_model`. The model options type should be consistent with the
|
||||||
|
architecture of the `supported_model`.
|
||||||
|
"""
|
||||||
|
supported_model: ms.SupportedModels
|
||||||
|
hparams: Optional[hp.BaseHParams] = None
|
||||||
|
model_options: Optional[mo.TextClassifierModelOptions] = None
|
|
@ -0,0 +1,138 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import filecmp
|
||||||
|
import os
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import dataset
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import text_classifier
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
|
||||||
|
class TextClassifierTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
|
||||||
|
test_utils.get_test_data_path('average_word_embedding_metadata.json'))
|
||||||
|
|
||||||
|
def _get_data(self):
|
||||||
|
labels_and_text = (('pos', 'super good'), (('neg', 'really bad')))
|
||||||
|
csv_file = os.path.join(self.get_temp_dir(), 'data.csv')
|
||||||
|
if os.path.exists(csv_file):
|
||||||
|
return csv_file
|
||||||
|
fieldnames = ['text', 'label']
|
||||||
|
with open(csv_file, 'w') as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
for label, text in labels_and_text:
|
||||||
|
writer.writerow({'text': text, 'label': label})
|
||||||
|
csv_params = dataset.CSVParameters(text_column='text', label_column='label')
|
||||||
|
all_data = dataset.Dataset.from_csv(
|
||||||
|
filename=csv_file, csv_params=csv_params)
|
||||||
|
return all_data.split(0.5)
|
||||||
|
|
||||||
|
def test_create_and_train_average_word_embedding_model(self):
|
||||||
|
train_data, validation_data = self._get_data()
|
||||||
|
options = text_classifier_options.TextClassifierOptions(
|
||||||
|
supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER,
|
||||||
|
hparams=hp.BaseHParams(epochs=1, batch_size=1, learning_rate=0))
|
||||||
|
average_word_embedding_classifier = text_classifier.TextClassifier.create(
|
||||||
|
train_data, validation_data, options)
|
||||||
|
|
||||||
|
_, accuracy = average_word_embedding_classifier.evaluate(validation_data)
|
||||||
|
self.assertGreaterEqual(accuracy, 0.0)
|
||||||
|
|
||||||
|
# Test export_model
|
||||||
|
average_word_embedding_classifier.export_model()
|
||||||
|
output_metadata_file = os.path.join(options.hparams.export_dir,
|
||||||
|
'metadata.json')
|
||||||
|
output_tflite_file = os.path.join(options.hparams.export_dir,
|
||||||
|
'model.tflite')
|
||||||
|
|
||||||
|
self.assertTrue(os.path.exists(output_tflite_file))
|
||||||
|
self.assertGreater(os.path.getsize(output_tflite_file), 0)
|
||||||
|
|
||||||
|
self.assertTrue(os.path.exists(output_metadata_file))
|
||||||
|
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||||
|
self.assertTrue(
|
||||||
|
filecmp.cmp(output_metadata_file,
|
||||||
|
self._AVERAGE_WORD_EMBEDDING_JSON_FILE))
|
||||||
|
|
||||||
|
def test_create_and_train_bert(self):
|
||||||
|
train_data, validation_data = self._get_data()
|
||||||
|
options = text_classifier_options.TextClassifierOptions(
|
||||||
|
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||||
|
model_options=mo.BertClassifierOptions(do_fine_tuning=False, seq_len=2),
|
||||||
|
hparams=hp.BaseHParams(
|
||||||
|
epochs=1,
|
||||||
|
batch_size=1,
|
||||||
|
learning_rate=3e-5,
|
||||||
|
distribution_strategy='off'))
|
||||||
|
bert_classifier = text_classifier.TextClassifier.create(
|
||||||
|
train_data, validation_data, options)
|
||||||
|
|
||||||
|
_, accuracy = bert_classifier.evaluate(validation_data)
|
||||||
|
self.assertGreaterEqual(accuracy, 0.0)
|
||||||
|
# TODO: Add a unit test that does not run OOM.
|
||||||
|
|
||||||
|
def test_label_mismatch(self):
|
||||||
|
options = (
|
||||||
|
text_classifier_options.TextClassifierOptions(
|
||||||
|
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER))
|
||||||
|
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||||
|
train_data = dataset.Dataset(train_tf_dataset, 1, ['foo'])
|
||||||
|
validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||||
|
validation_data = dataset.Dataset(validation_tf_dataset, 1, ['bar'])
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
'Training data label names .* not equal to validation data label names'
|
||||||
|
):
|
||||||
|
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||||
|
options)
|
||||||
|
|
||||||
|
def test_options_mismatch(self):
|
||||||
|
train_data, validation_data = self._get_data()
|
||||||
|
|
||||||
|
avg_options = (
|
||||||
|
text_classifier_options.TextClassifierOptions(
|
||||||
|
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||||
|
model_options=mo.AverageWordEmbeddingClassifierOptions()))
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
|
||||||
|
' SupportedModels.MOBILEBERT_CLASSIFIER'):
|
||||||
|
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||||
|
avg_options)
|
||||||
|
|
||||||
|
bert_options = (
|
||||||
|
text_classifier_options.TextClassifierOptions(
|
||||||
|
supported_model=(
|
||||||
|
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
|
||||||
|
model_options=mo.BertClassifierOptions()))
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got'
|
||||||
|
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):
|
||||||
|
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||||
|
bert_options)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Load compressed models from tensorflow_hub
|
||||||
|
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
|
||||||
|
tf.test.main()
|
165
mediapipe/model_maker/python/vision/gesture_recognizer/BUILD
Normal file
|
@ -0,0 +1,165 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict test compatibility macro.
|
||||||
|
# Placeholder for internal Python strict library and test compatibility macro.
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//mediapipe:__subpackages__"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Remove the unncessary test data once the demo data are moved to an open-sourced
|
||||||
|
# directory.
|
||||||
|
filegroup(
|
||||||
|
name = "test_data",
|
||||||
|
srcs = glob([
|
||||||
|
"test_data/**",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "constants",
|
||||||
|
srcs = ["constants.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Change to py_library after migrating the MediaPipe hand solution
|
||||||
|
# library to MediaPipe hand task library.
|
||||||
|
py_library(
|
||||||
|
name = "dataset",
|
||||||
|
srcs = ["dataset.py"],
|
||||||
|
deps = [
|
||||||
|
":constants",
|
||||||
|
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||||
|
"//mediapipe/model_maker/python/core/data:data_util",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||||
|
"//mediapipe/python/solutions:hands",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "dataset_test",
|
||||||
|
srcs = ["dataset_test.py"],
|
||||||
|
data = [
|
||||||
|
":test_data",
|
||||||
|
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
"//mediapipe/python/solutions:hands",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "hyperparameters",
|
||||||
|
srcs = ["hyperparameters.py"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "model_options",
|
||||||
|
srcs = ["model_options.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "gesture_recognizer_options",
|
||||||
|
srcs = ["gesture_recognizer_options.py"],
|
||||||
|
deps = [
|
||||||
|
":hyperparameters",
|
||||||
|
":model_options",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "gesture_recognizer",
|
||||||
|
srcs = ["gesture_recognizer.py"],
|
||||||
|
data = ["//mediapipe/model_maker/models/gesture_recognizer:models"],
|
||||||
|
deps = [
|
||||||
|
":constants",
|
||||||
|
":gesture_recognizer_options",
|
||||||
|
":hyperparameters",
|
||||||
|
":metadata_writer",
|
||||||
|
":model_options",
|
||||||
|
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||||
|
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:loss_functions",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "gesture_recognizer_import",
|
||||||
|
srcs = ["__init__.py"],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
":gesture_recognizer",
|
||||||
|
":gesture_recognizer_options",
|
||||||
|
":hyperparameters",
|
||||||
|
":model_options",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "metadata_writer",
|
||||||
|
srcs = ["metadata_writer.py"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:writer_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "gesture_recognizer_test",
|
||||||
|
size = "large",
|
||||||
|
srcs = ["gesture_recognizer_test.py"],
|
||||||
|
data = [
|
||||||
|
":test_data",
|
||||||
|
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
||||||
|
],
|
||||||
|
shard_count = 2,
|
||||||
|
deps = [
|
||||||
|
":gesture_recognizer_import",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "metadata_writer_test",
|
||||||
|
srcs = ["metadata_writer_test.py"],
|
||||||
|
data = [
|
||||||
|
":test_data",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":metadata_writer",
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_binary(
|
||||||
|
name = "gesture_recognizer_demo",
|
||||||
|
srcs = ["gesture_recognizer_demo.py"],
|
||||||
|
data = [
|
||||||
|
":test_data",
|
||||||
|
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
||||||
|
],
|
||||||
|
python_version = "PY3",
|
||||||
|
deps = [":gesture_recognizer_import"],
|
||||||
|
)
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MediaPipe Model Maker Python Public API For Gesture Recognizer."""
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.gesture_recognizer import dataset
|
||||||
|
from mediapipe.model_maker.python.vision.gesture_recognizer import gesture_recognizer
|
||||||
|
from mediapipe.model_maker.python.vision.gesture_recognizer import gesture_recognizer_options
|
||||||
|
from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters
|
||||||
|
from mediapipe.model_maker.python.vision.gesture_recognizer import model_options
|
||||||
|
|
||||||
|
GestureRecognizer = gesture_recognizer.GestureRecognizer
|
||||||
|
ModelOptions = model_options.GestureRecognizerModelOptions
|
||||||
|
HParams = hyperparameters.HParams
|
||||||
|
Dataset = dataset.Dataset
|
||||||
|
HandDataPreprocessingParams = dataset.HandDataPreprocessingParams
|
||||||
|
GestureRecognizerOptions = gesture_recognizer_options.GestureRecognizerOptions
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Gesture recognition constants."""
|
||||||
|
|
||||||
|
GESTURE_EMBEDDER_KERAS_MODEL_PATH = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder'
|
||||||
|
GESTURE_EMBEDDER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder.tflite'
|
||||||
|
HAND_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/palm_detection_full.tflite'
|
||||||
|
HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/hand_landmark_full.tflite'
|
||||||
|
CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/canned_gesture_classifier.tflite'
|
|
@ -0,0 +1,238 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Gesture recognition dataset library."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from typing import List, NamedTuple, Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
|
from mediapipe.model_maker.python.core.data import data_util
|
||||||
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
|
from mediapipe.model_maker.python.vision.gesture_recognizer import constants
|
||||||
|
from mediapipe.python.solutions import hands as mp_hands
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class HandDataPreprocessingParams:
|
||||||
|
"""A dataclass wraps the hand data preprocessing hyperparameters.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
shuffle: A boolean controlling if shuffle the dataset. Default to true.
|
||||||
|
min_detection_confidence: confidence threshold for hand detection.
|
||||||
|
"""
|
||||||
|
shuffle: bool = True
|
||||||
|
min_detection_confidence: float = 0.7
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class HandData:
|
||||||
|
"""A dataclass represents hand data for training gesture recognizer model.
|
||||||
|
|
||||||
|
See https://google.github.io/mediapipe/solutions/hands#mediapipe-hands for
|
||||||
|
more details of the hand gesture data API.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
hand: normalized hand landmarks of shape 21x3 from the screen based
|
||||||
|
hand-landmark model.
|
||||||
|
world_hand: hand landmarks of shape 21x3 in world coordinates.
|
||||||
|
handedness: Collection of handedness confidence of the detected hands (i.e.
|
||||||
|
is it a left or right hand).
|
||||||
|
"""
|
||||||
|
hand: List[List[float]]
|
||||||
|
world_hand: List[List[float]]
|
||||||
|
handedness: List[float]
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_data_sample(data: NamedTuple) -> bool:
|
||||||
|
"""Validates the input hand data sample.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: input hand data sample.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
False if the input data namedtuple does not contain the fields including
|
||||||
|
'multi_hand_landmarks' or 'multi_hand_world_landmarks' or 'multi_handedness'
|
||||||
|
or any of these attributes' values are none. Otherwise, True.
|
||||||
|
"""
|
||||||
|
if (not hasattr(data, 'multi_hand_landmarks') or
|
||||||
|
data.multi_hand_landmarks is None):
|
||||||
|
return False
|
||||||
|
if (not hasattr(data, 'multi_hand_world_landmarks') or
|
||||||
|
data.multi_hand_world_landmarks is None):
|
||||||
|
return False
|
||||||
|
if not hasattr(data, 'multi_handedness') or data.multi_handedness is None:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _get_hand_data(all_image_paths: List[str],
|
||||||
|
min_detection_confidence: float) -> Optional[HandData]:
|
||||||
|
"""Computes hand data (landmarks and handedness) in the input image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_image_paths: all input image paths.
|
||||||
|
min_detection_confidence: hand detection confidence threshold
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A HandData object. Returns None if no hand is detected.
|
||||||
|
"""
|
||||||
|
hand_data_result = []
|
||||||
|
with mp_hands.Hands(
|
||||||
|
static_image_mode=True,
|
||||||
|
max_num_hands=1,
|
||||||
|
min_detection_confidence=min_detection_confidence) as hands:
|
||||||
|
for path in all_image_paths:
|
||||||
|
tf.compat.v1.logging.info('Loading image %s', path)
|
||||||
|
image = data_util.load_image(path)
|
||||||
|
# Flip image around y-axis for correct handedness output
|
||||||
|
image = cv2.flip(image, 1)
|
||||||
|
data = hands.process(image)
|
||||||
|
if not _validate_data_sample(data):
|
||||||
|
hand_data_result.append(None)
|
||||||
|
continue
|
||||||
|
hand_landmarks = [[
|
||||||
|
hand_landmark.x, hand_landmark.y, hand_landmark.z
|
||||||
|
] for hand_landmark in data.multi_hand_landmarks[0].landmark]
|
||||||
|
hand_world_landmarks = [[
|
||||||
|
hand_landmark.x, hand_landmark.y, hand_landmark.z
|
||||||
|
] for hand_landmark in data.multi_hand_world_landmarks[0].landmark]
|
||||||
|
handedness_scores = [
|
||||||
|
handedness.score
|
||||||
|
for handedness in data.multi_handedness[0].classification
|
||||||
|
]
|
||||||
|
hand_data_result.append(
|
||||||
|
HandData(
|
||||||
|
hand=hand_landmarks,
|
||||||
|
world_hand=hand_world_landmarks,
|
||||||
|
handedness=handedness_scores))
|
||||||
|
return hand_data_result
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(classification_dataset.ClassificationDataset):
|
||||||
|
"""Dataset library for hand gesture recognizer."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_folder(
|
||||||
|
cls,
|
||||||
|
dirname: str,
|
||||||
|
hparams: Optional[HandDataPreprocessingParams] = None
|
||||||
|
) -> classification_dataset.ClassificationDataset:
|
||||||
|
"""Loads images and labels from the given directory.
|
||||||
|
|
||||||
|
Directory contents are expected to be in the format:
|
||||||
|
<root_dir>/<gesture_name>/*.jpg". One of the `gesture_name` must be `none`
|
||||||
|
(case insensitive). The `none` sub-directory is expected to contain images
|
||||||
|
of hands that don't belong to other gesture classes in <root_dir>. Assumes
|
||||||
|
the image data of the same label are in the same subdirectory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dirname: Name of the directory containing the data files.
|
||||||
|
hparams: Optional hyperparameters for processing input hand gesture
|
||||||
|
images.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset containing landmarks, labels, and other related info.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the input data directory is empty or the label set does not
|
||||||
|
contain label 'none' (case insensitive).
|
||||||
|
"""
|
||||||
|
data_root = os.path.abspath(dirname)
|
||||||
|
|
||||||
|
# Assumes the image data of the same label are in the same subdirectory,
|
||||||
|
# gets image path and label names.
|
||||||
|
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
|
||||||
|
if not all_image_paths:
|
||||||
|
raise ValueError('Image dataset directory is empty.')
|
||||||
|
|
||||||
|
if not hparams:
|
||||||
|
hparams = HandDataPreprocessingParams()
|
||||||
|
|
||||||
|
if hparams.shuffle:
|
||||||
|
# Random shuffle data.
|
||||||
|
random.shuffle(all_image_paths)
|
||||||
|
|
||||||
|
label_names = sorted(
|
||||||
|
name for name in os.listdir(data_root)
|
||||||
|
if os.path.isdir(os.path.join(data_root, name)))
|
||||||
|
if 'none' not in [v.lower() for v in label_names]:
|
||||||
|
raise ValueError('Label set does not contain label "None".')
|
||||||
|
# Move label 'none' to the front of label list.
|
||||||
|
none_idx = [v.lower() for v in label_names].index('none')
|
||||||
|
none_value = label_names.pop(none_idx)
|
||||||
|
label_names.insert(0, none_value)
|
||||||
|
|
||||||
|
index_by_label = dict(
|
||||||
|
(name, index) for index, name in enumerate(label_names))
|
||||||
|
all_gesture_indices = [
|
||||||
|
index_by_label[os.path.basename(os.path.dirname(path))]
|
||||||
|
for path in all_image_paths
|
||||||
|
]
|
||||||
|
|
||||||
|
# Compute hand data (including local hand landmark, world hand landmark, and
|
||||||
|
# handedness) for all the input images.
|
||||||
|
hand_data = _get_hand_data(
|
||||||
|
all_image_paths=all_image_paths,
|
||||||
|
min_detection_confidence=hparams.min_detection_confidence)
|
||||||
|
|
||||||
|
# Get a list of the valid hand landmark sample in the hand data list.
|
||||||
|
valid_indices = [
|
||||||
|
i for i in range(len(hand_data)) if hand_data[i] is not None
|
||||||
|
]
|
||||||
|
# Remove 'None' element from the hand data and label list.
|
||||||
|
valid_hand_data = [dataclasses.asdict(hand_data[i]) for i in valid_indices]
|
||||||
|
if not valid_hand_data:
|
||||||
|
raise ValueError('No valid hand is detected.')
|
||||||
|
|
||||||
|
valid_label = [all_gesture_indices[i] for i in valid_indices]
|
||||||
|
|
||||||
|
# Convert list of dictionaries to dictionary of lists.
|
||||||
|
hand_data_dict = {
|
||||||
|
k: [lm[k] for lm in valid_hand_data] for k in valid_hand_data[0]
|
||||||
|
}
|
||||||
|
hand_ds = tf.data.Dataset.from_tensor_slices(hand_data_dict)
|
||||||
|
|
||||||
|
embedder_model = model_util.load_keras_model(
|
||||||
|
constants.GESTURE_EMBEDDER_KERAS_MODEL_PATH)
|
||||||
|
|
||||||
|
hand_ds = hand_ds.batch(batch_size=1)
|
||||||
|
hand_embedding_ds = hand_ds.map(
|
||||||
|
map_func=lambda feature: embedder_model(dict(feature)),
|
||||||
|
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||||
|
hand_embedding_ds = hand_embedding_ds.unbatch()
|
||||||
|
|
||||||
|
# Create label dataset
|
||||||
|
label_ds = tf.data.Dataset.from_tensor_slices(
|
||||||
|
tf.cast(valid_label, tf.int64))
|
||||||
|
|
||||||
|
label_one_hot_ds = label_ds.map(
|
||||||
|
map_func=lambda index: tf.one_hot(index, len(label_names)),
|
||||||
|
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||||
|
|
||||||
|
# Create a dataset with (hand_embedding, one_hot_label) pairs
|
||||||
|
hand_embedding_label_ds = tf.data.Dataset.zip(
|
||||||
|
(hand_embedding_ds, label_one_hot_ds))
|
||||||
|
|
||||||
|
tf.compat.v1.logging.info(
|
||||||
|
'Load valid hands with size: {}, num_label: {}, labels: {}.'.format(
|
||||||
|
len(valid_hand_data), len(label_names), ','.join(label_names)))
|
||||||
|
return Dataset(
|
||||||
|
dataset=hand_embedding_label_ds,
|
||||||
|
size=len(valid_hand_data),
|
||||||
|
label_names=label_names)
|
|
@ -0,0 +1,161 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.s
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from typing import NamedTuple
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from absl import flags
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.gesture_recognizer import dataset
|
||||||
|
from mediapipe.python.solutions import hands as mp_hands
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
_TEST_DATA_DIRNAME = 'raw_data'
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def test_split(self):
|
||||||
|
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||||
|
data = dataset.Dataset.from_folder(
|
||||||
|
dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams())
|
||||||
|
train_data, test_data = data.split(0.5)
|
||||||
|
|
||||||
|
self.assertLen(train_data, 17)
|
||||||
|
for _, elem in enumerate(train_data.gen_tf_dataset(is_training=True)):
|
||||||
|
self.assertEqual(elem[0].shape, (1, 128))
|
||||||
|
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||||
|
self.assertEqual(train_data.num_classes, 4)
|
||||||
|
self.assertEqual(train_data.label_names, ['none', 'call', 'four', 'rock'])
|
||||||
|
|
||||||
|
self.assertLen(test_data, 18)
|
||||||
|
for _, elem in enumerate(test_data.gen_tf_dataset(is_training=True)):
|
||||||
|
self.assertEqual(elem[0].shape, (1, 128))
|
||||||
|
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||||
|
self.assertEqual(test_data.num_classes, 4)
|
||||||
|
self.assertEqual(test_data.label_names, ['none', 'call', 'four', 'rock'])
|
||||||
|
|
||||||
|
def test_from_folder(self):
|
||||||
|
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||||
|
data = dataset.Dataset.from_folder(
|
||||||
|
dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams())
|
||||||
|
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
|
||||||
|
self.assertEqual(elem[0].shape, (1, 128))
|
||||||
|
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||||
|
self.assertLen(data, 35)
|
||||||
|
self.assertEqual(data.num_classes, 4)
|
||||||
|
self.assertEqual(data.label_names, ['none', 'call', 'four', 'rock'])
|
||||||
|
|
||||||
|
def test_create_dataset_from_empty_folder_raise_value_error(self):
|
||||||
|
with self.assertRaisesRegex(ValueError, 'Image dataset directory is empty'):
|
||||||
|
dataset.Dataset.from_folder(
|
||||||
|
dirname=self.get_temp_dir(),
|
||||||
|
hparams=dataset.HandDataPreprocessingParams())
|
||||||
|
|
||||||
|
def test_create_dataset_from_folder_without_none_raise_value_error(self):
|
||||||
|
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||||
|
tmp_dir = self.create_tempdir()
|
||||||
|
# Copy input dataset to a temporary directory and skip 'None' directory.
|
||||||
|
for name in os.listdir(input_data_dir):
|
||||||
|
if name == 'none':
|
||||||
|
continue
|
||||||
|
src_dir = os.path.join(input_data_dir, name)
|
||||||
|
dst_dir = os.path.join(tmp_dir, name)
|
||||||
|
shutil.copytree(src_dir, dst_dir)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
'Label set does not contain label "None"'):
|
||||||
|
dataset.Dataset.from_folder(
|
||||||
|
dirname=tmp_dir, hparams=dataset.HandDataPreprocessingParams())
|
||||||
|
|
||||||
|
def test_create_dataset_from_folder_with_capital_letter_in_folder_name(self):
|
||||||
|
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||||
|
tmp_dir = self.create_tempdir()
|
||||||
|
# Copy input dataset to a temporary directory and change the base folder
|
||||||
|
# name to upper case letter, e.g. 'none' -> 'NONE'
|
||||||
|
for name in os.listdir(input_data_dir):
|
||||||
|
src_dir = os.path.join(input_data_dir, name)
|
||||||
|
dst_dir = os.path.join(tmp_dir, name.upper())
|
||||||
|
shutil.copytree(src_dir, dst_dir)
|
||||||
|
|
||||||
|
upper_base_folder_name = list(os.listdir(tmp_dir))
|
||||||
|
self.assertCountEqual(upper_base_folder_name,
|
||||||
|
['CALL', 'FOUR', 'NONE', 'ROCK'])
|
||||||
|
|
||||||
|
data = dataset.Dataset.from_folder(
|
||||||
|
dirname=tmp_dir, hparams=dataset.HandDataPreprocessingParams())
|
||||||
|
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
|
||||||
|
self.assertEqual(elem[0].shape, (1, 128))
|
||||||
|
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||||
|
self.assertLen(data, 35)
|
||||||
|
self.assertEqual(data.num_classes, 4)
|
||||||
|
self.assertEqual(data.label_names, ['NONE', 'CALL', 'FOUR', 'ROCK'])
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(
|
||||||
|
testcase_name='invalid_field_name_multi_hand_landmark',
|
||||||
|
hand=collections.namedtuple('Hand', [
|
||||||
|
'multi_hand_landmark', 'multi_hand_world_landmarks',
|
||||||
|
'multi_handedness'
|
||||||
|
])(1, 2, 3)),
|
||||||
|
dict(
|
||||||
|
testcase_name='invalid_field_name_multi_hand_world_landmarks',
|
||||||
|
hand=collections.namedtuple('Hand', [
|
||||||
|
'multi_hand_landmarks', 'multi_hand_world_landmark',
|
||||||
|
'multi_handedness'
|
||||||
|
])(1, 2, 3)),
|
||||||
|
dict(
|
||||||
|
testcase_name='invalid_field_name_multi_handed',
|
||||||
|
hand=collections.namedtuple('Hand', [
|
||||||
|
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
||||||
|
'multi_handed'
|
||||||
|
])(1, 2, 3)),
|
||||||
|
dict(
|
||||||
|
testcase_name='multi_hand_landmarks_is_none',
|
||||||
|
hand=collections.namedtuple('Hand', [
|
||||||
|
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
||||||
|
'multi_handedness'
|
||||||
|
])(None, 2, 3)),
|
||||||
|
dict(
|
||||||
|
testcase_name='multi_hand_world_landmarks_is_none',
|
||||||
|
hand=collections.namedtuple('Hand', [
|
||||||
|
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
||||||
|
'multi_handedness'
|
||||||
|
])(1, None, 3)),
|
||||||
|
dict(
|
||||||
|
testcase_name='multi_handedness_is_none',
|
||||||
|
hand=collections.namedtuple('Hand', [
|
||||||
|
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
||||||
|
'multi_handedness'
|
||||||
|
])(1, 2, None)),
|
||||||
|
)
|
||||||
|
def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple):
|
||||||
|
with unittest.mock.patch.object(
|
||||||
|
mp_hands.Hands, 'process', return_value=hand):
|
||||||
|
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||||
|
with self.assertRaisesRegex(ValueError, 'No valid hand is detected'):
|
||||||
|
dataset.Dataset.from_folder(
|
||||||
|
dirname=input_data_dir,
|
||||||
|
hparams=dataset.HandDataPreprocessingParams())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -0,0 +1,239 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""APIs to train gesture recognizer model."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
|
||||||
|
from mediapipe.model_maker.python.core.tasks import classifier
|
||||||
|
from mediapipe.model_maker.python.core.utils import loss_functions
|
||||||
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
|
from mediapipe.model_maker.python.vision.gesture_recognizer import constants
|
||||||
|
from mediapipe.model_maker.python.vision.gesture_recognizer import gesture_recognizer_options
|
||||||
|
from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters as hp
|
||||||
|
from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer
|
||||||
|
from mediapipe.model_maker.python.vision.gesture_recognizer import model_options as model_opt
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer as base_metadata_writer
|
||||||
|
|
||||||
|
_EMBEDDING_SIZE = 128
|
||||||
|
|
||||||
|
|
||||||
|
class GestureRecognizer(classifier.Classifier):
|
||||||
|
"""GestureRecognizer for building hand gesture recognizer model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
embedding_size: Size of the input gesture embedding vector.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, label_names: List[str],
|
||||||
|
model_options: model_opt.GestureRecognizerModelOptions,
|
||||||
|
hparams: hp.HParams):
|
||||||
|
"""Initializes GestureRecognizer class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
label_names: A list of label names for the classes.
|
||||||
|
model_options: options to create gesture recognizer model.
|
||||||
|
hparams: The hyperparameters for training hand gesture recognizer model.
|
||||||
|
"""
|
||||||
|
super().__init__(
|
||||||
|
model_spec=None, label_names=label_names, shuffle=hparams.shuffle)
|
||||||
|
self._model_options = model_options
|
||||||
|
self._hparams = hparams
|
||||||
|
self._history = None
|
||||||
|
self.embedding_size = _EMBEDDING_SIZE
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
train_data: classification_ds.ClassificationDataset,
|
||||||
|
validation_data: classification_ds.ClassificationDataset,
|
||||||
|
options: gesture_recognizer_options.GestureRecognizerOptions,
|
||||||
|
) -> 'GestureRecognizer':
|
||||||
|
"""Creates and trains a hand gesture recognizer with input datasets.
|
||||||
|
|
||||||
|
If a checkpoint file exists in the {options.hparams.export_dir}/checkpoint/
|
||||||
|
directory, the training process will load the weight from the checkpoint
|
||||||
|
file for continual training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data. If None, skips validation process.
|
||||||
|
options: options for creating and training gesture recognizer model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of GestureRecognizer.
|
||||||
|
"""
|
||||||
|
if options.model_options is None:
|
||||||
|
options.model_options = model_opt.GestureRecognizerModelOptions()
|
||||||
|
|
||||||
|
if options.hparams is None:
|
||||||
|
options.hparams = hp.HParams()
|
||||||
|
|
||||||
|
gesture_recognizer = cls(
|
||||||
|
label_names=train_data.label_names,
|
||||||
|
model_options=options.model_options,
|
||||||
|
hparams=options.hparams)
|
||||||
|
|
||||||
|
gesture_recognizer._create_model()
|
||||||
|
|
||||||
|
train_dataset = train_data.gen_tf_dataset(
|
||||||
|
batch_size=options.hparams.batch_size,
|
||||||
|
is_training=True,
|
||||||
|
shuffle=options.hparams.shuffle)
|
||||||
|
options.hparams.steps_per_epoch = model_util.get_steps_per_epoch(
|
||||||
|
steps_per_epoch=options.hparams.steps_per_epoch,
|
||||||
|
batch_size=options.hparams.batch_size,
|
||||||
|
train_data=train_data)
|
||||||
|
train_dataset = train_dataset.take(count=options.hparams.steps_per_epoch)
|
||||||
|
|
||||||
|
validation_dataset = validation_data.gen_tf_dataset(
|
||||||
|
batch_size=options.hparams.batch_size, is_training=False)
|
||||||
|
|
||||||
|
tf.compat.v1.logging.info('Training the gesture recognizer model...')
|
||||||
|
gesture_recognizer._train(
|
||||||
|
train_data=train_dataset, validation_data=validation_dataset)
|
||||||
|
|
||||||
|
return gesture_recognizer
|
||||||
|
|
||||||
|
def _train(self, train_data: tf.data.Dataset,
|
||||||
|
validation_data: tf.data.Dataset):
|
||||||
|
"""Trains the model with input train_data.
|
||||||
|
|
||||||
|
The training results are recorded by a self.History object returned by
|
||||||
|
tf.keras.Model.fit().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: Training data.
|
||||||
|
validation_data: Validation data.
|
||||||
|
"""
|
||||||
|
hparams = self._hparams
|
||||||
|
|
||||||
|
scheduler = lambda epoch: hparams.learning_rate * (hparams.lr_decay**epoch)
|
||||||
|
scheduler_callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
|
||||||
|
|
||||||
|
job_dir = hparams.export_dir
|
||||||
|
checkpoint_path = os.path.join(job_dir, 'epoch_models')
|
||||||
|
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||||
|
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
|
||||||
|
save_weights_only=True)
|
||||||
|
|
||||||
|
best_model_path = os.path.join(job_dir, 'best_model_weights')
|
||||||
|
best_model_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||||
|
best_model_path,
|
||||||
|
monitor='val_loss',
|
||||||
|
mode='min',
|
||||||
|
save_best_only=True,
|
||||||
|
save_weights_only=True)
|
||||||
|
|
||||||
|
tensorboard_callback = tf.keras.callbacks.TensorBoard(
|
||||||
|
log_dir=os.path.join(job_dir, 'logs'))
|
||||||
|
|
||||||
|
self._model.compile(
|
||||||
|
optimizer='adam',
|
||||||
|
loss=loss_functions.FocalLoss(gamma=self._hparams.gamma),
|
||||||
|
metrics=['categorical_accuracy'])
|
||||||
|
|
||||||
|
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
|
||||||
|
if latest_checkpoint:
|
||||||
|
print(f'Resuming from {latest_checkpoint}')
|
||||||
|
self._model.load_weights(latest_checkpoint)
|
||||||
|
|
||||||
|
self._history = self._model.fit(
|
||||||
|
x=train_data,
|
||||||
|
epochs=hparams.epochs,
|
||||||
|
validation_data=validation_data,
|
||||||
|
validation_freq=1,
|
||||||
|
callbacks=[
|
||||||
|
checkpoint_callback, best_model_callback, scheduler_callback,
|
||||||
|
tensorboard_callback
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_model(self):
|
||||||
|
"""Creates the hand gesture recognizer model.
|
||||||
|
|
||||||
|
The gesture embedding model is pretrained and loaded from a tf.saved_model.
|
||||||
|
"""
|
||||||
|
inputs = tf.keras.Input(
|
||||||
|
shape=[self.embedding_size],
|
||||||
|
batch_size=None,
|
||||||
|
dtype=tf.float32,
|
||||||
|
name='hand_embedding')
|
||||||
|
|
||||||
|
x = tf.keras.layers.BatchNormalization()(inputs)
|
||||||
|
x = tf.keras.layers.ReLU()(x)
|
||||||
|
dropout_rate = self._model_options.dropout_rate
|
||||||
|
x = tf.keras.layers.Dropout(rate=dropout_rate, name='dropout')(x)
|
||||||
|
outputs = tf.keras.layers.Dense(
|
||||||
|
self._num_classes,
|
||||||
|
activation='softmax',
|
||||||
|
name='custom_gesture_recognizer')(
|
||||||
|
x)
|
||||||
|
|
||||||
|
self._model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
|
||||||
|
print(self._model.summary())
|
||||||
|
|
||||||
|
def export_model(self, model_name: str = 'gesture_recognizer.task'):
|
||||||
|
"""Converts the model to TFLite and exports as a model bundle file.
|
||||||
|
|
||||||
|
Saves a model bundle file and metadata json file to hparams.export_dir. The
|
||||||
|
resulting model bundle file will contain necessary models for hand
|
||||||
|
detection, canned gesture classification, and customized gesture
|
||||||
|
classification. Only the model bundle file is needed for the downstream
|
||||||
|
gesture recognition task. The metadata.json file is saved only to
|
||||||
|
interpret the contents of the model bundle file.
|
||||||
|
|
||||||
|
The customized gesture model is in float without quantization. The model is
|
||||||
|
lightweight and there is no need to balance performance and efficiency by
|
||||||
|
quantization. The default score_thresholding is set to 0.5 as it can be
|
||||||
|
adjusted during inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: File name to save model bundle file. The full export path is
|
||||||
|
{export_dir}/{model_name}.
|
||||||
|
"""
|
||||||
|
# TODO: Convert keras embedder model instead of using tflite
|
||||||
|
gesture_embedding_model_buffer = model_util.load_tflite_model_buffer(
|
||||||
|
constants.GESTURE_EMBEDDER_TFLITE_MODEL_FILE)
|
||||||
|
hand_detector_model_buffer = model_util.load_tflite_model_buffer(
|
||||||
|
constants.HAND_DETECTOR_TFLITE_MODEL_FILE)
|
||||||
|
hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
|
||||||
|
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE)
|
||||||
|
canned_gesture_model_buffer = model_util.load_tflite_model_buffer(
|
||||||
|
constants.CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE)
|
||||||
|
|
||||||
|
if not tf.io.gfile.exists(self._hparams.export_dir):
|
||||||
|
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||||
|
model_bundle_file = os.path.join(self._hparams.export_dir, model_name)
|
||||||
|
metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json')
|
||||||
|
|
||||||
|
gesture_classifier_options = metadata_writer.GestureClassifierOptions(
|
||||||
|
model_buffer=model_util.convert_to_tflite(self._model),
|
||||||
|
labels=base_metadata_writer.Labels().add(list(self._label_names)),
|
||||||
|
score_thresholding=base_metadata_writer.ScoreThresholding(
|
||||||
|
global_score_threshold=0.5))
|
||||||
|
|
||||||
|
writer = metadata_writer.MetadataWriter.create(
|
||||||
|
hand_detector_model_buffer, hand_landmarks_detector_model_buffer,
|
||||||
|
gesture_embedding_model_buffer, canned_gesture_model_buffer,
|
||||||
|
gesture_classifier_options)
|
||||||
|
model_bundle_content, metadata_json = writer.populate()
|
||||||
|
with open(model_bundle_file, 'wb') as f:
|
||||||
|
f.write(model_bundle_content)
|
||||||
|
with open(metadata_file, 'w') as f:
|
||||||
|
f.write(metadata_json)
|
|
@ -0,0 +1,78 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Demo for making an gesture recognizer model by Mediapipe Model Maker."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Dependency imports
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
from absl import logging
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision import gesture_recognizer
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
# TODO: Move hand gesture recognizer demo dataset to an
|
||||||
|
# open-sourced directory.
|
||||||
|
TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data'
|
||||||
|
|
||||||
|
|
||||||
|
def define_flags():
|
||||||
|
flags.DEFINE_string('export_dir', None,
|
||||||
|
'The directory to save exported files.')
|
||||||
|
flags.DEFINE_string('input_data_dir', None,
|
||||||
|
'The directory with input training data.')
|
||||||
|
flags.mark_flag_as_required('export_dir')
|
||||||
|
|
||||||
|
|
||||||
|
def run(data_dir: str, export_dir: str):
|
||||||
|
"""Runs demo."""
|
||||||
|
data = gesture_recognizer.Dataset.from_folder(dirname=data_dir)
|
||||||
|
train_data, rest_data = data.split(0.8)
|
||||||
|
validation_data, test_data = rest_data.split(0.5)
|
||||||
|
|
||||||
|
model = gesture_recognizer.GestureRecognizer.create(
|
||||||
|
train_data=train_data,
|
||||||
|
validation_data=validation_data,
|
||||||
|
options=gesture_recognizer.GestureRecognizerOptions(
|
||||||
|
hparams=gesture_recognizer.HParams(export_dir=export_dir)))
|
||||||
|
|
||||||
|
metric = model.evaluate(test_data, batch_size=2)
|
||||||
|
print('Evaluation metric')
|
||||||
|
print(metric)
|
||||||
|
|
||||||
|
model.export_model()
|
||||||
|
|
||||||
|
|
||||||
|
def main(_):
|
||||||
|
logging.set_verbosity(logging.INFO)
|
||||||
|
|
||||||
|
if FLAGS.input_data_dir is None:
|
||||||
|
data_dir = os.path.join(FLAGS.test_srcdir, TEST_DATA_DIR)
|
||||||
|
else:
|
||||||
|
data_dir = FLAGS.input_data_dir
|
||||||
|
|
||||||
|
export_dir = os.path.expanduser(FLAGS.export_dir)
|
||||||
|
run(data_dir=data_dir, export_dir=export_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
define_flags()
|
||||||
|
app.run(main)
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Options for building gesture recognizer."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters
|
||||||
|
from mediapipe.model_maker.python.vision.gesture_recognizer import model_options as model_opt
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class GestureRecognizerOptions:
|
||||||
|
"""Configurable options for building gesture recognizer.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model_options: A set of options for configuring the selected model.
|
||||||
|
hparams: A set of hyperparameters used to train the gesture recognizer.
|
||||||
|
"""
|
||||||
|
model_options: Optional[model_opt.GestureRecognizerModelOptions] = None
|
||||||
|
hparams: Optional[hyperparameters.HParams] = None
|
|
@ -0,0 +1,132 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
from unittest import mock as unittest_mock
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
import mock
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.utils import test_util
|
||||||
|
from mediapipe.model_maker.python.vision import gesture_recognizer
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
_TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data'
|
||||||
|
|
||||||
|
|
||||||
|
class GestureRecognizerTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def _load_data(self):
|
||||||
|
input_data_dir = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, 'raw_data'))
|
||||||
|
|
||||||
|
data = gesture_recognizer.Dataset.from_folder(
|
||||||
|
dirname=input_data_dir,
|
||||||
|
hparams=gesture_recognizer.HandDataPreprocessingParams(shuffle=True))
|
||||||
|
return data
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self._model_options = gesture_recognizer.ModelOptions()
|
||||||
|
self._hparams = gesture_recognizer.HParams(epochs=2)
|
||||||
|
self._gesture_recognizer_options = (
|
||||||
|
gesture_recognizer.GestureRecognizerOptions(
|
||||||
|
model_options=self._model_options, hparams=self._hparams))
|
||||||
|
all_data = self._load_data()
|
||||||
|
# Splits data, 90% data for training, 10% for testing
|
||||||
|
self._train_data, self._test_data = all_data.split(0.9)
|
||||||
|
|
||||||
|
def test_gesture_recognizer_model(self):
|
||||||
|
model = gesture_recognizer.GestureRecognizer.create(
|
||||||
|
train_data=self._train_data,
|
||||||
|
validation_data=self._test_data,
|
||||||
|
options=self._gesture_recognizer_options)
|
||||||
|
|
||||||
|
self._test_accuracy(model)
|
||||||
|
|
||||||
|
def test_export_gesture_recognizer_model(self):
|
||||||
|
model = gesture_recognizer.GestureRecognizer.create(
|
||||||
|
train_data=self._train_data,
|
||||||
|
validation_data=self._test_data,
|
||||||
|
options=self._gesture_recognizer_options)
|
||||||
|
model.export_model()
|
||||||
|
model_bundle_file = os.path.join(self._hparams.export_dir,
|
||||||
|
'gesture_recognizer.task')
|
||||||
|
with zipfile.ZipFile(model_bundle_file) as zf:
|
||||||
|
self.assertEqual(
|
||||||
|
set(zf.namelist()),
|
||||||
|
set(['hand_landmarker.task', 'hand_gesture_recognizer.task']))
|
||||||
|
zf.extractall(self.get_temp_dir())
|
||||||
|
hand_gesture_recognizer_bundle_file = os.path.join(
|
||||||
|
self.get_temp_dir(), 'hand_gesture_recognizer.task')
|
||||||
|
with zipfile.ZipFile(hand_gesture_recognizer_bundle_file) as zf:
|
||||||
|
self.assertEqual(
|
||||||
|
set(zf.namelist()),
|
||||||
|
set([
|
||||||
|
'canned_gesture_classifier.tflite',
|
||||||
|
'custom_gesture_classifier.tflite', 'gesture_embedder.tflite'
|
||||||
|
]))
|
||||||
|
zf.extractall(self.get_temp_dir())
|
||||||
|
gesture_classifier_tflite_file = os.path.join(
|
||||||
|
self.get_temp_dir(), 'custom_gesture_classifier.tflite')
|
||||||
|
test_util.test_tflite_file(
|
||||||
|
keras_model=model._model,
|
||||||
|
tflite_file=gesture_classifier_tflite_file,
|
||||||
|
size=[1, model.embedding_size])
|
||||||
|
|
||||||
|
def _test_accuracy(self, model, threshold=0.5):
|
||||||
|
_, accuracy = model.evaluate(self._test_data)
|
||||||
|
tf.compat.v1.logging.info(f'accuracy: {accuracy}')
|
||||||
|
self.assertGreaterEqual(accuracy, threshold)
|
||||||
|
|
||||||
|
@unittest_mock.patch.object(
|
||||||
|
gesture_recognizer.hyperparameters,
|
||||||
|
'HParams',
|
||||||
|
autospec=True,
|
||||||
|
return_value=gesture_recognizer.HParams(epochs=1))
|
||||||
|
@unittest_mock.patch.object(
|
||||||
|
gesture_recognizer.model_options,
|
||||||
|
'GestureRecognizerModelOptions',
|
||||||
|
autospec=True,
|
||||||
|
return_value=gesture_recognizer.ModelOptions())
|
||||||
|
def test_create_hparams_and_model_options_if_none_in_image_classifier_options(
|
||||||
|
self, mock_hparams, mock_model_options):
|
||||||
|
options = gesture_recognizer.GestureRecognizerOptions()
|
||||||
|
gesture_recognizer.GestureRecognizer.create(
|
||||||
|
train_data=self._train_data,
|
||||||
|
validation_data=self._test_data,
|
||||||
|
options=options)
|
||||||
|
mock_hparams.assert_called_once()
|
||||||
|
mock_model_options.assert_called_once()
|
||||||
|
|
||||||
|
def test_continual_training_by_loading_checkpoint(self):
|
||||||
|
mock_stdout = io.StringIO()
|
||||||
|
with mock.patch('sys.stdout', mock_stdout):
|
||||||
|
model = gesture_recognizer.GestureRecognizer.create(
|
||||||
|
train_data=self._train_data,
|
||||||
|
validation_data=self._test_data,
|
||||||
|
options=self._gesture_recognizer_options)
|
||||||
|
model = gesture_recognizer.GestureRecognizer.create(
|
||||||
|
train_data=self._train_data,
|
||||||
|
validation_data=self._test_data,
|
||||||
|
options=self._gesture_recognizer_options)
|
||||||
|
self._test_accuracy(model)
|
||||||
|
|
||||||
|
self.assertRegex(mock_stdout.getvalue(), 'Resuming from')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Hyperparameters for training customized gesture recognizer models."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class HParams(hp.BaseHParams):
|
||||||
|
"""The hyperparameters for training gesture recognizer.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
learning_rate: Learning rate to use for gradient descent training.
|
||||||
|
batch_size: Batch size for training.
|
||||||
|
epochs: Number of training iterations over the dataset.
|
||||||
|
lr_decay: Learning rate decay to use for gradient descent training.
|
||||||
|
gamma: Gamma parameter for focal loss.
|
||||||
|
"""
|
||||||
|
# Parameters from BaseHParams class.
|
||||||
|
learning_rate: float = 0.001
|
||||||
|
batch_size: int = 2
|
||||||
|
epochs: int = 10
|
||||||
|
|
||||||
|
# Parameters about training configuration
|
||||||
|
# TODO: Move lr_decay to hp.baseHParams.
|
||||||
|
lr_decay: float = 0.99
|
||||||
|
gamma: int = 2
|
|
@ -0,0 +1,193 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Writes metadata and creates model asset bundle for gesture recognizer."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
|
||||||
|
|
||||||
|
_HAND_DETECTOR_TFLITE_NAME = "hand_detector.tflite"
|
||||||
|
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME = "hand_landmarks_detector.tflite"
|
||||||
|
_HAND_LANDMARKER_BUNDLE_NAME = "hand_landmarker.task"
|
||||||
|
_HAND_GESTURE_RECOGNIZER_BUNDLE_NAME = "hand_gesture_recognizer.task"
|
||||||
|
_GESTURE_EMBEDDER_TFLITE_NAME = "gesture_embedder.tflite"
|
||||||
|
_CANNED_GESTURE_CLASSIFIER_TFLITE_NAME = "canned_gesture_classifier.tflite"
|
||||||
|
_CUSTOM_GESTURE_CLASSIFIER_TFLITE_NAME = "custom_gesture_classifier.tflite"
|
||||||
|
|
||||||
|
_MODEL_NAME = "HandGestureRecognition"
|
||||||
|
_MODEL_DESCRIPTION = "Recognize the hand gesture in the image."
|
||||||
|
|
||||||
|
_INPUT_NAME = "embedding"
|
||||||
|
_INPUT_DESCRIPTION = "Embedding feature vector from gesture embedder."
|
||||||
|
_OUTPUT_NAME = "scores"
|
||||||
|
_OUTPUT_DESCRIPTION = "Hand gesture category scores."
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class GestureClassifierOptions:
|
||||||
|
"""Options to write metadata for gesture classifier.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model_buffer: Gesture classifier TFLite model buffer.
|
||||||
|
labels: Labels for the gesture classifier.
|
||||||
|
score_thresholding: Parameters to performs thresholding on output tensor
|
||||||
|
values [1].
|
||||||
|
[1]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||||
|
"""
|
||||||
|
model_buffer: bytearray
|
||||||
|
labels: metadata_writer.Labels
|
||||||
|
score_thresholding: metadata_writer.ScoreThresholding
|
||||||
|
|
||||||
|
|
||||||
|
def read_file(file_path: str, mode: str = "rb") -> Union[str, bytes]:
|
||||||
|
with tf.io.gfile.GFile(file_path, mode) as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataWriter:
|
||||||
|
"""MetadataWriter to write the metadata and the model asset bundle."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, hand_detector_model_buffer: bytearray,
|
||||||
|
hand_landmarks_detector_model_buffer: bytearray,
|
||||||
|
gesture_embedder_model_buffer: bytearray,
|
||||||
|
canned_gesture_classifier_model_buffer: bytearray,
|
||||||
|
custom_gesture_classifier_metadata_writer: metadata_writer.MetadataWriter
|
||||||
|
) -> None:
|
||||||
|
"""Initialize MetadataWriter to write the metadata and model asset bundle.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from
|
||||||
|
the TFLite hand detector model file.
|
||||||
|
hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata
|
||||||
|
loaded from the TFLite hand landmarks detector model file.
|
||||||
|
gesture_embedder_model_buffer: A valid flatbuffer *with* metadata loaded
|
||||||
|
from the TFLite gesture embedder model file.
|
||||||
|
canned_gesture_classifier_model_buffer: A valid flatbuffer *with* metadata
|
||||||
|
loaded from the TFLite canned gesture classifier model file.
|
||||||
|
custom_gesture_classifier_metadata_writer: Metadata writer to write custom
|
||||||
|
gesture classifier metadata into the TFLite file.
|
||||||
|
"""
|
||||||
|
self._hand_detector_model_buffer = hand_detector_model_buffer
|
||||||
|
self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer
|
||||||
|
self._gesture_embedder_model_buffer = gesture_embedder_model_buffer
|
||||||
|
self._canned_gesture_classifier_model_buffer = canned_gesture_classifier_model_buffer
|
||||||
|
self._custom_gesture_classifier_metadata_writer = custom_gesture_classifier_metadata_writer
|
||||||
|
self._temp_folder = tempfile.TemporaryDirectory()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if os.path.exists(self._temp_folder.name):
|
||||||
|
self._temp_folder.cleanup()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
hand_detector_model_buffer: bytearray,
|
||||||
|
hand_landmarks_detector_model_buffer: bytearray,
|
||||||
|
gesture_embedder_model_buffer: bytearray,
|
||||||
|
canned_gesture_classifier_model_buffer: bytearray,
|
||||||
|
custom_gesture_classifier_options: GestureClassifierOptions,
|
||||||
|
) -> "MetadataWriter":
|
||||||
|
"""Creates MetadataWriter to write the metadata for gesture recognizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from
|
||||||
|
the TFLite hand detector model file.
|
||||||
|
hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata
|
||||||
|
loaded from the TFLite hand landmarks detector model file.
|
||||||
|
gesture_embedder_model_buffer: A valid flatbuffer *with* metadata loaded
|
||||||
|
from the TFLite gesture embedder model file.
|
||||||
|
canned_gesture_classifier_model_buffer: A valid flatbuffer *with* metadata
|
||||||
|
loaded from the TFLite canned gesture classifier model file.
|
||||||
|
custom_gesture_classifier_options: Custom gesture classifier options to
|
||||||
|
write custom gesture classifier metadata into the TFLite file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An MetadataWrite object.
|
||||||
|
"""
|
||||||
|
writer = metadata_writer.MetadataWriter.create(
|
||||||
|
custom_gesture_classifier_options.model_buffer)
|
||||||
|
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
|
||||||
|
writer.add_feature_input(name=_INPUT_NAME, description=_INPUT_DESCRIPTION)
|
||||||
|
writer.add_classification_output(
|
||||||
|
labels=custom_gesture_classifier_options.labels,
|
||||||
|
score_thresholding=custom_gesture_classifier_options.score_thresholding,
|
||||||
|
name=_OUTPUT_NAME,
|
||||||
|
description=_OUTPUT_DESCRIPTION)
|
||||||
|
return cls(hand_detector_model_buffer, hand_landmarks_detector_model_buffer,
|
||||||
|
gesture_embedder_model_buffer,
|
||||||
|
canned_gesture_classifier_model_buffer, writer)
|
||||||
|
|
||||||
|
def populate(self):
|
||||||
|
"""Populates the metadata and creates model asset bundle.
|
||||||
|
|
||||||
|
Note that only the output model asset bundle is used for deployment.
|
||||||
|
The output JSON content is used to interpret the custom gesture classifier
|
||||||
|
metadata content.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (model_asset_bundle_in_bytes, metadata_json_content)
|
||||||
|
"""
|
||||||
|
# Creates the model asset bundle for hand landmarker task.
|
||||||
|
landmark_models = {
|
||||||
|
_HAND_DETECTOR_TFLITE_NAME:
|
||||||
|
self._hand_detector_model_buffer,
|
||||||
|
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME:
|
||||||
|
self._hand_landmarks_detector_model_buffer
|
||||||
|
}
|
||||||
|
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
|
||||||
|
_HAND_LANDMARKER_BUNDLE_NAME)
|
||||||
|
writer_utils.create_model_asset_bundle(landmark_models,
|
||||||
|
output_hand_landmarker_path)
|
||||||
|
|
||||||
|
# Write metadata into custom gesture classifier model.
|
||||||
|
self._custom_gesture_classifier_model_buffer, custom_gesture_classifier_metadata_json = self._custom_gesture_classifier_metadata_writer.populate(
|
||||||
|
)
|
||||||
|
# Creates the model asset bundle for hand gesture recognizer sub graph.
|
||||||
|
hand_gesture_recognizer_models = {
|
||||||
|
_GESTURE_EMBEDDER_TFLITE_NAME:
|
||||||
|
self._gesture_embedder_model_buffer,
|
||||||
|
_CANNED_GESTURE_CLASSIFIER_TFLITE_NAME:
|
||||||
|
self._canned_gesture_classifier_model_buffer,
|
||||||
|
_CUSTOM_GESTURE_CLASSIFIER_TFLITE_NAME:
|
||||||
|
self._custom_gesture_classifier_model_buffer
|
||||||
|
}
|
||||||
|
output_hand_gesture_recognizer_path = os.path.join(
|
||||||
|
self._temp_folder.name, _HAND_GESTURE_RECOGNIZER_BUNDLE_NAME)
|
||||||
|
writer_utils.create_model_asset_bundle(hand_gesture_recognizer_models,
|
||||||
|
output_hand_gesture_recognizer_path)
|
||||||
|
|
||||||
|
# Creates the model asset bundle for end-to-end hand gesture recognizer
|
||||||
|
# graph.
|
||||||
|
gesture_recognizer_models = {
|
||||||
|
_HAND_LANDMARKER_BUNDLE_NAME:
|
||||||
|
read_file(output_hand_landmarker_path),
|
||||||
|
_HAND_GESTURE_RECOGNIZER_BUNDLE_NAME:
|
||||||
|
read_file(output_hand_gesture_recognizer_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
output_file_path = os.path.join(self._temp_folder.name,
|
||||||
|
"gesture_recognizer.task")
|
||||||
|
writer_utils.create_model_asset_bundle(gesture_recognizer_models,
|
||||||
|
output_file_path)
|
||||||
|
with open(output_file_path, "rb") as f:
|
||||||
|
gesture_recognizer_model_buffer = f.read()
|
||||||
|
return gesture_recognizer_model_buffer, custom_gesture_classifier_metadata_json
|
|
@ -0,0 +1,90 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for metadata_writer."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer as base_metadata_writer
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
_TEST_DATA_DIR = "mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata"
|
||||||
|
|
||||||
|
_EXPECTED_JSON = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, "custom_gesture_classifier_meta.json"))
|
||||||
|
_CUSTOM_GESTURE_CLASSIFIER_PATH = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, "custom_gesture_classifier.tflite"))
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataWriterTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_write_metadata_and_create_model_asset_bundle_successful(self):
|
||||||
|
# Use dummy model buffer for unit test only.
|
||||||
|
hand_detector_model_buffer = b"\x11\x12"
|
||||||
|
hand_landmarks_detector_model_buffer = b"\x22"
|
||||||
|
gesture_embedder_model_buffer = b"\x33"
|
||||||
|
canned_gesture_classifier_model_buffer = b"\x44"
|
||||||
|
custom_gesture_classifier_metadata_writer = metadata_writer.GestureClassifierOptions(
|
||||||
|
model_buffer=metadata_writer.read_file(_CUSTOM_GESTURE_CLASSIFIER_PATH),
|
||||||
|
labels=base_metadata_writer.Labels().add(
|
||||||
|
["None", "Paper", "Rock", "Scissors"]),
|
||||||
|
score_thresholding=base_metadata_writer.ScoreThresholding(
|
||||||
|
global_score_threshold=0.5))
|
||||||
|
writer = metadata_writer.MetadataWriter.create(
|
||||||
|
hand_detector_model_buffer, hand_landmarks_detector_model_buffer,
|
||||||
|
gesture_embedder_model_buffer, canned_gesture_classifier_model_buffer,
|
||||||
|
custom_gesture_classifier_metadata_writer)
|
||||||
|
model_bundle_content, metadata_json = writer.populate()
|
||||||
|
with open(_EXPECTED_JSON, "r") as f:
|
||||||
|
expected_json = f.read()
|
||||||
|
self.assertEqual(metadata_json, expected_json)
|
||||||
|
|
||||||
|
# Checks the top-level model bundle can be extracted successfully.
|
||||||
|
model_bundle_filepath = os.path.join(self.get_temp_dir(),
|
||||||
|
"gesture_recognition.task")
|
||||||
|
|
||||||
|
with open(model_bundle_filepath, "wb") as f:
|
||||||
|
f.write(model_bundle_content)
|
||||||
|
|
||||||
|
with zipfile.ZipFile(model_bundle_filepath) as zf:
|
||||||
|
self.assertEqual(
|
||||||
|
set(zf.namelist()),
|
||||||
|
set(["hand_landmarker.task", "hand_gesture_recognizer.task"]))
|
||||||
|
zf.extractall(self.get_temp_dir())
|
||||||
|
|
||||||
|
# Checks the model bundles for sub-task can be extracted successfully.
|
||||||
|
hand_landmarker_bundle_filepath = os.path.join(self.get_temp_dir(),
|
||||||
|
"hand_landmarker.task")
|
||||||
|
with zipfile.ZipFile(hand_landmarker_bundle_filepath) as zf:
|
||||||
|
self.assertEqual(
|
||||||
|
set(zf.namelist()),
|
||||||
|
set(["hand_landmarks_detector.tflite", "hand_detector.tflite"]))
|
||||||
|
|
||||||
|
hand_gesture_recognizer_bundle_filepath = os.path.join(
|
||||||
|
self.get_temp_dir(), "hand_gesture_recognizer.task")
|
||||||
|
with zipfile.ZipFile(hand_gesture_recognizer_bundle_filepath) as zf:
|
||||||
|
self.assertEqual(
|
||||||
|
set(zf.namelist()),
|
||||||
|
set([
|
||||||
|
"canned_gesture_classifier.tflite",
|
||||||
|
"custom_gesture_classifier.tflite", "gesture_embedder.tflite"
|
||||||
|
]))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.test.main()
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Configurable model options for gesture recognizer models."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class GestureRecognizerModelOptions:
|
||||||
|
"""Configurable options for gesture recognizer model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
dropout_rate: The fraction of the input units to drop, used in dropout
|
||||||
|
layer.
|
||||||
|
"""
|
||||||
|
dropout_rate: float = 0.05
|
|
@ -0,0 +1,56 @@
|
||||||
|
{
|
||||||
|
"name": "HandGestureRecognition",
|
||||||
|
"description": "Recognize the hand gesture in the image.",
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"input_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "embedding",
|
||||||
|
"description": "Embedding feature vector from gesture embedder.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stats": {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"output_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "scores",
|
||||||
|
"description": "Hand gesture category scores.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"process_units": [
|
||||||
|
{
|
||||||
|
"options_type": "ScoreThresholdingOptions",
|
||||||
|
"options": {
|
||||||
|
"global_score_threshold": 0.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stats": {
|
||||||
|
"max": [
|
||||||
|
1.0
|
||||||
|
],
|
||||||
|
"min": [
|
||||||
|
0.0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"associated_files": [
|
||||||
|
{
|
||||||
|
"name": "labels.txt",
|
||||||
|
"description": "Labels for categories that the model can recognize.",
|
||||||
|
"type": "TENSOR_AXIS_LABELS"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"min_parser_version": "1.0.0"
|
||||||
|
}
|
After Width: | Height: | Size: 28 KiB |
After Width: | Height: | Size: 27 KiB |
After Width: | Height: | Size: 40 KiB |
After Width: | Height: | Size: 15 KiB |
After Width: | Height: | Size: 21 KiB |
After Width: | Height: | Size: 18 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 19 KiB |
After Width: | Height: | Size: 22 KiB |
After Width: | Height: | Size: 32 KiB |
After Width: | Height: | Size: 22 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 38 KiB |
After Width: | Height: | Size: 27 KiB |
After Width: | Height: | Size: 22 KiB |
After Width: | Height: | Size: 34 KiB |
After Width: | Height: | Size: 43 KiB |
After Width: | Height: | Size: 35 KiB |
After Width: | Height: | Size: 19 KiB |
After Width: | Height: | Size: 30 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 34 KiB |
After Width: | Height: | Size: 21 KiB |
After Width: | Height: | Size: 20 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 24 KiB |
After Width: | Height: | Size: 40 KiB |
After Width: | Height: | Size: 24 KiB |
After Width: | Height: | Size: 30 KiB |
After Width: | Height: | Size: 21 KiB |
After Width: | Height: | Size: 18 KiB |
After Width: | Height: | Size: 19 KiB |
After Width: | Height: | Size: 14 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 27 KiB |
After Width: | Height: | Size: 43 KiB |
After Width: | Height: | Size: 15 KiB |
After Width: | Height: | Size: 31 KiB |
After Width: | Height: | Size: 24 KiB |
After Width: | Height: | Size: 17 KiB |
|
@ -121,6 +121,34 @@ py_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "solution_base",
|
||||||
|
srcs = ["solution_base.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
visibility = [
|
||||||
|
"//mediapipe/python:__subpackages__",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":_framework_bindings",
|
||||||
|
":packet_creator",
|
||||||
|
":packet_getter",
|
||||||
|
"//mediapipe/calculators/core:constant_side_packet_calculator_py_pb2",
|
||||||
|
"//mediapipe/calculators/image:image_transformation_calculator_py_pb2",
|
||||||
|
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_py_pb2",
|
||||||
|
"//mediapipe/calculators/util:landmarks_smoothing_calculator_py_pb2",
|
||||||
|
"//mediapipe/calculators/util:logic_calculator_py_pb2",
|
||||||
|
"//mediapipe/calculators/util:thresholding_calculator_py_pb2",
|
||||||
|
"//mediapipe/framework:calculator_py_pb2",
|
||||||
|
"//mediapipe/framework/formats:classification_py_pb2",
|
||||||
|
"//mediapipe/framework/formats:detection_py_pb2",
|
||||||
|
"//mediapipe/framework/formats:landmark_py_pb2",
|
||||||
|
"//mediapipe/framework/formats:rect_py_pb2",
|
||||||
|
"//mediapipe/modules/objectron/calculators:annotation_py_pb2",
|
||||||
|
"//mediapipe/modules/objectron/calculators:lift_2d_frame_annotation_to_3d_calculator_py_pb2",
|
||||||
|
"@com_google_protobuf//:protobuf_python",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "calculator_graph_test",
|
name = "calculator_graph_test",
|
||||||
srcs = ["calculator_graph_test.py"],
|
srcs = ["calculator_graph_test.py"],
|
||||||
|
@ -176,3 +204,24 @@ py_test(
|
||||||
":_framework_bindings",
|
":_framework_bindings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "solution_base_test",
|
||||||
|
srcs = ["solution_base_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":solution_base",
|
||||||
|
"//file/google_src",
|
||||||
|
"//file/localfile",
|
||||||
|
"//mediapipe/calculators/core:gate_calculator",
|
||||||
|
"//mediapipe/calculators/core:pass_through_calculator",
|
||||||
|
"//mediapipe/calculators/core:side_packet_to_stream_calculator",
|
||||||
|
"//mediapipe/calculators/image:image_transformation_calculator",
|
||||||
|
"//mediapipe/calculators/util:detection_unique_id_calculator",
|
||||||
|
"//mediapipe/calculators/util:to_image_calculator",
|
||||||
|
"//mediapipe/framework:calculator_py_pb2",
|
||||||
|
"//mediapipe/framework/formats:detection_py_pb2",
|
||||||
|
"@com_google_protobuf//:protobuf_python",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -40,7 +40,6 @@ from mediapipe.calculators.util import landmarks_smoothing_calculator_pb2
|
||||||
from mediapipe.calculators.util import logic_calculator_pb2
|
from mediapipe.calculators.util import logic_calculator_pb2
|
||||||
from mediapipe.calculators.util import thresholding_calculator_pb2
|
from mediapipe.calculators.util import thresholding_calculator_pb2
|
||||||
from mediapipe.framework import calculator_pb2
|
from mediapipe.framework import calculator_pb2
|
||||||
from mediapipe.framework.formats import body_rig_pb2
|
|
||||||
from mediapipe.framework.formats import classification_pb2
|
from mediapipe.framework.formats import classification_pb2
|
||||||
from mediapipe.framework.formats import detection_pb2
|
from mediapipe.framework.formats import detection_pb2
|
||||||
from mediapipe.framework.formats import landmark_pb2
|
from mediapipe.framework.formats import landmark_pb2
|
||||||
|
|
105
mediapipe/python/solutions/BUILD
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
# Copyright 2020 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "hands",
|
||||||
|
srcs = [
|
||||||
|
"hands.py",
|
||||||
|
"hands_connections.py",
|
||||||
|
],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
|
||||||
|
"//mediapipe/modules/hand_landmark:hand_landmark_lite.tflite",
|
||||||
|
"//mediapipe/modules/hand_landmark:hand_landmark_tracking_cpu_graph",
|
||||||
|
"//mediapipe/modules/hand_landmark:handedness.txt",
|
||||||
|
"//mediapipe/modules/palm_detection:palm_detection_full.tflite",
|
||||||
|
"//mediapipe/modules/palm_detection:palm_detection_lite.tflite",
|
||||||
|
],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/calculators/core:constant_side_packet_calculator_py_pb2",
|
||||||
|
"//mediapipe/calculators/core:gate_calculator_py_pb2",
|
||||||
|
"//mediapipe/calculators/core:split_vector_calculator_py_pb2",
|
||||||
|
"//mediapipe/calculators/tensor:image_to_tensor_calculator_py_pb2",
|
||||||
|
"//mediapipe/calculators/tensor:inference_calculator_py_pb2",
|
||||||
|
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_py_pb2",
|
||||||
|
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_py_pb2",
|
||||||
|
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator_py_pb2",
|
||||||
|
"//mediapipe/calculators/tflite:ssd_anchors_calculator_py_pb2",
|
||||||
|
"//mediapipe/calculators/util:association_calculator_py_pb2",
|
||||||
|
"//mediapipe/calculators/util:detections_to_rects_calculator_py_pb2",
|
||||||
|
"//mediapipe/calculators/util:logic_calculator_py_pb2",
|
||||||
|
"//mediapipe/calculators/util:non_max_suppression_calculator_py_pb2",
|
||||||
|
"//mediapipe/calculators/util:rect_transformation_calculator_py_pb2",
|
||||||
|
"//mediapipe/calculators/util:thresholding_calculator_py_pb2",
|
||||||
|
"//mediapipe/python:solution_base",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "drawing_styles",
|
||||||
|
srcs = ["drawing_styles.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
"drawing_utils",
|
||||||
|
"face_mesh",
|
||||||
|
"hands",
|
||||||
|
"pose",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "drawing_utils",
|
||||||
|
srcs = ["drawing_utils.py"],
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework/formats:detection_py_pb2",
|
||||||
|
"//mediapipe/framework/formats:landmark_py_pb2",
|
||||||
|
"//mediapipe/framework/formats:location_data_py_pb2",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "drawing_utils_test",
|
||||||
|
srcs = ["drawing_utils_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":drawing_utils",
|
||||||
|
"//mediapipe/framework/formats:detection_py_pb2",
|
||||||
|
"//mediapipe/framework/formats:landmark_py_pb2",
|
||||||
|
"@com_google_protobuf//:protobuf_python",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "hands_test",
|
||||||
|
srcs = ["hands_test.py"],
|
||||||
|
data = [
|
||||||
|
":testdata/asl_hand.25fps.mp4",
|
||||||
|
":testdata/asl_hand.full.npz",
|
||||||
|
":testdata/hands.jpg",
|
||||||
|
],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":drawing_styles",
|
||||||
|
":drawing_utils",
|
||||||
|
":hands",
|
||||||
|
],
|
||||||
|
)
|
|
@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
|
option java_package = "com.google.mediapipe.tasks.audio.audioclassifier.proto";
|
||||||
|
option java_outer_classname = "AudioClassifierGraphOptionsProto";
|
||||||
|
|
||||||
message AudioClassifierGraphOptions {
|
message AudioClassifierGraphOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
optional AudioClassifierGraphOptions ext = 451755788;
|
optional AudioClassifierGraphOptions ext = 451755788;
|
||||||
|
|
|
@ -21,15 +21,6 @@ cc_library(
|
||||||
hdrs = ["rect.h"],
|
hdrs = ["rect.h"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "gesture_recognition_result",
|
|
||||||
hdrs = ["gesture_recognition_result.h"],
|
|
||||||
deps = [
|
|
||||||
"//mediapipe/framework/formats:classification_cc_proto",
|
|
||||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "category",
|
name = "category",
|
||||||
srcs = ["category.cc"],
|
srcs = ["category.cc"],
|
||||||
|
|
|
@ -124,12 +124,22 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gesture_recognizer_result",
|
||||||
|
hdrs = ["gesture_recognizer_result.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gesture_recognizer",
|
name = "gesture_recognizer",
|
||||||
srcs = ["gesture_recognizer.cc"],
|
srcs = ["gesture_recognizer.cc"],
|
||||||
hdrs = ["gesture_recognizer.h"],
|
hdrs = ["gesture_recognizer.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":gesture_recognizer_graph",
|
":gesture_recognizer_graph",
|
||||||
|
":gesture_recognizer_result",
|
||||||
":hand_gesture_recognizer_graph",
|
":hand_gesture_recognizer_graph",
|
||||||
"//mediapipe/framework:packet",
|
"//mediapipe/framework:packet",
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
|
@ -140,7 +150,6 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||||
"//mediapipe/tasks/cc/components/containers:gesture_recognition_result",
|
|
||||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core:base_options",
|
"//mediapipe/tasks/cc/core:base_options",
|
||||||
|
|
|
@ -58,8 +58,6 @@ namespace {
|
||||||
using GestureRecognizerGraphOptionsProto = ::mediapipe::tasks::vision::
|
using GestureRecognizerGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||||
gesture_recognizer::proto::GestureRecognizerGraphOptions;
|
gesture_recognizer::proto::GestureRecognizerGraphOptions;
|
||||||
|
|
||||||
using ::mediapipe::tasks::components::containers::GestureRecognitionResult;
|
|
||||||
|
|
||||||
constexpr char kHandGestureSubgraphTypeName[] =
|
constexpr char kHandGestureSubgraphTypeName[] =
|
||||||
"mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph";
|
"mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph";
|
||||||
|
|
||||||
|
@ -214,7 +212,7 @@ absl::StatusOr<std::unique_ptr<GestureRecognizer>> GestureRecognizer::Create(
|
||||||
std::move(packets_callback));
|
std::move(packets_callback));
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<GestureRecognitionResult> GestureRecognizer::Recognize(
|
absl::StatusOr<GestureRecognizerResult> GestureRecognizer::Recognize(
|
||||||
mediapipe::Image image,
|
mediapipe::Image image,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
|
@ -250,7 +248,7 @@ absl::StatusOr<GestureRecognitionResult> GestureRecognizer::Recognize(
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<GestureRecognitionResult> GestureRecognizer::RecognizeForVideo(
|
absl::StatusOr<GestureRecognizerResult> GestureRecognizer::RecognizeForVideo(
|
||||||
mediapipe::Image image, int64 timestamp_ms,
|
mediapipe::Image image, int64 timestamp_ms,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
|
|
|
@ -24,12 +24,12 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/formats/classification.pb.h"
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/gesture_recognition_result.h"
|
|
||||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_result.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -81,9 +81,8 @@ struct GestureRecognizerOptions {
|
||||||
// The user-defined result callback for processing live stream data.
|
// The user-defined result callback for processing live stream data.
|
||||||
// The result callback should only be specified when the running mode is set
|
// The result callback should only be specified when the running mode is set
|
||||||
// to RunningMode::LIVE_STREAM.
|
// to RunningMode::LIVE_STREAM.
|
||||||
std::function<void(
|
std::function<void(absl::StatusOr<GestureRecognizerResult>, const Image&,
|
||||||
absl::StatusOr<components::containers::GestureRecognitionResult>,
|
int64)>
|
||||||
const Image&, int64)>
|
|
||||||
result_callback = nullptr;
|
result_callback = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -104,7 +103,7 @@ struct GestureRecognizerOptions {
|
||||||
// 'width' and 'height' fields is NOT supported and will result in an
|
// 'width' and 'height' fields is NOT supported and will result in an
|
||||||
// invalid argument error being returned.
|
// invalid argument error being returned.
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// GestureRecognitionResult
|
// GestureRecognizerResult
|
||||||
// - The hand gesture recognition results.
|
// - The hand gesture recognition results.
|
||||||
class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
|
class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
|
||||||
public:
|
public:
|
||||||
|
@ -139,7 +138,7 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// The image can be of any size with format RGB or RGBA.
|
// The image can be of any size with format RGB or RGBA.
|
||||||
// TODO: Describes how the input image will be preprocessed
|
// TODO: Describes how the input image will be preprocessed
|
||||||
// after the yuv support is implemented.
|
// after the yuv support is implemented.
|
||||||
absl::StatusOr<components::containers::GestureRecognitionResult> Recognize(
|
absl::StatusOr<GestureRecognizerResult> Recognize(
|
||||||
Image image,
|
Image image,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
std::nullopt);
|
std::nullopt);
|
||||||
|
@ -157,10 +156,10 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// The image can be of any size with format RGB or RGBA. It's required to
|
// The image can be of any size with format RGB or RGBA. It's required to
|
||||||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||||
// must be monotonically increasing.
|
// must be monotonically increasing.
|
||||||
absl::StatusOr<components::containers::GestureRecognitionResult>
|
absl::StatusOr<GestureRecognizerResult> RecognizeForVideo(
|
||||||
RecognizeForVideo(Image image, int64 timestamp_ms,
|
Image image, int64 timestamp_ms,
|
||||||
std::optional<core::ImageProcessingOptions>
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
image_processing_options = std::nullopt);
|
std::nullopt);
|
||||||
|
|
||||||
// Sends live image data to perform gesture recognition, and the results will
|
// Sends live image data to perform gesture recognition, and the results will
|
||||||
// be available via the "result_callback" provided in the
|
// be available via the "result_callback" provided in the
|
||||||
|
@ -179,7 +178,7 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// and will result in an invalid argument error being returned.
|
// and will result in an invalid argument error being returned.
|
||||||
//
|
//
|
||||||
// The "result_callback" provides
|
// The "result_callback" provides
|
||||||
// - A vector of GestureRecognitionResult, each is the recognized results
|
// - A vector of GestureRecognizerResult, each is the recognized results
|
||||||
// for a input frame.
|
// for a input frame.
|
||||||
// - The const reference to the corresponding input image that the gesture
|
// - The const reference to the corresponding input image that the gesture
|
||||||
// recognizer runs on. Note that the const reference to the image will no
|
// recognizer runs on. Note that the const reference to the image will no
|
||||||
|
|
|
@ -13,20 +13,20 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_
|
#ifndef MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_GESTURE_RECOGNIZER_RESULT_H_
|
||||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_
|
#define MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_GESTURE_RECOGNIZER_RESULT_H_
|
||||||
|
|
||||||
#include "mediapipe/framework/formats/classification.pb.h"
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace vision {
|
||||||
namespace containers {
|
namespace gesture_recognizer {
|
||||||
|
|
||||||
// The gesture recognition result from GestureRecognizer, where each vector
|
// The gesture recognition result from GestureRecognizer, where each vector
|
||||||
// element represents a single hand detected in the image.
|
// element represents a single hand detected in the image.
|
||||||
struct GestureRecognitionResult {
|
struct GestureRecognizerResult {
|
||||||
// Recognized hand gestures with sorted order such that the winning label is
|
// Recognized hand gestures with sorted order such that the winning label is
|
||||||
// the first item in the list.
|
// the first item in the list.
|
||||||
std::vector<mediapipe::ClassificationList> gestures;
|
std::vector<mediapipe::ClassificationList> gestures;
|
||||||
|
@ -38,9 +38,9 @@ struct GestureRecognitionResult {
|
||||||
std::vector<mediapipe::LandmarkList> hand_world_landmarks;
|
std::vector<mediapipe::LandmarkList> hand_world_landmarks;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace containers
|
} // namespace gesture_recognizer
|
||||||
} // namespace components
|
} // namespace vision
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_
|
#endif // MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_GESTURE_RECOGNIZER_RESULT_H_
|
|
@ -276,33 +276,44 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
.set_min_size(max_num_hands);
|
.set_min_size(max_num_hands);
|
||||||
auto has_enough_hands = min_size_node.Out("").Cast<bool>();
|
auto has_enough_hands = min_size_node.Out("").Cast<bool>();
|
||||||
|
|
||||||
auto image_for_hand_detector =
|
|
||||||
DisallowIf(image_in, has_enough_hands, graph);
|
|
||||||
auto norm_rect_in_for_hand_detector =
|
|
||||||
DisallowIf(norm_rect_in, has_enough_hands, graph);
|
|
||||||
|
|
||||||
auto& hand_detector =
|
auto& hand_detector =
|
||||||
graph.AddNode("mediapipe.tasks.vision.hand_detector.HandDetectorGraph");
|
graph.AddNode("mediapipe.tasks.vision.hand_detector.HandDetectorGraph");
|
||||||
hand_detector.GetOptions<HandDetectorGraphOptions>().CopyFrom(
|
hand_detector.GetOptions<HandDetectorGraphOptions>().CopyFrom(
|
||||||
tasks_options.hand_detector_graph_options());
|
tasks_options.hand_detector_graph_options());
|
||||||
|
auto& clip_hand_rects =
|
||||||
|
graph.AddNode("ClipNormalizedRectVectorSizeCalculator");
|
||||||
|
clip_hand_rects.GetOptions<ClipVectorSizeCalculatorOptions>()
|
||||||
|
.set_max_vec_size(max_num_hands);
|
||||||
|
|
||||||
|
if (tasks_options.base_options().use_stream_mode()) {
|
||||||
|
// While in stream mode, skip hand detector graph when we successfully
|
||||||
|
// track the hands from the last frame.
|
||||||
|
auto image_for_hand_detector =
|
||||||
|
DisallowIf(image_in, has_enough_hands, graph);
|
||||||
|
auto norm_rect_in_for_hand_detector =
|
||||||
|
DisallowIf(norm_rect_in, has_enough_hands, graph);
|
||||||
image_for_hand_detector >> hand_detector.In("IMAGE");
|
image_for_hand_detector >> hand_detector.In("IMAGE");
|
||||||
norm_rect_in_for_hand_detector >> hand_detector.In("NORM_RECT");
|
norm_rect_in_for_hand_detector >> hand_detector.In("NORM_RECT");
|
||||||
auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS");
|
auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS");
|
||||||
|
|
||||||
auto& hand_association = graph.AddNode("HandAssociationCalculator");
|
auto& hand_association = graph.AddNode("HandAssociationCalculator");
|
||||||
hand_association.GetOptions<HandAssociationCalculatorOptions>()
|
hand_association.GetOptions<HandAssociationCalculatorOptions>()
|
||||||
.set_min_similarity_threshold(tasks_options.min_tracking_confidence());
|
.set_min_similarity_threshold(
|
||||||
|
tasks_options.min_tracking_confidence());
|
||||||
prev_hand_rects_from_landmarks >>
|
prev_hand_rects_from_landmarks >>
|
||||||
hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][0];
|
hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][0];
|
||||||
hand_rects_from_hand_detector >>
|
hand_rects_from_hand_detector >>
|
||||||
hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][1];
|
hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][1];
|
||||||
auto hand_rects = hand_association.Out("");
|
auto hand_rects = hand_association.Out("");
|
||||||
|
|
||||||
auto& clip_hand_rects =
|
|
||||||
graph.AddNode("ClipNormalizedRectVectorSizeCalculator");
|
|
||||||
clip_hand_rects.GetOptions<ClipVectorSizeCalculatorOptions>()
|
|
||||||
.set_max_vec_size(max_num_hands);
|
|
||||||
hand_rects >> clip_hand_rects.In("");
|
hand_rects >> clip_hand_rects.In("");
|
||||||
|
} else {
|
||||||
|
// While not in stream mode, the input images are not guaranteed to be in
|
||||||
|
// series, and we don't want to enable the tracking and hand associations
|
||||||
|
// between input images. Always use the hand detector graph.
|
||||||
|
image_in >> hand_detector.In("IMAGE");
|
||||||
|
norm_rect_in >> hand_detector.In("NORM_RECT");
|
||||||
|
auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS");
|
||||||
|
hand_rects_from_hand_detector >> clip_hand_rects.In("");
|
||||||
|
}
|
||||||
auto clipped_hand_rects = clip_hand_rects.Out("");
|
auto clipped_hand_rects = clip_hand_rects.Out("");
|
||||||
|
|
||||||
auto& hand_landmarks_detector_graph = graph.AddNode(
|
auto& hand_landmarks_detector_graph = graph.AddNode(
|
||||||
|
|
84
mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "core",
|
||||||
|
srcs = glob(["core/*.java"]),
|
||||||
|
javacopts = [
|
||||||
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":libmediapipe_tasks_audio_jni_lib",
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# The native library of all MediaPipe audio tasks.
|
||||||
|
cc_binary(
|
||||||
|
name = "libmediapipe_tasks_audio_jni.so",
|
||||||
|
linkshared = 1,
|
||||||
|
linkstatic = 1,
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
||||||
|
"//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "libmediapipe_tasks_audio_jni_lib",
|
||||||
|
srcs = [":libmediapipe_tasks_audio_jni.so"],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "audioclassifier",
|
||||||
|
srcs = [
|
||||||
|
"audioclassifier/AudioClassifier.java",
|
||||||
|
"audioclassifier/AudioClassifierResult.java",
|
||||||
|
],
|
||||||
|
javacopts = [
|
||||||
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
],
|
||||||
|
manifest = "audioclassifier/AndroidManifest.xml",
|
||||||
|
deps = [
|
||||||
|
":core",
|
||||||
|
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
|
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||||
|
"//third_party:autovalue",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_audio_aar")
|
||||||
|
|
||||||
|
mediapipe_tasks_audio_aar(
|
||||||
|
name = "tasks_audio",
|
||||||
|
srcs = glob(["**/*.java"]),
|
||||||
|
native_library = ":libmediapipe_tasks_audio_jni_lib",
|
||||||
|
)
|
|
@ -0,0 +1,8 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
package="com.google.mediapipe.tasks.audio.audioclassifier">
|
||||||
|
|
||||||
|
<uses-sdk android:minSdkVersion="24"
|
||||||
|
android:targetSdkVersion="30" />
|
||||||
|
|
||||||
|
</manifest>
|
|
@ -0,0 +1,399 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.audio.audioclassifier;
|
||||||
|
|
||||||
|
import android.content.Context;
|
||||||
|
import android.os.ParcelFileDescriptor;
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||||
|
import com.google.mediapipe.framework.MediaPipeException;
|
||||||
|
import com.google.mediapipe.framework.Packet;
|
||||||
|
import com.google.mediapipe.framework.PacketGetter;
|
||||||
|
import com.google.mediapipe.framework.ProtoUtil;
|
||||||
|
import com.google.mediapipe.tasks.audio.audioclassifier.proto.AudioClassifierGraphOptionsProto;
|
||||||
|
import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi;
|
||||||
|
import com.google.mediapipe.tasks.audio.core.RunningMode;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.AudioData;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||||
|
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||||
|
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||||
|
import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener;
|
||||||
|
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||||
|
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs audio classification on audio clips or audio stream.
|
||||||
|
*
|
||||||
|
* <p>This API expects a TFLite model with mandatory TFLite Model Metadata that contains the
|
||||||
|
* mandatory AudioProperties of the solo input audio tensor and the optional (but recommended) label
|
||||||
|
* items as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor.
|
||||||
|
*
|
||||||
|
* <p>Input tensor: (kTfLiteFloat32)
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>input audio buffer of size `[batch * samples]`.
|
||||||
|
* <li>batch inference is not supported (`batch` is required to be 1).
|
||||||
|
* <li>for multi-channel models, the channels need be interleaved.
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* <p>At least one output tensor with: (kTfLiteFloat32)
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>`[1 x N]` array with `N` represents the number of categories.
|
||||||
|
* <li>optional (but recommended) label items as AssociatedFiles with type TENSOR_AXIS_LABELS,
|
||||||
|
* containing one label per line. The first such AssociatedFile (if any) is used to fill the
|
||||||
|
* `category_name` field of the results. The `display_name` field is filled from the
|
||||||
|
* AssociatedFile (if any) whose locale matches the `display_names_locale` field of the
|
||||||
|
* `AudioClassifierOptions` used at creation time ("en" by default, i.e. English). If none of
|
||||||
|
* these are available, only the `index` field of the results will be filled.
|
||||||
|
* </ul>
|
||||||
|
*/
|
||||||
|
public final class AudioClassifier extends BaseAudioTaskApi {
|
||||||
|
private static final String TAG = AudioClassifier.class.getSimpleName();
|
||||||
|
private static final String AUDIO_IN_STREAM_NAME = "audio_in";
|
||||||
|
private static final String SAMPLE_RATE_IN_STREAM_NAME = "sample_rate_in";
|
||||||
|
private static final List<String> INPUT_STREAMS =
|
||||||
|
Collections.unmodifiableList(
|
||||||
|
Arrays.asList(
|
||||||
|
"AUDIO:" + AUDIO_IN_STREAM_NAME, "SAMPLE_RATE:" + SAMPLE_RATE_IN_STREAM_NAME));
|
||||||
|
private static final List<String> OUTPUT_STREAMS =
|
||||||
|
Collections.unmodifiableList(
|
||||||
|
Arrays.asList(
|
||||||
|
"CLASSIFICATIONS:classifications_out",
|
||||||
|
"TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications_out"));
|
||||||
|
private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0;
|
||||||
|
private static final int TIMESTAMPED_CLASSIFICATIONS_OUT_STREAM_INDEX = 1;
|
||||||
|
private static final String TASK_GRAPH_NAME =
|
||||||
|
"mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph";
|
||||||
|
private static final long MICROSECONDS_PER_MILLISECOND = 1000;
|
||||||
|
|
||||||
|
static {
|
||||||
|
ProtoUtil.registerTypeName(
|
||||||
|
ClassificationsProto.ClassificationResult.class,
|
||||||
|
"mediapipe.tasks.components.containers.proto.ClassificationResult");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link AudioClassifier} instance from a model file and default {@link
|
||||||
|
* AudioClassifierOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param modelPath path to the classification model in the assets.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
|
||||||
|
*/
|
||||||
|
public static AudioClassifier createFromFile(Context context, String modelPath) {
|
||||||
|
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
|
||||||
|
return createFromOptions(
|
||||||
|
context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link AudioClassifier} instance from a model file and default {@link
|
||||||
|
* AudioClassifierOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param modelFile the classification model {@link File} instance.
|
||||||
|
* @throws IOException if an I/O error occurs when opening the tflite model file.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
|
||||||
|
*/
|
||||||
|
public static AudioClassifier createFromFile(Context context, File modelFile) throws IOException {
|
||||||
|
try (ParcelFileDescriptor descriptor =
|
||||||
|
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
|
||||||
|
BaseOptions baseOptions =
|
||||||
|
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
|
||||||
|
return createFromOptions(
|
||||||
|
context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link AudioClassifier} instance from a model buffer and default {@link
|
||||||
|
* AudioClassifierOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
|
||||||
|
* classification model.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
|
||||||
|
*/
|
||||||
|
public static AudioClassifier createFromBuffer(Context context, final ByteBuffer modelBuffer) {
|
||||||
|
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
|
||||||
|
return createFromOptions(
|
||||||
|
context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link AudioClassifier} instance from an {@link AudioClassifierOptions} instance.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param options an {@link AudioClassifierOptions} instance.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
|
||||||
|
*/
|
||||||
|
public static AudioClassifier createFromOptions(Context context, AudioClassifierOptions options) {
|
||||||
|
OutputHandler<AudioClassifierResult, Void> handler = new OutputHandler<>();
|
||||||
|
handler.setOutputPacketConverter(
|
||||||
|
new OutputHandler.OutputPacketConverter<AudioClassifierResult, Void>() {
|
||||||
|
@Override
|
||||||
|
public AudioClassifierResult convertToTaskResult(List<Packet> packets) {
|
||||||
|
try {
|
||||||
|
if (!packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).isEmpty()) {
|
||||||
|
// For audio stream mode.
|
||||||
|
return AudioClassifierResult.createFromProto(
|
||||||
|
PacketGetter.getProto(
|
||||||
|
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX),
|
||||||
|
ClassificationsProto.ClassificationResult.getDefaultInstance()),
|
||||||
|
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp()
|
||||||
|
/ MICROSECONDS_PER_MILLISECOND);
|
||||||
|
} else {
|
||||||
|
// For audio clips mode.
|
||||||
|
return AudioClassifierResult.createFromProtoList(
|
||||||
|
PacketGetter.getProtoVector(
|
||||||
|
packets.get(TIMESTAMPED_CLASSIFICATIONS_OUT_STREAM_INDEX),
|
||||||
|
ClassificationsProto.ClassificationResult.parser()),
|
||||||
|
-1);
|
||||||
|
}
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Void convertToTaskInput(List<Packet> packets) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
if (options.resultListener().isPresent()) {
|
||||||
|
ResultListener<AudioClassifierResult, Void> resultListener =
|
||||||
|
new ResultListener<AudioClassifierResult, Void>() {
|
||||||
|
@Override
|
||||||
|
public void run(AudioClassifierResult audioClassifierResult, Void input) {
|
||||||
|
options.resultListener().get().run(audioClassifierResult);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
handler.setResultListener(resultListener);
|
||||||
|
}
|
||||||
|
options.errorListener().ifPresent(handler::setErrorListener);
|
||||||
|
// Audio tasks should not drop input audio due to flow limiting, which may cause data
|
||||||
|
// inconsistency.
|
||||||
|
TaskRunner runner =
|
||||||
|
TaskRunner.create(
|
||||||
|
context,
|
||||||
|
TaskInfo.<AudioClassifierOptions>builder()
|
||||||
|
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||||
|
.setInputStreams(INPUT_STREAMS)
|
||||||
|
.setOutputStreams(OUTPUT_STREAMS)
|
||||||
|
.setTaskOptions(options)
|
||||||
|
.setEnableFlowLimiting(false)
|
||||||
|
.build(),
|
||||||
|
handler);
|
||||||
|
return new AudioClassifier(runner, options.runningMode());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor to initialize an {@link AudioClassifier} from a {@link TaskRunner} and {@link
|
||||||
|
* RunningMode}.
|
||||||
|
*
|
||||||
|
* @param taskRunner a {@link TaskRunner}.
|
||||||
|
* @param runningMode a mediapipe audio task {@link RunningMode}.
|
||||||
|
*/
|
||||||
|
private AudioClassifier(TaskRunner taskRunner, RunningMode runningMode) {
|
||||||
|
super(taskRunner, runningMode, AUDIO_IN_STREAM_NAME, SAMPLE_RATE_IN_STREAM_NAME);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Performs audio classification on the provided audio clip. Only use this method when the
|
||||||
|
* AudioClassifier is created with the audio clips mode.
|
||||||
|
*
|
||||||
|
* <p>The audio clip is represented as a MediaPipe {@link AudioData} object The method accepts
|
||||||
|
* audio clips with various length and audio sample rate. It's required to provide the
|
||||||
|
* corresponding audio sample rate within the {@link AudioData} object.
|
||||||
|
*
|
||||||
|
* <p>The input audio clip may be longer than what the model is able to process in a single
|
||||||
|
* inference. When this occurs, the input audio clip is split into multiple chunks starting at
|
||||||
|
* different timestamps. For this reason, this function returns a vector of ClassificationResult
|
||||||
|
* objects, each associated with a timestamp corresponding to the start (in milliseconds) of the
|
||||||
|
* chunk data that was classified, e.g:
|
||||||
|
*
|
||||||
|
* ClassificationResult #0 (first chunk of data):
|
||||||
|
* timestamp_ms: 0 (starts at 0ms)
|
||||||
|
* classifications #0 (single head model):
|
||||||
|
* category #0:
|
||||||
|
* category_name: "Speech"
|
||||||
|
* score: 0.6
|
||||||
|
* category #1:
|
||||||
|
* category_name: "Music"
|
||||||
|
* score: 0.2
|
||||||
|
* ClassificationResult #1 (second chunk of data):
|
||||||
|
* timestamp_ms: 800 (starts at 800ms)
|
||||||
|
* classifications #0 (single head model):
|
||||||
|
* category #0:
|
||||||
|
* category_name: "Speech"
|
||||||
|
* score: 0.5
|
||||||
|
* category #1:
|
||||||
|
* category_name: "Silence"
|
||||||
|
* score: 0.1
|
||||||
|
*
|
||||||
|
* @param audioClip a MediaPipe {@link AudioData} object for processing.
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public AudioClassifierResult classify(AudioData audioClip) {
|
||||||
|
return (AudioClassifierResult) processAudioClip(audioClip);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Sends audio data (a block in a continuous audio stream) to perform audio classification. Only
|
||||||
|
* use this method when the AudioClassifier is created with the audio stream mode.
|
||||||
|
*
|
||||||
|
* <p>The audio block is represented as a MediaPipe {@link AudioData} object. The audio data will
|
||||||
|
* be resampled, accumulated, and framed to the proper size for the underlying model to consume.
|
||||||
|
* It's required to provide the corresponding audio sample rate within {@link AudioData} object as
|
||||||
|
* well as a timestamp (in milliseconds) to indicate the start time of the input audio block. The
|
||||||
|
* timestamps must be monotonically increasing. This method will return immediately after
|
||||||
|
* the input audio data is accepted. The results will be available in the `resultListener`
|
||||||
|
* provided in the `AudioClassifierOptions`. The `classifyAsync` method is designed to process
|
||||||
|
* auido stream data such as microphone input.
|
||||||
|
*
|
||||||
|
* <p>The input audio block may be longer than what the model is able to process in a single
|
||||||
|
* inference. When this occurs, the input audio block is split into multiple chunks. For this
|
||||||
|
* reason, the callback may be called multiple times (once per chunk) for each call to this
|
||||||
|
* function.
|
||||||
|
*
|
||||||
|
* @param audioBlock a MediaPipe {@link AudioData} object for processing.
|
||||||
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public void classifyAsync(AudioData audioBlock, long timestampMs) {
|
||||||
|
checkOrSetSampleRate(audioBlock.getFormat().getSampleRate());
|
||||||
|
sendAudioStreamData(audioBlock, timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Options for setting up and {@link AudioClassifier}. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract static class AudioClassifierOptions extends TaskOptions {
|
||||||
|
|
||||||
|
/** Builder for {@link AudioClassifierOptions}. */
|
||||||
|
@AutoValue.Builder
|
||||||
|
public abstract static class Builder {
|
||||||
|
/** Sets the {@link BaseOptions} for the audio classifier task. */
|
||||||
|
public abstract Builder setBaseOptions(BaseOptions baseOptions);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the {@link RunningMode} for the audio classifier task. Default to the audio clips
|
||||||
|
* mode. Image classifier has two modes:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>AUDIO_CLIPS: The mode for running audio classification on audio clips. Users feed
|
||||||
|
* audio clips to the `classify` method, and will receive the classification results as
|
||||||
|
* the return value.
|
||||||
|
* <li>AUDIO_STREAM: The mode for running audio classification on the audio stream, such as
|
||||||
|
* from microphone. Users call `classifyAsync` to push the audio data into the
|
||||||
|
* AudioClassifier, the classification results will be available in the result callback
|
||||||
|
* when the audio classifier finishes the work.
|
||||||
|
* </ul>
|
||||||
|
*/
|
||||||
|
public abstract Builder setRunningMode(RunningMode runningMode);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as
|
||||||
|
* score threshold, number of results, etc.
|
||||||
|
*/
|
||||||
|
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the {@link ResultListener} to receive the classification results asynchronously when
|
||||||
|
* the audio classifier is in the audio stream mode.
|
||||||
|
*/
|
||||||
|
public abstract Builder setResultListener(
|
||||||
|
PureResultListener<AudioClassifierResult> resultListener);
|
||||||
|
|
||||||
|
/** Sets an optional {@link ErrorListener}. */
|
||||||
|
public abstract Builder setErrorListener(ErrorListener errorListener);
|
||||||
|
|
||||||
|
abstract AudioClassifierOptions autoBuild();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates and builds the {@link AudioClassifierOptions} instance.
|
||||||
|
*
|
||||||
|
* @throws IllegalArgumentException if the result listener and the running mode are not
|
||||||
|
* properly configured. The result listener should only be set when the audio classifier
|
||||||
|
* is in the audio stream mode.
|
||||||
|
*/
|
||||||
|
public final AudioClassifierOptions build() {
|
||||||
|
AudioClassifierOptions options = autoBuild();
|
||||||
|
if (options.runningMode() == RunningMode.AUDIO_STREAM) {
|
||||||
|
if (!options.resultListener().isPresent()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"The audio classifier is in the audio stream mode, a user-defined result listener"
|
||||||
|
+ " must be provided in the AudioClassifierOptions.");
|
||||||
|
}
|
||||||
|
} else if (options.resultListener().isPresent()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"The audio classifier is in the audio clips mode, a user-defined result listener"
|
||||||
|
+ " shouldn't be provided in AudioClassifierOptions.");
|
||||||
|
}
|
||||||
|
return options;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract BaseOptions baseOptions();
|
||||||
|
|
||||||
|
abstract RunningMode runningMode();
|
||||||
|
|
||||||
|
abstract Optional<ClassifierOptions> classifierOptions();
|
||||||
|
|
||||||
|
abstract Optional<PureResultListener<AudioClassifierResult>> resultListener();
|
||||||
|
|
||||||
|
abstract Optional<ErrorListener> errorListener();
|
||||||
|
|
||||||
|
public static Builder builder() {
|
||||||
|
return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder()
|
||||||
|
.setRunningMode(RunningMode.AUDIO_CLIPS);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a {@link AudioClassifierOptions} to a {@link CalculatorOptions} protobuf message.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public CalculatorOptions convertToCalculatorOptionsProto() {
|
||||||
|
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
|
||||||
|
BaseOptionsProto.BaseOptions.newBuilder();
|
||||||
|
baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM);
|
||||||
|
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||||
|
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder =
|
||||||
|
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder()
|
||||||
|
.setBaseOptions(baseOptionsBuilder);
|
||||||
|
if (classifierOptions().isPresent()) {
|
||||||
|
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
|
||||||
|
}
|
||||||
|
return CalculatorOptions.newBuilder()
|
||||||
|
.setExtension(
|
||||||
|
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext,
|
||||||
|
taskOptionsBuilder.build())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,76 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.audio.audioclassifier;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskResult;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/** Represents the classification results generated by {@link AudioClassifier}. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract class AudioClassifierResult implements TaskResult {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link AudioClassifierResult} instance from a list of {@link
|
||||||
|
* ClassificationsProto.ClassificationResult} protobuf messages.
|
||||||
|
*
|
||||||
|
* @param protoList a list of {@link ClassificationsProto.ClassificationResult} protobuf message
|
||||||
|
* to convert.
|
||||||
|
* @param timestampMs a timestamp for this result.
|
||||||
|
*/
|
||||||
|
static AudioClassifierResult createFromProtoList(
|
||||||
|
List<ClassificationsProto.ClassificationResult> protoList, long timestampMs) {
|
||||||
|
List<ClassificationResult> classificationResultList = new ArrayList<>();
|
||||||
|
for (ClassificationsProto.ClassificationResult proto : protoList) {
|
||||||
|
classificationResultList.add(ClassificationResult.createFromProto(proto));
|
||||||
|
}
|
||||||
|
return new AutoValue_AudioClassifierResult(
|
||||||
|
Optional.of(classificationResultList), Optional.empty(), timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link AudioClassifierResult} instance from a {@link
|
||||||
|
* ClassificationsProto.ClassificationResult} protobuf message.
|
||||||
|
*
|
||||||
|
* @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
|
||||||
|
* @param timestampMs a timestamp for this result.
|
||||||
|
*/
|
||||||
|
static AudioClassifierResult createFromProto(
|
||||||
|
ClassificationsProto.ClassificationResult proto, long timestampMs) {
|
||||||
|
return new AutoValue_AudioClassifierResult(
|
||||||
|
Optional.empty(), Optional.of(ClassificationResult.createFromProto(proto)), timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A list of of timpstamed {@link ClassificationResult} objects, each contains one set of results
|
||||||
|
* per classifier head. The list represents the audio classification result of an audio clip, and
|
||||||
|
* is only available when running with the audio clips mode.
|
||||||
|
*/
|
||||||
|
public abstract Optional<List<ClassificationResult>> classificationResultList();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Contains one set of results per classifier head. A {@link ClassificationResult} usually
|
||||||
|
* represents one audio classification result in an audio stream, and s only available when
|
||||||
|
* running with the audio stream mode.
|
||||||
|
*/
|
||||||
|
public abstract Optional<ClassificationResult> classificationResult();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public abstract long timestampMs();
|
||||||
|
}
|
|
@ -0,0 +1,151 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.audio.core;
|
||||||
|
|
||||||
|
import com.google.mediapipe.framework.MediaPipeException;
|
||||||
|
import com.google.mediapipe.framework.Packet;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.AudioData;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskResult;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/** The base class of MediaPipe audio tasks. */
|
||||||
|
public class BaseAudioTaskApi implements AutoCloseable {
|
||||||
|
private static final long MICROSECONDS_PER_MILLISECOND = 1000;
|
||||||
|
private static final long PRESTREAM_TIMESTAMP = Long.MIN_VALUE + 2;
|
||||||
|
|
||||||
|
private final TaskRunner runner;
|
||||||
|
private final RunningMode runningMode;
|
||||||
|
private final String audioStreamName;
|
||||||
|
private final String sampleRateStreamName;
|
||||||
|
private double defaultSampleRate;
|
||||||
|
|
||||||
|
static {
|
||||||
|
System.loadLibrary("mediapipe_tasks_audio_jni");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor to initialize a {@link BaseAudioTaskApi}.
|
||||||
|
*
|
||||||
|
* @param runner a {@link TaskRunner}.
|
||||||
|
* @param runningMode a mediapipe audio task {@link RunningMode}.
|
||||||
|
* @param audioStreamName the name of the input audio stream.
|
||||||
|
* @param sampleRateStreamName the name of the audio sample rate stream.
|
||||||
|
*/
|
||||||
|
public BaseAudioTaskApi(
|
||||||
|
TaskRunner runner,
|
||||||
|
RunningMode runningMode,
|
||||||
|
String audioStreamName,
|
||||||
|
String sampleRateStreamName) {
|
||||||
|
this.runner = runner;
|
||||||
|
this.runningMode = runningMode;
|
||||||
|
this.audioStreamName = audioStreamName;
|
||||||
|
this.sampleRateStreamName = sampleRateStreamName;
|
||||||
|
this.defaultSampleRate = -1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A synchronous method to process audio clips. The call blocks the current thread until a failure
|
||||||
|
* status or a successful result is returned.
|
||||||
|
*
|
||||||
|
* @param audioClip a MediaPipe {@link AudioDatra} object for processing.
|
||||||
|
* @throws MediaPipeException if the task is not in the audio clips mode.
|
||||||
|
*/
|
||||||
|
protected TaskResult processAudioClip(AudioData audioClip) {
|
||||||
|
if (runningMode != RunningMode.AUDIO_CLIPS) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||||
|
"Task is not initialized with the audio clips mode. Current running mode:"
|
||||||
|
+ runningMode.name());
|
||||||
|
}
|
||||||
|
Map<String, Packet> inputPackets = new HashMap<>();
|
||||||
|
inputPackets.put(
|
||||||
|
audioStreamName,
|
||||||
|
runner
|
||||||
|
.getPacketCreator()
|
||||||
|
.createMatrix(
|
||||||
|
audioClip.getFormat().getNumOfChannels(),
|
||||||
|
audioClip.getBufferLength(),
|
||||||
|
audioClip.getBuffer()));
|
||||||
|
inputPackets.put(
|
||||||
|
sampleRateStreamName,
|
||||||
|
runner.getPacketCreator().createFloat64(audioClip.getFormat().getSampleRate()));
|
||||||
|
return runner.process(inputPackets);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks or sets the audio sample rate in the audio stream mode.
|
||||||
|
*
|
||||||
|
* @param sampleRate the audio sample rate.
|
||||||
|
* @throws MediaPipeException if the task is not in the audio stream mode or the provided sample
|
||||||
|
* rate is inconsisent with the previously recevied.
|
||||||
|
*/
|
||||||
|
protected void checkOrSetSampleRate(double sampleRate) {
|
||||||
|
if (runningMode != RunningMode.AUDIO_STREAM) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||||
|
"Task is not initialized with the audio stream mode. Current running mode:"
|
||||||
|
+ runningMode.name());
|
||||||
|
}
|
||||||
|
if (defaultSampleRate > 0) {
|
||||||
|
if (Double.compare(sampleRate, defaultSampleRate) != 0) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
|
||||||
|
"The input audio sample rate: "
|
||||||
|
+ sampleRate
|
||||||
|
+ " is inconsistent with the previously provided: "
|
||||||
|
+ defaultSampleRate);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Map<String, Packet> inputPackets = new HashMap<>();
|
||||||
|
inputPackets.put(sampleRateStreamName, runner.getPacketCreator().createFloat64(sampleRate));
|
||||||
|
runner.send(inputPackets, PRESTREAM_TIMESTAMP);
|
||||||
|
defaultSampleRate = sampleRate;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* An asynchronous method to send audio stream data to the {@link TaskRunner}. The results will be
|
||||||
|
* available in the user-defined result listener.
|
||||||
|
*
|
||||||
|
* @param audioClip a MediaPipe {@link AudioDatra} object for processing.
|
||||||
|
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||||
|
* @throws MediaPipeException if the task is not in the stream mode.
|
||||||
|
*/
|
||||||
|
protected void sendAudioStreamData(AudioData audioClip, long timestampMs) {
|
||||||
|
if (runningMode != RunningMode.AUDIO_STREAM) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||||
|
"Task is not initialized with the audio stream mode. Current running mode:"
|
||||||
|
+ runningMode.name());
|
||||||
|
}
|
||||||
|
Map<String, Packet> inputPackets = new HashMap<>();
|
||||||
|
inputPackets.put(
|
||||||
|
audioStreamName,
|
||||||
|
runner
|
||||||
|
.getPacketCreator()
|
||||||
|
.createMatrix(
|
||||||
|
audioClip.getFormat().getNumOfChannels(),
|
||||||
|
audioClip.getBufferLength(),
|
||||||
|
audioClip.getBuffer()));
|
||||||
|
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Closes and cleans up the MediaPipe audio task. */
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
runner.close();
|
||||||
|
}
|
||||||
|
}
|