Merge branch 'google:master' into image-embedder-python

This commit is contained in:
Kinar R 2022-11-11 10:38:52 +05:30 committed by GitHub
commit f27068c6f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
149 changed files with 5818 additions and 592 deletions

View File

@ -155,6 +155,7 @@ http_archive(
name = "com_google_audio_tools",
strip_prefix = "multichannel-audio-tools-master",
urls = ["https://github.com/google/multichannel-audio-tools/archive/master.zip"],
repo_mapping = {"@com_github_glog_glog" : "@com_github_glog_glog_no_gflags"},
)
http_archive(

View File

@ -12,3 +12,21 @@ py_binary(
"//third_party/py/tensorflow_docs/api_generator:public_api",
],
)
py_binary(
name = "build_java_api_docs",
srcs = ["build_java_api_docs.py"],
data = [
"//third_party/java/doclava/current:doclava.jar",
"//third_party/java/jsilver:jsilver_jar",
],
env = {
"DOCLAVA_JAR": "$(location //third_party/java/doclava/current:doclava.jar)",
"JSILVER_JAR": "$(location //third_party/java/jsilver:jsilver_jar)",
},
deps = [
"//third_party/py/absl:app",
"//third_party/py/absl/flags",
"//third_party/py/tensorflow_docs/api_generator/gen_java",
],
)

View File

@ -0,0 +1,58 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generate Java reference docs for MediaPipe."""
import pathlib
from absl import app
from absl import flags
from tensorflow_docs.api_generator import gen_java
_JAVA_ROOT = flags.DEFINE_string('java_src', None,
'Override the Java source path.',
required=False)
_OUT_DIR = flags.DEFINE_string('output_dir', '/tmp/mp_java/',
'Write docs here.')
_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/java',
'Path prefix in the _toc.yaml')
_ = flags.DEFINE_string('code_url_prefix', None,
'[UNUSED] The url prefix for links to code.')
_ = flags.DEFINE_bool(
'search_hints', True,
'[UNUSED] Include metadata search hints in the generated files')
def main(_) -> None:
if not (java_root := _JAVA_ROOT.value):
# Default to using a relative path to find the Java source.
mp_root = pathlib.Path(__file__)
while (mp_root := mp_root.parent).name != 'mediapipe':
# Find the nearest `mediapipe` dir.
pass
java_root = mp_root / 'tasks/java'
gen_java.gen_java_docs(
package='com.google.mediapipe',
source_path=pathlib.Path(java_root),
output_dir=pathlib.Path(_OUT_DIR.value),
site_path=pathlib.Path(_SITE_PATH.value))
if __name__ == '__main__':
app.run(main)

View File

@ -1,4 +1,4 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,14 +1,13 @@
"""Copyright 2019 - 2020 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Copyright 2019 - 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -277,6 +277,7 @@ cc_test(
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status",
"//mediapipe/util:time_series_test_util",
"@com_google_audio_tools//audio/dsp/mfcc",
"@eigen_archive//:eigen3",
],
)
@ -352,6 +353,8 @@ cc_test(
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:validate_type",
"//mediapipe/util:time_series_test_util",
"@com_google_audio_tools//audio/dsp:resampler",
"@com_google_audio_tools//audio/dsp:resampler_q",
"@com_google_audio_tools//audio/dsp:signal_vector_util",
"@eigen_archive//:eigen3",
],

View File

@ -130,7 +130,7 @@ class BypassCalculator : public Node {
pass_out.insert(entry.second);
auto& packet = cc->Inputs().Get(entry.first).Value();
if (packet.Timestamp() == cc->InputTimestamp()) {
cc->Outputs().Get(entry.first).AddPacket(packet);
cc->Outputs().Get(entry.second).AddPacket(packet);
}
}
Timestamp bound = cc->InputTimestamp().NextAllowedInStream();

View File

