Renames model options in TextClassifier.
PiperOrigin-RevId: 488063797
This commit is contained in:
parent
bf6c8a0b63
commit
aafbf73c0a
|
@ -23,9 +23,9 @@ from mediapipe.model_maker.python.text.text_classifier import text_classifier_op
|
||||||
HParams = hyperparameters.BaseHParams
|
HParams = hyperparameters.BaseHParams
|
||||||
CSVParams = dataset.CSVParameters
|
CSVParams = dataset.CSVParameters
|
||||||
Dataset = dataset.Dataset
|
Dataset = dataset.Dataset
|
||||||
AverageWordEmbeddingClassifierModelOptions = (
|
AverageWordEmbeddingModelOptions = (
|
||||||
model_options.AverageWordEmbeddingClassifierModelOptions)
|
model_options.AverageWordEmbeddingModelOptions)
|
||||||
BertClassifierModelOptions = model_options.BertClassifierModelOptions
|
BertModelOptions = model_options.BertModelOptions
|
||||||
SupportedModels = model_spec.SupportedModels
|
SupportedModels = model_spec.SupportedModels
|
||||||
TextClassifier = text_classifier.TextClassifier
|
TextClassifier = text_classifier.TextClassifier
|
||||||
TextClassifierOptions = text_classifier_options.TextClassifierOptions
|
TextClassifierOptions = text_classifier_options.TextClassifierOptions
|
||||||
|
|
|
@ -19,11 +19,11 @@ from typing import Union
|
||||||
from mediapipe.model_maker.python.text.core import bert_model_options
|
from mediapipe.model_maker.python.text.core import bert_model_options
|
||||||
|
|
||||||
# BERT text classifier model options inherited from BertModelOptions.
|
# BERT text classifier model options inherited from BertModelOptions.
|
||||||
BertClassifierModelOptions = bert_model_options.BertModelOptions
|
BertModelOptions = bert_model_options.BertModelOptions
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class AverageWordEmbeddingClassifierModelOptions:
|
class AverageWordEmbeddingModelOptions:
|
||||||
"""Configurable model options for an Average Word Embedding classifier.
|
"""Configurable model options for an Average Word Embedding classifier.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
@ -41,5 +41,5 @@ class AverageWordEmbeddingClassifierModelOptions:
|
||||||
dropout_rate: float = 0.2
|
dropout_rate: float = 0.2
|
||||||
|
|
||||||
|
|
||||||
TextClassifierModelOptions = Union[AverageWordEmbeddingClassifierModelOptions,
|
TextClassifierModelOptions = Union[AverageWordEmbeddingModelOptions,
|
||||||
BertClassifierModelOptions]
|
BertModelOptions]
|
||||||
|
|
|
@ -38,8 +38,8 @@ class AverageWordEmbeddingClassifierSpec:
|
||||||
# `learning_rate` is unused for the average word embedding model
|
# `learning_rate` is unused for the average word embedding model
|
||||||
hparams: hp.BaseHParams = hp.BaseHParams(
|
hparams: hp.BaseHParams = hp.BaseHParams(
|
||||||
epochs=10, batch_size=32, learning_rate=0)
|
epochs=10, batch_size=32, learning_rate=0)
|
||||||
model_options: mo.AverageWordEmbeddingClassifierModelOptions = (
|
model_options: mo.AverageWordEmbeddingModelOptions = (
|
||||||
mo.AverageWordEmbeddingClassifierModelOptions())
|
mo.AverageWordEmbeddingModelOptions())
|
||||||
name: str = 'AverageWordEmbedding'
|
name: str = 'AverageWordEmbedding'
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ class ModelSpecTest(tf.test.TestCase):
|
||||||
})
|
})
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
model_spec_obj.model_options,
|
model_spec_obj.model_options,
|
||||||
classifier_model_options.BertClassifierModelOptions(
|
classifier_model_options.BertModelOptions(
|
||||||
seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
|
seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
model_spec_obj.hparams,
|
model_spec_obj.hparams,
|
||||||
|
@ -57,7 +57,7 @@ class ModelSpecTest(tf.test.TestCase):
|
||||||
self.assertEqual(model_spec_obj.name, 'AverageWordEmbedding')
|
self.assertEqual(model_spec_obj.name, 'AverageWordEmbedding')
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
model_spec_obj.model_options,
|
model_spec_obj.model_options,
|
||||||
classifier_model_options.AverageWordEmbeddingClassifierModelOptions(
|
classifier_model_options.AverageWordEmbeddingModelOptions(
|
||||||
seq_len=256,
|
seq_len=256,
|
||||||
wordvec_dim=16,
|
wordvec_dim=16,
|
||||||
do_lower_case=True,
|
do_lower_case=True,
|
||||||
|
@ -77,7 +77,7 @@ class ModelSpecTest(tf.test.TestCase):
|
||||||
|
|
||||||
def test_custom_bert_spec(self):
|
def test_custom_bert_spec(self):
|
||||||
custom_bert_classifier_options = (
|
custom_bert_classifier_options = (
|
||||||
classifier_model_options.BertClassifierModelOptions(
|
classifier_model_options.BertModelOptions(
|
||||||
seq_len=512, do_fine_tuning=False, dropout_rate=0.3))
|
seq_len=512, do_fine_tuning=False, dropout_rate=0.3))
|
||||||
model_spec_obj = (
|
model_spec_obj = (
|
||||||
ms.SupportedModels.MOBILEBERT_CLASSIFIER.value(
|
ms.SupportedModels.MOBILEBERT_CLASSIFIER.value(
|
||||||
|
@ -97,7 +97,7 @@ class ModelSpecTest(tf.test.TestCase):
|
||||||
num_gpus=3,
|
num_gpus=3,
|
||||||
tpu='tpu/address')
|
tpu='tpu/address')
|
||||||
custom_average_word_embedding_model_options = (
|
custom_average_word_embedding_model_options = (
|
||||||
classifier_model_options.AverageWordEmbeddingClassifierModelOptions(
|
classifier_model_options.AverageWordEmbeddingModelOptions(
|
||||||
seq_len=512,
|
seq_len=512,
|
||||||
wordvec_dim=32,
|
wordvec_dim=32,
|
||||||
do_lower_case=False,
|
do_lower_case=False,
|
||||||
|
|
|
@ -50,13 +50,12 @@ def _validate(options: text_classifier_options.TextClassifierOptions):
|
||||||
if options.model_options is None:
|
if options.model_options is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if (isinstance(options.model_options,
|
if (isinstance(options.model_options, mo.AverageWordEmbeddingModelOptions) and
|
||||||
mo.AverageWordEmbeddingClassifierModelOptions) and
|
|
||||||
(options.supported_model !=
|
(options.supported_model !=
|
||||||
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
|
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
|
||||||
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
|
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
|
||||||
f" got {options.supported_model}")
|
f" got {options.supported_model}")
|
||||||
if (isinstance(options.model_options, mo.BertClassifierModelOptions) and
|
if (isinstance(options.model_options, mo.BertModelOptions) and
|
||||||
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
|
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
|
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
|
||||||
|
@ -194,7 +193,7 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
|
||||||
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
||||||
|
|
||||||
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
|
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
|
||||||
model_options: mo.AverageWordEmbeddingClassifierModelOptions,
|
model_options: mo.AverageWordEmbeddingModelOptions,
|
||||||
hparams: hp.BaseHParams, label_names: Sequence[str]):
|
hparams: hp.BaseHParams, label_names: Sequence[str]):
|
||||||
super().__init__(model_spec, hparams, label_names)
|
super().__init__(model_spec, hparams, label_names)
|
||||||
self._model_options = model_options
|
self._model_options = model_options
|
||||||
|
@ -304,8 +303,8 @@ class _BertClassifier(TextClassifier):
|
||||||
_INITIALIZER_RANGE = 0.02
|
_INITIALIZER_RANGE = 0.02
|
||||||
|
|
||||||
def __init__(self, model_spec: ms.BertClassifierSpec,
|
def __init__(self, model_spec: ms.BertClassifierSpec,
|
||||||
model_options: mo.BertClassifierModelOptions,
|
model_options: mo.BertModelOptions, hparams: hp.BaseHParams,
|
||||||
hparams: hp.BaseHParams, label_names: Sequence[str]):
|
label_names: Sequence[str]):
|
||||||
super().__init__(model_spec, hparams, label_names)
|
super().__init__(model_spec, hparams, label_names)
|
||||||
self._model_options = model_options
|
self._model_options = model_options
|
||||||
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||||
|
|
|
@ -79,7 +79,7 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
train_data, validation_data = self._get_data()
|
train_data, validation_data = self._get_data()
|
||||||
options = text_classifier.TextClassifierOptions(
|
options = text_classifier.TextClassifierOptions(
|
||||||
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
|
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||||
model_options=text_classifier.BertClassifierModelOptions(
|
model_options=text_classifier.BertModelOptions(
|
||||||
do_fine_tuning=False, seq_len=2),
|
do_fine_tuning=False, seq_len=2),
|
||||||
hparams=text_classifier.HParams(
|
hparams=text_classifier.HParams(
|
||||||
epochs=1,
|
epochs=1,
|
||||||
|
@ -116,8 +116,7 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
text_classifier.TextClassifierOptions(
|
text_classifier.TextClassifierOptions(
|
||||||
supported_model=(
|
supported_model=(
|
||||||
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER),
|
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER),
|
||||||
model_options=(
|
model_options=text_classifier.AverageWordEmbeddingModelOptions()))
|
||||||
text_classifier.AverageWordEmbeddingClassifierModelOptions())))
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
|
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
|
||||||
' SupportedModels.MOBILEBERT_CLASSIFIER'):
|
' SupportedModels.MOBILEBERT_CLASSIFIER'):
|
||||||
|
@ -128,7 +127,7 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
text_classifier.TextClassifierOptions(
|
text_classifier.TextClassifierOptions(
|
||||||
supported_model=(text_classifier.SupportedModels
|
supported_model=(text_classifier.SupportedModels
|
||||||
.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
|
.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
|
||||||
model_options=text_classifier.BertClassifierModelOptions()))
|
model_options=text_classifier.BertModelOptions()))
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got'
|
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got'
|
||||||
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):
|
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user