Merge branch 'google:master' into image-embedder-python
|
@ -155,6 +155,7 @@ http_archive(
|
|||
name = "com_google_audio_tools",
|
||||
strip_prefix = "multichannel-audio-tools-master",
|
||||
urls = ["https://github.com/google/multichannel-audio-tools/archive/master.zip"],
|
||||
repo_mapping = {"@com_github_glog_glog" : "@com_github_glog_glog_no_gflags"},
|
||||
)
|
||||
|
||||
http_archive(
|
||||
|
|
18
docs/BUILD
|
@ -12,3 +12,21 @@ py_binary(
|
|||
"//third_party/py/tensorflow_docs/api_generator:public_api",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "build_java_api_docs",
|
||||
srcs = ["build_java_api_docs.py"],
|
||||
data = [
|
||||
"//third_party/java/doclava/current:doclava.jar",
|
||||
"//third_party/java/jsilver:jsilver_jar",
|
||||
],
|
||||
env = {
|
||||
"DOCLAVA_JAR": "$(location //third_party/java/doclava/current:doclava.jar)",
|
||||
"JSILVER_JAR": "$(location //third_party/java/jsilver:jsilver_jar)",
|
||||
},
|
||||
deps = [
|
||||
"//third_party/py/absl:app",
|
||||
"//third_party/py/absl/flags",
|
||||
"//third_party/py/tensorflow_docs/api_generator/gen_java",
|
||||
],
|
||||
)
|
||||
|
|
58
docs/build_java_api_docs.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Generate Java reference docs for MediaPipe."""
|
||||
import pathlib
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from tensorflow_docs.api_generator import gen_java
|
||||
|
||||
_JAVA_ROOT = flags.DEFINE_string('java_src', None,
|
||||
'Override the Java source path.',
|
||||
required=False)
|
||||
|
||||
_OUT_DIR = flags.DEFINE_string('output_dir', '/tmp/mp_java/',
|
||||
'Write docs here.')
|
||||
|
||||
_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/java',
|
||||
'Path prefix in the _toc.yaml')
|
||||
|
||||
_ = flags.DEFINE_string('code_url_prefix', None,
|
||||
'[UNUSED] The url prefix for links to code.')
|
||||
|
||||
_ = flags.DEFINE_bool(
|
||||
'search_hints', True,
|
||||
'[UNUSED] Include metadata search hints in the generated files')
|
||||
|
||||
|
||||
def main(_) -> None:
|
||||
if not (java_root := _JAVA_ROOT.value):
|
||||
# Default to using a relative path to find the Java source.
|
||||
mp_root = pathlib.Path(__file__)
|
||||
while (mp_root := mp_root.parent).name != 'mediapipe':
|
||||
# Find the nearest `mediapipe` dir.
|
||||
pass
|
||||
java_root = mp_root / 'tasks/java'
|
||||
|
||||
gen_java.gen_java_docs(
|
||||
package='com.google.mediapipe',
|
||||
source_path=pathlib.Path(java_root),
|
||||
output_dir=pathlib.Path(_OUT_DIR.value),
|
||||
site_path=pathlib.Path(_SITE_PATH.value))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
"""Copyright 2019 - 2020 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
# Copyright 2019 - 2022 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
|
|
@ -277,6 +277,7 @@ cc_test(
|
|||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_test_util",
|
||||
"@com_google_audio_tools//audio/dsp/mfcc",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
)
|
||||
|
@ -352,6 +353,8 @@ cc_test(
|
|||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
"//mediapipe/util:time_series_test_util",
|
||||
"@com_google_audio_tools//audio/dsp:resampler",
|
||||
"@com_google_audio_tools//audio/dsp:resampler_q",
|
||||
"@com_google_audio_tools//audio/dsp:signal_vector_util",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
|
|
|
@ -130,7 +130,7 @@ class BypassCalculator : public Node {
|
|||
pass_out.insert(entry.second);
|
||||
auto& packet = cc->Inputs().Get(entry.first).Value();
|
||||
if (packet.Timestamp() == cc->InputTimestamp()) {
|
||||
cc->Outputs().Get(entry.first).AddPacket(packet);
|
||||
cc->Outputs().Get(entry.second).AddPacket(packet);
|
||||
}
|
||||
}
|
||||
Timestamp bound = cc->InputTimestamp().NextAllowedInStream();
|
||||
|
|
|
@ -42,10 +42,10 @@ constexpr char kTestGraphConfig1[] = R"pb(
|
|||
node {
|
||||
calculator: "BypassCalculator"
|
||||
input_stream: "PASS:appearances"
|
||||
input_stream: "TRUNCATE:0:video_frame"
|
||||
input_stream: "TRUNCATE:1:feature_config"
|
||||
input_stream: "IGNORE:0:video_frame"
|
||||
input_stream: "IGNORE:1:feature_config"
|
||||
output_stream: "PASS:passthrough_appearances"
|
||||
output_stream: "TRUNCATE:passthrough_federated_gaze_output"
|
||||
output_stream: "IGNORE:passthrough_federated_gaze_output"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.BypassCalculatorOptions] {
|
||||
pass_input_stream: "PASS"
|
||||
|
|
|
@ -156,6 +156,7 @@ cc_test(
|
|||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_audio_tools//audio/dsp:window_functions",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -750,19 +750,12 @@ objc_library(
|
|||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
mediapipe_proto_library(
|
||||
name = "scale_mode_proto",
|
||||
srcs = ["scale_mode.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "scale_mode_cc_proto",
|
||||
srcs = ["scale_mode.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":scale_mode_proto"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gl_quad_renderer",
|
||||
srcs = ["gl_quad_renderer.cc"],
|
||||
|
|
51
mediapipe/model_maker/models/gesture_recognizer/BUILD
Normal file
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load(
|
||||
"//mediapipe/framework/tool:mediapipe_files.bzl",
|
||||
"mediapipe_files",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__"],
|
||||
)
|
||||
|
||||
mediapipe_files(
|
||||
srcs = [
|
||||
"canned_gesture_classifier.tflite",
|
||||
"gesture_embedder.tflite",
|
||||
"gesture_embedder/keras_metadata.pb",
|
||||
"gesture_embedder/saved_model.pb",
|
||||
"gesture_embedder/variables/variables.data-00000-of-00001",
|
||||
"gesture_embedder/variables/variables.index",
|
||||
"hand_landmark_full.tflite",
|
||||
"palm_detection_full.tflite",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "models",
|
||||
srcs = [
|
||||
"canned_gesture_classifier.tflite",
|
||||
"gesture_embedder.tflite",
|
||||
"gesture_embedder/keras_metadata.pb",
|
||||
"gesture_embedder/saved_model.pb",
|
||||
"gesture_embedder/variables/variables.data-00000-of-00001",
|
||||
"gesture_embedder/variables/variables.index",
|
||||
"hand_landmark_full.tflite",
|
||||
"palm_detection_full.tflite",
|
||||
],
|
||||
)
|
35
mediapipe/model_maker/python/text/core/BUILD
Normal file
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
py_library(
|
||||
name = "bert_model_options",
|
||||
srcs = ["bert_model_options.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "bert_model_spec",
|
||||
srcs = ["bert_model_spec.py"],
|
||||
deps = [
|
||||
":bert_model_options",
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
],
|
||||
)
|
13
mediapipe/model_maker/python/text/core/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
33
mediapipe/model_maker/python/text/core/bert_model_options.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Configurable model options for a BERT model."""
|
||||
|
||||
import dataclasses
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BertModelOptions:
|
||||
"""Configurable model options for a BERT model.
|
||||
|
||||
See https://arxiv.org/abs/1810.04805 (BERT: Pre-training of Deep Bidirectional
|
||||
Transformers for Language Understanding) for more details.
|
||||
|
||||
Attributes:
|
||||
seq_len: Length of the sequence to feed into the model.
|
||||
do_fine_tuning: If true, then the BERT model is not frozen for training.
|
||||
dropout_rate: The rate for dropout.
|
||||
"""
|
||||
seq_len: int = 128
|
||||
do_fine_tuning: bool = True
|
||||
dropout_rate: float = 0.1
|
58
mediapipe/model_maker/python/text/core/bert_model_spec.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Specification for a BERT model."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Dict
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.text.core import bert_model_options
|
||||
|
||||
_DEFAULT_TFLITE_INPUT_NAME = {
|
||||
'ids': 'serving_default_input_word_ids:0',
|
||||
'mask': 'serving_default_input_mask:0',
|
||||
'segment_ids': 'serving_default_input_type_ids:0'
|
||||
}
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BertModelSpec:
|
||||
"""Specification for a BERT model.
|
||||
|
||||
See https://arxiv.org/abs/1810.04805 (BERT: Pre-training of Deep Bidirectional
|
||||
Transformers for Language Understanding) for more details.
|
||||
|
||||
Attributes:
|
||||
hparams: Hyperparameters used for training.
|
||||
model_options: Configurable options for a BERT model.
|
||||
do_lower_case: boolean, whether to lower case the input text. Should be
|
||||
True / False for uncased / cased models respectively, where the models
|
||||
are specified by the `uri`.
|
||||
tflite_input_name: Dict, input names for the TFLite model.
|
||||
uri: URI for the BERT module.
|
||||
name: The name of the object.
|
||||
"""
|
||||
|
||||
hparams: hp.BaseHParams = hp.BaseHParams(
|
||||
epochs=3,
|
||||
batch_size=32,
|
||||
learning_rate=3e-5,
|
||||
distribution_strategy='mirrored')
|
||||
model_options: bert_model_options.BertModelOptions = (
|
||||
bert_model_options.BertModelOptions())
|
||||
do_lower_case: bool = True
|
||||
tflite_input_name: Dict[str, str] = dataclasses.field(
|
||||
default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME)
|
||||
uri: str = 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1'
|
||||
name: str = 'Bert'
|
146
mediapipe/model_maker/python/text/text_classifier/BUILD
Normal file
|
@ -0,0 +1,146 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
py_library(
|
||||
name = "model_options",
|
||||
srcs = ["model_options.py"],
|
||||
deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_spec",
|
||||
srcs = ["model_spec.py"],
|
||||
deps = [
|
||||
":model_options",
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
"//mediapipe/model_maker/python/text/core:bert_model_spec",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "model_spec_test",
|
||||
srcs = ["model_spec_test.py"],
|
||||
deps = [
|
||||
":model_options",
|
||||
":model_spec",
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dataset",
|
||||
srcs = ["dataset.py"],
|
||||
deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dataset_test",
|
||||
srcs = ["dataset_test.py"],
|
||||
deps = [":dataset"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "preprocessor",
|
||||
srcs = ["preprocessor.py"],
|
||||
deps = [":dataset"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "preprocessor_test",
|
||||
srcs = ["preprocessor_test.py"],
|
||||
tags = ["requires-net:external"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":model_spec",
|
||||
":preprocessor",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "text_classifier_options",
|
||||
srcs = ["text_classifier_options.py"],
|
||||
deps = [
|
||||
":model_options",
|
||||
":model_spec",
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "text_classifier",
|
||||
srcs = ["text_classifier.py"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
":preprocessor",
|
||||
":text_classifier_options",
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:text_classifier",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "text_classifier_test",
|
||||
size = "large",
|
||||
srcs = ["text_classifier_test.py"],
|
||||
data = [
|
||||
"//mediapipe/model_maker/python/text/text_classifier/testdata",
|
||||
],
|
||||
tags = ["requires-net:external"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
":text_classifier",
|
||||
":text_classifier_options",
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "text_classifier_demo_lib",
|
||||
srcs = ["text_classifier_demo.py"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":model_spec",
|
||||
":text_classifier",
|
||||
":text_classifier_options",
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "text_classifier_demo",
|
||||
srcs = ["text_classifier_demo.py"],
|
||||
deps = [
|
||||
":text_classifier_demo_lib",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
88
mediapipe/model_maker/python/text/text_classifier/dataset.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Text classifier dataset library."""
|
||||
|
||||
import csv
|
||||
import dataclasses
|
||||
import random
|
||||
|
||||
from typing import Optional, Sequence
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CSVParameters:
|
||||
"""Parameters used when reading a CSV file.
|
||||
|
||||
Attributes:
|
||||
text_column: Column name for the input text.
|
||||
label_column: Column name for the labels.
|
||||
fieldnames: Sequence of keys for the CSV columns. If None, the first row of
|
||||
the CSV file is used as the keys.
|
||||
delimiter: Character that separates fields.
|
||||
quotechar: Character used to quote fields that contain special characters
|
||||
like the `delimiter`.
|
||||
"""
|
||||
text_column: str
|
||||
label_column: str
|
||||
fieldnames: Optional[Sequence[str]] = None
|
||||
delimiter: str = ","
|
||||
quotechar: str = '"'
|
||||
|
||||
|
||||
class Dataset(classification_dataset.ClassificationDataset):
|
||||
"""Dataset library for text classifier."""
|
||||
|
||||
@classmethod
|
||||
def from_csv(cls,
|
||||
filename: str,
|
||||
csv_params: CSVParameters,
|
||||
shuffle: bool = True) -> "Dataset":
|
||||
"""Loads text with labels from a CSV file.
|
||||
|
||||
Args:
|
||||
filename: Name of the CSV file.
|
||||
csv_params: Parameters used for reading the CSV file.
|
||||
shuffle: If True, randomly shuffle the data.
|
||||
|
||||
Returns:
|
||||
Dataset containing (text, label) pairs and other related info.
|
||||
"""
|
||||
with tf.io.gfile.GFile(filename, "r") as f:
|
||||
reader = csv.DictReader(
|
||||
f,
|
||||
fieldnames=csv_params.fieldnames,
|
||||
delimiter=csv_params.delimiter,
|
||||
quotechar=csv_params.quotechar)
|
||||
|
||||
lines = list(reader)
|
||||
if shuffle:
|
||||
random.shuffle(lines)
|
||||
|
||||
label_names = sorted(set([line[csv_params.label_column] for line in lines]))
|
||||
index_by_label = {label: index for index, label in enumerate(label_names)}
|
||||
|
||||
texts = [line[csv_params.text_column] for line in lines]
|
||||
text_ds = tf.data.Dataset.from_tensor_slices(tf.cast(texts, tf.string))
|
||||
label_indices = [
|
||||
index_by_label[line[csv_params.label_column]] for line in lines
|
||||
]
|
||||
label_index_ds = tf.data.Dataset.from_tensor_slices(
|
||||
tf.cast(label_indices, tf.int64))
|
||||
text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds))
|
||||
|
||||
return Dataset(
|
||||
dataset=text_label_ds, size=len(texts), label_names=label_names)
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import csv
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.text.text_classifier import dataset
|
||||
|
||||
|
||||
class DatasetTest(tf.test.TestCase):
|
||||
|
||||
def _get_csv_file(self):
|
||||
labels_and_text = (('neutral', 'indifferent'), ('pos', 'extremely great'),
|
||||
('neg', 'totally awful'), ('pos', 'super good'),
|
||||
('neg', 'really bad'))
|
||||
csv_file = os.path.join(self.get_temp_dir(), 'data.csv')
|
||||
if os.path.exists(csv_file):
|
||||
return csv_file
|
||||
fieldnames = ['text', 'label']
|
||||
with open(csv_file, 'w') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
for label, text in labels_and_text:
|
||||
writer.writerow({'text': text, 'label': label})
|
||||
return csv_file
|
||||
|
||||
def test_from_csv(self):
|
||||
csv_file = self._get_csv_file()
|
||||
csv_params = dataset.CSVParameters(text_column='text', label_column='label')
|
||||
data = dataset.Dataset.from_csv(filename=csv_file, csv_params=csv_params)
|
||||
self.assertLen(data, 5)
|
||||
self.assertEqual(data.num_classes, 3)
|
||||
self.assertEqual(data.label_names, ['neg', 'neutral', 'pos'])
|
||||
data_values = set([(text.numpy()[0], label.numpy()[0])
|
||||
for text, label in data.gen_tf_dataset()])
|
||||
expected_data_values = set([(b'indifferent', 1), (b'extremely great', 2),
|
||||
(b'totally awful', 0), (b'super good', 2),
|
||||
(b'really bad', 0)])
|
||||
self.assertEqual(data_values, expected_data_values)
|
||||
|
||||
def test_split(self):
|
||||
ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd'])
|
||||
data = dataset.Dataset(ds, 4, ['pos', 'neg'])
|
||||
train_data, test_data = data.split(0.5)
|
||||
expected_train_data = [b'good', b'bad']
|
||||
expected_test_data = [b'neutral', b'odd']
|
||||
|
||||
self.assertLen(train_data, 2)
|
||||
train_data_values = [elem.numpy() for elem in train_data._dataset]
|
||||
self.assertEqual(train_data_values, expected_train_data)
|
||||
self.assertEqual(train_data.num_classes, 2)
|
||||
self.assertEqual(train_data.label_names, ['pos', 'neg'])
|
||||
|
||||
self.assertLen(test_data, 2)
|
||||
test_data_values = [elem.numpy() for elem in test_data._dataset]
|
||||
self.assertEqual(test_data_values, expected_test_data)
|
||||
self.assertEqual(test_data.num_classes, 2)
|
||||
self.assertEqual(test_data.label_names, ['pos', 'neg'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Configurable model options for text classifier models."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Union
|
||||
|
||||
from mediapipe.model_maker.python.text.core import bert_model_options
|
||||
|
||||
# BERT text classifier options inherited from BertModelOptions.
|
||||
BertClassifierOptions = bert_model_options.BertModelOptions
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AverageWordEmbeddingClassifierOptions:
|
||||
"""Configurable model options for an Average Word Embedding classifier.
|
||||
|
||||
Attributes:
|
||||
seq_len: Length of the sequence to feed into the model.
|
||||
wordvec_dim: Dimension of the word embedding.
|
||||
do_lower_case: Whether to convert all uppercase characters to lowercase
|
||||
during preprocessing.
|
||||
vocab_size: Number of words to generate the vocabulary from data.
|
||||
dropout_rate: The rate for dropout.
|
||||
"""
|
||||
seq_len: int = 256
|
||||
wordvec_dim: int = 16
|
||||
do_lower_case: bool = True
|
||||
vocab_size: int = 10000
|
||||
dropout_rate: float = 0.2
|
||||
|
||||
|
||||
TextClassifierModelOptions = Union[AverageWordEmbeddingClassifierOptions,
|
||||
BertClassifierOptions]
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Specifications for text classifier models."""
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.text.core import bert_model_spec
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||
|
||||
# BERT-based text classifier spec inherited from BertModelSpec
|
||||
BertClassifierSpec = bert_model_spec.BertModelSpec
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AverageWordEmbeddingClassifierSpec:
|
||||
"""Specification for an average word embedding classifier model.
|
||||
|
||||
Attributes:
|
||||
hparams: Configurable hyperparameters for training.
|
||||
model_options: Configurable options for the average word embedding model.
|
||||
name: The name of the object.
|
||||
"""
|
||||
|
||||
# `learning_rate` is unused for the average word embedding model
|
||||
hparams: hp.BaseHParams = hp.BaseHParams(
|
||||
epochs=10, batch_size=32, learning_rate=0)
|
||||
model_options: mo.AverageWordEmbeddingClassifierOptions = (
|
||||
mo.AverageWordEmbeddingClassifierOptions())
|
||||
name: str = 'AverageWordEmbedding'
|
||||
|
||||
|
||||
average_word_embedding_classifier_spec = functools.partial(
|
||||
AverageWordEmbeddingClassifierSpec)
|
||||
|
||||
mobilebert_classifier_spec = functools.partial(
|
||||
BertClassifierSpec,
|
||||
hparams=hp.BaseHParams(
|
||||
epochs=3,
|
||||
batch_size=48,
|
||||
learning_rate=3e-5,
|
||||
distribution_strategy='off'),
|
||||
name='MobileBert',
|
||||
uri='https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1',
|
||||
tflite_input_name={
|
||||
'ids': 'serving_default_input_1:0',
|
||||
'mask': 'serving_default_input_3:0',
|
||||
'segment_ids': 'serving_default_input_2:0'
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@enum.unique
|
||||
class SupportedModels(enum.Enum):
|
||||
"""Predefined text classifier model specs supported by Model Maker."""
|
||||
AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
|
||||
MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for model_spec."""
|
||||
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||
|
||||
|
||||
class ModelSpecTest(tf.test.TestCase):
|
||||
|
||||
def test_predefined_bert_spec(self):
|
||||
model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||
self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec)
|
||||
self.assertEqual(model_spec_obj.name, 'MobileBert')
|
||||
self.assertEqual(
|
||||
model_spec_obj.uri, 'https://tfhub.dev/tensorflow/'
|
||||
'mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1')
|
||||
self.assertTrue(model_spec_obj.do_lower_case)
|
||||
self.assertEqual(
|
||||
model_spec_obj.tflite_input_name, {
|
||||
'ids': 'serving_default_input_1:0',
|
||||
'mask': 'serving_default_input_3:0',
|
||||
'segment_ids': 'serving_default_input_2:0'
|
||||
})
|
||||
self.assertEqual(
|
||||
model_spec_obj.model_options,
|
||||
classifier_model_options.BertClassifierOptions(
|
||||
seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
|
||||
self.assertEqual(
|
||||
model_spec_obj.hparams,
|
||||
hp.BaseHParams(
|
||||
epochs=3,
|
||||
batch_size=48,
|
||||
learning_rate=3e-5,
|
||||
distribution_strategy='off'))
|
||||
|
||||
def test_predefined_average_word_embedding_spec(self):
|
||||
model_spec_obj = (
|
||||
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER.value())
|
||||
self.assertIsInstance(model_spec_obj, ms.AverageWordEmbeddingClassifierSpec)
|
||||
self.assertEqual(model_spec_obj.name, 'AverageWordEmbedding')
|
||||
self.assertEqual(
|
||||
model_spec_obj.model_options,
|
||||
classifier_model_options.AverageWordEmbeddingClassifierOptions(
|
||||
seq_len=256,
|
||||
wordvec_dim=16,
|
||||
do_lower_case=True,
|
||||
vocab_size=10000,
|
||||
dropout_rate=0.2))
|
||||
self.assertEqual(
|
||||
model_spec_obj.hparams,
|
||||
hp.BaseHParams(
|
||||
epochs=10,
|
||||
batch_size=32,
|
||||
learning_rate=0,
|
||||
steps_per_epoch=None,
|
||||
shuffle=False,
|
||||
distribution_strategy='off',
|
||||
num_gpus=-1,
|
||||
tpu=''))
|
||||
|
||||
def test_custom_bert_spec(self):
|
||||
custom_bert_classifier_options = (
|
||||
classifier_model_options.BertClassifierOptions(
|
||||
seq_len=512, do_fine_tuning=False, dropout_rate=0.3))
|
||||
model_spec_obj = (
|
||||
ms.SupportedModels.MOBILEBERT_CLASSIFIER.value(
|
||||
model_options=custom_bert_classifier_options))
|
||||
self.assertEqual(model_spec_obj.model_options,
|
||||
custom_bert_classifier_options)
|
||||
|
||||
def test_custom_average_word_embedding_spec(self):
|
||||
custom_hparams = hp.BaseHParams(
|
||||
learning_rate=0.4,
|
||||
batch_size=64,
|
||||
epochs=10,
|
||||
steps_per_epoch=10,
|
||||
shuffle=True,
|
||||
export_dir='foo/bar',
|
||||
distribution_strategy='mirrored',
|
||||
num_gpus=3,
|
||||
tpu='tpu/address')
|
||||
custom_average_word_embedding_model_options = (
|
||||
classifier_model_options.AverageWordEmbeddingClassifierOptions(
|
||||
seq_len=512,
|
||||
wordvec_dim=32,
|
||||
do_lower_case=False,
|
||||
vocab_size=5000,
|
||||
dropout_rate=0.5))
|
||||
model_spec_obj = (
|
||||
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER.value(
|
||||
model_options=custom_average_word_embedding_model_options,
|
||||
hparams=custom_hparams))
|
||||
self.assertEqual(model_spec_obj.model_options,
|
||||
custom_average_word_embedding_model_options)
|
||||
self.assertEqual(model_spec_obj.hparams, custom_hparams)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Load compressed models from tensorflow_hub
|
||||
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
|
||||
tf.test.main()
|
|
@ -0,0 +1,285 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Preprocessors for text classification."""
|
||||
|
||||
import collections
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from typing import Mapping, Sequence, Tuple, Union
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_hub
|
||||
|
||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
||||
from official.nlp.data import classifier_data_lib
|
||||
from official.nlp.tools import tokenization
|
||||
|
||||
|
||||
def _validate_text_and_label(text: tf.Tensor, label: tf.Tensor) -> None:
|
||||
"""Validates the shape and type of `text` and `label`.
|
||||
|
||||
Args:
|
||||
text: Stores text data. Should have shape [1] and dtype tf.string.
|
||||
label: Stores the label for the corresponding `text`. Should have shape [1]
|
||||
and dtype tf.int64.
|
||||
|
||||
Raises:
|
||||
ValueError: If either tensor has the wrong shape or type.
|
||||
"""
|
||||
if text.shape != [1]:
|
||||
raise ValueError(f"`text` should have shape [1], got {text.shape}")
|
||||
if text.dtype != tf.string:
|
||||
raise ValueError(f"Expected dtype string for `text`, got {text.dtype}")
|
||||
if label.shape != [1]:
|
||||
raise ValueError(f"`label` should have shape [1], got {text.shape}")
|
||||
if label.dtype != tf.int64:
|
||||
raise ValueError(f"Expected dtype int64 for `label`, got {label.dtype}")
|
||||
|
||||
|
||||
def _decode_record(
|
||||
record: tf.Tensor, name_to_features: Mapping[str, tf.io.FixedLenFeature]
|
||||
) -> Tuple[Mapping[str, tf.Tensor], tf.Tensor]:
|
||||
"""Decodes a record into input for a BERT model.
|
||||
|
||||
Args:
|
||||
record: Stores serialized example.
|
||||
name_to_features: Maps record keys to feature types.
|
||||
|
||||
Returns:
|
||||
BERT model input features and label for the record.
|
||||
"""
|
||||
example = tf.io.parse_single_example(record, name_to_features)
|
||||
|
||||
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
||||
for name in list(example.keys()):
|
||||
example[name] = tf.cast(example[name], tf.int32)
|
||||
|
||||
bert_features = {
|
||||
"input_word_ids": example["input_ids"],
|
||||
"input_mask": example["input_mask"],
|
||||
"input_type_ids": example["segment_ids"]
|
||||
}
|
||||
return bert_features, example["label_ids"]
|
||||
|
||||
|
||||
def _single_file_dataset(
|
||||
input_file: str, name_to_features: Mapping[str, tf.io.FixedLenFeature]
|
||||
) -> tf.data.TFRecordDataset:
|
||||
"""Creates a single-file dataset to be passed for BERT custom training.
|
||||
|
||||
Args:
|
||||
input_file: Filepath for the dataset.
|
||||
name_to_features: Maps record keys to feature types.
|
||||
|
||||
Returns:
|
||||
Dataset containing BERT model input features and labels.
|
||||
"""
|
||||
d = tf.data.TFRecordDataset(input_file)
|
||||
d = d.map(
|
||||
lambda record: _decode_record(record, name_to_features),
|
||||
num_parallel_calls=tf.data.AUTOTUNE)
|
||||
return d
|
||||
|
||||
|
||||
class AverageWordEmbeddingClassifierPreprocessor:
|
||||
"""Preprocessor for an Average Word Embedding model.
|
||||
|
||||
Takes (text, label) data and applies regex tokenization and padding to the
|
||||
text to generate (token IDs, label) data.
|
||||
|
||||
Attributes:
|
||||
seq_len: Length of the input sequence to the model.
|
||||
do_lower_case: Whether text inputs should be converted to lower-case.
|
||||
vocab: Vocabulary of tokens used by the model.
|
||||
"""
|
||||
|
||||
PAD: str = "<PAD>" # Index: 0
|
||||
START: str = "<START>" # Index: 1
|
||||
UNKNOWN: str = "<UNKNOWN>" # Index: 2
|
||||
|
||||
def __init__(self, seq_len: int, do_lower_case: bool, texts: Sequence[str],
|
||||
vocab_size: int):
|
||||
self._seq_len = seq_len
|
||||
self._do_lower_case = do_lower_case
|
||||
self._vocab = self._gen_vocab(texts, vocab_size)
|
||||
|
||||
def _gen_vocab(self, texts: Sequence[str],
|
||||
vocab_size: int) -> Mapping[str, int]:
|
||||
"""Generates vocabulary list in `texts` with size `vocab_size`.
|
||||
|
||||
Args:
|
||||
texts: All texts (across training and validation data) that will be
|
||||
preprocessed by the model.
|
||||
vocab_size: Size of the vocab.
|
||||
|
||||
Returns:
|
||||
The vocab mapping tokens to IDs.
|
||||
"""
|
||||
vocab_counter = collections.Counter()
|
||||
|
||||
for text in texts:
|
||||
tokens = self._regex_tokenize(text)
|
||||
for token in tokens:
|
||||
vocab_counter[token] += 1
|
||||
|
||||
vocab_freq = vocab_counter.most_common(vocab_size)
|
||||
vocab_list = [self.PAD, self.START, self.UNKNOWN
|
||||
] + [word for word, _ in vocab_freq]
|
||||
return collections.OrderedDict(((v, i) for i, v in enumerate(vocab_list)))
|
||||
|
||||
def get_vocab(self) -> Mapping[str, int]:
|
||||
"""Returns the vocab of the AverageWordEmbeddingClassifierPreprocessor."""
|
||||
return self._vocab
|
||||
|
||||
# TODO: Align with MediaPipe's RegexTokenizer.
|
||||
def _regex_tokenize(self, text: str) -> Sequence[str]:
|
||||
"""Splits `text` by words but does not split on single quotes.
|
||||
|
||||
Args:
|
||||
text: Text to be tokenized.
|
||||
|
||||
Returns:
|
||||
List of tokens.
|
||||
"""
|
||||
text = tf.compat.as_text(text)
|
||||
if self._do_lower_case:
|
||||
text = text.lower()
|
||||
tokens = re.compile(r"[^\w\']+").split(text.strip())
|
||||
# Filters out any empty strings in `tokens`.
|
||||
return list(filter(None, tokens))
|
||||
|
||||
def _tokenize_and_pad(self, text: str) -> Sequence[int]:
|
||||
"""Tokenizes `text` and pads the tokens to `seq_len`.
|
||||
|
||||
Args:
|
||||
text: Text to be tokenized and padded.
|
||||
|
||||
Returns:
|
||||
List of token IDs padded to have length `seq_len`.
|
||||
"""
|
||||
tokens = self._regex_tokenize(text)
|
||||
|
||||
# Gets ids for START, PAD and UNKNOWN tokens.
|
||||
start_id = self._vocab[self.START]
|
||||
pad_id = self._vocab[self.PAD]
|
||||
unknown_id = self._vocab[self.UNKNOWN]
|
||||
|
||||
token_ids = [self._vocab.get(token, unknown_id) for token in tokens]
|
||||
token_ids = [start_id] + token_ids
|
||||
|
||||
if len(token_ids) < self._seq_len:
|
||||
pad_length = self._seq_len - len(token_ids)
|
||||
token_ids = token_ids + pad_length * [pad_id]
|
||||
else:
|
||||
token_ids = token_ids[:self._seq_len]
|
||||
return token_ids
|
||||
|
||||
def preprocess(
|
||||
self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset:
|
||||
"""Preprocesses data into input for an Average Word Embedding model.
|
||||
|
||||
Args:
|
||||
dataset: Stores (text, label) data.
|
||||
|
||||
Returns:
|
||||
Dataset containing (token IDs, label) data.
|
||||
"""
|
||||
token_ids_list = []
|
||||
labels_list = []
|
||||
for text, label in dataset.gen_tf_dataset():
|
||||
_validate_text_and_label(text, label)
|
||||
token_ids = self._tokenize_and_pad(text.numpy()[0].decode("utf-8"))
|
||||
token_ids_list.append(token_ids)
|
||||
labels_list.append(label.numpy()[0])
|
||||
|
||||
token_ids_ds = tf.data.Dataset.from_tensor_slices(token_ids_list)
|
||||
labels_ds = tf.data.Dataset.from_tensor_slices(labels_list)
|
||||
preprocessed_ds = tf.data.Dataset.zip((token_ids_ds, labels_ds))
|
||||
return text_classifier_ds.Dataset(
|
||||
dataset=preprocessed_ds,
|
||||
size=dataset.size,
|
||||
label_names=dataset.label_names)
|
||||
|
||||
|
||||
class BertClassifierPreprocessor:
|
||||
"""Preprocessor for a BERT-based classifier.
|
||||
|
||||
Attributes:
|
||||
seq_len: Length of the input sequence to the model.
|
||||
vocab_file: File containing the BERT vocab.
|
||||
tokenizer: BERT tokenizer.
|
||||
"""
|
||||
|
||||
def __init__(self, seq_len: int, do_lower_case: bool, uri: str):
|
||||
self._seq_len = seq_len
|
||||
# Vocab filepath is tied to the BERT module's URI.
|
||||
self._vocab_file = os.path.join(
|
||||
tensorflow_hub.resolve(uri), "assets", "vocab.txt")
|
||||
self._tokenizer = tokenization.FullTokenizer(self._vocab_file,
|
||||
do_lower_case)
|
||||
|
||||
def _get_name_to_features(self):
|
||||
"""Gets the dictionary mapping record keys to feature types."""
|
||||
return {
|
||||
"input_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
|
||||
"input_mask": tf.io.FixedLenFeature([self._seq_len], tf.int64),
|
||||
"segment_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
|
||||
"label_ids": tf.io.FixedLenFeature([], tf.int64),
|
||||
}
|
||||
|
||||
def get_vocab_file(self) -> str:
|
||||
"""Returns the vocab file of the BertClassifierPreprocessor."""
|
||||
return self._vocab_file
|
||||
|
||||
def preprocess(
|
||||
self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset:
|
||||
"""Preprocesses data into input for a BERT-based classifier.
|
||||
|
||||
Args:
|
||||
dataset: Stores (text, label) data.
|
||||
|
||||
Returns:
|
||||
Dataset containing (bert_features, label) data.
|
||||
"""
|
||||
examples = []
|
||||
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
|
||||
_validate_text_and_label(text, label)
|
||||
examples.append(
|
||||
classifier_data_lib.InputExample(
|
||||
guid=str(index),
|
||||
text_a=text.numpy()[0].decode("utf-8"),
|
||||
text_b=None,
|
||||
# InputExample expects the label name rather than the int ID
|
||||
label=dataset.label_names[label.numpy()[0]]))
|
||||
|
||||
tfrecord_file = os.path.join(tempfile.mkdtemp(), "bert_features.tfrecord")
|
||||
classifier_data_lib.file_based_convert_examples_to_features(
|
||||
examples=examples,
|
||||
label_list=dataset.label_names,
|
||||
max_seq_length=self._seq_len,
|
||||
tokenizer=self._tokenizer,
|
||||
output_file=tfrecord_file)
|
||||
preprocessed_ds = _single_file_dataset(tfrecord_file,
|
||||
self._get_name_to_features())
|
||||
return text_classifier_ds.Dataset(
|
||||
dataset=preprocessed_ds,
|
||||
size=dataset.size,
|
||||
label_names=dataset.label_names)
|
||||
|
||||
|
||||
TextClassifierPreprocessor = (
|
||||
Union[BertClassifierPreprocessor,
|
||||
AverageWordEmbeddingClassifierPreprocessor])
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import csv
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import numpy.testing as npt
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
||||
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
||||
|
||||
|
||||
class PreprocessorTest(tf.test.TestCase):
|
||||
CSV_PARAMS_ = text_classifier_ds.CSVParameters(
|
||||
text_column='text', label_column='label')
|
||||
|
||||
def _get_csv_file(self):
|
||||
labels_and_text = (('pos', 'super super super super good'),
|
||||
(('neg', 'really bad')))
|
||||
csv_file = os.path.join(self.get_temp_dir(), 'data.csv')
|
||||
if os.path.exists(csv_file):
|
||||
return csv_file
|
||||
fieldnames = ['text', 'label']
|
||||
with open(csv_file, 'w') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
for label, text in labels_and_text:
|
||||
writer.writerow({'text': text, 'label': label})
|
||||
return csv_file
|
||||
|
||||
def test_average_word_embedding_preprocessor(self):
|
||||
csv_file = self._get_csv_file()
|
||||
dataset = text_classifier_ds.Dataset.from_csv(
|
||||
filename=csv_file, csv_params=self.CSV_PARAMS_)
|
||||
average_word_embedding_preprocessor = (
|
||||
preprocessor.AverageWordEmbeddingClassifierPreprocessor(
|
||||
seq_len=5,
|
||||
do_lower_case=True,
|
||||
texts=['super super super super good', 'really bad'],
|
||||
vocab_size=7))
|
||||
preprocessed_dataset = (
|
||||
average_word_embedding_preprocessor.preprocess(dataset))
|
||||
labels = []
|
||||
features_list = []
|
||||
for features, label in preprocessed_dataset.gen_tf_dataset():
|
||||
self.assertEqual(label.shape, [1])
|
||||
labels.append(label.numpy()[0])
|
||||
self.assertEqual(features.shape, [1, 5])
|
||||
features_list.append(features.numpy()[0])
|
||||
self.assertEqual(labels, [1, 0])
|
||||
npt.assert_array_equal(
|
||||
np.stack(features_list), np.array([[1, 3, 3, 3, 3], [1, 5, 6, 0, 0]]))
|
||||
|
||||
def test_bert_preprocessor(self):
|
||||
csv_file = self._get_csv_file()
|
||||
dataset = text_classifier_ds.Dataset.from_csv(
|
||||
filename=csv_file, csv_params=self.CSV_PARAMS_)
|
||||
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
|
||||
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||
seq_len=5, do_lower_case=bert_spec.do_lower_case, uri=bert_spec.uri)
|
||||
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
|
||||
labels = []
|
||||
input_masks = []
|
||||
for features, label in preprocessed_dataset.gen_tf_dataset():
|
||||
self.assertEqual(label.shape, [1])
|
||||
labels.append(label.numpy()[0])
|
||||
self.assertSameElements(
|
||||
features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids'])
|
||||
for feature in features.values():
|
||||
self.assertEqual(feature.shape, [1, 5])
|
||||
input_masks.append(features['input_mask'].numpy()[0])
|
||||
npt.assert_array_equal(features['input_type_ids'].numpy()[0],
|
||||
[0, 0, 0, 0, 0])
|
||||
npt.assert_array_equal(
|
||||
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]]))
|
||||
self.assertEqual(labels, [1, 0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Load compressed models from tensorflow_hub
|
||||
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
|
||||
tf.test.main()
|
23
mediapipe/model_maker/python/text/text_classifier/testdata/BUILD
vendored
Normal file
|
@ -0,0 +1,23 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe/model_maker/python/text/text_classifier:__subpackages__"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "testdata",
|
||||
srcs = ["average_word_embedding_metadata.json"],
|
||||
)
|
63
mediapipe/model_maker/python/text/text_classifier/testdata/average_word_embedding_metadata.json
vendored
Normal file
|
@ -0,0 +1,63 @@
|
|||
{
|
||||
"name": "TextClassifier",
|
||||
"description": "Classify the input text into a set of known categories.",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "input_text",
|
||||
"description": "Embedding vectors representing the input text to be processed.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"process_units": [
|
||||
{
|
||||
"options_type": "RegexTokenizerOptions",
|
||||
"options": {
|
||||
"delim_regex_pattern": "[^\\w\\']+",
|
||||
"vocab_file": [
|
||||
{
|
||||
"name": "vocab.txt",
|
||||
"description": "Vocabulary file to convert natural language words to embedding vectors.",
|
||||
"type": "VOCABULARY"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
}
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "score",
|
||||
"description": "Score of the labels respectively.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
0.0
|
||||
]
|
||||
},
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "labels.txt",
|
||||
"description": "Labels for categories that the model can recognize.",
|
||||
"type": "TENSOR_AXIS_LABELS"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.2.1"
|
||||
}
|
|
@ -0,0 +1,437 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""API for text classification."""
|
||||
|
||||
import abc
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Optional, Sequence, Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||
from mediapipe.model_maker.python.core.tasks import classifier
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import text_classifier as text_classifier_writer
|
||||
from official.nlp import optimization
|
||||
|
||||
|
||||
def _validate(options: text_classifier_options.TextClassifierOptions):
|
||||
"""Validates that `model_options` and `supported_model` are compatible.
|
||||
|
||||
Args:
|
||||
options: Options for creating and training a text classifier.
|
||||
|
||||
Raises:
|
||||
ValueError if there is a mismatch between `model_options` and
|
||||
`supported_model`.
|
||||
"""
|
||||
|
||||
if options.model_options is None:
|
||||
return
|
||||
|
||||
if (isinstance(options.model_options,
|
||||
mo.AverageWordEmbeddingClassifierOptions) and
|
||||
(options.supported_model !=
|
||||
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
|
||||
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
|
||||
f" got {options.supported_model}")
|
||||
if (isinstance(options.model_options, mo.BertClassifierOptions) and
|
||||
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
|
||||
raise ValueError(
|
||||
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
|
||||
|
||||
|
||||
class TextClassifier(classifier.Classifier):
|
||||
"""API for creating and training a text classification model."""
|
||||
|
||||
def __init__(self, model_spec: Any, hparams: hp.BaseHParams,
|
||||
label_names: Sequence[str]):
|
||||
super().__init__(
|
||||
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
|
||||
self._model_spec = model_spec
|
||||
self._hparams = hparams
|
||||
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
||||
self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
|
||||
options: text_classifier_options.TextClassifierOptions
|
||||
) -> "TextClassifier":
|
||||
"""Factory function that creates and trains a text classifier.
|
||||
|
||||
Note that `train_data` and `validation_data` are expected to share the same
|
||||
`label_names` since they should be split from the same dataset.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
options: Options for creating and training the text classifier.
|
||||
|
||||
Returns:
|
||||
A text classifier.
|
||||
|
||||
Raises:
|
||||
ValueError if `train_data` and `validation_data` do not have the
|
||||
same label_names or `options` contains an unknown `supported_model`
|
||||
"""
|
||||
if train_data.label_names != validation_data.label_names:
|
||||
raise ValueError(
|
||||
f"Training data label names {train_data.label_names} not equal to "
|
||||
f"validation data label names {validation_data.label_names}")
|
||||
|
||||
_validate(options)
|
||||
if options.model_options is None:
|
||||
options.model_options = options.supported_model.value().model_options
|
||||
|
||||
if options.hparams is None:
|
||||
options.hparams = options.supported_model.value().hparams
|
||||
|
||||
if options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
|
||||
text_classifier = (
|
||||
_BertClassifier.create_bert_classifier(train_data, validation_data,
|
||||
options,
|
||||
train_data.label_names))
|
||||
elif (options.supported_model ==
|
||||
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
||||
text_classifier = (
|
||||
_AverageWordEmbeddingClassifier
|
||||
.create_average_word_embedding_classifier(train_data, validation_data,
|
||||
options,
|
||||
train_data.label_names))
|
||||
else:
|
||||
raise ValueError(f"Unknown model {options.supported_model}")
|
||||
|
||||
return text_classifier
|
||||
|
||||
def evaluate(self, data: ds.Dataset, batch_size: int = 32) -> Any:
|
||||
"""Overrides Classifier.evaluate().
|
||||
|
||||
Args:
|
||||
data: Evaluation dataset. Must be a TextClassifier Dataset.
|
||||
batch_size: Number of samples per evaluation step.
|
||||
|
||||
Returns:
|
||||
The loss value and accuracy.
|
||||
|
||||
Raises:
|
||||
ValueError if `data` is not a TextClassifier Dataset.
|
||||
"""
|
||||
# This override is needed because TextClassifier preprocesses its data
|
||||
# outside of the `gen_tf_dataset()` method. The preprocess call also
|
||||
# requires a TextClassifier Dataset instead of a core Dataset.
|
||||
if not isinstance(data, text_ds.Dataset):
|
||||
raise ValueError("Need a TextClassifier Dataset.")
|
||||
|
||||
processed_data = self._text_preprocessor.preprocess(data)
|
||||
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
|
||||
return self._model.evaluate(dataset)
|
||||
|
||||
def export_model(
|
||||
self,
|
||||
model_name: str = "model.tflite",
|
||||
quantization_config: Optional[quantization.QuantizationConfig] = None):
|
||||
"""Converts and saves the model to a TFLite file with metadata included.
|
||||
|
||||
Note that only the TFLite file is needed for deployment. This function also
|
||||
saves a metadata.json file to the same directory as the TFLite file which
|
||||
can be used to interpret the metadata content in the TFLite file.
|
||||
|
||||
Args:
|
||||
model_name: File name to save TFLite model with metadata. The full export
|
||||
path is {self._hparams.export_dir}/{model_name}.
|
||||
quantization_config: The configuration for model quantization.
|
||||
"""
|
||||
if not tf.io.gfile.exists(self._hparams.export_dir):
|
||||
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||
tflite_file = os.path.join(self._hparams.export_dir, model_name)
|
||||
metadata_file = os.path.join(self._hparams.export_dir, "metadata.json")
|
||||
|
||||
tflite_model = model_util.convert_to_tflite(
|
||||
model=self._model, quantization_config=quantization_config)
|
||||
vocab_filepath = os.path.join(tempfile.mkdtemp(), "vocab.txt")
|
||||
self._save_vocab(vocab_filepath)
|
||||
|
||||
writer = self._get_metadata_writer(tflite_model, vocab_filepath)
|
||||
tflite_model_with_metadata, metadata_json = writer.populate()
|
||||
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
|
||||
with open(metadata_file, "w") as f:
|
||||
f.write(metadata_json)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _save_vocab(self, vocab_filepath: str):
|
||||
"""Saves the preprocessor's vocab to `vocab_filepath`."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str):
|
||||
"""Gets the metadata writer for the text classifier TFLite model."""
|
||||
|
||||
|
||||
class _AverageWordEmbeddingClassifier(TextClassifier):
|
||||
"""APIs to help create and train an Average Word Embedding text classifier."""
|
||||
|
||||
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
||||
|
||||
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
|
||||
model_options: mo.AverageWordEmbeddingClassifierOptions,
|
||||
hparams: hp.BaseHParams, label_names: Sequence[str]):
|
||||
super().__init__(model_spec, hparams, label_names)
|
||||
self._model_options = model_options
|
||||
self._loss_function = "sparse_categorical_crossentropy"
|
||||
self._metric_function = "accuracy"
|
||||
self._text_preprocessor: (
|
||||
preprocessor.AverageWordEmbeddingClassifierPreprocessor) = None
|
||||
|
||||
@classmethod
|
||||
def create_average_word_embedding_classifier(
|
||||
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
|
||||
options: text_classifier_options.TextClassifierOptions,
|
||||
label_names: Sequence[str]) -> "_AverageWordEmbeddingClassifier":
|
||||
"""Creates, trains, and returns an Average Word Embedding classifier.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
options: Options for creating and training the text classifier.
|
||||
label_names: Label names used in the data.
|
||||
|
||||
Returns:
|
||||
An Average Word Embedding classifier.
|
||||
"""
|
||||
average_word_embedding_classifier = _AverageWordEmbeddingClassifier(
|
||||
model_spec=options.supported_model.value(),
|
||||
model_options=options.model_options,
|
||||
hparams=options.hparams,
|
||||
label_names=train_data.label_names)
|
||||
average_word_embedding_classifier._create_and_train_model(
|
||||
train_data, validation_data)
|
||||
return average_word_embedding_classifier
|
||||
|
||||
def _create_and_train_model(self, train_data: text_ds.Dataset,
|
||||
validation_data: text_ds.Dataset):
|
||||
"""Creates the Average Word Embedding classifier keras model and trains it.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
"""
|
||||
(processed_train_data, processed_validation_data) = (
|
||||
self._load_and_run_preprocessor(train_data, validation_data))
|
||||
self._create_model()
|
||||
self._optimizer = "rmsprop"
|
||||
self._train_model(processed_train_data, processed_validation_data)
|
||||
|
||||
def _load_and_run_preprocessor(
|
||||
self, train_data: text_ds.Dataset, validation_data: text_ds.Dataset
|
||||
) -> Tuple[text_ds.Dataset, text_ds.Dataset]:
|
||||
"""Runs an AverageWordEmbeddingClassifierPreprocessor on the data.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
|
||||
Returns:
|
||||
Preprocessed training data and preprocessed validation data.
|
||||
"""
|
||||
train_texts = [text.numpy()[0] for text, _ in train_data.gen_tf_dataset()]
|
||||
validation_texts = [
|
||||
text.numpy()[0] for text, _ in validation_data.gen_tf_dataset()
|
||||
]
|
||||
self._text_preprocessor = (
|
||||
preprocessor.AverageWordEmbeddingClassifierPreprocessor(
|
||||
seq_len=self._model_options.seq_len,
|
||||
do_lower_case=self._model_options.do_lower_case,
|
||||
texts=train_texts + validation_texts,
|
||||
vocab_size=self._model_options.vocab_size))
|
||||
return self._text_preprocessor.preprocess(
|
||||
train_data), self._text_preprocessor.preprocess(validation_data)
|
||||
|
||||
def _create_model(self):
|
||||
"""Creates an Average Word Embedding model."""
|
||||
self._model = tf.keras.Sequential([
|
||||
tf.keras.layers.InputLayer(
|
||||
input_shape=[self._model_options.seq_len], dtype=tf.int32),
|
||||
tf.keras.layers.Embedding(
|
||||
len(self._text_preprocessor.get_vocab()),
|
||||
self._model_options.wordvec_dim,
|
||||
input_length=self._model_options.seq_len),
|
||||
tf.keras.layers.GlobalAveragePooling1D(),
|
||||
tf.keras.layers.Dense(
|
||||
self._model_options.wordvec_dim, activation=tf.nn.relu),
|
||||
tf.keras.layers.Dropout(self._model_options.dropout_rate),
|
||||
tf.keras.layers.Dense(self._num_classes, activation="softmax")
|
||||
])
|
||||
|
||||
def _save_vocab(self, vocab_filepath: str):
|
||||
with tf.io.gfile.GFile(vocab_filepath, "w") as f:
|
||||
for token, index in self._text_preprocessor.get_vocab().items():
|
||||
f.write(f"{token} {index}\n")
|
||||
|
||||
def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str):
|
||||
return text_classifier_writer.MetadataWriter.create_for_regex_model(
|
||||
model_buffer=tflite_model,
|
||||
regex_tokenizer=metadata_writer.RegexTokenizer(
|
||||
# TODO: Align with MediaPipe's RegexTokenizer.
|
||||
delim_regex_pattern=self._DELIM_REGEX_PATTERN,
|
||||
vocab_file_path=vocab_filepath),
|
||||
labels=metadata_writer.Labels().add(list(self._label_names)))
|
||||
|
||||
|
||||
class _BertClassifier(TextClassifier):
|
||||
"""APIs to help create and train a BERT-based text classifier."""
|
||||
|
||||
_INITIALIZER_RANGE = 0.02
|
||||
|
||||
def __init__(self, model_spec: ms.BertClassifierSpec,
|
||||
model_options: mo.BertClassifierOptions, hparams: hp.BaseHParams,
|
||||
label_names: Sequence[str]):
|
||||
super().__init__(model_spec, hparams, label_names)
|
||||
self._model_options = model_options
|
||||
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy(
|
||||
"test_accuracy", dtype=tf.float32)
|
||||
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
||||
|
||||
@classmethod
|
||||
def create_bert_classifier(
|
||||
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
|
||||
options: text_classifier_options.TextClassifierOptions,
|
||||
label_names: Sequence[str]) -> "_BertClassifier":
|
||||
"""Creates, trains, and returns a BERT-based classifier.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
options: Options for creating and training the text classifier.
|
||||
label_names: Label names used in the data.
|
||||
|
||||
Returns:
|
||||
A BERT-based classifier.
|
||||
"""
|
||||
bert_classifier = _BertClassifier(
|
||||
model_spec=options.supported_model.value(),
|
||||
model_options=options.model_options,
|
||||
hparams=options.hparams,
|
||||
label_names=train_data.label_names)
|
||||
bert_classifier._create_and_train_model(train_data, validation_data)
|
||||
return bert_classifier
|
||||
|
||||
def _create_and_train_model(self, train_data: text_ds.Dataset,
|
||||
validation_data: text_ds.Dataset):
|
||||
"""Creates the BERT-based classifier keras model and trains it.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
"""
|
||||
(processed_train_data, processed_validation_data) = (
|
||||
self._load_and_run_preprocessor(train_data, validation_data))
|
||||
self._create_model()
|
||||
self._create_optimizer(processed_train_data)
|
||||
self._train_model(processed_train_data, processed_validation_data)
|
||||
|
||||
def _load_and_run_preprocessor(
|
||||
self, train_data: text_ds.Dataset, validation_data: text_ds.Dataset
|
||||
) -> Tuple[text_ds.Dataset, text_ds.Dataset]:
|
||||
"""Loads a BertClassifierPreprocessor and runs it on the data.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
|
||||
Returns:
|
||||
Preprocessed training data and preprocessed validation data.
|
||||
"""
|
||||
self._text_preprocessor = preprocessor.BertClassifierPreprocessor(
|
||||
seq_len=self._model_options.seq_len,
|
||||
do_lower_case=self._model_spec.do_lower_case,
|
||||
uri=self._model_spec.uri)
|
||||
return (self._text_preprocessor.preprocess(train_data),
|
||||
self._text_preprocessor.preprocess(validation_data))
|
||||
|
||||
def _create_model(self):
|
||||
"""Creates a BERT-based classifier model.
|
||||
|
||||
The model architecture consists of stacking a dense classification layer and
|
||||
dropout layer on top of the BERT encoder outputs.
|
||||
"""
|
||||
encoder_inputs = dict(
|
||||
input_word_ids=tf.keras.layers.Input(
|
||||
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
||||
input_mask=tf.keras.layers.Input(
|
||||
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
||||
input_type_ids=tf.keras.layers.Input(
|
||||
shape=(self._model_options.seq_len,), dtype=tf.int32),
|
||||
)
|
||||
encoder = hub.KerasLayer(
|
||||
self._model_spec.uri, trainable=self._model_options.do_fine_tuning)
|
||||
encoder_outputs = encoder(encoder_inputs)
|
||||
pooled_output = encoder_outputs["pooled_output"]
|
||||
|
||||
output = tf.keras.layers.Dropout(rate=self._model_options.dropout_rate)(
|
||||
pooled_output)
|
||||
initializer = tf.keras.initializers.TruncatedNormal(
|
||||
stddev=self._INITIALIZER_RANGE)
|
||||
output = tf.keras.layers.Dense(
|
||||
self._num_classes,
|
||||
kernel_initializer=initializer,
|
||||
name="output",
|
||||
activation="softmax",
|
||||
dtype=tf.float32)(
|
||||
output)
|
||||
self._model = tf.keras.Model(inputs=encoder_inputs, outputs=output)
|
||||
|
||||
def _create_optimizer(self, train_data: text_ds.Dataset):
|
||||
"""Loads an optimizer with a learning rate schedule.
|
||||
|
||||
The decay steps in the learning rate schedule depend on the
|
||||
`steps_per_epoch` which may depend on the size of the training data.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
"""
|
||||
self._hparams.steps_per_epoch = model_util.get_steps_per_epoch(
|
||||
steps_per_epoch=self._hparams.steps_per_epoch,
|
||||
batch_size=self._hparams.batch_size,
|
||||
train_data=train_data)
|
||||
total_steps = self._hparams.steps_per_epoch * self._hparams.epochs
|
||||
warmup_steps = int(total_steps * 0.1)
|
||||
initial_lr = self._hparams.learning_rate
|
||||
self._optimizer = optimization.create_optimizer(initial_lr, total_steps,
|
||||
warmup_steps)
|
||||
|
||||
def _save_vocab(self, vocab_filepath: str):
|
||||
tf.io.gfile.copy(
|
||||
self._text_preprocessor.get_vocab_file(),
|
||||
vocab_filepath,
|
||||
overwrite=True)
|
||||
|
||||
def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str):
|
||||
return text_classifier_writer.MetadataWriter.create_for_bert_model(
|
||||
model_buffer=tflite_model,
|
||||
tokenizer=metadata_writer.BertTokenizer(vocab_filepath),
|
||||
labels=metadata_writer.Labels().add(list(self._label_names)),
|
||||
ids_name=self._model_spec.tflite_input_name["ids"],
|
||||
mask_name=self._model_spec.tflite_input_name["mask"],
|
||||
segment_name=self._model_spec.tflite_input_name["segment_ids"])
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Demo for making a text classifier model by MediaPipe Model Maker."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
# Dependency imports
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier
|
||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def define_flags():
|
||||
flags.DEFINE_string('export_dir', None,
|
||||
'The directory to save exported files.')
|
||||
flags.DEFINE_enum('supported_model', 'average_word_embedding',
|
||||
['average_word_embedding', 'bert'],
|
||||
'The text classifier to run.')
|
||||
flags.mark_flag_as_required('export_dir')
|
||||
|
||||
|
||||
def download_demo_data():
|
||||
"""Downloads demo data, and returns directory path."""
|
||||
data_path = tf.keras.utils.get_file(
|
||||
fname='SST-2.zip',
|
||||
origin='https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',
|
||||
extract=True)
|
||||
return os.path.join(os.path.dirname(data_path), 'SST-2') # folder name
|
||||
|
||||
|
||||
def run(data_dir,
|
||||
export_dir=tempfile.mkdtemp(),
|
||||
supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
||||
"""Runs demo."""
|
||||
|
||||
# Gets training data and validation data.
|
||||
csv_params = text_ds.CSVParameters(
|
||||
text_column='sentence', label_column='label', delimiter='\t')
|
||||
train_data = text_ds.Dataset.from_csv(
|
||||
filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
|
||||
csv_params=csv_params)
|
||||
validation_data = text_ds.Dataset.from_csv(
|
||||
filename=os.path.join(os.path.join(data_dir, 'dev.tsv')),
|
||||
csv_params=csv_params)
|
||||
|
||||
quantization_config = None
|
||||
if supported_model == ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER:
|
||||
hparams = hp.BaseHParams(
|
||||
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir)
|
||||
# Warning: This takes extremely long to run on CPU
|
||||
elif supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
|
||||
quantization_config = quantization.QuantizationConfig.for_dynamic()
|
||||
hparams = hp.BaseHParams(
|
||||
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir)
|
||||
|
||||
# Fine-tunes the model.
|
||||
options = text_classifier_options.TextClassifierOptions(
|
||||
supported_model=supported_model, hparams=hparams)
|
||||
model = text_classifier.TextClassifier.create(train_data, validation_data,
|
||||
options)
|
||||
|
||||
# Gets evaluation results.
|
||||
_, acc = model.evaluate(validation_data)
|
||||
print('Eval accuracy: %f' % acc)
|
||||
|
||||
model.export_model(quantization_config=quantization_config)
|
||||
model.export_labels(export_dir=options.hparams.export_dir)
|
||||
|
||||
|
||||
def main(_):
|
||||
logging.set_verbosity(logging.INFO)
|
||||
data_dir = download_demo_data()
|
||||
export_dir = os.path.expanduser(FLAGS.export_dir)
|
||||
|
||||
if FLAGS.supported_model == 'average_word_embedding':
|
||||
supported_model = ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
|
||||
elif FLAGS.supported_model == 'bert':
|
||||
supported_model = ms.SupportedModels.MOBILEBERT_CLASSIFIER
|
||||
|
||||
run(data_dir, export_dir, supported_model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
define_flags()
|
||||
app.run(main)
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""User-facing customization options to create and train a text classifier."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TextClassifierOptions:
|
||||
"""User-facing options for creating the text classifier.
|
||||
|
||||
Attributes:
|
||||
supported_model: A preconfigured model spec.
|
||||
hparams: Training hyperparameters the user can set to override the ones in
|
||||
`supported_model`.
|
||||
model_options: Model options the user can set to override the ones in
|
||||
`supported_model`. The model options type should be consistent with the
|
||||
architecture of the `supported_model`.
|
||||
"""
|
||||
supported_model: ms.SupportedModels
|
||||
hparams: Optional[hp.BaseHParams] = None
|
||||
model_options: Optional[mo.TextClassifierModelOptions] = None
|
|
@ -0,0 +1,138 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import csv
|
||||
import filecmp
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.text.text_classifier import dataset
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier
|
||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
|
||||
class TextClassifierTest(tf.test.TestCase):
|
||||
|
||||
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
|
||||
test_utils.get_test_data_path('average_word_embedding_metadata.json'))
|
||||
|
||||
def _get_data(self):
|
||||
labels_and_text = (('pos', 'super good'), (('neg', 'really bad')))
|
||||
csv_file = os.path.join(self.get_temp_dir(), 'data.csv')
|
||||
if os.path.exists(csv_file):
|
||||
return csv_file
|
||||
fieldnames = ['text', 'label']
|
||||
with open(csv_file, 'w') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
for label, text in labels_and_text:
|
||||
writer.writerow({'text': text, 'label': label})
|
||||
csv_params = dataset.CSVParameters(text_column='text', label_column='label')
|
||||
all_data = dataset.Dataset.from_csv(
|
||||
filename=csv_file, csv_params=csv_params)
|
||||
return all_data.split(0.5)
|
||||
|
||||
def test_create_and_train_average_word_embedding_model(self):
|
||||
train_data, validation_data = self._get_data()
|
||||
options = text_classifier_options.TextClassifierOptions(
|
||||
supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER,
|
||||
hparams=hp.BaseHParams(epochs=1, batch_size=1, learning_rate=0))
|
||||
average_word_embedding_classifier = text_classifier.TextClassifier.create(
|
||||
train_data, validation_data, options)
|
||||
|
||||
_, accuracy = average_word_embedding_classifier.evaluate(validation_data)
|
||||
self.assertGreaterEqual(accuracy, 0.0)
|
||||
|
||||
# Test export_model
|
||||
average_word_embedding_classifier.export_model()
|
||||
output_metadata_file = os.path.join(options.hparams.export_dir,
|
||||
'metadata.json')
|
||||
output_tflite_file = os.path.join(options.hparams.export_dir,
|
||||
'model.tflite')
|
||||
|
||||
self.assertTrue(os.path.exists(output_tflite_file))
|
||||
self.assertGreater(os.path.getsize(output_tflite_file), 0)
|
||||
|
||||
self.assertTrue(os.path.exists(output_metadata_file))
|
||||
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||
self.assertTrue(
|
||||
filecmp.cmp(output_metadata_file,
|
||||
self._AVERAGE_WORD_EMBEDDING_JSON_FILE))
|
||||
|
||||
def test_create_and_train_bert(self):
|
||||
train_data, validation_data = self._get_data()
|
||||
options = text_classifier_options.TextClassifierOptions(
|
||||
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||
model_options=mo.BertClassifierOptions(do_fine_tuning=False, seq_len=2),
|
||||
hparams=hp.BaseHParams(
|
||||
epochs=1,
|
||||
batch_size=1,
|
||||
learning_rate=3e-5,
|
||||
distribution_strategy='off'))
|
||||
bert_classifier = text_classifier.TextClassifier.create(
|
||||
train_data, validation_data, options)
|
||||
|
||||
_, accuracy = bert_classifier.evaluate(validation_data)
|
||||
self.assertGreaterEqual(accuracy, 0.0)
|
||||
# TODO: Add a unit test that does not run OOM.
|
||||
|
||||
def test_label_mismatch(self):
|
||||
options = (
|
||||
text_classifier_options.TextClassifierOptions(
|
||||
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER))
|
||||
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||
train_data = dataset.Dataset(train_tf_dataset, 1, ['foo'])
|
||||
validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||
validation_data = dataset.Dataset(validation_tf_dataset, 1, ['bar'])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Training data label names .* not equal to validation data label names'
|
||||
):
|
||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||
options)
|
||||
|
||||
def test_options_mismatch(self):
|
||||
train_data, validation_data = self._get_data()
|
||||
|
||||
avg_options = (
|
||||
text_classifier_options.TextClassifierOptions(
|
||||
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||
model_options=mo.AverageWordEmbeddingClassifierOptions()))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
|
||||
' SupportedModels.MOBILEBERT_CLASSIFIER'):
|
||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||
avg_options)
|
||||
|
||||
bert_options = (
|
||||
text_classifier_options.TextClassifierOptions(
|
||||
supported_model=(
|
||||
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
|
||||
model_options=mo.BertClassifierOptions()))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got'
|
||||
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):
|
||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||
bert_options)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Load compressed models from tensorflow_hub
|
||||
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
|
||||
tf.test.main()
|
165
mediapipe/model_maker/python/vision/gesture_recognizer/BUILD
Normal file
|
@ -0,0 +1,165 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
# TODO: Remove the unncessary test data once the demo data are moved to an open-sourced
|
||||
# directory.
|
||||
filegroup(
|
||||
name = "test_data",
|
||||
srcs = glob([
|
||||
"test_data/**",
|
||||
]),
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "constants",
|
||||
srcs = ["constants.py"],
|
||||
)
|
||||
|
||||
# TODO: Change to py_library after migrating the MediaPipe hand solution
|
||||
# library to MediaPipe hand task library.
|
||||
py_library(
|
||||
name = "dataset",
|
||||
srcs = ["dataset.py"],
|
||||
deps = [
|
||||
":constants",
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
"//mediapipe/model_maker/python/core/data:data_util",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
"//mediapipe/python/solutions:hands",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dataset_test",
|
||||
srcs = ["dataset_test.py"],
|
||||
data = [
|
||||
":test_data",
|
||||
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
||||
],
|
||||
deps = [
|
||||
":dataset",
|
||||
"//mediapipe/python/solutions:hands",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "hyperparameters",
|
||||
srcs = ["hyperparameters.py"],
|
||||
deps = [
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_options",
|
||||
srcs = ["model_options.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gesture_recognizer_options",
|
||||
srcs = ["gesture_recognizer_options.py"],
|
||||
deps = [
|
||||
":hyperparameters",
|
||||
":model_options",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gesture_recognizer",
|
||||
srcs = ["gesture_recognizer.py"],
|
||||
data = ["//mediapipe/model_maker/models/gesture_recognizer:models"],
|
||||
deps = [
|
||||
":constants",
|
||||
":gesture_recognizer_options",
|
||||
":hyperparameters",
|
||||
":metadata_writer",
|
||||
":model_options",
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||
"//mediapipe/model_maker/python/core/utils:loss_functions",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gesture_recognizer_import",
|
||||
srcs = ["__init__.py"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":gesture_recognizer",
|
||||
":gesture_recognizer_options",
|
||||
":hyperparameters",
|
||||
":model_options",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "metadata_writer",
|
||||
srcs = ["metadata_writer.py"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:writer_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "gesture_recognizer_test",
|
||||
size = "large",
|
||||
srcs = ["gesture_recognizer_test.py"],
|
||||
data = [
|
||||
":test_data",
|
||||
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
||||
],
|
||||
shard_count = 2,
|
||||
deps = [
|
||||
":gesture_recognizer_import",
|
||||
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "metadata_writer_test",
|
||||
srcs = ["metadata_writer_test.py"],
|
||||
data = [
|
||||
":test_data",
|
||||
],
|
||||
deps = [
|
||||
":metadata_writer",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "gesture_recognizer_demo",
|
||||
srcs = ["gesture_recognizer_demo.py"],
|
||||
data = [
|
||||
":test_data",
|
||||
"//mediapipe/model_maker/models/gesture_recognizer:models",
|
||||
],
|
||||
python_version = "PY3",
|
||||
deps = [":gesture_recognizer_import"],
|
||||
)
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MediaPipe Model Maker Python Public API For Gesture Recognizer."""
|
||||
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import dataset
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import gesture_recognizer
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import gesture_recognizer_options
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import model_options
|
||||
|
||||
GestureRecognizer = gesture_recognizer.GestureRecognizer
|
||||
ModelOptions = model_options.GestureRecognizerModelOptions
|
||||
HParams = hyperparameters.HParams
|
||||
Dataset = dataset.Dataset
|
||||
HandDataPreprocessingParams = dataset.HandDataPreprocessingParams
|
||||
GestureRecognizerOptions = gesture_recognizer_options.GestureRecognizerOptions
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Gesture recognition constants."""
|
||||
|
||||
GESTURE_EMBEDDER_KERAS_MODEL_PATH = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder'
|
||||
GESTURE_EMBEDDER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder.tflite'
|
||||
HAND_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/palm_detection_full.tflite'
|
||||
HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/hand_landmark_full.tflite'
|
||||
CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/canned_gesture_classifier.tflite'
|
|
@ -0,0 +1,238 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Gesture recognition dataset library."""
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import random
|
||||
from typing import List, NamedTuple, Optional
|
||||
|
||||
import cv2
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||
from mediapipe.model_maker.python.core.data import data_util
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import constants
|
||||
from mediapipe.python.solutions import hands as mp_hands
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HandDataPreprocessingParams:
|
||||
"""A dataclass wraps the hand data preprocessing hyperparameters.
|
||||
|
||||
Attributes:
|
||||
shuffle: A boolean controlling if shuffle the dataset. Default to true.
|
||||
min_detection_confidence: confidence threshold for hand detection.
|
||||
"""
|
||||
shuffle: bool = True
|
||||
min_detection_confidence: float = 0.7
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HandData:
|
||||
"""A dataclass represents hand data for training gesture recognizer model.
|
||||
|
||||
See https://google.github.io/mediapipe/solutions/hands#mediapipe-hands for
|
||||
more details of the hand gesture data API.
|
||||
|
||||
Attributes:
|
||||
hand: normalized hand landmarks of shape 21x3 from the screen based
|
||||
hand-landmark model.
|
||||
world_hand: hand landmarks of shape 21x3 in world coordinates.
|
||||
handedness: Collection of handedness confidence of the detected hands (i.e.
|
||||
is it a left or right hand).
|
||||
"""
|
||||
hand: List[List[float]]
|
||||
world_hand: List[List[float]]
|
||||
handedness: List[float]
|
||||
|
||||
|
||||
def _validate_data_sample(data: NamedTuple) -> bool:
|
||||
"""Validates the input hand data sample.
|
||||
|
||||
Args:
|
||||
data: input hand data sample.
|
||||
|
||||
Returns:
|
||||
False if the input data namedtuple does not contain the fields including
|
||||
'multi_hand_landmarks' or 'multi_hand_world_landmarks' or 'multi_handedness'
|
||||
or any of these attributes' values are none. Otherwise, True.
|
||||
"""
|
||||
if (not hasattr(data, 'multi_hand_landmarks') or
|
||||
data.multi_hand_landmarks is None):
|
||||
return False
|
||||
if (not hasattr(data, 'multi_hand_world_landmarks') or
|
||||
data.multi_hand_world_landmarks is None):
|
||||
return False
|
||||
if not hasattr(data, 'multi_handedness') or data.multi_handedness is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _get_hand_data(all_image_paths: List[str],
|
||||
min_detection_confidence: float) -> Optional[HandData]:
|
||||
"""Computes hand data (landmarks and handedness) in the input image.
|
||||
|
||||
Args:
|
||||
all_image_paths: all input image paths.
|
||||
min_detection_confidence: hand detection confidence threshold
|
||||
|
||||
Returns:
|
||||
A HandData object. Returns None if no hand is detected.
|
||||
"""
|
||||
hand_data_result = []
|
||||
with mp_hands.Hands(
|
||||
static_image_mode=True,
|
||||
max_num_hands=1,
|
||||
min_detection_confidence=min_detection_confidence) as hands:
|
||||
for path in all_image_paths:
|
||||
tf.compat.v1.logging.info('Loading image %s', path)
|
||||
image = data_util.load_image(path)
|
||||
# Flip image around y-axis for correct handedness output
|
||||
image = cv2.flip(image, 1)
|
||||
data = hands.process(image)
|
||||
if not _validate_data_sample(data):
|
||||
hand_data_result.append(None)
|
||||
continue
|
||||
hand_landmarks = [[
|
||||
hand_landmark.x, hand_landmark.y, hand_landmark.z
|
||||
] for hand_landmark in data.multi_hand_landmarks[0].landmark]
|
||||
hand_world_landmarks = [[
|
||||
hand_landmark.x, hand_landmark.y, hand_landmark.z
|
||||
] for hand_landmark in data.multi_hand_world_landmarks[0].landmark]
|
||||
handedness_scores = [
|
||||
handedness.score
|
||||
for handedness in data.multi_handedness[0].classification
|
||||
]
|
||||
hand_data_result.append(
|
||||
HandData(
|
||||
hand=hand_landmarks,
|
||||
world_hand=hand_world_landmarks,
|
||||
handedness=handedness_scores))
|
||||
return hand_data_result
|
||||
|
||||
|
||||
class Dataset(classification_dataset.ClassificationDataset):
|
||||
"""Dataset library for hand gesture recognizer."""
|
||||
|
||||
@classmethod
|
||||
def from_folder(
|
||||
cls,
|
||||
dirname: str,
|
||||
hparams: Optional[HandDataPreprocessingParams] = None
|
||||
) -> classification_dataset.ClassificationDataset:
|
||||
"""Loads images and labels from the given directory.
|
||||
|
||||
Directory contents are expected to be in the format:
|
||||
<root_dir>/<gesture_name>/*.jpg". One of the `gesture_name` must be `none`
|
||||
(case insensitive). The `none` sub-directory is expected to contain images
|
||||
of hands that don't belong to other gesture classes in <root_dir>. Assumes
|
||||
the image data of the same label are in the same subdirectory.
|
||||
|
||||
Args:
|
||||
dirname: Name of the directory containing the data files.
|
||||
hparams: Optional hyperparameters for processing input hand gesture
|
||||
images.
|
||||
|
||||
Returns:
|
||||
Dataset containing landmarks, labels, and other related info.
|
||||
|
||||
Raises:
|
||||
ValueError: if the input data directory is empty or the label set does not
|
||||
contain label 'none' (case insensitive).
|
||||
"""
|
||||
data_root = os.path.abspath(dirname)
|
||||
|
||||
# Assumes the image data of the same label are in the same subdirectory,
|
||||
# gets image path and label names.
|
||||
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
|
||||
if not all_image_paths:
|
||||
raise ValueError('Image dataset directory is empty.')
|
||||
|
||||
if not hparams:
|
||||
hparams = HandDataPreprocessingParams()
|
||||
|
||||
if hparams.shuffle:
|
||||
# Random shuffle data.
|
||||
random.shuffle(all_image_paths)
|
||||
|
||||
label_names = sorted(
|
||||
name for name in os.listdir(data_root)
|
||||
if os.path.isdir(os.path.join(data_root, name)))
|
||||
if 'none' not in [v.lower() for v in label_names]:
|
||||
raise ValueError('Label set does not contain label "None".')
|
||||
# Move label 'none' to the front of label list.
|
||||
none_idx = [v.lower() for v in label_names].index('none')
|
||||
none_value = label_names.pop(none_idx)
|
||||
label_names.insert(0, none_value)
|
||||
|
||||
index_by_label = dict(
|
||||
(name, index) for index, name in enumerate(label_names))
|
||||
all_gesture_indices = [
|
||||
index_by_label[os.path.basename(os.path.dirname(path))]
|
||||
for path in all_image_paths
|
||||
]
|
||||
|
||||
# Compute hand data (including local hand landmark, world hand landmark, and
|
||||
# handedness) for all the input images.
|
||||
hand_data = _get_hand_data(
|
||||
all_image_paths=all_image_paths,
|
||||
min_detection_confidence=hparams.min_detection_confidence)
|
||||
|
||||
# Get a list of the valid hand landmark sample in the hand data list.
|
||||
valid_indices = [
|
||||
i for i in range(len(hand_data)) if hand_data[i] is not None
|
||||
]
|
||||
# Remove 'None' element from the hand data and label list.
|
||||
valid_hand_data = [dataclasses.asdict(hand_data[i]) for i in valid_indices]
|
||||
if not valid_hand_data:
|
||||
raise ValueError('No valid hand is detected.')
|
||||
|
||||
valid_label = [all_gesture_indices[i] for i in valid_indices]
|
||||
|
||||
# Convert list of dictionaries to dictionary of lists.
|
||||
hand_data_dict = {
|
||||
k: [lm[k] for lm in valid_hand_data] for k in valid_hand_data[0]
|
||||
}
|
||||
hand_ds = tf.data.Dataset.from_tensor_slices(hand_data_dict)
|
||||
|
||||
embedder_model = model_util.load_keras_model(
|
||||
constants.GESTURE_EMBEDDER_KERAS_MODEL_PATH)
|
||||
|
||||
hand_ds = hand_ds.batch(batch_size=1)
|
||||
hand_embedding_ds = hand_ds.map(
|
||||
map_func=lambda feature: embedder_model(dict(feature)),
|
||||
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
hand_embedding_ds = hand_embedding_ds.unbatch()
|
||||
|
||||
# Create label dataset
|
||||
label_ds = tf.data.Dataset.from_tensor_slices(
|
||||
tf.cast(valid_label, tf.int64))
|
||||
|
||||
label_one_hot_ds = label_ds.map(
|
||||
map_func=lambda index: tf.one_hot(index, len(label_names)),
|
||||
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
|
||||
# Create a dataset with (hand_embedding, one_hot_label) pairs
|
||||
hand_embedding_label_ds = tf.data.Dataset.zip(
|
||||
(hand_embedding_ds, label_one_hot_ds))
|
||||
|
||||
tf.compat.v1.logging.info(
|
||||
'Load valid hands with size: {}, num_label: {}, labels: {}.'.format(
|
||||
len(valid_hand_data), len(label_names), ','.join(label_names)))
|
||||
return Dataset(
|
||||
dataset=hand_embedding_label_ds,
|
||||
size=len(valid_hand_data),
|
||||
label_names=label_names)
|
|
@ -0,0 +1,161 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.s
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import os
|
||||
import shutil
|
||||
from typing import NamedTuple
|
||||
import unittest
|
||||
|
||||
from absl import flags
|
||||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import dataset
|
||||
from mediapipe.python.solutions import hands as mp_hands
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
_TEST_DATA_DIRNAME = 'raw_data'
|
||||
|
||||
|
||||
class DatasetTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_split(self):
|
||||
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||
data = dataset.Dataset.from_folder(
|
||||
dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams())
|
||||
train_data, test_data = data.split(0.5)
|
||||
|
||||
self.assertLen(train_data, 17)
|
||||
for _, elem in enumerate(train_data.gen_tf_dataset(is_training=True)):
|
||||
self.assertEqual(elem[0].shape, (1, 128))
|
||||
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||
self.assertEqual(train_data.num_classes, 4)
|
||||
self.assertEqual(train_data.label_names, ['none', 'call', 'four', 'rock'])
|
||||
|
||||
self.assertLen(test_data, 18)
|
||||
for _, elem in enumerate(test_data.gen_tf_dataset(is_training=True)):
|
||||
self.assertEqual(elem[0].shape, (1, 128))
|
||||
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||
self.assertEqual(test_data.num_classes, 4)
|
||||
self.assertEqual(test_data.label_names, ['none', 'call', 'four', 'rock'])
|
||||
|
||||
def test_from_folder(self):
|
||||
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||
data = dataset.Dataset.from_folder(
|
||||
dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams())
|
||||
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
|
||||
self.assertEqual(elem[0].shape, (1, 128))
|
||||
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||
self.assertLen(data, 35)
|
||||
self.assertEqual(data.num_classes, 4)
|
||||
self.assertEqual(data.label_names, ['none', 'call', 'four', 'rock'])
|
||||
|
||||
def test_create_dataset_from_empty_folder_raise_value_error(self):
|
||||
with self.assertRaisesRegex(ValueError, 'Image dataset directory is empty'):
|
||||
dataset.Dataset.from_folder(
|
||||
dirname=self.get_temp_dir(),
|
||||
hparams=dataset.HandDataPreprocessingParams())
|
||||
|
||||
def test_create_dataset_from_folder_without_none_raise_value_error(self):
|
||||
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||
tmp_dir = self.create_tempdir()
|
||||
# Copy input dataset to a temporary directory and skip 'None' directory.
|
||||
for name in os.listdir(input_data_dir):
|
||||
if name == 'none':
|
||||
continue
|
||||
src_dir = os.path.join(input_data_dir, name)
|
||||
dst_dir = os.path.join(tmp_dir, name)
|
||||
shutil.copytree(src_dir, dst_dir)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'Label set does not contain label "None"'):
|
||||
dataset.Dataset.from_folder(
|
||||
dirname=tmp_dir, hparams=dataset.HandDataPreprocessingParams())
|
||||
|
||||
def test_create_dataset_from_folder_with_capital_letter_in_folder_name(self):
|
||||
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||
tmp_dir = self.create_tempdir()
|
||||
# Copy input dataset to a temporary directory and change the base folder
|
||||
# name to upper case letter, e.g. 'none' -> 'NONE'
|
||||
for name in os.listdir(input_data_dir):
|
||||
src_dir = os.path.join(input_data_dir, name)
|
||||
dst_dir = os.path.join(tmp_dir, name.upper())
|
||||
shutil.copytree(src_dir, dst_dir)
|
||||
|
||||
upper_base_folder_name = list(os.listdir(tmp_dir))
|
||||
self.assertCountEqual(upper_base_folder_name,
|
||||
['CALL', 'FOUR', 'NONE', 'ROCK'])
|
||||
|
||||
data = dataset.Dataset.from_folder(
|
||||
dirname=tmp_dir, hparams=dataset.HandDataPreprocessingParams())
|
||||
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
|
||||
self.assertEqual(elem[0].shape, (1, 128))
|
||||
self.assertEqual(elem[1].shape, ([1, 4]))
|
||||
self.assertLen(data, 35)
|
||||
self.assertEqual(data.num_classes, 4)
|
||||
self.assertEqual(data.label_names, ['NONE', 'CALL', 'FOUR', 'ROCK'])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='invalid_field_name_multi_hand_landmark',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmark', 'multi_hand_world_landmarks',
|
||||
'multi_handedness'
|
||||
])(1, 2, 3)),
|
||||
dict(
|
||||
testcase_name='invalid_field_name_multi_hand_world_landmarks',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmarks', 'multi_hand_world_landmark',
|
||||
'multi_handedness'
|
||||
])(1, 2, 3)),
|
||||
dict(
|
||||
testcase_name='invalid_field_name_multi_handed',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
||||
'multi_handed'
|
||||
])(1, 2, 3)),
|
||||
dict(
|
||||
testcase_name='multi_hand_landmarks_is_none',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
||||
'multi_handedness'
|
||||
])(None, 2, 3)),
|
||||
dict(
|
||||
testcase_name='multi_hand_world_landmarks_is_none',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
||||
'multi_handedness'
|
||||
])(1, None, 3)),
|
||||
dict(
|
||||
testcase_name='multi_handedness_is_none',
|
||||
hand=collections.namedtuple('Hand', [
|
||||
'multi_hand_landmarks', 'multi_hand_world_landmarks',
|
||||
'multi_handedness'
|
||||
])(1, 2, None)),
|
||||
)
|
||||
def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple):
|
||||
with unittest.mock.patch.object(
|
||||
mp_hands.Hands, 'process', return_value=hand):
|
||||
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
|
||||
with self.assertRaisesRegex(ValueError, 'No valid hand is detected'):
|
||||
dataset.Dataset.from_folder(
|
||||
dirname=input_data_dir,
|
||||
hparams=dataset.HandDataPreprocessingParams())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,239 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""APIs to train gesture recognizer model."""
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
|
||||
from mediapipe.model_maker.python.core.tasks import classifier
|
||||
from mediapipe.model_maker.python.core.utils import loss_functions
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import constants
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import gesture_recognizer_options
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import model_options as model_opt
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer as base_metadata_writer
|
||||
|
||||
_EMBEDDING_SIZE = 128
|
||||
|
||||
|
||||
class GestureRecognizer(classifier.Classifier):
|
||||
"""GestureRecognizer for building hand gesture recognizer model.
|
||||
|
||||
Attributes:
|
||||
embedding_size: Size of the input gesture embedding vector.
|
||||
"""
|
||||
|
||||
def __init__(self, label_names: List[str],
|
||||
model_options: model_opt.GestureRecognizerModelOptions,
|
||||
hparams: hp.HParams):
|
||||
"""Initializes GestureRecognizer class.
|
||||
|
||||
Args:
|
||||
label_names: A list of label names for the classes.
|
||||
model_options: options to create gesture recognizer model.
|
||||
hparams: The hyperparameters for training hand gesture recognizer model.
|
||||
"""
|
||||
super().__init__(
|
||||
model_spec=None, label_names=label_names, shuffle=hparams.shuffle)
|
||||
self._model_options = model_options
|
||||
self._hparams = hparams
|
||||
self._history = None
|
||||
self.embedding_size = _EMBEDDING_SIZE
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
train_data: classification_ds.ClassificationDataset,
|
||||
validation_data: classification_ds.ClassificationDataset,
|
||||
options: gesture_recognizer_options.GestureRecognizerOptions,
|
||||
) -> 'GestureRecognizer':
|
||||
"""Creates and trains a hand gesture recognizer with input datasets.
|
||||
|
||||
If a checkpoint file exists in the {options.hparams.export_dir}/checkpoint/
|
||||
directory, the training process will load the weight from the checkpoint
|
||||
file for continual training.
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data. If None, skips validation process.
|
||||
options: options for creating and training gesture recognizer model.
|
||||
|
||||
Returns:
|
||||
An instance of GestureRecognizer.
|
||||
"""
|
||||
if options.model_options is None:
|
||||
options.model_options = model_opt.GestureRecognizerModelOptions()
|
||||
|
||||
if options.hparams is None:
|
||||
options.hparams = hp.HParams()
|
||||
|
||||
gesture_recognizer = cls(
|
||||
label_names=train_data.label_names,
|
||||
model_options=options.model_options,
|
||||
hparams=options.hparams)
|
||||
|
||||
gesture_recognizer._create_model()
|
||||
|
||||
train_dataset = train_data.gen_tf_dataset(
|
||||
batch_size=options.hparams.batch_size,
|
||||
is_training=True,
|
||||
shuffle=options.hparams.shuffle)
|
||||
options.hparams.steps_per_epoch = model_util.get_steps_per_epoch(
|
||||
steps_per_epoch=options.hparams.steps_per_epoch,
|
||||
batch_size=options.hparams.batch_size,
|
||||
train_data=train_data)
|
||||
train_dataset = train_dataset.take(count=options.hparams.steps_per_epoch)
|
||||
|
||||
validation_dataset = validation_data.gen_tf_dataset(
|
||||
batch_size=options.hparams.batch_size, is_training=False)
|
||||
|
||||
tf.compat.v1.logging.info('Training the gesture recognizer model...')
|
||||
gesture_recognizer._train(
|
||||
train_data=train_dataset, validation_data=validation_dataset)
|
||||
|
||||
return gesture_recognizer
|
||||
|
||||
def _train(self, train_data: tf.data.Dataset,
|
||||
validation_data: tf.data.Dataset):
|
||||
"""Trains the model with input train_data.
|
||||
|
||||
The training results are recorded by a self.History object returned by
|
||||
tf.keras.Model.fit().
|
||||
|
||||
Args:
|
||||
train_data: Training data.
|
||||
validation_data: Validation data.
|
||||
"""
|
||||
hparams = self._hparams
|
||||
|
||||
scheduler = lambda epoch: hparams.learning_rate * (hparams.lr_decay**epoch)
|
||||
scheduler_callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
|
||||
|
||||
job_dir = hparams.export_dir
|
||||
checkpoint_path = os.path.join(job_dir, 'epoch_models')
|
||||
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
|
||||
save_weights_only=True)
|
||||
|
||||
best_model_path = os.path.join(job_dir, 'best_model_weights')
|
||||
best_model_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
best_model_path,
|
||||
monitor='val_loss',
|
||||
mode='min',
|
||||
save_best_only=True,
|
||||
save_weights_only=True)
|
||||
|
||||
tensorboard_callback = tf.keras.callbacks.TensorBoard(
|
||||
log_dir=os.path.join(job_dir, 'logs'))
|
||||
|
||||
self._model.compile(
|
||||
optimizer='adam',
|
||||
loss=loss_functions.FocalLoss(gamma=self._hparams.gamma),
|
||||
metrics=['categorical_accuracy'])
|
||||
|
||||
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
|
||||
if latest_checkpoint:
|
||||
print(f'Resuming from {latest_checkpoint}')
|
||||
self._model.load_weights(latest_checkpoint)
|
||||
|
||||
self._history = self._model.fit(
|
||||
x=train_data,
|
||||
epochs=hparams.epochs,
|
||||
validation_data=validation_data,
|
||||
validation_freq=1,
|
||||
callbacks=[
|
||||
checkpoint_callback, best_model_callback, scheduler_callback,
|
||||
tensorboard_callback
|
||||
],
|
||||
)
|
||||
|
||||
def _create_model(self):
|
||||
"""Creates the hand gesture recognizer model.
|
||||
|
||||
The gesture embedding model is pretrained and loaded from a tf.saved_model.
|
||||
"""
|
||||
inputs = tf.keras.Input(
|
||||
shape=[self.embedding_size],
|
||||
batch_size=None,
|
||||
dtype=tf.float32,
|
||||
name='hand_embedding')
|
||||
|
||||
x = tf.keras.layers.BatchNormalization()(inputs)
|
||||
x = tf.keras.layers.ReLU()(x)
|
||||
dropout_rate = self._model_options.dropout_rate
|
||||
x = tf.keras.layers.Dropout(rate=dropout_rate, name='dropout')(x)
|
||||
outputs = tf.keras.layers.Dense(
|
||||
self._num_classes,
|
||||
activation='softmax',
|
||||
name='custom_gesture_recognizer')(
|
||||
x)
|
||||
|
||||
self._model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
|
||||
print(self._model.summary())
|
||||
|
||||
def export_model(self, model_name: str = 'gesture_recognizer.task'):
|
||||
"""Converts the model to TFLite and exports as a model bundle file.
|
||||
|
||||
Saves a model bundle file and metadata json file to hparams.export_dir. The
|
||||
resulting model bundle file will contain necessary models for hand
|
||||
detection, canned gesture classification, and customized gesture
|
||||
classification. Only the model bundle file is needed for the downstream
|
||||
gesture recognition task. The metadata.json file is saved only to
|
||||
interpret the contents of the model bundle file.
|
||||
|
||||
The customized gesture model is in float without quantization. The model is
|
||||
lightweight and there is no need to balance performance and efficiency by
|
||||
quantization. The default score_thresholding is set to 0.5 as it can be
|
||||
adjusted during inference.
|
||||
|
||||
Args:
|
||||
model_name: File name to save model bundle file. The full export path is
|
||||
{export_dir}/{model_name}.
|
||||
"""
|
||||
# TODO: Convert keras embedder model instead of using tflite
|
||||
gesture_embedding_model_buffer = model_util.load_tflite_model_buffer(
|
||||
constants.GESTURE_EMBEDDER_TFLITE_MODEL_FILE)
|
||||
hand_detector_model_buffer = model_util.load_tflite_model_buffer(
|
||||
constants.HAND_DETECTOR_TFLITE_MODEL_FILE)
|
||||
hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
|
||||
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE)
|
||||
canned_gesture_model_buffer = model_util.load_tflite_model_buffer(
|
||||
constants.CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE)
|
||||
|
||||
if not tf.io.gfile.exists(self._hparams.export_dir):
|
||||
tf.io.gfile.makedirs(self._hparams.export_dir)
|
||||
model_bundle_file = os.path.join(self._hparams.export_dir, model_name)
|
||||
metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json')
|
||||
|
||||
gesture_classifier_options = metadata_writer.GestureClassifierOptions(
|
||||
model_buffer=model_util.convert_to_tflite(self._model),
|
||||
labels=base_metadata_writer.Labels().add(list(self._label_names)),
|
||||
score_thresholding=base_metadata_writer.ScoreThresholding(
|
||||
global_score_threshold=0.5))
|
||||
|
||||
writer = metadata_writer.MetadataWriter.create(
|
||||
hand_detector_model_buffer, hand_landmarks_detector_model_buffer,
|
||||
gesture_embedding_model_buffer, canned_gesture_model_buffer,
|
||||
gesture_classifier_options)
|
||||
model_bundle_content, metadata_json = writer.populate()
|
||||
with open(model_bundle_file, 'wb') as f:
|
||||
f.write(model_bundle_content)
|
||||
with open(metadata_file, 'w') as f:
|
||||
f.write(metadata_json)
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Demo for making an gesture recognizer model by Mediapipe Model Maker."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
# Dependency imports
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
|
||||
from mediapipe.model_maker.python.vision import gesture_recognizer
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
# TODO: Move hand gesture recognizer demo dataset to an
|
||||
# open-sourced directory.
|
||||
TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data'
|
||||
|
||||
|
||||
def define_flags():
|
||||
flags.DEFINE_string('export_dir', None,
|
||||
'The directory to save exported files.')
|
||||
flags.DEFINE_string('input_data_dir', None,
|
||||
'The directory with input training data.')
|
||||
flags.mark_flag_as_required('export_dir')
|
||||
|
||||
|
||||
def run(data_dir: str, export_dir: str):
|
||||
"""Runs demo."""
|
||||
data = gesture_recognizer.Dataset.from_folder(dirname=data_dir)
|
||||
train_data, rest_data = data.split(0.8)
|
||||
validation_data, test_data = rest_data.split(0.5)
|
||||
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=train_data,
|
||||
validation_data=validation_data,
|
||||
options=gesture_recognizer.GestureRecognizerOptions(
|
||||
hparams=gesture_recognizer.HParams(export_dir=export_dir)))
|
||||
|
||||
metric = model.evaluate(test_data, batch_size=2)
|
||||
print('Evaluation metric')
|
||||
print(metric)
|
||||
|
||||
model.export_model()
|
||||
|
||||
|
||||
def main(_):
|
||||
logging.set_verbosity(logging.INFO)
|
||||
|
||||
if FLAGS.input_data_dir is None:
|
||||
data_dir = os.path.join(FLAGS.test_srcdir, TEST_DATA_DIR)
|
||||
else:
|
||||
data_dir = FLAGS.input_data_dir
|
||||
|
||||
export_dir = os.path.expanduser(FLAGS.export_dir)
|
||||
run(data_dir=data_dir, export_dir=export_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
define_flags()
|
||||
app.run(main)
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Options for building gesture recognizer."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional
|
||||
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import model_options as model_opt
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GestureRecognizerOptions:
|
||||
"""Configurable options for building gesture recognizer.
|
||||
|
||||
Attributes:
|
||||
model_options: A set of options for configuring the selected model.
|
||||
hparams: A set of hyperparameters used to train the gesture recognizer.
|
||||
"""
|
||||
model_options: Optional[model_opt.GestureRecognizerModelOptions] = None
|
||||
hparams: Optional[hyperparameters.HParams] = None
|
|
@ -0,0 +1,132 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import os
|
||||
from unittest import mock as unittest_mock
|
||||
import zipfile
|
||||
|
||||
import mock
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import test_util
|
||||
from mediapipe.model_maker.python.vision import gesture_recognizer
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
_TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data'
|
||||
|
||||
|
||||
class GestureRecognizerTest(tf.test.TestCase):
|
||||
|
||||
def _load_data(self):
|
||||
input_data_dir = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, 'raw_data'))
|
||||
|
||||
data = gesture_recognizer.Dataset.from_folder(
|
||||
dirname=input_data_dir,
|
||||
hparams=gesture_recognizer.HandDataPreprocessingParams(shuffle=True))
|
||||
return data
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._model_options = gesture_recognizer.ModelOptions()
|
||||
self._hparams = gesture_recognizer.HParams(epochs=2)
|
||||
self._gesture_recognizer_options = (
|
||||
gesture_recognizer.GestureRecognizerOptions(
|
||||
model_options=self._model_options, hparams=self._hparams))
|
||||
all_data = self._load_data()
|
||||
# Splits data, 90% data for training, 10% for testing
|
||||
self._train_data, self._test_data = all_data.split(0.9)
|
||||
|
||||
def test_gesture_recognizer_model(self):
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=self._gesture_recognizer_options)
|
||||
|
||||
self._test_accuracy(model)
|
||||
|
||||
def test_export_gesture_recognizer_model(self):
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=self._gesture_recognizer_options)
|
||||
model.export_model()
|
||||
model_bundle_file = os.path.join(self._hparams.export_dir,
|
||||
'gesture_recognizer.task')
|
||||
with zipfile.ZipFile(model_bundle_file) as zf:
|
||||
self.assertEqual(
|
||||
set(zf.namelist()),
|
||||
set(['hand_landmarker.task', 'hand_gesture_recognizer.task']))
|
||||
zf.extractall(self.get_temp_dir())
|
||||
hand_gesture_recognizer_bundle_file = os.path.join(
|
||||
self.get_temp_dir(), 'hand_gesture_recognizer.task')
|
||||
with zipfile.ZipFile(hand_gesture_recognizer_bundle_file) as zf:
|
||||
self.assertEqual(
|
||||
set(zf.namelist()),
|
||||
set([
|
||||
'canned_gesture_classifier.tflite',
|
||||
'custom_gesture_classifier.tflite', 'gesture_embedder.tflite'
|
||||
]))
|
||||
zf.extractall(self.get_temp_dir())
|
||||
gesture_classifier_tflite_file = os.path.join(
|
||||
self.get_temp_dir(), 'custom_gesture_classifier.tflite')
|
||||
test_util.test_tflite_file(
|
||||
keras_model=model._model,
|
||||
tflite_file=gesture_classifier_tflite_file,
|
||||
size=[1, model.embedding_size])
|
||||
|
||||
def _test_accuracy(self, model, threshold=0.5):
|
||||
_, accuracy = model.evaluate(self._test_data)
|
||||
tf.compat.v1.logging.info(f'accuracy: {accuracy}')
|
||||
self.assertGreaterEqual(accuracy, threshold)
|
||||
|
||||
@unittest_mock.patch.object(
|
||||
gesture_recognizer.hyperparameters,
|
||||
'HParams',
|
||||
autospec=True,
|
||||
return_value=gesture_recognizer.HParams(epochs=1))
|
||||
@unittest_mock.patch.object(
|
||||
gesture_recognizer.model_options,
|
||||
'GestureRecognizerModelOptions',
|
||||
autospec=True,
|
||||
return_value=gesture_recognizer.ModelOptions())
|
||||
def test_create_hparams_and_model_options_if_none_in_image_classifier_options(
|
||||
self, mock_hparams, mock_model_options):
|
||||
options = gesture_recognizer.GestureRecognizerOptions()
|
||||
gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=options)
|
||||
mock_hparams.assert_called_once()
|
||||
mock_model_options.assert_called_once()
|
||||
|
||||
def test_continual_training_by_loading_checkpoint(self):
|
||||
mock_stdout = io.StringIO()
|
||||
with mock.patch('sys.stdout', mock_stdout):
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=self._gesture_recognizer_options)
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._test_data,
|
||||
options=self._gesture_recognizer_options)
|
||||
self._test_accuracy(model)
|
||||
|
||||
self.assertRegex(mock_stdout.getvalue(), 'Resuming from')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Hyperparameters for training customized gesture recognizer models."""
|
||||
|
||||
import dataclasses
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HParams(hp.BaseHParams):
|
||||
"""The hyperparameters for training gesture recognizer.
|
||||
|
||||
Attributes:
|
||||
learning_rate: Learning rate to use for gradient descent training.
|
||||
batch_size: Batch size for training.
|
||||
epochs: Number of training iterations over the dataset.
|
||||
lr_decay: Learning rate decay to use for gradient descent training.
|
||||
gamma: Gamma parameter for focal loss.
|
||||
"""
|
||||
# Parameters from BaseHParams class.
|
||||
learning_rate: float = 0.001
|
||||
batch_size: int = 2
|
||||
epochs: int = 10
|
||||
|
||||
# Parameters about training configuration
|
||||
# TODO: Move lr_decay to hp.baseHParams.
|
||||
lr_decay: float = 0.99
|
||||
gamma: int = 2
|
|
@ -0,0 +1,193 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Writes metadata and creates model asset bundle for gesture recognizer."""
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Union
|
||||
|
||||
import tensorflow as tf
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
|
||||
|
||||
_HAND_DETECTOR_TFLITE_NAME = "hand_detector.tflite"
|
||||
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME = "hand_landmarks_detector.tflite"
|
||||
_HAND_LANDMARKER_BUNDLE_NAME = "hand_landmarker.task"
|
||||
_HAND_GESTURE_RECOGNIZER_BUNDLE_NAME = "hand_gesture_recognizer.task"
|
||||
_GESTURE_EMBEDDER_TFLITE_NAME = "gesture_embedder.tflite"
|
||||
_CANNED_GESTURE_CLASSIFIER_TFLITE_NAME = "canned_gesture_classifier.tflite"
|
||||
_CUSTOM_GESTURE_CLASSIFIER_TFLITE_NAME = "custom_gesture_classifier.tflite"
|
||||
|
||||
_MODEL_NAME = "HandGestureRecognition"
|
||||
_MODEL_DESCRIPTION = "Recognize the hand gesture in the image."
|
||||
|
||||
_INPUT_NAME = "embedding"
|
||||
_INPUT_DESCRIPTION = "Embedding feature vector from gesture embedder."
|
||||
_OUTPUT_NAME = "scores"
|
||||
_OUTPUT_DESCRIPTION = "Hand gesture category scores."
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GestureClassifierOptions:
|
||||
"""Options to write metadata for gesture classifier.
|
||||
|
||||
Attributes:
|
||||
model_buffer: Gesture classifier TFLite model buffer.
|
||||
labels: Labels for the gesture classifier.
|
||||
score_thresholding: Parameters to performs thresholding on output tensor
|
||||
values [1].
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
|
||||
"""
|
||||
model_buffer: bytearray
|
||||
labels: metadata_writer.Labels
|
||||
score_thresholding: metadata_writer.ScoreThresholding
|
||||
|
||||
|
||||
def read_file(file_path: str, mode: str = "rb") -> Union[str, bytes]:
|
||||
with tf.io.gfile.GFile(file_path, mode) as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
class MetadataWriter:
|
||||
"""MetadataWriter to write the metadata and the model asset bundle."""
|
||||
|
||||
def __init__(
|
||||
self, hand_detector_model_buffer: bytearray,
|
||||
hand_landmarks_detector_model_buffer: bytearray,
|
||||
gesture_embedder_model_buffer: bytearray,
|
||||
canned_gesture_classifier_model_buffer: bytearray,
|
||||
custom_gesture_classifier_metadata_writer: metadata_writer.MetadataWriter
|
||||
) -> None:
|
||||
"""Initialize MetadataWriter to write the metadata and model asset bundle.
|
||||
|
||||
Args:
|
||||
hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from
|
||||
the TFLite hand detector model file.
|
||||
hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata
|
||||
loaded from the TFLite hand landmarks detector model file.
|
||||
gesture_embedder_model_buffer: A valid flatbuffer *with* metadata loaded
|
||||
from the TFLite gesture embedder model file.
|
||||
canned_gesture_classifier_model_buffer: A valid flatbuffer *with* metadata
|
||||
loaded from the TFLite canned gesture classifier model file.
|
||||
custom_gesture_classifier_metadata_writer: Metadata writer to write custom
|
||||
gesture classifier metadata into the TFLite file.
|
||||
"""
|
||||
self._hand_detector_model_buffer = hand_detector_model_buffer
|
||||
self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer
|
||||
self._gesture_embedder_model_buffer = gesture_embedder_model_buffer
|
||||
self._canned_gesture_classifier_model_buffer = canned_gesture_classifier_model_buffer
|
||||
self._custom_gesture_classifier_metadata_writer = custom_gesture_classifier_metadata_writer
|
||||
self._temp_folder = tempfile.TemporaryDirectory()
|
||||
|
||||
def __del__(self):
|
||||
if os.path.exists(self._temp_folder.name):
|
||||
self._temp_folder.cleanup()
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
hand_detector_model_buffer: bytearray,
|
||||
hand_landmarks_detector_model_buffer: bytearray,
|
||||
gesture_embedder_model_buffer: bytearray,
|
||||
canned_gesture_classifier_model_buffer: bytearray,
|
||||
custom_gesture_classifier_options: GestureClassifierOptions,
|
||||
) -> "MetadataWriter":
|
||||
"""Creates MetadataWriter to write the metadata for gesture recognizer.
|
||||
|
||||
Args:
|
||||
hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from
|
||||
the TFLite hand detector model file.
|
||||
hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata
|
||||
loaded from the TFLite hand landmarks detector model file.
|
||||
gesture_embedder_model_buffer: A valid flatbuffer *with* metadata loaded
|
||||
from the TFLite gesture embedder model file.
|
||||
canned_gesture_classifier_model_buffer: A valid flatbuffer *with* metadata
|
||||
loaded from the TFLite canned gesture classifier model file.
|
||||
custom_gesture_classifier_options: Custom gesture classifier options to
|
||||
write custom gesture classifier metadata into the TFLite file.
|
||||
|
||||
Returns:
|
||||
An MetadataWrite object.
|
||||
"""
|
||||
writer = metadata_writer.MetadataWriter.create(
|
||||
custom_gesture_classifier_options.model_buffer)
|
||||
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
|
||||
writer.add_feature_input(name=_INPUT_NAME, description=_INPUT_DESCRIPTION)
|
||||
writer.add_classification_output(
|
||||
labels=custom_gesture_classifier_options.labels,
|
||||
score_thresholding=custom_gesture_classifier_options.score_thresholding,
|
||||
name=_OUTPUT_NAME,
|
||||
description=_OUTPUT_DESCRIPTION)
|
||||
return cls(hand_detector_model_buffer, hand_landmarks_detector_model_buffer,
|
||||
gesture_embedder_model_buffer,
|
||||
canned_gesture_classifier_model_buffer, writer)
|
||||
|
||||
def populate(self):
|
||||
"""Populates the metadata and creates model asset bundle.
|
||||
|
||||
Note that only the output model asset bundle is used for deployment.
|
||||
The output JSON content is used to interpret the custom gesture classifier
|
||||
metadata content.
|
||||
|
||||
Returns:
|
||||
A tuple of (model_asset_bundle_in_bytes, metadata_json_content)
|
||||
"""
|
||||
# Creates the model asset bundle for hand landmarker task.
|
||||
landmark_models = {
|
||||
_HAND_DETECTOR_TFLITE_NAME:
|
||||
self._hand_detector_model_buffer,
|
||||
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME:
|
||||
self._hand_landmarks_detector_model_buffer
|
||||
}
|
||||
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
|
||||
_HAND_LANDMARKER_BUNDLE_NAME)
|
||||
writer_utils.create_model_asset_bundle(landmark_models,
|
||||
output_hand_landmarker_path)
|
||||
|
||||
# Write metadata into custom gesture classifier model.
|
||||
self._custom_gesture_classifier_model_buffer, custom_gesture_classifier_metadata_json = self._custom_gesture_classifier_metadata_writer.populate(
|
||||
)
|
||||
# Creates the model asset bundle for hand gesture recognizer sub graph.
|
||||
hand_gesture_recognizer_models = {
|
||||
_GESTURE_EMBEDDER_TFLITE_NAME:
|
||||
self._gesture_embedder_model_buffer,
|
||||
_CANNED_GESTURE_CLASSIFIER_TFLITE_NAME:
|
||||
self._canned_gesture_classifier_model_buffer,
|
||||
_CUSTOM_GESTURE_CLASSIFIER_TFLITE_NAME:
|
||||
self._custom_gesture_classifier_model_buffer
|
||||
}
|
||||
output_hand_gesture_recognizer_path = os.path.join(
|
||||
self._temp_folder.name, _HAND_GESTURE_RECOGNIZER_BUNDLE_NAME)
|
||||
writer_utils.create_model_asset_bundle(hand_gesture_recognizer_models,
|
||||
output_hand_gesture_recognizer_path)
|
||||
|
||||
# Creates the model asset bundle for end-to-end hand gesture recognizer
|
||||
# graph.
|
||||
gesture_recognizer_models = {
|
||||
_HAND_LANDMARKER_BUNDLE_NAME:
|
||||
read_file(output_hand_landmarker_path),
|
||||
_HAND_GESTURE_RECOGNIZER_BUNDLE_NAME:
|
||||
read_file(output_hand_gesture_recognizer_path),
|
||||
}
|
||||
|
||||
output_file_path = os.path.join(self._temp_folder.name,
|
||||
"gesture_recognizer.task")
|
||||
writer_utils.create_model_asset_bundle(gesture_recognizer_models,
|
||||
output_file_path)
|
||||
with open(output_file_path, "rb") as f:
|
||||
gesture_recognizer_model_buffer = f.read()
|
||||
return gesture_recognizer_model_buffer, custom_gesture_classifier_metadata_json
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for metadata_writer."""
|
||||
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer as base_metadata_writer
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
_TEST_DATA_DIR = "mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata"
|
||||
|
||||
_EXPECTED_JSON = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "custom_gesture_classifier_meta.json"))
|
||||
_CUSTOM_GESTURE_CLASSIFIER_PATH = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "custom_gesture_classifier.tflite"))
|
||||
|
||||
|
||||
class MetadataWriterTest(tf.test.TestCase):
|
||||
|
||||
def test_write_metadata_and_create_model_asset_bundle_successful(self):
|
||||
# Use dummy model buffer for unit test only.
|
||||
hand_detector_model_buffer = b"\x11\x12"
|
||||
hand_landmarks_detector_model_buffer = b"\x22"
|
||||
gesture_embedder_model_buffer = b"\x33"
|
||||
canned_gesture_classifier_model_buffer = b"\x44"
|
||||
custom_gesture_classifier_metadata_writer = metadata_writer.GestureClassifierOptions(
|
||||
model_buffer=metadata_writer.read_file(_CUSTOM_GESTURE_CLASSIFIER_PATH),
|
||||
labels=base_metadata_writer.Labels().add(
|
||||
["None", "Paper", "Rock", "Scissors"]),
|
||||
score_thresholding=base_metadata_writer.ScoreThresholding(
|
||||
global_score_threshold=0.5))
|
||||
writer = metadata_writer.MetadataWriter.create(
|
||||
hand_detector_model_buffer, hand_landmarks_detector_model_buffer,
|
||||
gesture_embedder_model_buffer, canned_gesture_classifier_model_buffer,
|
||||
custom_gesture_classifier_metadata_writer)
|
||||
model_bundle_content, metadata_json = writer.populate()
|
||||
with open(_EXPECTED_JSON, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
# Checks the top-level model bundle can be extracted successfully.
|
||||
model_bundle_filepath = os.path.join(self.get_temp_dir(),
|
||||
"gesture_recognition.task")
|
||||
|
||||
with open(model_bundle_filepath, "wb") as f:
|
||||
f.write(model_bundle_content)
|
||||
|
||||
with zipfile.ZipFile(model_bundle_filepath) as zf:
|
||||
self.assertEqual(
|
||||
set(zf.namelist()),
|
||||
set(["hand_landmarker.task", "hand_gesture_recognizer.task"]))
|
||||
zf.extractall(self.get_temp_dir())
|
||||
|
||||
# Checks the model bundles for sub-task can be extracted successfully.
|
||||
hand_landmarker_bundle_filepath = os.path.join(self.get_temp_dir(),
|
||||
"hand_landmarker.task")
|
||||
with zipfile.ZipFile(hand_landmarker_bundle_filepath) as zf:
|
||||
self.assertEqual(
|
||||
set(zf.namelist()),
|
||||
set(["hand_landmarks_detector.tflite", "hand_detector.tflite"]))
|
||||
|
||||
hand_gesture_recognizer_bundle_filepath = os.path.join(
|
||||
self.get_temp_dir(), "hand_gesture_recognizer.task")
|
||||
with zipfile.ZipFile(hand_gesture_recognizer_bundle_filepath) as zf:
|
||||
self.assertEqual(
|
||||
set(zf.namelist()),
|
||||
set([
|
||||
"canned_gesture_classifier.tflite",
|
||||
"custom_gesture_classifier.tflite", "gesture_embedder.tflite"
|
||||
]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Configurable model options for gesture recognizer models."""
|
||||
|
||||
import dataclasses
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GestureRecognizerModelOptions:
|
||||
"""Configurable options for gesture recognizer model.
|
||||
|
||||
Attributes:
|
||||
dropout_rate: The fraction of the input units to drop, used in dropout
|
||||
layer.
|
||||
"""
|
||||
dropout_rate: float = 0.05
|
|
@ -0,0 +1,56 @@
|
|||
{
|
||||
"name": "HandGestureRecognition",
|
||||
"description": "Recognize the hand gesture in the image.",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "embedding",
|
||||
"description": "Embedding feature vector from gesture embedder.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "scores",
|
||||
"description": "Hand gesture category scores.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"process_units": [
|
||||
{
|
||||
"options_type": "ScoreThresholdingOptions",
|
||||
"options": {
|
||||
"global_score_threshold": 0.5
|
||||
}
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
0.0
|
||||
]
|
||||
},
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "labels.txt",
|
||||
"description": "Labels for categories that the model can recognize.",
|
||||
"type": "TENSOR_AXIS_LABELS"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.0.0"
|
||||
}
|
After Width: | Height: | Size: 28 KiB |
After Width: | Height: | Size: 27 KiB |
After Width: | Height: | Size: 40 KiB |
After Width: | Height: | Size: 15 KiB |
After Width: | Height: | Size: 21 KiB |
After Width: | Height: | Size: 18 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 19 KiB |
After Width: | Height: | Size: 22 KiB |
After Width: | Height: | Size: 32 KiB |
After Width: | Height: | Size: 22 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 38 KiB |
After Width: | Height: | Size: 27 KiB |
After Width: | Height: | Size: 22 KiB |
After Width: | Height: | Size: 34 KiB |
After Width: | Height: | Size: 43 KiB |
After Width: | Height: | Size: 35 KiB |
After Width: | Height: | Size: 19 KiB |
After Width: | Height: | Size: 30 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 34 KiB |
After Width: | Height: | Size: 21 KiB |
After Width: | Height: | Size: 20 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 24 KiB |
After Width: | Height: | Size: 40 KiB |
After Width: | Height: | Size: 24 KiB |
After Width: | Height: | Size: 30 KiB |
After Width: | Height: | Size: 21 KiB |
After Width: | Height: | Size: 18 KiB |
After Width: | Height: | Size: 19 KiB |
After Width: | Height: | Size: 14 KiB |
After Width: | Height: | Size: 17 KiB |
After Width: | Height: | Size: 27 KiB |
After Width: | Height: | Size: 43 KiB |
After Width: | Height: | Size: 15 KiB |
After Width: | Height: | Size: 31 KiB |
After Width: | Height: | Size: 24 KiB |
After Width: | Height: | Size: 17 KiB |
|
@ -121,6 +121,34 @@ py_library(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "solution_base",
|
||||
srcs = ["solution_base.py"],
|
||||
srcs_version = "PY3",
|
||||
visibility = [
|
||||
"//mediapipe/python:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
":_framework_bindings",
|
||||
":packet_creator",
|
||||
":packet_getter",
|
||||
"//mediapipe/calculators/core:constant_side_packet_calculator_py_pb2",
|
||||
"//mediapipe/calculators/image:image_transformation_calculator_py_pb2",
|
||||
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_py_pb2",
|
||||
"//mediapipe/calculators/util:landmarks_smoothing_calculator_py_pb2",
|
||||
"//mediapipe/calculators/util:logic_calculator_py_pb2",
|
||||
"//mediapipe/calculators/util:thresholding_calculator_py_pb2",
|
||||
"//mediapipe/framework:calculator_py_pb2",
|
||||
"//mediapipe/framework/formats:classification_py_pb2",
|
||||
"//mediapipe/framework/formats:detection_py_pb2",
|
||||
"//mediapipe/framework/formats:landmark_py_pb2",
|
||||
"//mediapipe/framework/formats:rect_py_pb2",
|
||||
"//mediapipe/modules/objectron/calculators:annotation_py_pb2",
|
||||
"//mediapipe/modules/objectron/calculators:lift_2d_frame_annotation_to_3d_calculator_py_pb2",
|
||||
"@com_google_protobuf//:protobuf_python",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "calculator_graph_test",
|
||||
srcs = ["calculator_graph_test.py"],
|
||||
|
@ -176,3 +204,24 @@ py_test(
|
|||
":_framework_bindings",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "solution_base_test",
|
||||
srcs = ["solution_base_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":solution_base",
|
||||
"//file/google_src",
|
||||
"//file/localfile",
|
||||
"//mediapipe/calculators/core:gate_calculator",
|
||||
"//mediapipe/calculators/core:pass_through_calculator",
|
||||
"//mediapipe/calculators/core:side_packet_to_stream_calculator",
|
||||
"//mediapipe/calculators/image:image_transformation_calculator",
|
||||
"//mediapipe/calculators/util:detection_unique_id_calculator",
|
||||
"//mediapipe/calculators/util:to_image_calculator",
|
||||
"//mediapipe/framework:calculator_py_pb2",
|
||||
"//mediapipe/framework/formats:detection_py_pb2",
|
||||
"@com_google_protobuf//:protobuf_python",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -40,7 +40,6 @@ from mediapipe.calculators.util import landmarks_smoothing_calculator_pb2
|
|||
from mediapipe.calculators.util import logic_calculator_pb2
|
||||
from mediapipe.calculators.util import thresholding_calculator_pb2
|
||||
from mediapipe.framework import calculator_pb2
|
||||
from mediapipe.framework.formats import body_rig_pb2
|
||||
from mediapipe.framework.formats import classification_pb2
|
||||
from mediapipe.framework.formats import detection_pb2
|
||||
from mediapipe.framework.formats import landmark_pb2
|
||||
|
|
105
mediapipe/python/solutions/BUILD
Normal file
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2020 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
py_library(
|
||||
name = "hands",
|
||||
srcs = [
|
||||
"hands.py",
|
||||
"hands_connections.py",
|
||||
],
|
||||
data = [
|
||||
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
|
||||
"//mediapipe/modules/hand_landmark:hand_landmark_lite.tflite",
|
||||
"//mediapipe/modules/hand_landmark:hand_landmark_tracking_cpu_graph",
|
||||
"//mediapipe/modules/hand_landmark:handedness.txt",
|
||||
"//mediapipe/modules/palm_detection:palm_detection_full.tflite",
|
||||
"//mediapipe/modules/palm_detection:palm_detection_lite.tflite",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:constant_side_packet_calculator_py_pb2",
|
||||
"//mediapipe/calculators/core:gate_calculator_py_pb2",
|
||||
"//mediapipe/calculators/core:split_vector_calculator_py_pb2",
|
||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator_py_pb2",
|
||||
"//mediapipe/calculators/tensor:inference_calculator_py_pb2",
|
||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_py_pb2",
|
||||
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_py_pb2",
|
||||
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator_py_pb2",
|
||||
"//mediapipe/calculators/tflite:ssd_anchors_calculator_py_pb2",
|
||||
"//mediapipe/calculators/util:association_calculator_py_pb2",
|
||||
"//mediapipe/calculators/util:detections_to_rects_calculator_py_pb2",
|
||||
"//mediapipe/calculators/util:logic_calculator_py_pb2",
|
||||
"//mediapipe/calculators/util:non_max_suppression_calculator_py_pb2",
|
||||
"//mediapipe/calculators/util:rect_transformation_calculator_py_pb2",
|
||||
"//mediapipe/calculators/util:thresholding_calculator_py_pb2",
|
||||
"//mediapipe/python:solution_base",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "drawing_styles",
|
||||
srcs = ["drawing_styles.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"drawing_utils",
|
||||
"face_mesh",
|
||||
"hands",
|
||||
"pose",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "drawing_utils",
|
||||
srcs = ["drawing_utils.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:detection_py_pb2",
|
||||
"//mediapipe/framework/formats:landmark_py_pb2",
|
||||
"//mediapipe/framework/formats:location_data_py_pb2",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "drawing_utils_test",
|
||||
srcs = ["drawing_utils_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":drawing_utils",
|
||||
"//mediapipe/framework/formats:detection_py_pb2",
|
||||
"//mediapipe/framework/formats:landmark_py_pb2",
|
||||
"@com_google_protobuf//:protobuf_python",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "hands_test",
|
||||
srcs = ["hands_test.py"],
|
||||
data = [
|
||||
":testdata/asl_hand.25fps.mp4",
|
||||
":testdata/asl_hand.full.npz",
|
||||
":testdata/hands.jpg",
|
||||
],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":drawing_styles",
|
||||
":drawing_utils",
|
||||
":hands",
|
||||
],
|
||||
)
|
|
@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto";
|
|||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.audio.audioclassifier.proto";
|
||||
option java_outer_classname = "AudioClassifierGraphOptionsProto";
|
||||
|
||||
message AudioClassifierGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional AudioClassifierGraphOptions ext = 451755788;
|
||||
|
|
|
@ -21,15 +21,6 @@ cc_library(
|
|||
hdrs = ["rect.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gesture_recognition_result",
|
||||
hdrs = ["gesture_recognition_result.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "category",
|
||||
srcs = ["category.cc"],
|
||||
|
|
|
@ -124,12 +124,22 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gesture_recognizer_result",
|
||||
hdrs = ["gesture_recognizer_result.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gesture_recognizer",
|
||||
srcs = ["gesture_recognizer.cc"],
|
||||
hdrs = ["gesture_recognizer.h"],
|
||||
deps = [
|
||||
":gesture_recognizer_graph",
|
||||
":gesture_recognizer_result",
|
||||
":hand_gesture_recognizer_graph",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
|
@ -140,7 +150,6 @@ cc_library(
|
|||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||
"//mediapipe/tasks/cc/components/containers:gesture_recognition_result",
|
||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
|
|
|
@ -58,8 +58,6 @@ namespace {
|
|||
using GestureRecognizerGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||
gesture_recognizer::proto::GestureRecognizerGraphOptions;
|
||||
|
||||
using ::mediapipe::tasks::components::containers::GestureRecognitionResult;
|
||||
|
||||
constexpr char kHandGestureSubgraphTypeName[] =
|
||||
"mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph";
|
||||
|
||||
|
@ -214,7 +212,7 @@ absl::StatusOr<std::unique_ptr<GestureRecognizer>> GestureRecognizer::Create(
|
|||
std::move(packets_callback));
|
||||
}
|
||||
|
||||
absl::StatusOr<GestureRecognitionResult> GestureRecognizer::Recognize(
|
||||
absl::StatusOr<GestureRecognizerResult> GestureRecognizer::Recognize(
|
||||
mediapipe::Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
|
@ -250,7 +248,7 @@ absl::StatusOr<GestureRecognitionResult> GestureRecognizer::Recognize(
|
|||
};
|
||||
}
|
||||
|
||||
absl::StatusOr<GestureRecognitionResult> GestureRecognizer::RecognizeForVideo(
|
||||
absl::StatusOr<GestureRecognizerResult> GestureRecognizer::RecognizeForVideo(
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
|
|
|
@ -24,12 +24,12 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/gesture_recognition_result.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_result.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -81,9 +81,8 @@ struct GestureRecognizerOptions {
|
|||
// The user-defined result callback for processing live stream data.
|
||||
// The result callback should only be specified when the running mode is set
|
||||
// to RunningMode::LIVE_STREAM.
|
||||
std::function<void(
|
||||
absl::StatusOr<components::containers::GestureRecognitionResult>,
|
||||
const Image&, int64)>
|
||||
std::function<void(absl::StatusOr<GestureRecognizerResult>, const Image&,
|
||||
int64)>
|
||||
result_callback = nullptr;
|
||||
};
|
||||
|
||||
|
@ -104,7 +103,7 @@ struct GestureRecognizerOptions {
|
|||
// 'width' and 'height' fields is NOT supported and will result in an
|
||||
// invalid argument error being returned.
|
||||
// Outputs:
|
||||
// GestureRecognitionResult
|
||||
// GestureRecognizerResult
|
||||
// - The hand gesture recognition results.
|
||||
class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
|
||||
public:
|
||||
|
@ -139,7 +138,7 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
|
|||
// The image can be of any size with format RGB or RGBA.
|
||||
// TODO: Describes how the input image will be preprocessed
|
||||
// after the yuv support is implemented.
|
||||
absl::StatusOr<components::containers::GestureRecognitionResult> Recognize(
|
||||
absl::StatusOr<GestureRecognizerResult> Recognize(
|
||||
Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
@ -157,10 +156,10 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
|
|||
// The image can be of any size with format RGB or RGBA. It's required to
|
||||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
// must be monotonically increasing.
|
||||
absl::StatusOr<components::containers::GestureRecognitionResult>
|
||||
RecognizeForVideo(Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions>
|
||||
image_processing_options = std::nullopt);
|
||||
absl::StatusOr<GestureRecognizerResult> RecognizeForVideo(
|
||||
Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Sends live image data to perform gesture recognition, and the results will
|
||||
// be available via the "result_callback" provided in the
|
||||
|
@ -179,7 +178,7 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
|
|||
// and will result in an invalid argument error being returned.
|
||||
//
|
||||
// The "result_callback" provides
|
||||
// - A vector of GestureRecognitionResult, each is the recognized results
|
||||
// - A vector of GestureRecognizerResult, each is the recognized results
|
||||
// for a input frame.
|
||||
// - The const reference to the corresponding input image that the gesture
|
||||
// recognizer runs on. Note that the const reference to the image will no
|
||||
|
|
|
@ -13,20 +13,20 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_GESTURE_RECOGNIZER_RESULT_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_GESTURE_RECOGNIZER_RESULT_H_
|
||||
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace containers {
|
||||
namespace vision {
|
||||
namespace gesture_recognizer {
|
||||
|
||||
// The gesture recognition result from GestureRecognizer, where each vector
|
||||
// element represents a single hand detected in the image.
|
||||
struct GestureRecognitionResult {
|
||||
struct GestureRecognizerResult {
|
||||
// Recognized hand gestures with sorted order such that the winning label is
|
||||
// the first item in the list.
|
||||
std::vector<mediapipe::ClassificationList> gestures;
|
||||
|
@ -38,9 +38,9 @@ struct GestureRecognitionResult {
|
|||
std::vector<mediapipe::LandmarkList> hand_world_landmarks;
|
||||
};
|
||||
|
||||
} // namespace containers
|
||||
} // namespace components
|
||||
} // namespace gesture_recognizer
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_GESTURE_RECOGNIZER_RESULT_H_
|
|
@ -276,33 +276,44 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
|
|||
.set_min_size(max_num_hands);
|
||||
auto has_enough_hands = min_size_node.Out("").Cast<bool>();
|
||||
|
||||
auto image_for_hand_detector =
|
||||
DisallowIf(image_in, has_enough_hands, graph);
|
||||
auto norm_rect_in_for_hand_detector =
|
||||
DisallowIf(norm_rect_in, has_enough_hands, graph);
|
||||
|
||||
auto& hand_detector =
|
||||
graph.AddNode("mediapipe.tasks.vision.hand_detector.HandDetectorGraph");
|
||||
hand_detector.GetOptions<HandDetectorGraphOptions>().CopyFrom(
|
||||
tasks_options.hand_detector_graph_options());
|
||||
auto& clip_hand_rects =
|
||||
graph.AddNode("ClipNormalizedRectVectorSizeCalculator");
|
||||
clip_hand_rects.GetOptions<ClipVectorSizeCalculatorOptions>()
|
||||
.set_max_vec_size(max_num_hands);
|
||||
|
||||
if (tasks_options.base_options().use_stream_mode()) {
|
||||
// While in stream mode, skip hand detector graph when we successfully
|
||||
// track the hands from the last frame.
|
||||
auto image_for_hand_detector =
|
||||
DisallowIf(image_in, has_enough_hands, graph);
|
||||
auto norm_rect_in_for_hand_detector =
|
||||
DisallowIf(norm_rect_in, has_enough_hands, graph);
|
||||
image_for_hand_detector >> hand_detector.In("IMAGE");
|
||||
norm_rect_in_for_hand_detector >> hand_detector.In("NORM_RECT");
|
||||
auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS");
|
||||
|
||||
auto& hand_association = graph.AddNode("HandAssociationCalculator");
|
||||
hand_association.GetOptions<HandAssociationCalculatorOptions>()
|
||||
.set_min_similarity_threshold(tasks_options.min_tracking_confidence());
|
||||
.set_min_similarity_threshold(
|
||||
tasks_options.min_tracking_confidence());
|
||||
prev_hand_rects_from_landmarks >>
|
||||
hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][0];
|
||||
hand_rects_from_hand_detector >>
|
||||
hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][1];
|
||||
auto hand_rects = hand_association.Out("");
|
||||
|
||||
auto& clip_hand_rects =
|
||||
graph.AddNode("ClipNormalizedRectVectorSizeCalculator");
|
||||
clip_hand_rects.GetOptions<ClipVectorSizeCalculatorOptions>()
|
||||
.set_max_vec_size(max_num_hands);
|
||||
hand_rects >> clip_hand_rects.In("");
|
||||
} else {
|
||||
// While not in stream mode, the input images are not guaranteed to be in
|
||||
// series, and we don't want to enable the tracking and hand associations
|
||||
// between input images. Always use the hand detector graph.
|
||||
image_in >> hand_detector.In("IMAGE");
|
||||
norm_rect_in >> hand_detector.In("NORM_RECT");
|
||||
auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS");
|
||||
hand_rects_from_hand_detector >> clip_hand_rects.In("");
|
||||
}
|
||||
auto clipped_hand_rects = clip_hand_rects.Out("");
|
||||
|
||||
auto& hand_landmarks_detector_graph = graph.AddNode(
|
||||
|
|
84
mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD
Normal file
|
@ -0,0 +1,84 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
android_library(
|
||||
name = "core",
|
||||
srcs = glob(["core/*.java"]),
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
deps = [
|
||||
":libmediapipe_tasks_audio_jni_lib",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
# The native library of all MediaPipe audio tasks.
|
||||
cc_binary(
|
||||
name = "libmediapipe_tasks_audio_jni.so",
|
||||
linkshared = 1,
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
||||
"//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "libmediapipe_tasks_audio_jni_lib",
|
||||
srcs = [":libmediapipe_tasks_audio_jni.so"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "audioclassifier",
|
||||
srcs = [
|
||||
"audioclassifier/AudioClassifier.java",
|
||||
"audioclassifier/AudioClassifierResult.java",
|
||||
],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
manifest = "audioclassifier/AndroidManifest.xml",
|
||||
deps = [
|
||||
":core",
|
||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_audio_aar")
|
||||
|
||||
mediapipe_tasks_audio_aar(
|
||||
name = "tasks_audio",
|
||||
srcs = glob(["**/*.java"]),
|
||||
native_library = ":libmediapipe_tasks_audio_jni_lib",
|
||||
)
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.audio.audioclassifier">
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,399 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.audio.audioclassifier;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.ParcelFileDescriptor;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.framework.PacketGetter;
|
||||
import com.google.mediapipe.framework.ProtoUtil;
|
||||
import com.google.mediapipe.tasks.audio.audioclassifier.proto.AudioClassifierGraphOptionsProto;
|
||||
import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi;
|
||||
import com.google.mediapipe.tasks.audio.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.components.containers.AudioData;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
|
||||
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Performs audio classification on audio clips or audio stream.
|
||||
*
|
||||
* <p>This API expects a TFLite model with mandatory TFLite Model Metadata that contains the
|
||||
* mandatory AudioProperties of the solo input audio tensor and the optional (but recommended) label
|
||||
* items as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor.
|
||||
*
|
||||
* <p>Input tensor: (kTfLiteFloat32)
|
||||
*
|
||||
* <ul>
|
||||
* <li>input audio buffer of size `[batch * samples]`.
|
||||
* <li>batch inference is not supported (`batch` is required to be 1).
|
||||
* <li>for multi-channel models, the channels need be interleaved.
|
||||
* </ul>
|
||||
*
|
||||
* <p>At least one output tensor with: (kTfLiteFloat32)
|
||||
*
|
||||
* <ul>
|
||||
* <li>`[1 x N]` array with `N` represents the number of categories.
|
||||
* <li>optional (but recommended) label items as AssociatedFiles with type TENSOR_AXIS_LABELS,
|
||||
* containing one label per line. The first such AssociatedFile (if any) is used to fill the
|
||||
* `category_name` field of the results. The `display_name` field is filled from the
|
||||
* AssociatedFile (if any) whose locale matches the `display_names_locale` field of the
|
||||
* `AudioClassifierOptions` used at creation time ("en" by default, i.e. English). If none of
|
||||
* these are available, only the `index` field of the results will be filled.
|
||||
* </ul>
|
||||
*/
|
||||
public final class AudioClassifier extends BaseAudioTaskApi {
|
||||
private static final String TAG = AudioClassifier.class.getSimpleName();
|
||||
private static final String AUDIO_IN_STREAM_NAME = "audio_in";
|
||||
private static final String SAMPLE_RATE_IN_STREAM_NAME = "sample_rate_in";
|
||||
private static final List<String> INPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList(
|
||||
"AUDIO:" + AUDIO_IN_STREAM_NAME, "SAMPLE_RATE:" + SAMPLE_RATE_IN_STREAM_NAME));
|
||||
private static final List<String> OUTPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList(
|
||||
"CLASSIFICATIONS:classifications_out",
|
||||
"TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications_out"));
|
||||
private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0;
|
||||
private static final int TIMESTAMPED_CLASSIFICATIONS_OUT_STREAM_INDEX = 1;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph";
|
||||
private static final long MICROSECONDS_PER_MILLISECOND = 1000;
|
||||
|
||||
static {
|
||||
ProtoUtil.registerTypeName(
|
||||
ClassificationsProto.ClassificationResult.class,
|
||||
"mediapipe.tasks.components.containers.proto.ClassificationResult");
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link AudioClassifier} instance from a model file and default {@link
|
||||
* AudioClassifierOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelPath path to the classification model in the assets.
|
||||
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
|
||||
*/
|
||||
public static AudioClassifier createFromFile(Context context, String modelPath) {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
|
||||
return createFromOptions(
|
||||
context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link AudioClassifier} instance from a model file and default {@link
|
||||
* AudioClassifierOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelFile the classification model {@link File} instance.
|
||||
* @throws IOException if an I/O error occurs when opening the tflite model file.
|
||||
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
|
||||
*/
|
||||
public static AudioClassifier createFromFile(Context context, File modelFile) throws IOException {
|
||||
try (ParcelFileDescriptor descriptor =
|
||||
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
|
||||
BaseOptions baseOptions =
|
||||
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
|
||||
return createFromOptions(
|
||||
context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link AudioClassifier} instance from a model buffer and default {@link
|
||||
* AudioClassifierOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
|
||||
* classification model.
|
||||
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
|
||||
*/
|
||||
public static AudioClassifier createFromBuffer(Context context, final ByteBuffer modelBuffer) {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
|
||||
return createFromOptions(
|
||||
context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link AudioClassifier} instance from an {@link AudioClassifierOptions} instance.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param options an {@link AudioClassifierOptions} instance.
|
||||
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
|
||||
*/
|
||||
public static AudioClassifier createFromOptions(Context context, AudioClassifierOptions options) {
|
||||
OutputHandler<AudioClassifierResult, Void> handler = new OutputHandler<>();
|
||||
handler.setOutputPacketConverter(
|
||||
new OutputHandler.OutputPacketConverter<AudioClassifierResult, Void>() {
|
||||
@Override
|
||||
public AudioClassifierResult convertToTaskResult(List<Packet> packets) {
|
||||
try {
|
||||
if (!packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).isEmpty()) {
|
||||
// For audio stream mode.
|
||||
return AudioClassifierResult.createFromProto(
|
||||
PacketGetter.getProto(
|
||||
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX),
|
||||
ClassificationsProto.ClassificationResult.getDefaultInstance()),
|
||||
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp()
|
||||
/ MICROSECONDS_PER_MILLISECOND);
|
||||
} else {
|
||||
// For audio clips mode.
|
||||
return AudioClassifierResult.createFromProtoList(
|
||||
PacketGetter.getProtoVector(
|
||||
packets.get(TIMESTAMPED_CLASSIFICATIONS_OUT_STREAM_INDEX),
|
||||
ClassificationsProto.ClassificationResult.parser()),
|
||||
-1);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Void convertToTaskInput(List<Packet> packets) {
|
||||
return null;
|
||||
}
|
||||
});
|
||||
if (options.resultListener().isPresent()) {
|
||||
ResultListener<AudioClassifierResult, Void> resultListener =
|
||||
new ResultListener<AudioClassifierResult, Void>() {
|
||||
@Override
|
||||
public void run(AudioClassifierResult audioClassifierResult, Void input) {
|
||||
options.resultListener().get().run(audioClassifierResult);
|
||||
}
|
||||
};
|
||||
handler.setResultListener(resultListener);
|
||||
}
|
||||
options.errorListener().ifPresent(handler::setErrorListener);
|
||||
// Audio tasks should not drop input audio due to flow limiting, which may cause data
|
||||
// inconsistency.
|
||||
TaskRunner runner =
|
||||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<AudioClassifierOptions>builder()
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(OUTPUT_STREAMS)
|
||||
.setTaskOptions(options)
|
||||
.setEnableFlowLimiting(false)
|
||||
.build(),
|
||||
handler);
|
||||
return new AudioClassifier(runner, options.runningMode());
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor to initialize an {@link AudioClassifier} from a {@link TaskRunner} and {@link
|
||||
* RunningMode}.
|
||||
*
|
||||
* @param taskRunner a {@link TaskRunner}.
|
||||
* @param runningMode a mediapipe audio task {@link RunningMode}.
|
||||
*/
|
||||
private AudioClassifier(TaskRunner taskRunner, RunningMode runningMode) {
|
||||
super(taskRunner, runningMode, AUDIO_IN_STREAM_NAME, SAMPLE_RATE_IN_STREAM_NAME);
|
||||
}
|
||||
|
||||
/*
|
||||
* Performs audio classification on the provided audio clip. Only use this method when the
|
||||
* AudioClassifier is created with the audio clips mode.
|
||||
*
|
||||
* <p>The audio clip is represented as a MediaPipe {@link AudioData} object The method accepts
|
||||
* audio clips with various length and audio sample rate. It's required to provide the
|
||||
* corresponding audio sample rate within the {@link AudioData} object.
|
||||
*
|
||||
* <p>The input audio clip may be longer than what the model is able to process in a single
|
||||
* inference. When this occurs, the input audio clip is split into multiple chunks starting at
|
||||
* different timestamps. For this reason, this function returns a vector of ClassificationResult
|
||||
* objects, each associated with a timestamp corresponding to the start (in milliseconds) of the
|
||||
* chunk data that was classified, e.g:
|
||||
*
|
||||
* ClassificationResult #0 (first chunk of data):
|
||||
* timestamp_ms: 0 (starts at 0ms)
|
||||
* classifications #0 (single head model):
|
||||
* category #0:
|
||||
* category_name: "Speech"
|
||||
* score: 0.6
|
||||
* category #1:
|
||||
* category_name: "Music"
|
||||
* score: 0.2
|
||||
* ClassificationResult #1 (second chunk of data):
|
||||
* timestamp_ms: 800 (starts at 800ms)
|
||||
* classifications #0 (single head model):
|
||||
* category #0:
|
||||
* category_name: "Speech"
|
||||
* score: 0.5
|
||||
* category #1:
|
||||
* category_name: "Silence"
|
||||
* score: 0.1
|
||||
*
|
||||
* @param audioClip a MediaPipe {@link AudioData} object for processing.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public AudioClassifierResult classify(AudioData audioClip) {
|
||||
return (AudioClassifierResult) processAudioClip(audioClip);
|
||||
}
|
||||
|
||||
/*
|
||||
* Sends audio data (a block in a continuous audio stream) to perform audio classification. Only
|
||||
* use this method when the AudioClassifier is created with the audio stream mode.
|
||||
*
|
||||
* <p>The audio block is represented as a MediaPipe {@link AudioData} object. The audio data will
|
||||
* be resampled, accumulated, and framed to the proper size for the underlying model to consume.
|
||||
* It's required to provide the corresponding audio sample rate within {@link AudioData} object as
|
||||
* well as a timestamp (in milliseconds) to indicate the start time of the input audio block. The
|
||||
* timestamps must be monotonically increasing. This method will return immediately after
|
||||
* the input audio data is accepted. The results will be available in the `resultListener`
|
||||
* provided in the `AudioClassifierOptions`. The `classifyAsync` method is designed to process
|
||||
* auido stream data such as microphone input.
|
||||
*
|
||||
* <p>The input audio block may be longer than what the model is able to process in a single
|
||||
* inference. When this occurs, the input audio block is split into multiple chunks. For this
|
||||
* reason, the callback may be called multiple times (once per chunk) for each call to this
|
||||
* function.
|
||||
*
|
||||
* @param audioBlock a MediaPipe {@link AudioData} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void classifyAsync(AudioData audioBlock, long timestampMs) {
|
||||
checkOrSetSampleRate(audioBlock.getFormat().getSampleRate());
|
||||
sendAudioStreamData(audioBlock, timestampMs);
|
||||
}
|
||||
|
||||
/** Options for setting up and {@link AudioClassifier}. */
|
||||
@AutoValue
|
||||
public abstract static class AudioClassifierOptions extends TaskOptions {
|
||||
|
||||
/** Builder for {@link AudioClassifierOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/** Sets the {@link BaseOptions} for the audio classifier task. */
|
||||
public abstract Builder setBaseOptions(BaseOptions baseOptions);
|
||||
|
||||
/**
|
||||
* Sets the {@link RunningMode} for the audio classifier task. Default to the audio clips
|
||||
* mode. Image classifier has two modes:
|
||||
*
|
||||
* <ul>
|
||||
* <li>AUDIO_CLIPS: The mode for running audio classification on audio clips. Users feed
|
||||
* audio clips to the `classify` method, and will receive the classification results as
|
||||
* the return value.
|
||||
* <li>AUDIO_STREAM: The mode for running audio classification on the audio stream, such as
|
||||
* from microphone. Users call `classifyAsync` to push the audio data into the
|
||||
* AudioClassifier, the classification results will be available in the result callback
|
||||
* when the audio classifier finishes the work.
|
||||
* </ul>
|
||||
*/
|
||||
public abstract Builder setRunningMode(RunningMode runningMode);
|
||||
|
||||
/**
|
||||
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as
|
||||
* score threshold, number of results, etc.
|
||||
*/
|
||||
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
|
||||
|
||||
/**
|
||||
* Sets the {@link ResultListener} to receive the classification results asynchronously when
|
||||
* the audio classifier is in the audio stream mode.
|
||||
*/
|
||||
public abstract Builder setResultListener(
|
||||
PureResultListener<AudioClassifierResult> resultListener);
|
||||
|
||||
/** Sets an optional {@link ErrorListener}. */
|
||||
public abstract Builder setErrorListener(ErrorListener errorListener);
|
||||
|
||||
abstract AudioClassifierOptions autoBuild();
|
||||
|
||||
/**
|
||||
* Validates and builds the {@link AudioClassifierOptions} instance.
|
||||
*
|
||||
* @throws IllegalArgumentException if the result listener and the running mode are not
|
||||
* properly configured. The result listener should only be set when the audio classifier
|
||||
* is in the audio stream mode.
|
||||
*/
|
||||
public final AudioClassifierOptions build() {
|
||||
AudioClassifierOptions options = autoBuild();
|
||||
if (options.runningMode() == RunningMode.AUDIO_STREAM) {
|
||||
if (!options.resultListener().isPresent()) {
|
||||
throw new IllegalArgumentException(
|
||||
"The audio classifier is in the audio stream mode, a user-defined result listener"
|
||||
+ " must be provided in the AudioClassifierOptions.");
|
||||
}
|
||||
} else if (options.resultListener().isPresent()) {
|
||||
throw new IllegalArgumentException(
|
||||
"The audio classifier is in the audio clips mode, a user-defined result listener"
|
||||
+ " shouldn't be provided in AudioClassifierOptions.");
|
||||
}
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract RunningMode runningMode();
|
||||
|
||||
abstract Optional<ClassifierOptions> classifierOptions();
|
||||
|
||||
abstract Optional<PureResultListener<AudioClassifierResult>> resultListener();
|
||||
|
||||
abstract Optional<ErrorListener> errorListener();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder()
|
||||
.setRunningMode(RunningMode.AUDIO_CLIPS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a {@link AudioClassifierOptions} to a {@link CalculatorOptions} protobuf message.
|
||||
*/
|
||||
@Override
|
||||
public CalculatorOptions convertToCalculatorOptionsProto() {
|
||||
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
|
||||
BaseOptionsProto.BaseOptions.newBuilder();
|
||||
baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM);
|
||||
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder =
|
||||
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder()
|
||||
.setBaseOptions(baseOptionsBuilder);
|
||||
if (classifierOptions().isPresent()) {
|
||||
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
|
||||
}
|
||||
return CalculatorOptions.newBuilder()
|
||||
.setExtension(
|
||||
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext,
|
||||
taskOptionsBuilder.build())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.audio.audioclassifier;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
/** Represents the classification results generated by {@link AudioClassifier}. */
|
||||
@AutoValue
|
||||
public abstract class AudioClassifierResult implements TaskResult {
|
||||
|
||||
/**
|
||||
* Creates an {@link AudioClassifierResult} instance from a list of {@link
|
||||
* ClassificationsProto.ClassificationResult} protobuf messages.
|
||||
*
|
||||
* @param protoList a list of {@link ClassificationsProto.ClassificationResult} protobuf message
|
||||
* to convert.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
static AudioClassifierResult createFromProtoList(
|
||||
List<ClassificationsProto.ClassificationResult> protoList, long timestampMs) {
|
||||
List<ClassificationResult> classificationResultList = new ArrayList<>();
|
||||
for (ClassificationsProto.ClassificationResult proto : protoList) {
|
||||
classificationResultList.add(ClassificationResult.createFromProto(proto));
|
||||
}
|
||||
return new AutoValue_AudioClassifierResult(
|
||||
Optional.of(classificationResultList), Optional.empty(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link AudioClassifierResult} instance from a {@link
|
||||
* ClassificationsProto.ClassificationResult} protobuf message.
|
||||
*
|
||||
* @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
static AudioClassifierResult createFromProto(
|
||||
ClassificationsProto.ClassificationResult proto, long timestampMs) {
|
||||
return new AutoValue_AudioClassifierResult(
|
||||
Optional.empty(), Optional.of(ClassificationResult.createFromProto(proto)), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* A list of of timpstamed {@link ClassificationResult} objects, each contains one set of results
|
||||
* per classifier head. The list represents the audio classification result of an audio clip, and
|
||||
* is only available when running with the audio clips mode.
|
||||
*/
|
||||
public abstract Optional<List<ClassificationResult>> classificationResultList();
|
||||
|
||||
/**
|
||||
* Contains one set of results per classifier head. A {@link ClassificationResult} usually
|
||||
* represents one audio classification result in an audio stream, and s only available when
|
||||
* running with the audio stream mode.
|
||||
*/
|
||||
public abstract Optional<ClassificationResult> classificationResult();
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
}
|
|
@ -0,0 +1,151 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.audio.core;
|
||||
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.tasks.components.containers.AudioData;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/** The base class of MediaPipe audio tasks. */
|
||||
public class BaseAudioTaskApi implements AutoCloseable {
|
||||
private static final long MICROSECONDS_PER_MILLISECOND = 1000;
|
||||
private static final long PRESTREAM_TIMESTAMP = Long.MIN_VALUE + 2;
|
||||
|
||||
private final TaskRunner runner;
|
||||
private final RunningMode runningMode;
|
||||
private final String audioStreamName;
|
||||
private final String sampleRateStreamName;
|
||||
private double defaultSampleRate;
|
||||
|
||||
static {
|
||||
System.loadLibrary("mediapipe_tasks_audio_jni");
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor to initialize a {@link BaseAudioTaskApi}.
|
||||
*
|
||||
* @param runner a {@link TaskRunner}.
|
||||
* @param runningMode a mediapipe audio task {@link RunningMode}.
|
||||
* @param audioStreamName the name of the input audio stream.
|
||||
* @param sampleRateStreamName the name of the audio sample rate stream.
|
||||
*/
|
||||
public BaseAudioTaskApi(
|
||||
TaskRunner runner,
|
||||
RunningMode runningMode,
|
||||
String audioStreamName,
|
||||
String sampleRateStreamName) {
|
||||
this.runner = runner;
|
||||
this.runningMode = runningMode;
|
||||
this.audioStreamName = audioStreamName;
|
||||
this.sampleRateStreamName = sampleRateStreamName;
|
||||
this.defaultSampleRate = -1.0;
|
||||
}
|
||||
|
||||
/**
|
||||
* A synchronous method to process audio clips. The call blocks the current thread until a failure
|
||||
* status or a successful result is returned.
|
||||
*
|
||||
* @param audioClip a MediaPipe {@link AudioDatra} object for processing.
|
||||
* @throws MediaPipeException if the task is not in the audio clips mode.
|
||||
*/
|
||||
protected TaskResult processAudioClip(AudioData audioClip) {
|
||||
if (runningMode != RunningMode.AUDIO_CLIPS) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the audio clips mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(
|
||||
audioStreamName,
|
||||
runner
|
||||
.getPacketCreator()
|
||||
.createMatrix(
|
||||
audioClip.getFormat().getNumOfChannels(),
|
||||
audioClip.getBufferLength(),
|
||||
audioClip.getBuffer()));
|
||||
inputPackets.put(
|
||||
sampleRateStreamName,
|
||||
runner.getPacketCreator().createFloat64(audioClip.getFormat().getSampleRate()));
|
||||
return runner.process(inputPackets);
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks or sets the audio sample rate in the audio stream mode.
|
||||
*
|
||||
* @param sampleRate the audio sample rate.
|
||||
* @throws MediaPipeException if the task is not in the audio stream mode or the provided sample
|
||||
* rate is inconsisent with the previously recevied.
|
||||
*/
|
||||
protected void checkOrSetSampleRate(double sampleRate) {
|
||||
if (runningMode != RunningMode.AUDIO_STREAM) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the audio stream mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
if (defaultSampleRate > 0) {
|
||||
if (Double.compare(sampleRate, defaultSampleRate) != 0) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
|
||||
"The input audio sample rate: "
|
||||
+ sampleRate
|
||||
+ " is inconsistent with the previously provided: "
|
||||
+ defaultSampleRate);
|
||||
}
|
||||
} else {
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(sampleRateStreamName, runner.getPacketCreator().createFloat64(sampleRate));
|
||||
runner.send(inputPackets, PRESTREAM_TIMESTAMP);
|
||||
defaultSampleRate = sampleRate;
|
||||
}
|
||||
}
|
||||
/**
|
||||
* An asynchronous method to send audio stream data to the {@link TaskRunner}. The results will be
|
||||
* available in the user-defined result listener.
|
||||
*
|
||||
* @param audioClip a MediaPipe {@link AudioDatra} object for processing.
|
||||
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||
* @throws MediaPipeException if the task is not in the stream mode.
|
||||
*/
|
||||
protected void sendAudioStreamData(AudioData audioClip, long timestampMs) {
|
||||
if (runningMode != RunningMode.AUDIO_STREAM) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the audio stream mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(
|
||||
audioStreamName,
|
||||
runner
|
||||
.getPacketCreator()
|
||||
.createMatrix(
|
||||
audioClip.getFormat().getNumOfChannels(),
|
||||
audioClip.getBufferLength(),
|
||||
audioClip.getBuffer()));
|
||||
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
/** Closes and cleans up the MediaPipe audio task. */
|
||||
@Override
|
||||
public void close() {
|
||||
runner.close();
|
||||
}
|
||||
}
|