@ -42,10 +42,10 @@ constexpr char kTestGraphConfig1[] = R"pb(
node {
calculator: "BypassCalculator"
input_stream: "PASS:appearances"
input_stream: "TRUNCATE:0:video_frame"
input_stream: "TRUNCATE:1:feature_config"
input_stream: "IGNORE:0:video_frame"
input_stream: "IGNORE:1:feature_config"
output_stream: "PASS:passthrough_appearances"
output_stream: "TRUNCATE:passthrough_federated_gaze_output"
output_stream: "IGNORE:passthrough_federated_gaze_output"
node_options: {
[type.googleapis.com/mediapipe.BypassCalculatorOptions] {
pass_input_stream: "PASS"

View File

@ -156,6 +156,7 @@ cc_test(
"//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_audio_tools//audio/dsp:window_functions",
],
)

View File

@ -750,19 +750,12 @@ objc_library(
],
)
proto_library(
mediapipe_proto_library(
name = "scale_mode_proto",
srcs = ["scale_mode.proto"],
visibility = ["//visibility:public"],
)
mediapipe_cc_proto_library(
name = "scale_mode_cc_proto",
srcs = ["scale_mode.proto"],
visibility = ["//visibility:public"],
deps = [":scale_mode_proto"],
)
cc_library(
name = "gl_quad_renderer",
srcs = ["gl_quad_renderer.cc"],

View File

@ -0,0 +1,51 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//mediapipe/framework/tool:mediapipe_files.bzl",
"mediapipe_files",
)
licenses(["notice"])
package(
default_visibility = ["//mediapipe/model_maker/python/vision/gesture_recognizer:__subpackages__"],
)
mediapipe_files(
srcs = [
"canned_gesture_classifier.tflite",
"gesture_embedder.tflite",
"gesture_embedder/keras_metadata.pb",
"gesture_embedder/saved_model.pb",
"gesture_embedder/variables/variables.data-00000-of-00001",
"gesture_embedder/variables/variables.index",
"hand_landmark_full.tflite",
"palm_detection_full.tflite",
],
)
filegroup(
name = "models",
srcs = [
"canned_gesture_classifier.tflite",
"gesture_embedder.tflite",
"gesture_embedder/keras_metadata.pb",
"gesture_embedder/saved_model.pb",
"gesture_embedder/variables/variables.data-00000-of-00001",
"gesture_embedder/variables/variables.index",
"hand_landmark_full.tflite",
"palm_detection_full.tflite",
],
)

View File

@ -0,0 +1,35 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
licenses(["notice"])
py_library(
name = "bert_model_options",
srcs = ["bert_model_options.py"],
)
py_library(
name = "bert_model_spec",
srcs = ["bert_model_spec.py"],
deps = [
":bert_model_options",
"//mediapipe/model_maker/python/core:hyperparameters",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,33 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configurable model options for a BERT model."""
import dataclasses
@dataclasses.dataclass
class BertModelOptions:
"""Configurable model options for a BERT model.
See https://arxiv.org/abs/1810.04805 (BERT: Pre-training of Deep Bidirectional
Transformers for Language Understanding) for more details.
Attributes:
seq_len: Length of the sequence to feed into the model.
do_fine_tuning: If true, then the BERT model is not frozen for training.
dropout_rate: The rate for dropout.
"""
seq_len: int = 128
do_fine_tuning: bool = True
dropout_rate: float = 0.1

View File

@ -0,0 +1,58 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Specification for a BERT model."""
import dataclasses
from typing import Dict
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.text.core import bert_model_options
_DEFAULT_TFLITE_INPUT_NAME = {
'ids': 'serving_default_input_word_ids:0',
'mask': 'serving_default_input_mask:0',
'segment_ids': 'serving_default_input_type_ids:0'
}
@dataclasses.dataclass
class BertModelSpec:
"""Specification for a BERT model.
See https://arxiv.org/abs/1810.04805 (BERT: Pre-training of Deep Bidirectional
Transformers for Language Understanding) for more details.
Attributes:
hparams: Hyperparameters used for training.
model_options: Configurable options for a BERT model.
do_lower_case: boolean, whether to lower case the input text. Should be
True / False for uncased / cased models respectively, where the models
are specified by the `uri`.
tflite_input_name: Dict, input names for the TFLite model.
uri: URI for the BERT module.
name: The name of the object.
"""
hparams: hp.BaseHParams = hp.BaseHParams(
epochs=3,
batch_size=32,
learning_rate=3e-5,
distribution_strategy='mirrored')
model_options: bert_model_options.BertModelOptions = (
bert_model_options.BertModelOptions())
do_lower_case: bool = True
tflite_input_name: Dict[str, str] = dataclasses.field(
default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME)
uri: str = 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1'
name: str = 'Bert'

View File

@ -0,0 +1,146 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python strict test compatibility macro.
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
licenses(["notice"])
py_library(
name = "model_options",
srcs = ["model_options.py"],
deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"],
)
py_library(
name = "model_spec",
srcs = ["model_spec.py"],
deps = [
":model_options",
"//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/model_maker/python/text/core:bert_model_spec",
],
)
py_test(
name = "model_spec_test",
srcs = ["model_spec_test.py"],
deps = [
":model_options",
":model_spec",
"//mediapipe/model_maker/python/core:hyperparameters",
],
)
py_library(
name = "dataset",
srcs = ["dataset.py"],
deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"],
)
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
deps = [":dataset"],
)
py_library(
name = "preprocessor",
srcs = ["preprocessor.py"],
deps = [":dataset"],
)
py_test(
name = "preprocessor_test",
srcs = ["preprocessor_test.py"],
tags = ["requires-net:external"],
deps = [
":dataset",
":model_spec",
":preprocessor",
],
)
py_library(
name = "text_classifier_options",
srcs = ["text_classifier_options.py"],
deps = [
":model_options",
":model_spec",
"//mediapipe/model_maker/python/core:hyperparameters",
],
)
py_library(
name = "text_classifier",
srcs = ["text_classifier.py"],
deps = [
":dataset",
":model_options",
":model_spec",
":preprocessor",
":text_classifier_options",
"//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/model_maker/python/core/data:dataset",
"//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/core/utils:quantization",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
"//mediapipe/tasks/python/metadata/metadata_writers:text_classifier",
],
)
py_test(
name = "text_classifier_test",
size = "large",
srcs = ["text_classifier_test.py"],
data = [
"//mediapipe/model_maker/python/text/text_classifier/testdata",
],
tags = ["requires-net:external"],
deps = [
":dataset",
":model_options",
":model_spec",
":text_classifier",
":text_classifier_options",
"//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/tasks/python/test:test_utils",
],
)
py_library(
name = "text_classifier_demo_lib",
srcs = ["text_classifier_demo.py"],
deps = [
":dataset",
":model_spec",
":text_classifier",
":text_classifier_options",
"//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/model_maker/python/core/utils:quantization",
],
)
py_binary(
name = "text_classifier_demo",
srcs = ["text_classifier_demo.py"],
deps = [
":text_classifier_demo_lib",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,88 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text classifier dataset library."""
import csv
import dataclasses
import random
from typing import Optional, Sequence
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset
@dataclasses.dataclass
class CSVParameters:
"""Parameters used when reading a CSV file.
Attributes:
text_column: Column name for the input text.
label_column: Column name for the labels.
fieldnames: Sequence of keys for the CSV columns. If None, the first row of
the CSV file is used as the keys.
delimiter: Character that separates fields.
quotechar: Character used to quote fields that contain special characters
like the `delimiter`.
"""
text_column: str
label_column: str
fieldnames: Optional[Sequence[str]] = None
delimiter: str = ","
quotechar: str = '"'
class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for text classifier."""
@classmethod
def from_csv(cls,
filename: str,
csv_params: CSVParameters,
shuffle: bool = True) -> "Dataset":
"""Loads text with labels from a CSV file.
Args:
filename: Name of the CSV file.
csv_params: Parameters used for reading the CSV file.
shuffle: If True, randomly shuffle the data.
Returns:
Dataset containing (text, label) pairs and other related info.
"""
with tf.io.gfile.GFile(filename, "r") as f:
reader = csv.DictReader(
f,
fieldnames=csv_params.fieldnames,
delimiter=csv_params.delimiter,
quotechar=csv_params.quotechar)
lines = list(reader)
if shuffle:
random.shuffle(lines)
label_names = sorted(set([line[csv_params.label_column] for line in lines]))
index_by_label = {label: index for index, label in enumerate(label_names)}
texts = [line[csv_params.text_column] for line in lines]
text_ds = tf.data.Dataset.from_tensor_slices(tf.cast(texts, tf.string))
label_indices = [
index_by_label[line[csv_params.label_column]] for line in lines
]
label_index_ds = tf.data.Dataset.from_tensor_slices(
tf.cast(label_indices, tf.int64))
text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds))
return Dataset(
dataset=text_label_ds, size=len(texts), label_names=label_names)

View File

@ -0,0 +1,75 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
import os
import tensorflow as tf
from mediapipe.model_maker.python.text.text_classifier import dataset
class DatasetTest(tf.test.TestCase):
def _get_csv_file(self):
labels_and_text = (('neutral', 'indifferent'), ('pos', 'extremely great'),
('neg', 'totally awful'), ('pos', 'super good'),
('neg', 'really bad'))
csv_file = os.path.join(self.get_temp_dir(), 'data.csv')
if os.path.exists(csv_file):
return csv_file
fieldnames = ['text', 'label']
with open(csv_file, 'w') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for label, text in labels_and_text:
writer.writerow({'text': text, 'label': label})
return csv_file
def test_from_csv(self):
csv_file = self._get_csv_file()
csv_params = dataset.CSVParameters(text_column='text', label_column='label')
data = dataset.Dataset.from_csv(filename=csv_file, csv_params=csv_params)
self.assertLen(data, 5)
self.assertEqual(data.num_classes, 3)
self.assertEqual(data.label_names, ['neg', 'neutral', 'pos'])
data_values = set([(text.numpy()[0], label.numpy()[0])
for text, label in data.gen_tf_dataset()])
expected_data_values = set([(b'indifferent', 1), (b'extremely great', 2),
(b'totally awful', 0), (b'super good', 2),
(b'really bad', 0)])
self.assertEqual(data_values, expected_data_values)
def test_split(self):
ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd'])
data = dataset.Dataset(ds, 4, ['pos', 'neg'])
train_data, test_data = data.split(0.5)
expected_train_data = [b'good', b'bad']
expected_test_data = [b'neutral', b'odd']
self.assertLen(train_data, 2)
train_data_values = [elem.numpy() for elem in train_data._dataset]
self.assertEqual(train_data_values, expected_train_data)
self.assertEqual(train_data.num_classes, 2)
self.assertEqual(train_data.label_names, ['pos', 'neg'])
self.assertLen(test_data, 2)
test_data_values = [elem.numpy() for elem in test_data._dataset]
self.assertEqual(test_data_values, expected_test_data)
self.assertEqual(test_data.num_classes, 2)
self.assertEqual(test_data.label_names, ['pos', 'neg'])
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,45 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configurable model options for text classifier models."""
import dataclasses
from typing import Union
from mediapipe.model_maker.python.text.core import bert_model_options
# BERT text classifier options inherited from BertModelOptions.
BertClassifierOptions = bert_model_options.BertModelOptions
@dataclasses.dataclass
class AverageWordEmbeddingClassifierOptions:
"""Configurable model options for an Average Word Embedding classifier.
Attributes:
seq_len: Length of the sequence to feed into the model.
wordvec_dim: Dimension of the word embedding.
do_lower_case: Whether to convert all uppercase characters to lowercase
during preprocessing.
vocab_size: Number of words to generate the vocabulary from data.
dropout_rate: The rate for dropout.
"""
seq_len: int = 256
wordvec_dim: int = 16
do_lower_case: bool = True
vocab_size: int = 10000
dropout_rate: float = 0.2
TextClassifierModelOptions = Union[AverageWordEmbeddingClassifierOptions,
BertClassifierOptions]

View File

@ -0,0 +1,70 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Specifications for text classifier models."""
import dataclasses
import enum
import functools
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.text.core import bert_model_spec
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
# BERT-based text classifier spec inherited from BertModelSpec
BertClassifierSpec = bert_model_spec.BertModelSpec
@dataclasses.dataclass
class AverageWordEmbeddingClassifierSpec:
"""Specification for an average word embedding classifier model.
Attributes:
hparams: Configurable hyperparameters for training.
model_options: Configurable options for the average word embedding model.
name: The name of the object.
"""
# `learning_rate` is unused for the average word embedding model
hparams: hp.BaseHParams = hp.BaseHParams(
epochs=10, batch_size=32, learning_rate=0)
model_options: mo.AverageWordEmbeddingClassifierOptions = (
mo.AverageWordEmbeddingClassifierOptions())
name: str = 'AverageWordEmbedding'
average_word_embedding_classifier_spec = functools.partial(
AverageWordEmbeddingClassifierSpec)
mobilebert_classifier_spec = functools.partial(
BertClassifierSpec,
hparams=hp.BaseHParams(
epochs=3,
batch_size=48,
learning_rate=3e-5,
distribution_strategy='off'),
name='MobileBert',
uri='https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1',
tflite_input_name={
'ids': 'serving_default_input_1:0',
'mask': 'serving_default_input_3:0',
'segment_ids': 'serving_default_input_2:0'
},
)
@enum.unique
class SupportedModels(enum.Enum):
"""Predefined text classifier model specs supported by Model Maker."""
AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec

View File

@ -0,0 +1,118 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for model_spec."""
import os
import tensorflow as tf
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
class ModelSpecTest(tf.test.TestCase):
def test_predefined_bert_spec(self):
model_spec_obj = ms.SupportedModels.MOBILEBERT_CLASSIFIER.value()
self.assertIsInstance(model_spec_obj, ms.BertClassifierSpec)
self.assertEqual(model_spec_obj.name, 'MobileBert')
self.assertEqual(
model_spec_obj.uri, 'https://tfhub.dev/tensorflow/'
'mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1')
self.assertTrue(model_spec_obj.do_lower_case)
self.assertEqual(
model_spec_obj.tflite_input_name, {
'ids': 'serving_default_input_1:0',
'mask': 'serving_default_input_3:0',
'segment_ids': 'serving_default_input_2:0'
})
self.assertEqual(
model_spec_obj.model_options,
classifier_model_options.BertClassifierOptions(
seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
self.assertEqual(
model_spec_obj.hparams,
hp.BaseHParams(
epochs=3,
batch_size=48,
learning_rate=3e-5,
distribution_strategy='off'))
def test_predefined_average_word_embedding_spec(self):
model_spec_obj = (
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER.value())
self.assertIsInstance(model_spec_obj, ms.AverageWordEmbeddingClassifierSpec)
self.assertEqual(model_spec_obj.name, 'AverageWordEmbedding')
self.assertEqual(
model_spec_obj.model_options,
classifier_model_options.AverageWordEmbeddingClassifierOptions(
seq_len=256,
wordvec_dim=16,
do_lower_case=True,
vocab_size=10000,
dropout_rate=0.2))
self.assertEqual(
model_spec_obj.hparams,
hp.BaseHParams(
epochs=10,
batch_size=32,
learning_rate=0,
steps_per_epoch=None,
shuffle=False,
distribution_strategy='off',
num_gpus=-1,
tpu=''))
def test_custom_bert_spec(self):
custom_bert_classifier_options = (
classifier_model_options.BertClassifierOptions(
seq_len=512, do_fine_tuning=False, dropout_rate=0.3))
model_spec_obj = (
ms.SupportedModels.MOBILEBERT_CLASSIFIER.value(
model_options=custom_bert_classifier_options))
self.assertEqual(model_spec_obj.model_options,
custom_bert_classifier_options)
def test_custom_average_word_embedding_spec(self):
custom_hparams = hp.BaseHParams(
learning_rate=0.4,
batch_size=64,
epochs=10,
steps_per_epoch=10,
shuffle=True,
export_dir='foo/bar',
distribution_strategy='mirrored',
num_gpus=3,
tpu='tpu/address')
custom_average_word_embedding_model_options = (
classifier_model_options.AverageWordEmbeddingClassifierOptions(
seq_len=512,
wordvec_dim=32,
do_lower_case=False,
vocab_size=5000,
dropout_rate=0.5))
model_spec_obj = (
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER.value(
model_options=custom_average_word_embedding_model_options,
hparams=custom_hparams))
self.assertEqual(model_spec_obj.model_options,
custom_average_word_embedding_model_options)
self.assertEqual(model_spec_obj.hparams, custom_hparams)
if __name__ == '__main__':
# Load compressed models from tensorflow_hub
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
tf.test.main()

View File

@ -0,0 +1,285 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocessors for text classification."""
import collections
import os
import re
import tempfile
from typing import Mapping, Sequence, Tuple, Union
import tensorflow as tf
import tensorflow_hub
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
from official.nlp.data import classifier_data_lib
from official.nlp.tools import tokenization
def _validate_text_and_label(text: tf.Tensor, label: tf.Tensor) -> None:
"""Validates the shape and type of `text` and `label`.
Args:
text: Stores text data. Should have shape [1] and dtype tf.string.
label: Stores the label for the corresponding `text`. Should have shape [1]
and dtype tf.int64.
Raises:
ValueError: If either tensor has the wrong shape or type.
"""
if text.shape != [1]:
raise ValueError(f"`text` should have shape [1], got {text.shape}")
if text.dtype != tf.string:
raise ValueError(f"Expected dtype string for `text`, got {text.dtype}")
if label.shape != [1]:
raise ValueError(f"`label` should have shape [1], got {text.shape}")
if label.dtype != tf.int64:
raise ValueError(f"Expected dtype int64 for `label`, got {label.dtype}")
def _decode_record(
record: tf.Tensor, name_to_features: Mapping[str, tf.io.FixedLenFeature]
) -> Tuple[Mapping[str, tf.Tensor], tf.Tensor]:
"""Decodes a record into input for a BERT model.
Args:
record: Stores serialized example.
name_to_features: Maps record keys to feature types.
Returns:
BERT model input features and label for the record.
"""
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
for name in list(example.keys()):
example[name] = tf.cast(example[name], tf.int32)
bert_features = {
"input_word_ids": example["input_ids"],
"input_mask": example["input_mask"],
"input_type_ids": example["segment_ids"]
}
return bert_features, example["label_ids"]
def _single_file_dataset(
input_file: str, name_to_features: Mapping[str, tf.io.FixedLenFeature]
) -> tf.data.TFRecordDataset:
"""Creates a single-file dataset to be passed for BERT custom training.
Args:
input_file: Filepath for the dataset.
name_to_features: Maps record keys to feature types.
Returns:
Dataset containing BERT model input features and labels.
"""
d = tf.data.TFRecordDataset(input_file)
d = d.map(
lambda record: _decode_record(record, name_to_features),
num_parallel_calls=tf.data.AUTOTUNE)
return d
class AverageWordEmbeddingClassifierPreprocessor:
"""Preprocessor for an Average Word Embedding model.
Takes (text, label) data and applies regex tokenization and padding to the
text to generate (token IDs, label) data.
Attributes:
seq_len: Length of the input sequence to the model.
do_lower_case: Whether text inputs should be converted to lower-case.
vocab: Vocabulary of tokens used by the model.
"""
PAD: str = "<PAD>" # Index: 0
START: str = "<START>" # Index: 1
UNKNOWN: str = "<UNKNOWN>" # Index: 2
def __init__(self, seq_len: int, do_lower_case: bool, texts: Sequence[str],
vocab_size: int):
self._seq_len = seq_len
self._do_lower_case = do_lower_case
self._vocab = self._gen_vocab(texts, vocab_size)
def _gen_vocab(self, texts: Sequence[str],
vocab_size: int) -> Mapping[str, int]:
"""Generates vocabulary list in `texts` with size `vocab_size`.
Args:
texts: All texts (across training and validation data) that will be
preprocessed by the model.
vocab_size: Size of the vocab.
Returns:
The vocab mapping tokens to IDs.
"""
vocab_counter = collections.Counter()
for text in texts:
tokens = self._regex_tokenize(text)
for token in tokens:
vocab_counter[token] += 1
vocab_freq = vocab_counter.most_common(vocab_size)
vocab_list = [self.PAD, self.START, self.UNKNOWN
] + [word for word, _ in vocab_freq]
return collections.OrderedDict(((v, i) for i, v in enumerate(vocab_list)))
def get_vocab(self) -> Mapping[str, int]:
"""Returns the vocab of the AverageWordEmbeddingClassifierPreprocessor."""
return self._vocab
# TODO: Align with MediaPipe's RegexTokenizer.
def _regex_tokenize(self, text: str) -> Sequence[str]:
"""Splits `text` by words but does not split on single quotes.
Args:
text: Text to be tokenized.
Returns:
List of tokens.
"""
text = tf.compat.as_text(text)
if self._do_lower_case:
text = text.lower()
tokens = re.compile(r"[^\w\']+").split(text.strip())
# Filters out any empty strings in `tokens`.
return list(filter(None, tokens))
def _tokenize_and_pad(self, text: str) -> Sequence[int]:
"""Tokenizes `text` and pads the tokens to `seq_len`.
Args:
text: Text to be tokenized and padded.
Returns:
List of token IDs padded to have length `seq_len`.
"""
tokens = self._regex_tokenize(text)
# Gets ids for START, PAD and UNKNOWN tokens.
start_id = self._vocab[self.START]
pad_id = self._vocab[self.PAD]
unknown_id = self._vocab[self.UNKNOWN]
token_ids = [self._vocab.get(token, unknown_id) for token in tokens]
token_ids = [start_id] + token_ids
if len(token_ids) < self._seq_len:
pad_length = self._seq_len - len(token_ids)
token_ids = token_ids + pad_length * [pad_id]
else:
token_ids = token_ids[:self._seq_len]
return token_ids
def preprocess(
self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset:
"""Preprocesses data into input for an Average Word Embedding model.
Args:
dataset: Stores (text, label) data.
Returns:
Dataset containing (token IDs, label) data.
"""
token_ids_list = []
labels_list = []
for text, label in dataset.gen_tf_dataset():
_validate_text_and_label(text, label)
token_ids = self._tokenize_and_pad(text.numpy()[0].decode("utf-8"))
token_ids_list.append(token_ids)
labels_list.append(label.numpy()[0])
token_ids_ds = tf.data.Dataset.from_tensor_slices(token_ids_list)
labels_ds = tf.data.Dataset.from_tensor_slices(labels_list)
preprocessed_ds = tf.data.Dataset.zip((token_ids_ds, labels_ds))
return text_classifier_ds.Dataset(
dataset=preprocessed_ds,
size=dataset.size,
label_names=dataset.label_names)
class BertClassifierPreprocessor:
"""Preprocessor for a BERT-based classifier.
Attributes:
seq_len: Length of the input sequence to the model.
vocab_file: File containing the BERT vocab.
tokenizer: BERT tokenizer.
"""
def __init__(self, seq_len: int, do_lower_case: bool, uri: str):
self._seq_len = seq_len
# Vocab filepath is tied to the BERT module's URI.
self._vocab_file = os.path.join(
tensorflow_hub.resolve(uri), "assets", "vocab.txt")
self._tokenizer = tokenization.FullTokenizer(self._vocab_file,
do_lower_case)
def _get_name_to_features(self):
"""Gets the dictionary mapping record keys to feature types."""
return {
"input_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
"input_mask": tf.io.FixedLenFeature([self._seq_len], tf.int64),
"segment_ids": tf.io.FixedLenFeature([self._seq_len], tf.int64),
"label_ids": tf.io.FixedLenFeature([], tf.int64),
}
def get_vocab_file(self) -> str:
"""Returns the vocab file of the BertClassifierPreprocessor."""
return self._vocab_file
def preprocess(
self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset:
"""Preprocesses data into input for a BERT-based classifier.
Args:
dataset: Stores (text, label) data.
Returns:
Dataset containing (bert_features, label) data.
"""
examples = []
for index, (text, label) in enumerate(dataset.gen_tf_dataset()):
_validate_text_and_label(text, label)
examples.append(
classifier_data_lib.InputExample(
guid=str(index),
text_a=text.numpy()[0].decode("utf-8"),
text_b=None,
# InputExample expects the label name rather than the int ID
label=dataset.label_names[label.numpy()[0]]))
tfrecord_file = os.path.join(tempfile.mkdtemp(), "bert_features.tfrecord")
classifier_data_lib.file_based_convert_examples_to_features(
examples=examples,
label_list=dataset.label_names,
max_seq_length=self._seq_len,
tokenizer=self._tokenizer,
output_file=tfrecord_file)
preprocessed_ds = _single_file_dataset(tfrecord_file,
self._get_name_to_features())
return text_classifier_ds.Dataset(
dataset=preprocessed_ds,
size=dataset.size,
label_names=dataset.label_names)
TextClassifierPreprocessor = (
Union[BertClassifierPreprocessor,
AverageWordEmbeddingClassifierPreprocessor])

View File

@ -0,0 +1,96 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
import os
import numpy as np
import numpy.testing as npt
import tensorflow as tf
from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds
from mediapipe.model_maker.python.text.text_classifier import model_spec
from mediapipe.model_maker.python.text.text_classifier import preprocessor
class PreprocessorTest(tf.test.TestCase):
CSV_PARAMS_ = text_classifier_ds.CSVParameters(
text_column='text', label_column='label')
def _get_csv_file(self):
labels_and_text = (('pos', 'super super super super good'),
(('neg', 'really bad')))
csv_file = os.path.join(self.get_temp_dir(), 'data.csv')
if os.path.exists(csv_file):
return csv_file
fieldnames = ['text', 'label']
with open(csv_file, 'w') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for label, text in labels_and_text:
writer.writerow({'text': text, 'label': label})
return csv_file
def test_average_word_embedding_preprocessor(self):
csv_file = self._get_csv_file()
dataset = text_classifier_ds.Dataset.from_csv(
filename=csv_file, csv_params=self.CSV_PARAMS_)
average_word_embedding_preprocessor = (
preprocessor.AverageWordEmbeddingClassifierPreprocessor(
seq_len=5,
do_lower_case=True,
texts=['super super super super good', 'really bad'],
vocab_size=7))
preprocessed_dataset = (
average_word_embedding_preprocessor.preprocess(dataset))
labels = []
features_list = []
for features, label in preprocessed_dataset.gen_tf_dataset():
self.assertEqual(label.shape, [1])
labels.append(label.numpy()[0])
self.assertEqual(features.shape, [1, 5])
features_list.append(features.numpy()[0])
self.assertEqual(labels, [1, 0])
npt.assert_array_equal(
np.stack(features_list), np.array([[1, 3, 3, 3, 3], [1, 5, 6, 0, 0]]))
def test_bert_preprocessor(self):
csv_file = self._get_csv_file()
dataset = text_classifier_ds.Dataset.from_csv(
filename=csv_file, csv_params=self.CSV_PARAMS_)
bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
bert_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=5, do_lower_case=bert_spec.do_lower_case, uri=bert_spec.uri)
preprocessed_dataset = bert_preprocessor.preprocess(dataset)
labels = []
input_masks = []
for features, label in preprocessed_dataset.gen_tf_dataset():
self.assertEqual(label.shape, [1])
labels.append(label.numpy()[0])
self.assertSameElements(
features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids'])
for feature in features.values():
self.assertEqual(feature.shape, [1, 5])
input_masks.append(features['input_mask'].numpy()[0])
npt.assert_array_equal(features['input_type_ids'].numpy()[0],
[0, 0, 0, 0, 0])
npt.assert_array_equal(
np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]]))
self.assertEqual(labels, [1, 0])
if __name__ == '__main__':
# Load compressed models from tensorflow_hub
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
tf.test.main()

View File

@ -0,0 +1,23 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(
default_visibility = ["//mediapipe/model_maker/python/text/text_classifier:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
filegroup(
name = "testdata",
srcs = ["average_word_embedding_metadata.json"],
)

View File

@ -0,0 +1,63 @@
{
"name": "TextClassifier",
"description": "Classify the input text into a set of known categories.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "input_text",
"description": "Embedding vectors representing the input text to be processed.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"process_units": [
{
"options_type": "RegexTokenizerOptions",
"options": {
"delim_regex_pattern": "[^\\w\\']+",
"vocab_file": [
{
"name": "vocab.txt",
"description": "Vocabulary file to convert natural language words to embedding vectors.",
"type": "VOCABULARY"
}
]
}
}
],
"stats": {
}
}
],
"output_tensor_metadata": [
{
"name": "score",
"description": "Score of the labels respectively.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
]
}
],
"min_parser_version": "1.2.1"
}

