Add support for customizing gesture recognizer layers
PiperOrigin-RevId: 496456160
This commit is contained in:
parent
4822476974
commit
3e6cd5d2bf
|
@ -173,15 +173,20 @@ class GestureRecognizer(classifier.Classifier):
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
dtype=tf.float32,
|
dtype=tf.float32,
|
||||||
name='hand_embedding')
|
name='hand_embedding')
|
||||||
|
x = inputs
|
||||||
x = tf.keras.layers.BatchNormalization()(inputs)
|
|
||||||
x = tf.keras.layers.ReLU()(x)
|
|
||||||
dropout_rate = self._model_options.dropout_rate
|
dropout_rate = self._model_options.dropout_rate
|
||||||
x = tf.keras.layers.Dropout(rate=dropout_rate, name='dropout')(x)
|
for i, width in enumerate(self._model_options.layer_widths):
|
||||||
|
x = tf.keras.layers.BatchNormalization()(x)
|
||||||
|
x = tf.keras.layers.ReLU()(x)
|
||||||
|
x = tf.keras.layers.Dropout(rate=dropout_rate)(x)
|
||||||
|
x = tf.keras.layers.Dense(width, name=f'custom_gesture_recognizer_{i}')(x)
|
||||||
|
x = tf.keras.layers.BatchNormalization()(x)
|
||||||
|
x = tf.keras.layers.ReLU()(x)
|
||||||
|
x = tf.keras.layers.Dropout(rate=dropout_rate)(x)
|
||||||
outputs = tf.keras.layers.Dense(
|
outputs = tf.keras.layers.Dense(
|
||||||
self._num_classes,
|
self._num_classes,
|
||||||
activation='softmax',
|
activation='softmax',
|
||||||
name='custom_gesture_recognizer')(
|
name='custom_gesture_recognizer_out')(
|
||||||
x)
|
x)
|
||||||
|
|
||||||
self._model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
self._model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
|
|
@ -60,6 +60,32 @@ class GestureRecognizerTest(tf.test.TestCase):
|
||||||
|
|
||||||
self._test_accuracy(model)
|
self._test_accuracy(model)
|
||||||
|
|
||||||
|
@unittest_mock.patch.object(
|
||||||
|
tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense)
|
||||||
|
def test_gesture_recognizer_model_layer_widths(self, mock_dense):
|
||||||
|
layer_widths = [64, 32]
|
||||||
|
model_options = gesture_recognizer.ModelOptions(layer_widths=layer_widths)
|
||||||
|
hparams = gesture_recognizer.HParams(
|
||||||
|
export_dir=tempfile.mkdtemp(), epochs=2)
|
||||||
|
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
|
||||||
|
model_options=model_options, hparams=hparams)
|
||||||
|
model = gesture_recognizer.GestureRecognizer.create(
|
||||||
|
train_data=self._train_data,
|
||||||
|
validation_data=self._validation_data,
|
||||||
|
options=gesture_recognizer_options)
|
||||||
|
expected_calls = [
|
||||||
|
unittest_mock.call(w, name=f'custom_gesture_recognizer_{i}')
|
||||||
|
for i, w in enumerate(layer_widths)
|
||||||
|
]
|
||||||
|
expected_calls.append(
|
||||||
|
unittest_mock.call(
|
||||||
|
len(self._train_data.label_names),
|
||||||
|
activation='softmax',
|
||||||
|
name='custom_gesture_recognizer_out'))
|
||||||
|
self.assertLen(mock_dense.call_args_list, len(expected_calls))
|
||||||
|
mock_dense.assert_has_calls(expected_calls)
|
||||||
|
self._test_accuracy(model)
|
||||||
|
|
||||||
def test_export_gesture_recognizer_model(self):
|
def test_export_gesture_recognizer_model(self):
|
||||||
model_options = gesture_recognizer.ModelOptions()
|
model_options = gesture_recognizer.ModelOptions()
|
||||||
hparams = gesture_recognizer.HParams(
|
hparams = gesture_recognizer.HParams(
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
"""Configurable model options for gesture recognizer models."""
|
"""Configurable model options for gesture recognizer models."""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
@ -23,5 +24,10 @@ class GestureRecognizerModelOptions:
|
||||||
Attributes:
|
Attributes:
|
||||||
dropout_rate: The fraction of the input units to drop, used in dropout
|
dropout_rate: The fraction of the input units to drop, used in dropout
|
||||||
layer.
|
layer.
|
||||||
|
layer_widths: A list of hidden layer widths for the gesture model. Each
|
||||||
|
element in the list will create a new hidden layer with the specified
|
||||||
|
width. The hidden layers are separated with BatchNorm, Dropout, and ReLU.
|
||||||
|
Defaults to an empty list(no hidden layers).
|
||||||
"""
|
"""
|
||||||
dropout_rate: float = 0.05
|
dropout_rate: float = 0.05
|
||||||
|
layer_widths: List[int] = dataclasses.field(default_factory=list)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user