View File

@ -0,0 +1,437 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""API for text classification."""
import abc
import os
import tempfile
from typing import Any, Optional, Sequence, Tuple
import tensorflow as tf
import tensorflow_hub as hub
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.core.data import dataset as ds
from mediapipe.model_maker.python.core.tasks import classifier
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
from mediapipe.model_maker.python.text.text_classifier import preprocessor
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.metadata.metadata_writers import text_classifier as text_classifier_writer
from official.nlp import optimization
def _validate(options: text_classifier_options.TextClassifierOptions):
"""Validates that `model_options` and `supported_model` are compatible.
Args:
options: Options for creating and training a text classifier.
Raises:
ValueError if there is a mismatch between `model_options` and
`supported_model`.
"""
if options.model_options is None:
return
if (isinstance(options.model_options,
mo.AverageWordEmbeddingClassifierOptions) and
(options.supported_model !=
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
f" got {options.supported_model}")
if (isinstance(options.model_options, mo.BertClassifierOptions) and
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
raise ValueError(
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
class TextClassifier(classifier.Classifier):
"""API for creating and training a text classification model."""
def __init__(self, model_spec: Any, hparams: hp.BaseHParams,
label_names: Sequence[str]):
super().__init__(
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
self._model_spec = model_spec
self._hparams = hparams
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None
@classmethod
def create(
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
options: text_classifier_options.TextClassifierOptions
) -> "TextClassifier":
"""Factory function that creates and trains a text classifier.
Note that `train_data` and `validation_data` are expected to share the same
`label_names` since they should be split from the same dataset.
Args:
train_data: Training data.
validation_data: Validation data.
options: Options for creating and training the text classifier.
Returns:
A text classifier.
Raises:
ValueError if `train_data` and `validation_data` do not have the
same label_names or `options` contains an unknown `supported_model`
"""
if train_data.label_names != validation_data.label_names:
raise ValueError(
f"Training data label names {train_data.label_names} not equal to "
f"validation data label names {validation_data.label_names}")
_validate(options)
if options.model_options is None:
options.model_options = options.supported_model.value().model_options
if options.hparams is None:
options.hparams = options.supported_model.value().hparams
if options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
text_classifier = (
_BertClassifier.create_bert_classifier(train_data, validation_data,
options,
train_data.label_names))
elif (options.supported_model ==
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
text_classifier = (
_AverageWordEmbeddingClassifier
.create_average_word_embedding_classifier(train_data, validation_data,
options,
train_data.label_names))
else:
raise ValueError(f"Unknown model {options.supported_model}")
return text_classifier
def evaluate(self, data: ds.Dataset, batch_size: int = 32) -> Any:
"""Overrides Classifier.evaluate().
Args:
data: Evaluation dataset. Must be a TextClassifier Dataset.
batch_size: Number of samples per evaluation step.
Returns:
The loss value and accuracy.
Raises:
ValueError if `data` is not a TextClassifier Dataset.
"""
# This override is needed because TextClassifier preprocesses its data
# outside of the `gen_tf_dataset()` method. The preprocess call also
# requires a TextClassifier Dataset instead of a core Dataset.
if not isinstance(data, text_ds.Dataset):
raise ValueError("Need a TextClassifier Dataset.")
processed_data = self._text_preprocessor.preprocess(data)
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
return self._model.evaluate(dataset)
def export_model(
self,
model_name: str = "model.tflite",
quantization_config: Optional[quantization.QuantizationConfig] = None):
"""Converts and saves the model to a TFLite file with metadata included.
Note that only the TFLite file is needed for deployment. This function also
saves a metadata.json file to the same directory as the TFLite file which
can be used to interpret the metadata content in the TFLite file.
Args:
model_name: File name to save TFLite model with metadata. The full export
path is {self._hparams.export_dir}/{model_name}.
quantization_config: The configuration for model quantization.
"""
if not tf.io.gfile.exists(self._hparams.export_dir):
tf.io.gfile.makedirs(self._hparams.export_dir)
tflite_file = os.path.join(self._hparams.export_dir, model_name)
metadata_file = os.path.join(self._hparams.export_dir, "metadata.json")
tflite_model = model_util.convert_to_tflite(
model=self._model, quantization_config=quantization_config)
vocab_filepath = os.path.join(tempfile.mkdtemp(), "vocab.txt")
self._save_vocab(vocab_filepath)
writer = self._get_metadata_writer(tflite_model, vocab_filepath)
tflite_model_with_metadata, metadata_json = writer.populate()
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
with open(metadata_file, "w") as f:
f.write(metadata_json)
@abc.abstractmethod
def _save_vocab(self, vocab_filepath: str):
"""Saves the preprocessor's vocab to `vocab_filepath`."""
@abc.abstractmethod
def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str):
"""Gets the metadata writer for the text classifier TFLite model."""
class _AverageWordEmbeddingClassifier(TextClassifier):
"""APIs to help create and train an Average Word Embedding text classifier."""
_DELIM_REGEX_PATTERN = r"[^\w\']+"
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
model_options: mo.AverageWordEmbeddingClassifierOptions,
hparams: hp.BaseHParams, label_names: Sequence[str]):
super().__init__(model_spec, hparams, label_names)
self._model_options = model_options
self._loss_function = "sparse_categorical_crossentropy"
self._metric_function = "accuracy"
self._text_preprocessor: (
preprocessor.AverageWordEmbeddingClassifierPreprocessor) = None
@classmethod
def create_average_word_embedding_classifier(
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
options: text_classifier_options.TextClassifierOptions,
label_names: Sequence[str]) -> "_AverageWordEmbeddingClassifier":
"""Creates, trains, and returns an Average Word Embedding classifier.
Args:
train_data: Training data.
validation_data: Validation data.
options: Options for creating and training the text classifier.
label_names: Label names used in the data.
Returns:
An Average Word Embedding classifier.
"""
average_word_embedding_classifier = _AverageWordEmbeddingClassifier(
model_spec=options.supported_model.value(),
model_options=options.model_options,
hparams=options.hparams,
label_names=train_data.label_names)
average_word_embedding_classifier._create_and_train_model(
train_data, validation_data)
return average_word_embedding_classifier
def _create_and_train_model(self, train_data: text_ds.Dataset,
validation_data: text_ds.Dataset):
"""Creates the Average Word Embedding classifier keras model and trains it.
Args:
train_data: Training data.
validation_data: Validation data.
"""
(processed_train_data, processed_validation_data) = (
self._load_and_run_preprocessor(train_data, validation_data))
self._create_model()
self._optimizer = "rmsprop"
self._train_model(processed_train_data, processed_validation_data)
def _load_and_run_preprocessor(
self, train_data: text_ds.Dataset, validation_data: text_ds.Dataset
) -> Tuple[text_ds.Dataset, text_ds.Dataset]:
"""Runs an AverageWordEmbeddingClassifierPreprocessor on the data.
Args:
train_data: Training data.
validation_data: Validation data.
Returns:
Preprocessed training data and preprocessed validation data.
"""
train_texts = [text.numpy()[0] for text, _ in train_data.gen_tf_dataset()]
validation_texts = [
text.numpy()[0] for text, _ in validation_data.gen_tf_dataset()
]
self._text_preprocessor = (
preprocessor.AverageWordEmbeddingClassifierPreprocessor(
seq_len=self._model_options.seq_len,
do_lower_case=self._model_options.do_lower_case,
texts=train_texts + validation_texts,
vocab_size=self._model_options.vocab_size))
return self._text_preprocessor.preprocess(
train_data), self._text_preprocessor.preprocess(validation_data)
def _create_model(self):
"""Creates an Average Word Embedding model."""
self._model = tf.keras.Sequential([
tf.keras.layers.InputLayer(
input_shape=[self._model_options.seq_len], dtype=tf.int32),
tf.keras.layers.Embedding(
len(self._text_preprocessor.get_vocab()),
self._model_options.wordvec_dim,
input_length=self._model_options.seq_len),
tf.keras.layers.GlobalAveragePooling1D(),
tf.keras.layers.Dense(
self._model_options.wordvec_dim, activation=tf.nn.relu),
tf.keras.layers.Dropout(self._model_options.dropout_rate),
tf.keras.layers.Dense(self._num_classes, activation="softmax")
])
def _save_vocab(self, vocab_filepath: str):
with tf.io.gfile.GFile(vocab_filepath, "w") as f:
for token, index in self._text_preprocessor.get_vocab().items():
f.write(f"{token} {index}\n")
def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str):
return text_classifier_writer.MetadataWriter.create_for_regex_model(
model_buffer=tflite_model,
regex_tokenizer=metadata_writer.RegexTokenizer(
# TODO: Align with MediaPipe's RegexTokenizer.
delim_regex_pattern=self._DELIM_REGEX_PATTERN,
vocab_file_path=vocab_filepath),
labels=metadata_writer.Labels().add(list(self._label_names)))
class _BertClassifier(TextClassifier):
"""APIs to help create and train a BERT-based text classifier."""
_INITIALIZER_RANGE = 0.02
def __init__(self, model_spec: ms.BertClassifierSpec,
model_options: mo.BertClassifierOptions, hparams: hp.BaseHParams,
label_names: Sequence[str]):
super().__init__(model_spec, hparams, label_names)
self._model_options = model_options
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy(
"test_accuracy", dtype=tf.float32)
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
@classmethod
def create_bert_classifier(
cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset,
options: text_classifier_options.TextClassifierOptions,
label_names: Sequence[str]) -> "_BertClassifier":
"""Creates, trains, and returns a BERT-based classifier.
Args:
train_data: Training data.
validation_data: Validation data.
options: Options for creating and training the text classifier.
label_names: Label names used in the data.
Returns:
A BERT-based classifier.
"""
bert_classifier = _BertClassifier(
model_spec=options.supported_model.value(),
model_options=options.model_options,
hparams=options.hparams,
label_names=train_data.label_names)
bert_classifier._create_and_train_model(train_data, validation_data)
return bert_classifier
def _create_and_train_model(self, train_data: text_ds.Dataset,
validation_data: text_ds.Dataset):
"""Creates the BERT-based classifier keras model and trains it.
Args:
train_data: Training data.
validation_data: Validation data.
"""
(processed_train_data, processed_validation_data) = (
self._load_and_run_preprocessor(train_data, validation_data))
self._create_model()
self._create_optimizer(processed_train_data)
self._train_model(processed_train_data, processed_validation_data)
def _load_and_run_preprocessor(
self, train_data: text_ds.Dataset, validation_data: text_ds.Dataset
) -> Tuple[text_ds.Dataset, text_ds.Dataset]:
"""Loads a BertClassifierPreprocessor and runs it on the data.
Args:
train_data: Training data.
validation_data: Validation data.
Returns:
Preprocessed training data and preprocessed validation data.
"""
self._text_preprocessor = preprocessor.BertClassifierPreprocessor(
seq_len=self._model_options.seq_len,
do_lower_case=self._model_spec.do_lower_case,
uri=self._model_spec.uri)
return (self._text_preprocessor.preprocess(train_data),
self._text_preprocessor.preprocess(validation_data))
def _create_model(self):
"""Creates a BERT-based classifier model.
The model architecture consists of stacking a dense classification layer and
dropout layer on top of the BERT encoder outputs.
"""
encoder_inputs = dict(
input_word_ids=tf.keras.layers.Input(
shape=(self._model_options.seq_len,), dtype=tf.int32),
input_mask=tf.keras.layers.Input(
shape=(self._model_options.seq_len,), dtype=tf.int32),
input_type_ids=tf.keras.layers.Input(
shape=(self._model_options.seq_len,), dtype=tf.int32),
)
encoder = hub.KerasLayer(
self._model_spec.uri, trainable=self._model_options.do_fine_tuning)
encoder_outputs = encoder(encoder_inputs)
pooled_output = encoder_outputs["pooled_output"]
output = tf.keras.layers.Dropout(rate=self._model_options.dropout_rate)(
pooled_output)
initializer = tf.keras.initializers.TruncatedNormal(
stddev=self._INITIALIZER_RANGE)
output = tf.keras.layers.Dense(
self._num_classes,
kernel_initializer=initializer,
name="output",
activation="softmax",
dtype=tf.float32)(
output)
self._model = tf.keras.Model(inputs=encoder_inputs, outputs=output)
def _create_optimizer(self, train_data: text_ds.Dataset):
"""Loads an optimizer with a learning rate schedule.
The decay steps in the learning rate schedule depend on the
`steps_per_epoch` which may depend on the size of the training data.
Args:
train_data: Training data.
"""
self._hparams.steps_per_epoch = model_util.get_steps_per_epoch(
steps_per_epoch=self._hparams.steps_per_epoch,
batch_size=self._hparams.batch_size,
train_data=train_data)
total_steps = self._hparams.steps_per_epoch * self._hparams.epochs
warmup_steps = int(total_steps * 0.1)
initial_lr = self._hparams.learning_rate
self._optimizer = optimization.create_optimizer(initial_lr, total_steps,
warmup_steps)
def _save_vocab(self, vocab_filepath: str):
tf.io.gfile.copy(
self._text_preprocessor.get_vocab_file(),
vocab_filepath,
overwrite=True)
def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str):
return text_classifier_writer.MetadataWriter.create_for_bert_model(
model_buffer=tflite_model,
tokenizer=metadata_writer.BertTokenizer(vocab_filepath),
labels=metadata_writer.Labels().add(list(self._label_names)),
ids_name=self._model_spec.tflite_input_name["ids"],
mask_name=self._model_spec.tflite_input_name["mask"],
segment_name=self._model_spec.tflite_input_name["segment_ids"])

View File

@ -0,0 +1,108 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Demo for making a text classifier model by MediaPipe Model Maker."""
import os
import tempfile
# Dependency imports
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
from mediapipe.model_maker.python.text.text_classifier import text_classifier
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
FLAGS = flags.FLAGS
def define_flags():
flags.DEFINE_string('export_dir', None,
'The directory to save exported files.')
flags.DEFINE_enum('supported_model', 'average_word_embedding',
['average_word_embedding', 'bert'],
'The text classifier to run.')
flags.mark_flag_as_required('export_dir')
def download_demo_data():
"""Downloads demo data, and returns directory path."""
data_path = tf.keras.utils.get_file(
fname='SST-2.zip',
origin='https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',
extract=True)
return os.path.join(os.path.dirname(data_path), 'SST-2') # folder name
def run(data_dir,
export_dir=tempfile.mkdtemp(),
supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
"""Runs demo."""
# Gets training data and validation data.
csv_params = text_ds.CSVParameters(
text_column='sentence', label_column='label', delimiter='\t')
train_data = text_ds.Dataset.from_csv(
filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
csv_params=csv_params)
validation_data = text_ds.Dataset.from_csv(
filename=os.path.join(os.path.join(data_dir, 'dev.tsv')),
csv_params=csv_params)
quantization_config = None
if supported_model == ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER:
hparams = hp.BaseHParams(
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir)
# Warning: This takes extremely long to run on CPU
elif supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
quantization_config = quantization.QuantizationConfig.for_dynamic()
hparams = hp.BaseHParams(
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir)
# Fine-tunes the model.
options = text_classifier_options.TextClassifierOptions(
supported_model=supported_model, hparams=hparams)
model = text_classifier.TextClassifier.create(train_data, validation_data,
options)
# Gets evaluation results.
_, acc = model.evaluate(validation_data)
print('Eval accuracy: %f' % acc)
model.export_model(quantization_config=quantization_config)
model.export_labels(export_dir=options.hparams.export_dir)
def main(_):
logging.set_verbosity(logging.INFO)
data_dir = download_demo_data()
export_dir = os.path.expanduser(FLAGS.export_dir)
if FLAGS.supported_model == 'average_word_embedding':
supported_model = ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
elif FLAGS.supported_model == 'bert':
supported_model = ms.SupportedModels.MOBILEBERT_CLASSIFIER
run(data_dir, export_dir, supported_model)
if __name__ == '__main__':
define_flags()
app.run(main)

View File

@ -0,0 +1,38 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""User-facing customization options to create and train a text classifier."""
import dataclasses
from typing import Optional
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
@dataclasses.dataclass
class TextClassifierOptions:
"""User-facing options for creating the text classifier.
Attributes:
supported_model: A preconfigured model spec.
hparams: Training hyperparameters the user can set to override the ones in
`supported_model`.
model_options: Model options the user can set to override the ones in
`supported_model`. The model options type should be consistent with the
architecture of the `supported_model`.
"""
supported_model: ms.SupportedModels
hparams: Optional[hp.BaseHParams] = None
model_options: Optional[mo.TextClassifierModelOptions] = None

View File

@ -0,0 +1,138 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
import filecmp
import os
import tensorflow as tf
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import dataset
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
from mediapipe.model_maker.python.text.text_classifier import text_classifier
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
from mediapipe.tasks.python.test import test_utils
class TextClassifierTest(tf.test.TestCase):
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
test_utils.get_test_data_path('average_word_embedding_metadata.json'))
def _get_data(self):
labels_and_text = (('pos', 'super good'), (('neg', 'really bad')))
csv_file = os.path.join(self.get_temp_dir(), 'data.csv')
if os.path.exists(csv_file):
return csv_file
fieldnames = ['text', 'label']
with open(csv_file, 'w') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for label, text in labels_and_text:
writer.writerow({'text': text, 'label': label})
csv_params = dataset.CSVParameters(text_column='text', label_column='label')
all_data = dataset.Dataset.from_csv(
filename=csv_file, csv_params=csv_params)
return all_data.split(0.5)
def test_create_and_train_average_word_embedding_model(self):
train_data, validation_data = self._get_data()
options = text_classifier_options.TextClassifierOptions(
supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER,
hparams=hp.BaseHParams(epochs=1, batch_size=1, learning_rate=0))
average_word_embedding_classifier = text_classifier.TextClassifier.create(
train_data, validation_data, options)
_, accuracy = average_word_embedding_classifier.evaluate(validation_data)
self.assertGreaterEqual(accuracy, 0.0)
# Test export_model
average_word_embedding_classifier.export_model()
output_metadata_file = os.path.join(options.hparams.export_dir,
'metadata.json')
output_tflite_file = os.path.join(options.hparams.export_dir,
'model.tflite')
self.assertTrue(os.path.exists(output_tflite_file))
self.assertGreater(os.path.getsize(output_tflite_file), 0)
self.assertTrue(os.path.exists(output_metadata_file))
self.assertGreater(os.path.getsize(output_metadata_file), 0)
self.assertTrue(
filecmp.cmp(output_metadata_file,
self._AVERAGE_WORD_EMBEDDING_JSON_FILE))
def test_create_and_train_bert(self):
train_data, validation_data = self._get_data()
options = text_classifier_options.TextClassifierOptions(
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER,
model_options=mo.BertClassifierOptions(do_fine_tuning=False, seq_len=2),
hparams=hp.BaseHParams(
epochs=1,
batch_size=1,
learning_rate=3e-5,
distribution_strategy='off'))
bert_classifier = text_classifier.TextClassifier.create(
train_data, validation_data, options)
_, accuracy = bert_classifier.evaluate(validation_data)
self.assertGreaterEqual(accuracy, 0.0)
# TODO: Add a unit test that does not run OOM.
def test_label_mismatch(self):
options = (
text_classifier_options.TextClassifierOptions(
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER))
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
train_data = dataset.Dataset(train_tf_dataset, 1, ['foo'])
validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
validation_data = dataset.Dataset(validation_tf_dataset, 1, ['bar'])
with self.assertRaisesRegex(
ValueError,
'Training data label names .* not equal to validation data label names'
):
text_classifier.TextClassifier.create(train_data, validation_data,
options)
def test_options_mismatch(self):
train_data, validation_data = self._get_data()
avg_options = (
text_classifier_options.TextClassifierOptions(
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER,
model_options=mo.AverageWordEmbeddingClassifierOptions()))
with self.assertRaisesRegex(
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
' SupportedModels.MOBILEBERT_CLASSIFIER'):
text_classifier.TextClassifier.create(train_data, validation_data,
avg_options)
bert_options = (
text_classifier_options.TextClassifierOptions(
supported_model=(
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
model_options=mo.BertClassifierOptions()))
with self.assertRaisesRegex(
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got'
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):
text_classifier.TextClassifier.create(train_data, validation_data,
bert_options)
if __name__ == '__main__':
# Load compressed models from tensorflow_hub
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
tf.test.main()

View File

@ -0,0 +1,165 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict test compatibility macro.
# Placeholder for internal Python strict library and test compatibility macro.
licenses(["notice"])
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
# TODO: Remove the unncessary test data once the demo data are moved to an open-sourced
# directory.
filegroup(
name = "test_data",
srcs = glob([
"test_data/**",
]),
)
py_library(
name = "constants",
srcs = ["constants.py"],
)
# TODO: Change to py_library after migrating the MediaPipe hand solution
# library to MediaPipe hand task library.
py_library(
name = "dataset",
srcs = ["dataset.py"],
deps = [
":constants",
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/core/data:data_util",
"//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/python/solutions:hands",
],
)
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
data = [
":test_data",
"//mediapipe/model_maker/models/gesture_recognizer:models",
],
deps = [
":dataset",
"//mediapipe/python/solutions:hands",
"//mediapipe/tasks/python/test:test_utils",
],
)
py_library(
name = "hyperparameters",
srcs = ["hyperparameters.py"],
deps = [
"//mediapipe/model_maker/python/core:hyperparameters",
],
)
py_library(
name = "model_options",
srcs = ["model_options.py"],
)
py_library(
name = "gesture_recognizer_options",
srcs = ["gesture_recognizer_options.py"],
deps = [
":hyperparameters",
":model_options",
],
)
py_library(
name = "gesture_recognizer",
srcs = ["gesture_recognizer.py"],
data = ["//mediapipe/model_maker/models/gesture_recognizer:models"],
deps = [
":constants",
":gesture_recognizer_options",
":hyperparameters",
":metadata_writer",
":model_options",
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:loss_functions",
"//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
],
)
py_library(
name = "gesture_recognizer_import",
srcs = ["__init__.py"],
deps = [
":dataset",
":gesture_recognizer",
":gesture_recognizer_options",
":hyperparameters",
":model_options",
],
)
py_library(
name = "metadata_writer",
srcs = ["metadata_writer.py"],
deps = [
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
"//mediapipe/tasks/python/metadata/metadata_writers:writer_utils",
],
)
py_test(
name = "gesture_recognizer_test",
size = "large",
srcs = ["gesture_recognizer_test.py"],
data = [
":test_data",
"//mediapipe/model_maker/models/gesture_recognizer:models",
],
shard_count = 2,
deps = [
":gesture_recognizer_import",
"//mediapipe/model_maker/python/core/utils:test_util",
"//mediapipe/tasks/python/test:test_utils",
],
)
py_test(
name = "metadata_writer_test",
srcs = ["metadata_writer_test.py"],
data = [
":test_data",
],
deps = [
":metadata_writer",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
"//mediapipe/tasks/python/test:test_utils",
],
)
py_binary(
name = "gesture_recognizer_demo",
srcs = ["gesture_recognizer_demo.py"],
data = [
":test_data",
"//mediapipe/model_maker/models/gesture_recognizer:models",
],
python_version = "PY3",
deps = [":gesture_recognizer_import"],
)

View File

@ -0,0 +1,27 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe Model Maker Python Public API For Gesture Recognizer."""
from mediapipe.model_maker.python.vision.gesture_recognizer import dataset
from mediapipe.model_maker.python.vision.gesture_recognizer import gesture_recognizer
from mediapipe.model_maker.python.vision.gesture_recognizer import gesture_recognizer_options
from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters
from mediapipe.model_maker.python.vision.gesture_recognizer import model_options
GestureRecognizer = gesture_recognizer.GestureRecognizer
ModelOptions = model_options.GestureRecognizerModelOptions
HParams = hyperparameters.HParams
Dataset = dataset.Dataset
HandDataPreprocessingParams = dataset.HandDataPreprocessingParams
GestureRecognizerOptions = gesture_recognizer_options.GestureRecognizerOptions

View File

@ -0,0 +1,20 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gesture recognition constants."""
GESTURE_EMBEDDER_KERAS_MODEL_PATH = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder'
GESTURE_EMBEDDER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/gesture_embedder.tflite'
HAND_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/palm_detection_full.tflite'
HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/hand_landmark_full.tflite'
CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE = 'mediapipe/model_maker/models/gesture_recognizer/canned_gesture_classifier.tflite'

View File

@ -0,0 +1,238 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gesture recognition dataset library."""
import dataclasses
import os
import random
from typing import List, NamedTuple, Optional
import cv2
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.core.data import data_util
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.vision.gesture_recognizer import constants
from mediapipe.python.solutions import hands as mp_hands
@dataclasses.dataclass
class HandDataPreprocessingParams:
"""A dataclass wraps the hand data preprocessing hyperparameters.
Attributes:
shuffle: A boolean controlling if shuffle the dataset. Default to true.
min_detection_confidence: confidence threshold for hand detection.
"""
shuffle: bool = True
min_detection_confidence: float = 0.7
@dataclasses.dataclass
class HandData:
"""A dataclass represents hand data for training gesture recognizer model.
See https://google.github.io/mediapipe/solutions/hands#mediapipe-hands for
more details of the hand gesture data API.
Attributes:
hand: normalized hand landmarks of shape 21x3 from the screen based
hand-landmark model.
world_hand: hand landmarks of shape 21x3 in world coordinates.
handedness: Collection of handedness confidence of the detected hands (i.e.
is it a left or right hand).
"""
hand: List[List[float]]
world_hand: List[List[float]]
handedness: List[float]
def _validate_data_sample(data: NamedTuple) -> bool:
"""Validates the input hand data sample.
Args:
data: input hand data sample.
Returns:
False if the input data namedtuple does not contain the fields including
'multi_hand_landmarks' or 'multi_hand_world_landmarks' or 'multi_handedness'
or any of these attributes' values are none. Otherwise, True.
"""
if (not hasattr(data, 'multi_hand_landmarks') or
data.multi_hand_landmarks is None):
return False
if (not hasattr(data, 'multi_hand_world_landmarks') or
data.multi_hand_world_landmarks is None):
return False
if not hasattr(data, 'multi_handedness') or data.multi_handedness is None:
return False
return True
def _get_hand_data(all_image_paths: List[str],
min_detection_confidence: float) -> Optional[HandData]:
"""Computes hand data (landmarks and handedness) in the input image.
Args:
all_image_paths: all input image paths.
min_detection_confidence: hand detection confidence threshold
Returns:
A HandData object. Returns None if no hand is detected.
"""
hand_data_result = []
with mp_hands.Hands(
static_image_mode=True,
max_num_hands=1,
min_detection_confidence=min_detection_confidence) as hands:
for path in all_image_paths:
tf.compat.v1.logging.info('Loading image %s', path)
image = data_util.load_image(path)
# Flip image around y-axis for correct handedness output
image = cv2.flip(image, 1)
data = hands.process(image)
if not _validate_data_sample(data):
hand_data_result.append(None)
continue
hand_landmarks = [[
hand_landmark.x, hand_landmark.y, hand_landmark.z
] for hand_landmark in data.multi_hand_landmarks[0].landmark]
hand_world_landmarks = [[
hand_landmark.x, hand_landmark.y, hand_landmark.z
] for hand_landmark in data.multi_hand_world_landmarks[0].landmark]
handedness_scores = [
handedness.score
for handedness in data.multi_handedness[0].classification
]
hand_data_result.append(
HandData(
hand=hand_landmarks,
world_hand=hand_world_landmarks,
handedness=handedness_scores))
return hand_data_result
class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for hand gesture recognizer."""
@classmethod
def from_folder(
cls,
dirname: str,
hparams: Optional[HandDataPreprocessingParams] = None
) -> classification_dataset.ClassificationDataset:
"""Loads images and labels from the given directory.
Directory contents are expected to be in the format:
<root_dir>/<gesture_name>/*.jpg". One of the `gesture_name` must be `none`
(case insensitive). The `none` sub-directory is expected to contain images
of hands that don't belong to other gesture classes in <root_dir>. Assumes
the image data of the same label are in the same subdirectory.
Args:
dirname: Name of the directory containing the data files.
hparams: Optional hyperparameters for processing input hand gesture
images.
Returns:
Dataset containing landmarks, labels, and other related info.
Raises:
ValueError: if the input data directory is empty or the label set does not
contain label 'none' (case insensitive).
"""
data_root = os.path.abspath(dirname)
# Assumes the image data of the same label are in the same subdirectory,
# gets image path and label names.
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
if not all_image_paths:
raise ValueError('Image dataset directory is empty.')
if not hparams:
hparams = HandDataPreprocessingParams()
if hparams.shuffle:
# Random shuffle data.
random.shuffle(all_image_paths)
label_names = sorted(
name for name in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, name)))
if 'none' not in [v.lower() for v in label_names]:
raise ValueError('Label set does not contain label "None".')
# Move label 'none' to the front of label list.
none_idx = [v.lower() for v in label_names].index('none')
none_value = label_names.pop(none_idx)
label_names.insert(0, none_value)
index_by_label = dict(
(name, index) for index, name in enumerate(label_names))
all_gesture_indices = [
index_by_label[os.path.basename(os.path.dirname(path))]
for path in all_image_paths
]
# Compute hand data (including local hand landmark, world hand landmark, and
# handedness) for all the input images.
hand_data = _get_hand_data(
all_image_paths=all_image_paths,
min_detection_confidence=hparams.min_detection_confidence)
# Get a list of the valid hand landmark sample in the hand data list.
valid_indices = [
i for i in range(len(hand_data)) if hand_data[i] is not None
]
# Remove 'None' element from the hand data and label list.
valid_hand_data = [dataclasses.asdict(hand_data[i]) for i in valid_indices]
if not valid_hand_data:
raise ValueError('No valid hand is detected.')
valid_label = [all_gesture_indices[i] for i in valid_indices]
# Convert list of dictionaries to dictionary of lists.
hand_data_dict = {
k: [lm[k] for lm in valid_hand_data] for k in valid_hand_data[0]
}
hand_ds = tf.data.Dataset.from_tensor_slices(hand_data_dict)
embedder_model = model_util.load_keras_model(
constants.GESTURE_EMBEDDER_KERAS_MODEL_PATH)
hand_ds = hand_ds.batch(batch_size=1)
hand_embedding_ds = hand_ds.map(
map_func=lambda feature: embedder_model(dict(feature)),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
hand_embedding_ds = hand_embedding_ds.unbatch()
# Create label dataset
label_ds = tf.data.Dataset.from_tensor_slices(
tf.cast(valid_label, tf.int64))
label_one_hot_ds = label_ds.map(
map_func=lambda index: tf.one_hot(index, len(label_names)),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
# Create a dataset with (hand_embedding, one_hot_label) pairs
hand_embedding_label_ds = tf.data.Dataset.zip(
(hand_embedding_ds, label_one_hot_ds))
tf.compat.v1.logging.info(
'Load valid hands with size: {}, num_label: {}, labels: {}.'.format(
len(valid_hand_data), len(label_names), ','.join(label_names)))
return Dataset(
dataset=hand_embedding_label_ds,
size=len(valid_hand_data),
label_names=label_names)

View File

@ -0,0 +1,161 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.s
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
import shutil
from typing import NamedTuple
import unittest
from absl import flags
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.vision.gesture_recognizer import dataset
from mediapipe.python.solutions import hands as mp_hands
from mediapipe.tasks.python.test import test_utils
FLAGS = flags.FLAGS
_TEST_DATA_DIRNAME = 'raw_data'
class DatasetTest(tf.test.TestCase, parameterized.TestCase):
def test_split(self):
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
data = dataset.Dataset.from_folder(
dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams())
train_data, test_data = data.split(0.5)
self.assertLen(train_data, 17)
for _, elem in enumerate(train_data.gen_tf_dataset(is_training=True)):
self.assertEqual(elem[0].shape, (1, 128))
self.assertEqual(elem[1].shape, ([1, 4]))
self.assertEqual(train_data.num_classes, 4)
self.assertEqual(train_data.label_names, ['none', 'call', 'four', 'rock'])
self.assertLen(test_data, 18)
for _, elem in enumerate(test_data.gen_tf_dataset(is_training=True)):
self.assertEqual(elem[0].shape, (1, 128))
self.assertEqual(elem[1].shape, ([1, 4]))
self.assertEqual(test_data.num_classes, 4)
self.assertEqual(test_data.label_names, ['none', 'call', 'four', 'rock'])
def test_from_folder(self):
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
data = dataset.Dataset.from_folder(
dirname=input_data_dir, hparams=dataset.HandDataPreprocessingParams())
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
self.assertEqual(elem[0].shape, (1, 128))
self.assertEqual(elem[1].shape, ([1, 4]))
self.assertLen(data, 35)
self.assertEqual(data.num_classes, 4)
self.assertEqual(data.label_names, ['none', 'call', 'four', 'rock'])
def test_create_dataset_from_empty_folder_raise_value_error(self):
with self.assertRaisesRegex(ValueError, 'Image dataset directory is empty'):
dataset.Dataset.from_folder(
dirname=self.get_temp_dir(),
hparams=dataset.HandDataPreprocessingParams())
def test_create_dataset_from_folder_without_none_raise_value_error(self):
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
tmp_dir = self.create_tempdir()
# Copy input dataset to a temporary directory and skip 'None' directory.
for name in os.listdir(input_data_dir):
if name == 'none':
continue
src_dir = os.path.join(input_data_dir, name)
dst_dir = os.path.join(tmp_dir, name)
shutil.copytree(src_dir, dst_dir)
with self.assertRaisesRegex(ValueError,
'Label set does not contain label "None"'):
dataset.Dataset.from_folder(
dirname=tmp_dir, hparams=dataset.HandDataPreprocessingParams())
def test_create_dataset_from_folder_with_capital_letter_in_folder_name(self):
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
tmp_dir = self.create_tempdir()
# Copy input dataset to a temporary directory and change the base folder
# name to upper case letter, e.g. 'none' -> 'NONE'
for name in os.listdir(input_data_dir):
src_dir = os.path.join(input_data_dir, name)
dst_dir = os.path.join(tmp_dir, name.upper())
shutil.copytree(src_dir, dst_dir)
upper_base_folder_name = list(os.listdir(tmp_dir))
self.assertCountEqual(upper_base_folder_name,
['CALL', 'FOUR', 'NONE', 'ROCK'])
data = dataset.Dataset.from_folder(
dirname=tmp_dir, hparams=dataset.HandDataPreprocessingParams())
for _, elem in enumerate(data.gen_tf_dataset(is_training=True)):
self.assertEqual(elem[0].shape, (1, 128))
self.assertEqual(elem[1].shape, ([1, 4]))
self.assertLen(data, 35)
self.assertEqual(data.num_classes, 4)
self.assertEqual(data.label_names, ['NONE', 'CALL', 'FOUR', 'ROCK'])
@parameterized.named_parameters(
dict(
testcase_name='invalid_field_name_multi_hand_landmark',
hand=collections.namedtuple('Hand', [
'multi_hand_landmark', 'multi_hand_world_landmarks',
'multi_handedness'
])(1, 2, 3)),
dict(
testcase_name='invalid_field_name_multi_hand_world_landmarks',
hand=collections.namedtuple('Hand', [
'multi_hand_landmarks', 'multi_hand_world_landmark',
'multi_handedness'
])(1, 2, 3)),
dict(
testcase_name='invalid_field_name_multi_handed',
hand=collections.namedtuple('Hand', [
'multi_hand_landmarks', 'multi_hand_world_landmarks',
'multi_handed'
])(1, 2, 3)),
dict(
testcase_name='multi_hand_landmarks_is_none',
hand=collections.namedtuple('Hand', [
'multi_hand_landmarks', 'multi_hand_world_landmarks',
'multi_handedness'
])(None, 2, 3)),
dict(
testcase_name='multi_hand_world_landmarks_is_none',
hand=collections.namedtuple('Hand', [
'multi_hand_landmarks', 'multi_hand_world_landmarks',
'multi_handedness'
])(1, None, 3)),
dict(
testcase_name='multi_handedness_is_none',
hand=collections.namedtuple('Hand', [
'multi_hand_landmarks', 'multi_hand_world_landmarks',
'multi_handedness'
])(1, 2, None)),
)
def test_create_dataset_from_invalid_hand_data(self, hand: NamedTuple):
with unittest.mock.patch.object(
mp_hands.Hands, 'process', return_value=hand):
input_data_dir = test_utils.get_test_data_path(_TEST_DATA_DIRNAME)
with self.assertRaisesRegex(ValueError, 'No valid hand is detected'):
dataset.Dataset.from_folder(
dirname=input_data_dir,
hparams=dataset.HandDataPreprocessingParams())
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,239 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""APIs to train gesture recognizer model."""
import os
from typing import List
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
from mediapipe.model_maker.python.core.tasks import classifier
from mediapipe.model_maker.python.core.utils import loss_functions
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.vision.gesture_recognizer import constants
from mediapipe.model_maker.python.vision.gesture_recognizer import gesture_recognizer_options
from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters as hp
from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer
from mediapipe.model_maker.python.vision.gesture_recognizer import model_options as model_opt
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer as base_metadata_writer
_EMBEDDING_SIZE = 128
class GestureRecognizer(classifier.Classifier):
"""GestureRecognizer for building hand gesture recognizer model.
Attributes:
embedding_size: Size of the input gesture embedding vector.
"""
def __init__(self, label_names: List[str],
model_options: model_opt.GestureRecognizerModelOptions,
hparams: hp.HParams):
"""Initializes GestureRecognizer class.
Args:
label_names: A list of label names for the classes.
model_options: options to create gesture recognizer model.
hparams: The hyperparameters for training hand gesture recognizer model.
"""
super().__init__(
model_spec=None, label_names=label_names, shuffle=hparams.shuffle)
self._model_options = model_options
self._hparams = hparams
self._history = None
self.embedding_size = _EMBEDDING_SIZE
@classmethod
def create(
cls,
train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset,
options: gesture_recognizer_options.GestureRecognizerOptions,
) -> 'GestureRecognizer':
"""Creates and trains a hand gesture recognizer with input datasets.
If a checkpoint file exists in the {options.hparams.export_dir}/checkpoint/
directory, the training process will load the weight from the checkpoint
file for continual training.
Args:
train_data: Training data.
validation_data: Validation data. If None, skips validation process.
options: options for creating and training gesture recognizer model.
Returns:
An instance of GestureRecognizer.
"""
if options.model_options is None:
options.model_options = model_opt.GestureRecognizerModelOptions()
if options.hparams is None:
options.hparams = hp.HParams()
gesture_recognizer = cls(
label_names=train_data.label_names,
model_options=options.model_options,
hparams=options.hparams)
gesture_recognizer._create_model()
train_dataset = train_data.gen_tf_dataset(
batch_size=options.hparams.batch_size,
is_training=True,
shuffle=options.hparams.shuffle)
options.hparams.steps_per_epoch = model_util.get_steps_per_epoch(
steps_per_epoch=options.hparams.steps_per_epoch,
batch_size=options.hparams.batch_size,
train_data=train_data)
train_dataset = train_dataset.take(count=options.hparams.steps_per_epoch)
validation_dataset = validation_data.gen_tf_dataset(
batch_size=options.hparams.batch_size, is_training=False)
tf.compat.v1.logging.info('Training the gesture recognizer model...')
gesture_recognizer._train(
train_data=train_dataset, validation_data=validation_dataset)
return gesture_recognizer
def _train(self, train_data: tf.data.Dataset,
validation_data: tf.data.Dataset):
"""Trains the model with input train_data.
The training results are recorded by a self.History object returned by
tf.keras.Model.fit().
Args:
train_data: Training data.
validation_data: Validation data.
"""
hparams = self._hparams
scheduler = lambda epoch: hparams.learning_rate * (hparams.lr_decay**epoch)
scheduler_callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
job_dir = hparams.export_dir
checkpoint_path = os.path.join(job_dir, 'epoch_models')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
save_weights_only=True)
best_model_path = os.path.join(job_dir, 'best_model_weights')
best_model_callback = tf.keras.callbacks.ModelCheckpoint(
best_model_path,
monitor='val_loss',
mode='min',
save_best_only=True,
save_weights_only=True)
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=os.path.join(job_dir, 'logs'))
self._model.compile(
optimizer='adam',
loss=loss_functions.FocalLoss(gamma=self._hparams.gamma),
metrics=['categorical_accuracy'])
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
if latest_checkpoint:
print(f'Resuming from {latest_checkpoint}')
self._model.load_weights(latest_checkpoint)
self._history = self._model.fit(
x=train_data,
epochs=hparams.epochs,
validation_data=validation_data,
validation_freq=1,
callbacks=[
checkpoint_callback, best_model_callback, scheduler_callback,
tensorboard_callback
],
)
def _create_model(self):
"""Creates the hand gesture recognizer model.
The gesture embedding model is pretrained and loaded from a tf.saved_model.
"""
inputs = tf.keras.Input(
shape=[self.embedding_size],
batch_size=None,
dtype=tf.float32,
name='hand_embedding')
x = tf.keras.layers.BatchNormalization()(inputs)
x = tf.keras.layers.ReLU()(x)
dropout_rate = self._model_options.dropout_rate
x = tf.keras.layers.Dropout(rate=dropout_rate, name='dropout')(x)
outputs = tf.keras.layers.Dense(
self._num_classes,
activation='softmax',
name='custom_gesture_recognizer')(
x)
self._model = tf.keras.Model(inputs=inputs, outputs=outputs)
print(self._model.summary())
def export_model(self, model_name: str = 'gesture_recognizer.task'):
"""Converts the model to TFLite and exports as a model bundle file.
Saves a model bundle file and metadata json file to hparams.export_dir. The
resulting model bundle file will contain necessary models for hand
detection, canned gesture classification, and customized gesture
classification. Only the model bundle file is needed for the downstream
gesture recognition task. The metadata.json file is saved only to
interpret the contents of the model bundle file.
The customized gesture model is in float without quantization. The model is
lightweight and there is no need to balance performance and efficiency by
quantization. The default score_thresholding is set to 0.5 as it can be
adjusted during inference.
Args:
model_name: File name to save model bundle file. The full export path is
{export_dir}/{model_name}.
"""
# TODO: Convert keras embedder model instead of using tflite
gesture_embedding_model_buffer = model_util.load_tflite_model_buffer(
constants.GESTURE_EMBEDDER_TFLITE_MODEL_FILE)
hand_detector_model_buffer = model_util.load_tflite_model_buffer(
constants.HAND_DETECTOR_TFLITE_MODEL_FILE)
hand_landmarks_detector_model_buffer = model_util.load_tflite_model_buffer(
constants.HAND_LANDMARKS_DETECTOR_TFLITE_MODEL_FILE)
canned_gesture_model_buffer = model_util.load_tflite_model_buffer(
constants.CANNED_GESTURE_CLASSIFIER_TFLITE_MODEL_FILE)
if not tf.io.gfile.exists(self._hparams.export_dir):
tf.io.gfile.makedirs(self._hparams.export_dir)
model_bundle_file = os.path.join(self._hparams.export_dir, model_name)
metadata_file = os.path.join(self._hparams.export_dir, 'metadata.json')
gesture_classifier_options = metadata_writer.GestureClassifierOptions(
model_buffer=model_util.convert_to_tflite(self._model),
labels=base_metadata_writer.Labels().add(list(self._label_names)),
score_thresholding=base_metadata_writer.ScoreThresholding(
global_score_threshold=0.5))
writer = metadata_writer.MetadataWriter.create(
hand_detector_model_buffer, hand_landmarks_detector_model_buffer,
gesture_embedding_model_buffer, canned_gesture_model_buffer,
gesture_classifier_options)
model_bundle_content, metadata_json = writer.populate()
with open(model_bundle_file, 'wb') as f:
f.write(model_bundle_content)
with open(metadata_file, 'w') as f:
f.write(metadata_json)

View File

@ -0,0 +1,78 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Demo for making an gesture recognizer model by Mediapipe Model Maker."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# Dependency imports
from absl import app
from absl import flags
from absl import logging
from mediapipe.model_maker.python.vision import gesture_recognizer
FLAGS = flags.FLAGS
# TODO: Move hand gesture recognizer demo dataset to an
# open-sourced directory.
TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data'
def define_flags():
flags.DEFINE_string('export_dir', None,
'The directory to save exported files.')
flags.DEFINE_string('input_data_dir', None,
'The directory with input training data.')
flags.mark_flag_as_required('export_dir')
def run(data_dir: str, export_dir: str):
"""Runs demo."""
data = gesture_recognizer.Dataset.from_folder(dirname=data_dir)
train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)
model = gesture_recognizer.GestureRecognizer.create(
train_data=train_data,
validation_data=validation_data,
options=gesture_recognizer.GestureRecognizerOptions(
hparams=gesture_recognizer.HParams(export_dir=export_dir)))
metric = model.evaluate(test_data, batch_size=2)
print('Evaluation metric')
print(metric)
model.export_model()
def main(_):
logging.set_verbosity(logging.INFO)
if FLAGS.input_data_dir is None:
data_dir = os.path.join(FLAGS.test_srcdir, TEST_DATA_DIR)
else:
data_dir = FLAGS.input_data_dir
export_dir = os.path.expanduser(FLAGS.export_dir)
run(data_dir=data_dir, export_dir=export_dir)
if __name__ == '__main__':
define_flags()
app.run(main)

View File

@ -0,0 +1,32 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Options for building gesture recognizer."""
import dataclasses
from typing import Optional
from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters
from mediapipe.model_maker.python.vision.gesture_recognizer import model_options as model_opt
@dataclasses.dataclass
class GestureRecognizerOptions:
"""Configurable options for building gesture recognizer.
Attributes:
model_options: A set of options for configuring the selected model.
hparams: A set of hyperparameters used to train the gesture recognizer.
"""
model_options: Optional[model_opt.GestureRecognizerModelOptions] = None
hparams: Optional[hyperparameters.HParams] = None

View File

@ -0,0 +1,132 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
from unittest import mock as unittest_mock
import zipfile
import mock
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import test_util
from mediapipe.model_maker.python.vision import gesture_recognizer
from mediapipe.tasks.python.test import test_utils
_TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data'
class GestureRecognizerTest(tf.test.TestCase):
def _load_data(self):
input_data_dir = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'raw_data'))
data = gesture_recognizer.Dataset.from_folder(
dirname=input_data_dir,
hparams=gesture_recognizer.HandDataPreprocessingParams(shuffle=True))
return data
def setUp(self):
super().setUp()
self._model_options = gesture_recognizer.ModelOptions()
self._hparams = gesture_recognizer.HParams(epochs=2)
self._gesture_recognizer_options = (
gesture_recognizer.GestureRecognizerOptions(
model_options=self._model_options, hparams=self._hparams))
all_data = self._load_data()
# Splits data, 90% data for training, 10% for testing
self._train_data, self._test_data = all_data.split(0.9)
def test_gesture_recognizer_model(self):
model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data,
validation_data=self._test_data,
options=self._gesture_recognizer_options)
self._test_accuracy(model)
def test_export_gesture_recognizer_model(self):
model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data,
validation_data=self._test_data,
options=self._gesture_recognizer_options)
model.export_model()
model_bundle_file = os.path.join(self._hparams.export_dir,
'gesture_recognizer.task')
with zipfile.ZipFile(model_bundle_file) as zf:
self.assertEqual(
set(zf.namelist()),
set(['hand_landmarker.task', 'hand_gesture_recognizer.task']))
zf.extractall(self.get_temp_dir())
hand_gesture_recognizer_bundle_file = os.path.join(
self.get_temp_dir(), 'hand_gesture_recognizer.task')
with zipfile.ZipFile(hand_gesture_recognizer_bundle_file) as zf:
self.assertEqual(
set(zf.namelist()),
set([
'canned_gesture_classifier.tflite',
'custom_gesture_classifier.tflite', 'gesture_embedder.tflite'
]))
zf.extractall(self.get_temp_dir())
gesture_classifier_tflite_file = os.path.join(
self.get_temp_dir(), 'custom_gesture_classifier.tflite')
test_util.test_tflite_file(
keras_model=model._model,
tflite_file=gesture_classifier_tflite_file,
size=[1, model.embedding_size])
def _test_accuracy(self, model, threshold=0.5):
_, accuracy = model.evaluate(self._test_data)
tf.compat.v1.logging.info(f'accuracy: {accuracy}')
self.assertGreaterEqual(accuracy, threshold)
@unittest_mock.patch.object(
gesture_recognizer.hyperparameters,
'HParams',
autospec=True,
return_value=gesture_recognizer.HParams(epochs=1))
@unittest_mock.patch.object(
gesture_recognizer.model_options,
'GestureRecognizerModelOptions',
autospec=True,
return_value=gesture_recognizer.ModelOptions())
def test_create_hparams_and_model_options_if_none_in_image_classifier_options(
self, mock_hparams, mock_model_options):
options = gesture_recognizer.GestureRecognizerOptions()
gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data,
validation_data=self._test_data,
options=options)
mock_hparams.assert_called_once()
mock_model_options.assert_called_once()
def test_continual_training_by_loading_checkpoint(self):
mock_stdout = io.StringIO()
with mock.patch('sys.stdout', mock_stdout):
model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data,
validation_data=self._test_data,
options=self._gesture_recognizer_options)
model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data,
validation_data=self._test_data,
options=self._gesture_recognizer_options)
self._test_accuracy(model)
self.assertRegex(mock_stdout.getvalue(), 'Resuming from')
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,40 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hyperparameters for training customized gesture recognizer models."""
import dataclasses
from mediapipe.model_maker.python.core import hyperparameters as hp
@dataclasses.dataclass
class HParams(hp.BaseHParams):
"""The hyperparameters for training gesture recognizer.
Attributes:
learning_rate: Learning rate to use for gradient descent training.
batch_size: Batch size for training.
epochs: Number of training iterations over the dataset.
lr_decay: Learning rate decay to use for gradient descent training.
gamma: Gamma parameter for focal loss.
"""
# Parameters from BaseHParams class.
learning_rate: float = 0.001
batch_size: int = 2
epochs: int = 10
# Parameters about training configuration
# TODO: Move lr_decay to hp.baseHParams.
lr_decay: float = 0.99
gamma: int = 2

View File

@ -0,0 +1,193 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Writes metadata and creates model asset bundle for gesture recognizer."""
import dataclasses
import os
import tempfile
from typing import Union
import tensorflow as tf
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
_HAND_DETECTOR_TFLITE_NAME = "hand_detector.tflite"
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME = "hand_landmarks_detector.tflite"
_HAND_LANDMARKER_BUNDLE_NAME = "hand_landmarker.task"
_HAND_GESTURE_RECOGNIZER_BUNDLE_NAME = "hand_gesture_recognizer.task"
_GESTURE_EMBEDDER_TFLITE_NAME = "gesture_embedder.tflite"
_CANNED_GESTURE_CLASSIFIER_TFLITE_NAME = "canned_gesture_classifier.tflite"
_CUSTOM_GESTURE_CLASSIFIER_TFLITE_NAME = "custom_gesture_classifier.tflite"
_MODEL_NAME = "HandGestureRecognition"
_MODEL_DESCRIPTION = "Recognize the hand gesture in the image."
_INPUT_NAME = "embedding"
_INPUT_DESCRIPTION = "Embedding feature vector from gesture embedder."
_OUTPUT_NAME = "scores"
_OUTPUT_DESCRIPTION = "Hand gesture category scores."
@dataclasses.dataclass
class GestureClassifierOptions:
"""Options to write metadata for gesture classifier.
Attributes:
model_buffer: Gesture classifier TFLite model buffer.
labels: Labels for the gesture classifier.
score_thresholding: Parameters to performs thresholding on output tensor
values [1].
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L468
"""
model_buffer: bytearray
labels: metadata_writer.Labels
score_thresholding: metadata_writer.ScoreThresholding
def read_file(file_path: str, mode: str = "rb") -> Union[str, bytes]:
with tf.io.gfile.GFile(file_path, mode) as f:
return f.read()
class MetadataWriter:
"""MetadataWriter to write the metadata and the model asset bundle."""
def __init__(
self, hand_detector_model_buffer: bytearray,
hand_landmarks_detector_model_buffer: bytearray,
gesture_embedder_model_buffer: bytearray,
canned_gesture_classifier_model_buffer: bytearray,
custom_gesture_classifier_metadata_writer: metadata_writer.MetadataWriter
) -> None:
"""Initialize MetadataWriter to write the metadata and model asset bundle.
Args:
hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from
the TFLite hand detector model file.
hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata
loaded from the TFLite hand landmarks detector model file.
gesture_embedder_model_buffer: A valid flatbuffer *with* metadata loaded
from the TFLite gesture embedder model file.
canned_gesture_classifier_model_buffer: A valid flatbuffer *with* metadata
loaded from the TFLite canned gesture classifier model file.
custom_gesture_classifier_metadata_writer: Metadata writer to write custom
gesture classifier metadata into the TFLite file.
"""
self._hand_detector_model_buffer = hand_detector_model_buffer
self._hand_landmarks_detector_model_buffer = hand_landmarks_detector_model_buffer
self._gesture_embedder_model_buffer = gesture_embedder_model_buffer
self._canned_gesture_classifier_model_buffer = canned_gesture_classifier_model_buffer
self._custom_gesture_classifier_metadata_writer = custom_gesture_classifier_metadata_writer
self._temp_folder = tempfile.TemporaryDirectory()
def __del__(self):
if os.path.exists(self._temp_folder.name):
self._temp_folder.cleanup()
@classmethod
def create(
cls,
hand_detector_model_buffer: bytearray,
hand_landmarks_detector_model_buffer: bytearray,
gesture_embedder_model_buffer: bytearray,
canned_gesture_classifier_model_buffer: bytearray,
custom_gesture_classifier_options: GestureClassifierOptions,
) -> "MetadataWriter":
"""Creates MetadataWriter to write the metadata for gesture recognizer.
Args:
hand_detector_model_buffer: A valid flatbuffer *with* metadata loaded from
the TFLite hand detector model file.
hand_landmarks_detector_model_buffer: A valid flatbuffer *with* metadata
loaded from the TFLite hand landmarks detector model file.
gesture_embedder_model_buffer: A valid flatbuffer *with* metadata loaded
from the TFLite gesture embedder model file.
canned_gesture_classifier_model_buffer: A valid flatbuffer *with* metadata
loaded from the TFLite canned gesture classifier model file.
custom_gesture_classifier_options: Custom gesture classifier options to
write custom gesture classifier metadata into the TFLite file.
Returns:
An MetadataWrite object.
"""
writer = metadata_writer.MetadataWriter.create(
custom_gesture_classifier_options.model_buffer)
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
writer.add_feature_input(name=_INPUT_NAME, description=_INPUT_DESCRIPTION)
writer.add_classification_output(
labels=custom_gesture_classifier_options.labels,
score_thresholding=custom_gesture_classifier_options.score_thresholding,
name=_OUTPUT_NAME,
description=_OUTPUT_DESCRIPTION)
return cls(hand_detector_model_buffer, hand_landmarks_detector_model_buffer,
gesture_embedder_model_buffer,
canned_gesture_classifier_model_buffer, writer)
def populate(self):
"""Populates the metadata and creates model asset bundle.
Note that only the output model asset bundle is used for deployment.
The output JSON content is used to interpret the custom gesture classifier
metadata content.
Returns:
A tuple of (model_asset_bundle_in_bytes, metadata_json_content)
"""
# Creates the model asset bundle for hand landmarker task.
landmark_models = {
_HAND_DETECTOR_TFLITE_NAME:
self._hand_detector_model_buffer,
_HAND_LANDMARKS_DETECTOR_TFLITE_NAME:
self._hand_landmarks_detector_model_buffer
}
output_hand_landmarker_path = os.path.join(self._temp_folder.name,
_HAND_LANDMARKER_BUNDLE_NAME)
writer_utils.create_model_asset_bundle(landmark_models,
output_hand_landmarker_path)
# Write metadata into custom gesture classifier model.
self._custom_gesture_classifier_model_buffer, custom_gesture_classifier_metadata_json = self._custom_gesture_classifier_metadata_writer.populate(
)
# Creates the model asset bundle for hand gesture recognizer sub graph.
hand_gesture_recognizer_models = {
_GESTURE_EMBEDDER_TFLITE_NAME:
self._gesture_embedder_model_buffer,
_CANNED_GESTURE_CLASSIFIER_TFLITE_NAME:
self._canned_gesture_classifier_model_buffer,
_CUSTOM_GESTURE_CLASSIFIER_TFLITE_NAME:
self._custom_gesture_classifier_model_buffer
}
output_hand_gesture_recognizer_path = os.path.join(
self._temp_folder.name, _HAND_GESTURE_RECOGNIZER_BUNDLE_NAME)
writer_utils.create_model_asset_bundle(hand_gesture_recognizer_models,
output_hand_gesture_recognizer_path)
# Creates the model asset bundle for end-to-end hand gesture recognizer
# graph.
gesture_recognizer_models = {
_HAND_LANDMARKER_BUNDLE_NAME:
read_file(output_hand_landmarker_path),
_HAND_GESTURE_RECOGNIZER_BUNDLE_NAME:
read_file(output_hand_gesture_recognizer_path),
}
output_file_path = os.path.join(self._temp_folder.name,
"gesture_recognizer.task")
writer_utils.create_model_asset_bundle(gesture_recognizer_models,
output_file_path)
with open(output_file_path, "rb") as f:
gesture_recognizer_model_buffer = f.read()
return gesture_recognizer_model_buffer, custom_gesture_classifier_metadata_json

View File

@ -0,0 +1,90 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for metadata_writer."""
import os
import zipfile
import tensorflow as tf
from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writer
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer as base_metadata_writer
from mediapipe.tasks.python.test import test_utils
_TEST_DATA_DIR = "mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata"
_EXPECTED_JSON = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "custom_gesture_classifier_meta.json"))
_CUSTOM_GESTURE_CLASSIFIER_PATH = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "custom_gesture_classifier.tflite"))
class MetadataWriterTest(tf.test.TestCase):
def test_write_metadata_and_create_model_asset_bundle_successful(self):
# Use dummy model buffer for unit test only.
hand_detector_model_buffer = b"\x11\x12"
hand_landmarks_detector_model_buffer = b"\x22"
gesture_embedder_model_buffer = b"\x33"
canned_gesture_classifier_model_buffer = b"\x44"
custom_gesture_classifier_metadata_writer = metadata_writer.GestureClassifierOptions(
model_buffer=metadata_writer.read_file(_CUSTOM_GESTURE_CLASSIFIER_PATH),
labels=base_metadata_writer.Labels().add(
["None", "Paper", "Rock", "Scissors"]),
score_thresholding=base_metadata_writer.ScoreThresholding(
global_score_threshold=0.5))
writer = metadata_writer.MetadataWriter.create(
hand_detector_model_buffer, hand_landmarks_detector_model_buffer,
gesture_embedder_model_buffer, canned_gesture_classifier_model_buffer,
custom_gesture_classifier_metadata_writer)
model_bundle_content, metadata_json = writer.populate()
with open(_EXPECTED_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
# Checks the top-level model bundle can be extracted successfully.
model_bundle_filepath = os.path.join(self.get_temp_dir(),
"gesture_recognition.task")
with open(model_bundle_filepath, "wb") as f:
f.write(model_bundle_content)
with zipfile.ZipFile(model_bundle_filepath) as zf:
self.assertEqual(
set(zf.namelist()),
set(["hand_landmarker.task", "hand_gesture_recognizer.task"]))
zf.extractall(self.get_temp_dir())
# Checks the model bundles for sub-task can be extracted successfully.
hand_landmarker_bundle_filepath = os.path.join(self.get_temp_dir(),
"hand_landmarker.task")
with zipfile.ZipFile(hand_landmarker_bundle_filepath) as zf:
self.assertEqual(
set(zf.namelist()),
set(["hand_landmarks_detector.tflite", "hand_detector.tflite"]))
hand_gesture_recognizer_bundle_filepath = os.path.join(
self.get_temp_dir(), "hand_gesture_recognizer.task")
with zipfile.ZipFile(hand_gesture_recognizer_bundle_filepath) as zf:
self.assertEqual(
set(zf.namelist()),
set([
"canned_gesture_classifier.tflite",
"custom_gesture_classifier.tflite", "gesture_embedder.tflite"
]))
if __name__ == "__main__":
tf.test.main()

View File

@ -0,0 +1,27 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configurable model options for gesture recognizer models."""
import dataclasses
@dataclasses.dataclass
class GestureRecognizerModelOptions:
"""Configurable options for gesture recognizer model.
Attributes:
dropout_rate: The fraction of the input units to drop, used in dropout
layer.
"""
dropout_rate: float = 0.05

View File

@ -0,0 +1,56 @@
{
"name": "HandGestureRecognition",
"description": "Recognize the hand gesture in the image.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "embedding",
"description": "Embedding feature vector from gesture embedder.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
}
],
"output_tensor_metadata": [
{
"name": "scores",
"description": "Hand gesture category scores.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"process_units": [
{
"options_type": "ScoreThresholdingOptions",
"options": {
"global_score_threshold": 0.5
}
}
],
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
]
}
],
"min_parser_version": "1.0.0"
}

View File

@ -121,6 +121,34 @@ py_library(
],
)
py_library(
name = "solution_base",
srcs = ["solution_base.py"],
srcs_version = "PY3",
visibility = [
"//mediapipe/python:__subpackages__",
],
deps = [
":_framework_bindings",
":packet_creator",
":packet_getter",
"//mediapipe/calculators/core:constant_side_packet_calculator_py_pb2",
"//mediapipe/calculators/image:image_transformation_calculator_py_pb2",
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_py_pb2",
"//mediapipe/calculators/util:landmarks_smoothing_calculator_py_pb2",
"//mediapipe/calculators/util:logic_calculator_py_pb2",
"//mediapipe/calculators/util:thresholding_calculator_py_pb2",
"//mediapipe/framework:calculator_py_pb2",
"//mediapipe/framework/formats:classification_py_pb2",
"//mediapipe/framework/formats:detection_py_pb2",
"//mediapipe/framework/formats:landmark_py_pb2",
"//mediapipe/framework/formats:rect_py_pb2",
"//mediapipe/modules/objectron/calculators:annotation_py_pb2",
"//mediapipe/modules/objectron/calculators:lift_2d_frame_annotation_to_3d_calculator_py_pb2",
"@com_google_protobuf//:protobuf_python",
],
)
py_test(
name = "calculator_graph_test",
srcs = ["calculator_graph_test.py"],
@ -176,3 +204,24 @@ py_test(
":_framework_bindings",
],
)
py_test(
name = "solution_base_test",
srcs = ["solution_base_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":solution_base",
"//file/google_src",
"//file/localfile",
"//mediapipe/calculators/core:gate_calculator",
"//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/calculators/core:side_packet_to_stream_calculator",
"//mediapipe/calculators/image:image_transformation_calculator",
"//mediapipe/calculators/util:detection_unique_id_calculator",
"//mediapipe/calculators/util:to_image_calculator",
"//mediapipe/framework:calculator_py_pb2",
"//mediapipe/framework/formats:detection_py_pb2",
"@com_google_protobuf//:protobuf_python",
],
)

View File

@ -40,7 +40,6 @@ from mediapipe.calculators.util import landmarks_smoothing_calculator_pb2
from mediapipe.calculators.util import logic_calculator_pb2
from mediapipe.calculators.util import thresholding_calculator_pb2
from mediapipe.framework import calculator_pb2
from mediapipe.framework.formats import body_rig_pb2
from mediapipe.framework.formats import classification_pb2
from mediapipe.framework.formats import detection_pb2
from mediapipe.framework.formats import landmark_pb2

View File

@ -0,0 +1,105 @@
# Copyright 2020 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//visibility:public"])
py_library(
name = "hands",
srcs = [
"hands.py",
"hands_connections.py",
],
data = [
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
"//mediapipe/modules/hand_landmark:hand_landmark_lite.tflite",
"//mediapipe/modules/hand_landmark:hand_landmark_tracking_cpu_graph",
"//mediapipe/modules/hand_landmark:handedness.txt",
"//mediapipe/modules/palm_detection:palm_detection_full.tflite",
"//mediapipe/modules/palm_detection:palm_detection_lite.tflite",
],
srcs_version = "PY3",
deps = [
"//mediapipe/calculators/core:constant_side_packet_calculator_py_pb2",
"//mediapipe/calculators/core:gate_calculator_py_pb2",
"//mediapipe/calculators/core:split_vector_calculator_py_pb2",
"//mediapipe/calculators/tensor:image_to_tensor_calculator_py_pb2",
"//mediapipe/calculators/tensor:inference_calculator_py_pb2",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_py_pb2",
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_py_pb2",
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator_py_pb2",
"//mediapipe/calculators/tflite:ssd_anchors_calculator_py_pb2",
"//mediapipe/calculators/util:association_calculator_py_pb2",
"//mediapipe/calculators/util:detections_to_rects_calculator_py_pb2",
"//mediapipe/calculators/util:logic_calculator_py_pb2",
"//mediapipe/calculators/util:non_max_suppression_calculator_py_pb2",
"//mediapipe/calculators/util:rect_transformation_calculator_py_pb2",
"//mediapipe/calculators/util:thresholding_calculator_py_pb2",
"//mediapipe/python:solution_base",
],
)
py_library(
name = "drawing_styles",
srcs = ["drawing_styles.py"],
srcs_version = "PY3",
deps = [
"drawing_utils",
"face_mesh",
"hands",
"pose",
],
)
py_library(
name = "drawing_utils",
srcs = ["drawing_utils.py"],
srcs_version = "PY3",
deps = [
"//mediapipe/framework/formats:detection_py_pb2",
"//mediapipe/framework/formats:landmark_py_pb2",
"//mediapipe/framework/formats:location_data_py_pb2",
],
)
py_test(
name = "drawing_utils_test",
srcs = ["drawing_utils_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":drawing_utils",
"//mediapipe/framework/formats:detection_py_pb2",
"//mediapipe/framework/formats:landmark_py_pb2",
"@com_google_protobuf//:protobuf_python",
],
)
py_test(
name = "hands_test",
srcs = ["hands_test.py"],
data = [
":testdata/asl_hand.25fps.mp4",
":testdata/asl_hand.full.npz",
":testdata/hands.jpg",
],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":drawing_styles",
":drawing_utils",
":hands",
],
)

View File

@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
option java_package = "com.google.mediapipe.tasks.audio.audioclassifier.proto";
option java_outer_classname = "AudioClassifierGraphOptionsProto";
message AudioClassifierGraphOptions {
extend mediapipe.CalculatorOptions {
optional AudioClassifierGraphOptions ext = 451755788;

View File

@ -21,15 +21,6 @@ cc_library(
hdrs = ["rect.h"],
)
cc_library(
name = "gesture_recognition_result",
hdrs = ["gesture_recognition_result.h"],
deps = [
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
],
)
cc_library(
name = "category",
srcs = ["category.cc"],

View File

@ -124,12 +124,22 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "gesture_recognizer_result",
hdrs = ["gesture_recognizer_result.h"],
deps = [
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
],
)
cc_library(
name = "gesture_recognizer",
srcs = ["gesture_recognizer.cc"],
hdrs = ["gesture_recognizer.h"],
deps = [
":gesture_recognizer_graph",
":gesture_recognizer_result",
":hand_gesture_recognizer_graph",
"//mediapipe/framework:packet",
"//mediapipe/framework/api2:builder",
@ -140,7 +150,6 @@ cc_library(
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/containers:gesture_recognition_result",
"//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:base_options",

View File

@ -58,8 +58,6 @@ namespace {
using GestureRecognizerGraphOptionsProto = ::mediapipe::tasks::vision::
gesture_recognizer::proto::GestureRecognizerGraphOptions;
using ::mediapipe::tasks::components::containers::GestureRecognitionResult;
constexpr char kHandGestureSubgraphTypeName[] =
"mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph";
@ -214,7 +212,7 @@ absl::StatusOr<std::unique_ptr<GestureRecognizer>> GestureRecognizer::Create(
std::move(packets_callback));
}
absl::StatusOr<GestureRecognitionResult> GestureRecognizer::Recognize(
absl::StatusOr<GestureRecognizerResult> GestureRecognizer::Recognize(
mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
@ -250,7 +248,7 @@ absl::StatusOr<GestureRecognitionResult> GestureRecognizer::Recognize(
};
}
absl::StatusOr<GestureRecognitionResult> GestureRecognizer::RecognizeForVideo(
absl::StatusOr<GestureRecognizerResult> GestureRecognizer::RecognizeForVideo(
mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {

View File

@ -24,12 +24,12 @@ limitations under the License.
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/tasks/cc/components/containers/gesture_recognition_result.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_result.h"
namespace mediapipe {
namespace tasks {
@ -81,9 +81,8 @@ struct GestureRecognizerOptions {
// The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM.
std::function<void(
absl::StatusOr<components::containers::GestureRecognitionResult>,
const Image&, int64)>
std::function<void(absl::StatusOr<GestureRecognizerResult>, const Image&,
int64)>
result_callback = nullptr;
};
@ -104,7 +103,7 @@ struct GestureRecognizerOptions {
// 'width' and 'height' fields is NOT supported and will result in an
// invalid argument error being returned.
// Outputs:
// GestureRecognitionResult
// GestureRecognizerResult
// - The hand gesture recognition results.
class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
public:
@ -139,7 +138,7 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA.
// TODO: Describes how the input image will be preprocessed
// after the yuv support is implemented.
absl::StatusOr<components::containers::GestureRecognitionResult> Recognize(
absl::StatusOr<GestureRecognizerResult> Recognize(
Image image,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
@ -157,10 +156,10 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing.
absl::StatusOr<components::containers::GestureRecognitionResult>
RecognizeForVideo(Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt);
absl::StatusOr<GestureRecognizerResult> RecognizeForVideo(
Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Sends live image data to perform gesture recognition, and the results will
// be available via the "result_callback" provided in the
@ -179,7 +178,7 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
// and will result in an invalid argument error being returned.
//
// The "result_callback" provides
// - A vector of GestureRecognitionResult, each is the recognized results
// - A vector of GestureRecognizerResult, each is the recognized results
// for a input frame.
// - The const reference to the corresponding input image that the gesture
// recognizer runs on. Note that the const reference to the image will no

View File

@ -13,20 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_
#ifndef MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_GESTURE_RECOGNIZER_RESULT_H_
#define MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_GESTURE_RECOGNIZER_RESULT_H_
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
namespace mediapipe {
namespace tasks {
namespace components {
namespace containers {
namespace vision {
namespace gesture_recognizer {
// The gesture recognition result from GestureRecognizer, where each vector
// element represents a single hand detected in the image.
struct GestureRecognitionResult {
struct GestureRecognizerResult {
// Recognized hand gestures with sorted order such that the winning label is
// the first item in the list.
std::vector<mediapipe::ClassificationList> gestures;
@ -38,9 +38,9 @@ struct GestureRecognitionResult {
std::vector<mediapipe::LandmarkList> hand_world_landmarks;
};
} // namespace containers
} // namespace components
} // namespace gesture_recognizer
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_
#endif // MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_GESTURE_RECOGNIZER_RESULT_H_

View File

@ -276,33 +276,44 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
.set_min_size(max_num_hands);
auto has_enough_hands = min_size_node.Out("").Cast<bool>();
auto image_for_hand_detector =
DisallowIf(image_in, has_enough_hands, graph);
auto norm_rect_in_for_hand_detector =
DisallowIf(norm_rect_in, has_enough_hands, graph);
auto& hand_detector =
graph.AddNode("mediapipe.tasks.vision.hand_detector.HandDetectorGraph");
hand_detector.GetOptions<HandDetectorGraphOptions>().CopyFrom(
tasks_options.hand_detector_graph_options());
auto& clip_hand_rects =
graph.AddNode("ClipNormalizedRectVectorSizeCalculator");
clip_hand_rects.GetOptions<ClipVectorSizeCalculatorOptions>()
.set_max_vec_size(max_num_hands);
if (tasks_options.base_options().use_stream_mode()) {
// While in stream mode, skip hand detector graph when we successfully
// track the hands from the last frame.
auto image_for_hand_detector =
DisallowIf(image_in, has_enough_hands, graph);
auto norm_rect_in_for_hand_detector =
DisallowIf(norm_rect_in, has_enough_hands, graph);
image_for_hand_detector >> hand_detector.In("IMAGE");
norm_rect_in_for_hand_detector >> hand_detector.In("NORM_RECT");
auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS");
auto& hand_association = graph.AddNode("HandAssociationCalculator");
hand_association.GetOptions<HandAssociationCalculatorOptions>()
.set_min_similarity_threshold(tasks_options.min_tracking_confidence());
.set_min_similarity_threshold(
tasks_options.min_tracking_confidence());
prev_hand_rects_from_landmarks >>
hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][0];
hand_rects_from_hand_detector >>
hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][1];
auto hand_rects = hand_association.Out("");
auto& clip_hand_rects =
graph.AddNode("ClipNormalizedRectVectorSizeCalculator");
clip_hand_rects.GetOptions<ClipVectorSizeCalculatorOptions>()
.set_max_vec_size(max_num_hands);
hand_rects >> clip_hand_rects.In("");
} else {
// While not in stream mode, the input images are not guaranteed to be in
// series, and we don't want to enable the tracking and hand associations
// between input images. Always use the hand detector graph.
image_in >> hand_detector.In("IMAGE");
norm_rect_in >> hand_detector.In("NORM_RECT");
auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS");
hand_rects_from_hand_detector >> clip_hand_rects.In("");
}
auto clipped_hand_rects = clip_hand_rects.Out("");
auto& hand_landmarks_detector_graph = graph.AddNode(

View File

@ -0,0 +1,84 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//mediapipe/tasks:internal"])
android_library(
name = "core",
srcs = glob(["core/*.java"]),
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [
":libmediapipe_tasks_audio_jni_lib",
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"@maven//:com_google_guava_guava",
],
)
# The native library of all MediaPipe audio tasks.
cc_binary(
name = "libmediapipe_tasks_audio_jni.so",
linkshared = 1,
linkstatic = 1,
deps = [
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
"//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
],
)
cc_library(
name = "libmediapipe_tasks_audio_jni_lib",
srcs = [":libmediapipe_tasks_audio_jni.so"],
alwayslink = 1,
)
android_library(
name = "audioclassifier",
srcs = [
"audioclassifier/AudioClassifier.java",
"audioclassifier/AudioClassifierResult.java",
],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = "audioclassifier/AndroidManifest.xml",
deps = [
":core",
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_audio_aar")
mediapipe_tasks_audio_aar(
name = "tasks_audio",
srcs = glob(["**/*.java"]),
native_library = ":libmediapipe_tasks_audio_jni_lib",
)

View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.audio.audioclassifier">
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
</manifest>

View File

@ -0,0 +1,399 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.audio.audioclassifier;
import android.content.Context;
import android.os.ParcelFileDescriptor;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.tasks.audio.audioclassifier.proto.AudioClassifierGraphOptionsProto;
import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi;
import com.google.mediapipe.tasks.audio.core.RunningMode;
import com.google.mediapipe.tasks.components.containers.AudioData;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler;
import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener;
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
import com.google.mediapipe.tasks.core.TaskInfo;
import com.google.mediapipe.tasks.core.TaskOptions;
import com.google.mediapipe.tasks.core.TaskRunner;
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
/**
* Performs audio classification on audio clips or audio stream.
*
* <p>This API expects a TFLite model with mandatory TFLite Model Metadata that contains the
* mandatory AudioProperties of the solo input audio tensor and the optional (but recommended) label
* items as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor.
*
* <p>Input tensor: (kTfLiteFloat32)
*
* <ul>
* <li>input audio buffer of size `[batch * samples]`.
* <li>batch inference is not supported (`batch` is required to be 1).
* <li>for multi-channel models, the channels need be interleaved.
* </ul>
*
* <p>At least one output tensor with: (kTfLiteFloat32)
*
* <ul>
* <li>`[1 x N]` array with `N` represents the number of categories.
* <li>optional (but recommended) label items as AssociatedFiles with type TENSOR_AXIS_LABELS,
* containing one label per line. The first such AssociatedFile (if any) is used to fill the
* `category_name` field of the results. The `display_name` field is filled from the
* AssociatedFile (if any) whose locale matches the `display_names_locale` field of the
* `AudioClassifierOptions` used at creation time ("en" by default, i.e. English). If none of
* these are available, only the `index` field of the results will be filled.
* </ul>
*/
public final class AudioClassifier extends BaseAudioTaskApi {
private static final String TAG = AudioClassifier.class.getSimpleName();
private static final String AUDIO_IN_STREAM_NAME = "audio_in";
private static final String SAMPLE_RATE_IN_STREAM_NAME = "sample_rate_in";
private static final List<String> INPUT_STREAMS =
Collections.unmodifiableList(
Arrays.asList(
"AUDIO:" + AUDIO_IN_STREAM_NAME, "SAMPLE_RATE:" + SAMPLE_RATE_IN_STREAM_NAME));
private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList(
Arrays.asList(
"CLASSIFICATIONS:classifications_out",
"TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications_out"));
private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0;
private static final int TIMESTAMPED_CLASSIFICATIONS_OUT_STREAM_INDEX = 1;
private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph";
private static final long MICROSECONDS_PER_MILLISECOND = 1000;
static {
ProtoUtil.registerTypeName(
ClassificationsProto.ClassificationResult.class,
"mediapipe.tasks.components.containers.proto.ClassificationResult");
}
/**
* Creates an {@link AudioClassifier} instance from a model file and default {@link
* AudioClassifierOptions}.
*
* @param context an Android {@link Context}.
* @param modelPath path to the classification model in the assets.
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
*/
public static AudioClassifier createFromFile(Context context, String modelPath) {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
return createFromOptions(
context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
}
/**
* Creates an {@link AudioClassifier} instance from a model file and default {@link
* AudioClassifierOptions}.
*
* @param context an Android {@link Context}.
* @param modelFile the classification model {@link File} instance.
* @throws IOException if an I/O error occurs when opening the tflite model file.
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
*/
public static AudioClassifier createFromFile(Context context, File modelFile) throws IOException {
try (ParcelFileDescriptor descriptor =
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
BaseOptions baseOptions =
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
return createFromOptions(
context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
}
}
/**
* Creates an {@link AudioClassifier} instance from a model buffer and default {@link
* AudioClassifierOptions}.
*
* @param context an Android {@link Context}.
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
* classification model.
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
*/
public static AudioClassifier createFromBuffer(Context context, final ByteBuffer modelBuffer) {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
return createFromOptions(
context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
}
/**
* Creates an {@link AudioClassifier} instance from an {@link AudioClassifierOptions} instance.
*
* @param context an Android {@link Context}.
* @param options an {@link AudioClassifierOptions} instance.
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
*/
public static AudioClassifier createFromOptions(Context context, AudioClassifierOptions options) {
OutputHandler<AudioClassifierResult, Void> handler = new OutputHandler<>();
handler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<AudioClassifierResult, Void>() {
@Override
public AudioClassifierResult convertToTaskResult(List<Packet> packets) {
try {
if (!packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).isEmpty()) {
// For audio stream mode.
return AudioClassifierResult.createFromProto(
PacketGetter.getProto(
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX),
ClassificationsProto.ClassificationResult.getDefaultInstance()),
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp()
/ MICROSECONDS_PER_MILLISECOND);
} else {
// For audio clips mode.
return AudioClassifierResult.createFromProtoList(
PacketGetter.getProtoVector(
packets.get(TIMESTAMPED_CLASSIFICATIONS_OUT_STREAM_INDEX),
ClassificationsProto.ClassificationResult.parser()),
-1);
}
} catch (IOException e) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
}
}
@Override
public Void convertToTaskInput(List<Packet> packets) {
return null;
}
});
if (options.resultListener().isPresent()) {
ResultListener<AudioClassifierResult, Void> resultListener =
new ResultListener<AudioClassifierResult, Void>() {
@Override
public void run(AudioClassifierResult audioClassifierResult, Void input) {
options.resultListener().get().run(audioClassifierResult);
}
};
handler.setResultListener(resultListener);
}
options.errorListener().ifPresent(handler::setErrorListener);
// Audio tasks should not drop input audio due to flow limiting, which may cause data
// inconsistency.
TaskRunner runner =
TaskRunner.create(
context,
TaskInfo.<AudioClassifierOptions>builder()
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)
.setTaskOptions(options)
.setEnableFlowLimiting(false)
.build(),
handler);
return new AudioClassifier(runner, options.runningMode());
}
/**
* Constructor to initialize an {@link AudioClassifier} from a {@link TaskRunner} and {@link
* RunningMode}.
*
* @param taskRunner a {@link TaskRunner}.
* @param runningMode a mediapipe audio task {@link RunningMode}.
*/
private AudioClassifier(TaskRunner taskRunner, RunningMode runningMode) {
super(taskRunner, runningMode, AUDIO_IN_STREAM_NAME, SAMPLE_RATE_IN_STREAM_NAME);
}
/*
* Performs audio classification on the provided audio clip. Only use this method when the
* AudioClassifier is created with the audio clips mode.
*
* <p>The audio clip is represented as a MediaPipe {@link AudioData} object The method accepts
* audio clips with various length and audio sample rate. It's required to provide the
* corresponding audio sample rate within the {@link AudioData} object.
*
* <p>The input audio clip may be longer than what the model is able to process in a single
* inference. When this occurs, the input audio clip is split into multiple chunks starting at
* different timestamps. For this reason, this function returns a vector of ClassificationResult
* objects, each associated with a timestamp corresponding to the start (in milliseconds) of the
* chunk data that was classified, e.g:
*
* ClassificationResult #0 (first chunk of data):
* timestamp_ms: 0 (starts at 0ms)
* classifications #0 (single head model):
* category #0:
* category_name: "Speech"
* score: 0.6
* category #1:
* category_name: "Music"
* score: 0.2
* ClassificationResult #1 (second chunk of data):
* timestamp_ms: 800 (starts at 800ms)
* classifications #0 (single head model):
* category #0:
* category_name: "Speech"
* score: 0.5
* category #1:
* category_name: "Silence"
* score: 0.1
*
* @param audioClip a MediaPipe {@link AudioData} object for processing.
* @throws MediaPipeException if there is an internal error.
*/
public AudioClassifierResult classify(AudioData audioClip) {
return (AudioClassifierResult) processAudioClip(audioClip);
}
/*
* Sends audio data (a block in a continuous audio stream) to perform audio classification. Only
* use this method when the AudioClassifier is created with the audio stream mode.
*
* <p>The audio block is represented as a MediaPipe {@link AudioData} object. The audio data will
* be resampled, accumulated, and framed to the proper size for the underlying model to consume.
* It's required to provide the corresponding audio sample rate within {@link AudioData} object as
* well as a timestamp (in milliseconds) to indicate the start time of the input audio block. The
* timestamps must be monotonically increasing. This method will return immediately after
* the input audio data is accepted. The results will be available in the `resultListener`
* provided in the `AudioClassifierOptions`. The `classifyAsync` method is designed to process
* auido stream data such as microphone input.
*
* <p>The input audio block may be longer than what the model is able to process in a single
* inference. When this occurs, the input audio block is split into multiple chunks. For this
* reason, the callback may be called multiple times (once per chunk) for each call to this
* function.
*
* @param audioBlock a MediaPipe {@link AudioData} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public void classifyAsync(AudioData audioBlock, long timestampMs) {
checkOrSetSampleRate(audioBlock.getFormat().getSampleRate());
sendAudioStreamData(audioBlock, timestampMs);
}
/** Options for setting up and {@link AudioClassifier}. */
@AutoValue
public abstract static class AudioClassifierOptions extends TaskOptions {
/** Builder for {@link AudioClassifierOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/** Sets the {@link BaseOptions} for the audio classifier task. */
public abstract Builder setBaseOptions(BaseOptions baseOptions);
/**
* Sets the {@link RunningMode} for the audio classifier task. Default to the audio clips
* mode. Image classifier has two modes:
*
* <ul>
* <li>AUDIO_CLIPS: The mode for running audio classification on audio clips. Users feed
* audio clips to the `classify` method, and will receive the classification results as
* the return value.
* <li>AUDIO_STREAM: The mode for running audio classification on the audio stream, such as
* from microphone. Users call `classifyAsync` to push the audio data into the
* AudioClassifier, the classification results will be available in the result callback
* when the audio classifier finishes the work.
* </ul>
*/
public abstract Builder setRunningMode(RunningMode runningMode);
/**
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as
* score threshold, number of results, etc.
*/
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
/**
* Sets the {@link ResultListener} to receive the classification results asynchronously when
* the audio classifier is in the audio stream mode.
*/
public abstract Builder setResultListener(
PureResultListener<AudioClassifierResult> resultListener);
/** Sets an optional {@link ErrorListener}. */
public abstract Builder setErrorListener(ErrorListener errorListener);
abstract AudioClassifierOptions autoBuild();
/**
* Validates and builds the {@link AudioClassifierOptions} instance.
*
* @throws IllegalArgumentException if the result listener and the running mode are not
* properly configured. The result listener should only be set when the audio classifier
* is in the audio stream mode.
*/
public final AudioClassifierOptions build() {
AudioClassifierOptions options = autoBuild();
if (options.runningMode() == RunningMode.AUDIO_STREAM) {
if (!options.resultListener().isPresent()) {
throw new IllegalArgumentException(
"The audio classifier is in the audio stream mode, a user-defined result listener"
+ " must be provided in the AudioClassifierOptions.");
}
} else if (options.resultListener().isPresent()) {
throw new IllegalArgumentException(
"The audio classifier is in the audio clips mode, a user-defined result listener"
+ " shouldn't be provided in AudioClassifierOptions.");
}
return options;
}
}
abstract BaseOptions baseOptions();
abstract RunningMode runningMode();
abstract Optional<ClassifierOptions> classifierOptions();
abstract Optional<PureResultListener<AudioClassifierResult>> resultListener();
abstract Optional<ErrorListener> errorListener();
public static Builder builder() {
return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder()
.setRunningMode(RunningMode.AUDIO_CLIPS);
}
/**
* Converts a {@link AudioClassifierOptions} to a {@link CalculatorOptions} protobuf message.
*/
@Override
public CalculatorOptions convertToCalculatorOptionsProto() {
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
BaseOptionsProto.BaseOptions.newBuilder();
baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM);
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder =
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder()
.setBaseOptions(baseOptionsBuilder);
if (classifierOptions().isPresent()) {
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
}
return CalculatorOptions.newBuilder()
.setExtension(
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext,
taskOptionsBuilder.build())
.build();
}
}
}

View File

@ -0,0 +1,76 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.audio.audioclassifier;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.core.TaskResult;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
/** Represents the classification results generated by {@link AudioClassifier}. */
@AutoValue
public abstract class AudioClassifierResult implements TaskResult {
/**
* Creates an {@link AudioClassifierResult} instance from a list of {@link
* ClassificationsProto.ClassificationResult} protobuf messages.
*
* @param protoList a list of {@link ClassificationsProto.ClassificationResult} protobuf message
* to convert.
* @param timestampMs a timestamp for this result.
*/
static AudioClassifierResult createFromProtoList(
List<ClassificationsProto.ClassificationResult> protoList, long timestampMs) {
List<ClassificationResult> classificationResultList = new ArrayList<>();
for (ClassificationsProto.ClassificationResult proto : protoList) {
classificationResultList.add(ClassificationResult.createFromProto(proto));
}
return new AutoValue_AudioClassifierResult(
Optional.of(classificationResultList), Optional.empty(), timestampMs);
}
/**
* Creates an {@link AudioClassifierResult} instance from a {@link
* ClassificationsProto.ClassificationResult} protobuf message.
*
* @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
* @param timestampMs a timestamp for this result.
*/
static AudioClassifierResult createFromProto(
ClassificationsProto.ClassificationResult proto, long timestampMs) {
return new AutoValue_AudioClassifierResult(
Optional.empty(), Optional.of(ClassificationResult.createFromProto(proto)), timestampMs);
}
/**
* A list of of timpstamed {@link ClassificationResult} objects, each contains one set of results
* per classifier head. The list represents the audio classification result of an audio clip, and
* is only available when running with the audio clips mode.
*/
public abstract Optional<List<ClassificationResult>> classificationResultList();
/**
* Contains one set of results per classifier head. A {@link ClassificationResult} usually
* represents one audio classification result in an audio stream, and s only available when
* running with the audio stream mode.
*/
public abstract Optional<ClassificationResult> classificationResult();
@Override
public abstract long timestampMs();
}

View File

@ -0,0 +1,151 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.audio.core;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.tasks.components.containers.AudioData;
import com.google.mediapipe.tasks.core.TaskResult;
import com.google.mediapipe.tasks.core.TaskRunner;
import java.util.HashMap;
import java.util.Map;
/** The base class of MediaPipe audio tasks. */
public class BaseAudioTaskApi implements AutoCloseable {
private static final long MICROSECONDS_PER_MILLISECOND = 1000;
private static final long PRESTREAM_TIMESTAMP = Long.MIN_VALUE + 2;
private final TaskRunner runner;
private final RunningMode runningMode;
private final String audioStreamName;
private final String sampleRateStreamName;
private double defaultSampleRate;
static {
System.loadLibrary("mediapipe_tasks_audio_jni");
}
/**
* Constructor to initialize a {@link BaseAudioTaskApi}.
*
* @param runner a {@link TaskRunner}.
* @param runningMode a mediapipe audio task {@link RunningMode}.
* @param audioStreamName the name of the input audio stream.
* @param sampleRateStreamName the name of the audio sample rate stream.
*/
public BaseAudioTaskApi(
TaskRunner runner,
RunningMode runningMode,
String audioStreamName,
String sampleRateStreamName) {
this.runner = runner;
this.runningMode = runningMode;
this.audioStreamName = audioStreamName;
this.sampleRateStreamName = sampleRateStreamName;
this.defaultSampleRate = -1.0;
}
/**
* A synchronous method to process audio clips. The call blocks the current thread until a failure
* status or a successful result is returned.
*
* @param audioClip a MediaPipe {@link AudioDatra} object for processing.
* @throws MediaPipeException if the task is not in the audio clips mode.
*/
protected TaskResult processAudioClip(AudioData audioClip) {
if (runningMode != RunningMode.AUDIO_CLIPS) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the audio clips mode. Current running mode:"
+ runningMode.name());
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(
audioStreamName,
runner
.getPacketCreator()
.createMatrix(
audioClip.getFormat().getNumOfChannels(),
audioClip.getBufferLength(),
audioClip.getBuffer()));
inputPackets.put(
sampleRateStreamName,
runner.getPacketCreator().createFloat64(audioClip.getFormat().getSampleRate()));
return runner.process(inputPackets);
}
/**
* Checks or sets the audio sample rate in the audio stream mode.
*
* @param sampleRate the audio sample rate.
* @throws MediaPipeException if the task is not in the audio stream mode or the provided sample
* rate is inconsisent with the previously recevied.
*/
protected void checkOrSetSampleRate(double sampleRate) {
if (runningMode != RunningMode.AUDIO_STREAM) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the audio stream mode. Current running mode:"
+ runningMode.name());
}
if (defaultSampleRate > 0) {
if (Double.compare(sampleRate, defaultSampleRate) != 0) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
"The input audio sample rate: "
+ sampleRate
+ " is inconsistent with the previously provided: "
+ defaultSampleRate);
}
} else {
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(sampleRateStreamName, runner.getPacketCreator().createFloat64(sampleRate));
runner.send(inputPackets, PRESTREAM_TIMESTAMP);
defaultSampleRate = sampleRate;
}
}
/**
* An asynchronous method to send audio stream data to the {@link TaskRunner}. The results will be
* available in the user-defined result listener.
*
* @param audioClip a MediaPipe {@link AudioDatra} object for processing.
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the stream mode.
*/
protected void sendAudioStreamData(AudioData audioClip, long timestampMs) {
if (runningMode != RunningMode.AUDIO_STREAM) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the audio stream mode. Current running mode:"
+ runningMode.name());
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(
audioStreamName,
runner
.getPacketCreator()
.createMatrix(
audioClip.getFormat().getNumOfChannels(),
audioClip.getBufferLength(),
audioClip.getBuffer()));
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
/** Closes and cleans up the MediaPipe audio task. */
@Override
public void close() {
runner.close();
}
}

Some files were not shown because too many files have changed in this diff Show More