Renames model options in TextClassifier.
PiperOrigin-RevId: 488063797
This commit is contained in:
parent
bf6c8a0b63
commit
aafbf73c0a
|
@ -23,9 +23,9 @@ from mediapipe.model_maker.python.text.text_classifier import text_classifier_op
|
|||
HParams = hyperparameters.BaseHParams
|
||||
CSVParams = dataset.CSVParameters
|
||||
Dataset = dataset.Dataset
|
||||
AverageWordEmbeddingClassifierModelOptions = (
|
||||
model_options.AverageWordEmbeddingClassifierModelOptions)
|
||||
BertClassifierModelOptions = model_options.BertClassifierModelOptions
|
||||
AverageWordEmbeddingModelOptions = (
|
||||
model_options.AverageWordEmbeddingModelOptions)
|
||||
BertModelOptions = model_options.BertModelOptions
|
||||
SupportedModels = model_spec.SupportedModels
|
||||
TextClassifier = text_classifier.TextClassifier
|
||||
TextClassifierOptions = text_classifier_options.TextClassifierOptions
|
||||
|
|
|
@ -19,11 +19,11 @@ from typing import Union
|
|||
from mediapipe.model_maker.python.text.core import bert_model_options
|
||||
|
||||
# BERT text classifier model options inherited from BertModelOptions.
|
||||
BertClassifierModelOptions = bert_model_options.BertModelOptions
|
||||
BertModelOptions = bert_model_options.BertModelOptions
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AverageWordEmbeddingClassifierModelOptions:
|
||||
class AverageWordEmbeddingModelOptions:
|
||||
"""Configurable model options for an Average Word Embedding classifier.
|
||||
|
||||
Attributes:
|
||||
|
@ -41,5 +41,5 @@ class AverageWordEmbeddingClassifierModelOptions:
|
|||
dropout_rate: float = 0.2
|
||||
|
||||
|
||||
TextClassifierModelOptions = Union[AverageWordEmbeddingClassifierModelOptions,
|
||||
BertClassifierModelOptions]
|
||||
TextClassifierModelOptions = Union[AverageWordEmbeddingModelOptions,
|
||||
BertModelOptions]
|
||||
|
|
|
@ -38,8 +38,8 @@ class AverageWordEmbeddingClassifierSpec:
|
|||
# `learning_rate` is unused for the average word embedding model
|
||||
hparams: hp.BaseHParams = hp.BaseHParams(
|
||||
epochs=10, batch_size=32, learning_rate=0)
|
||||
model_options: mo.AverageWordEmbeddingClassifierModelOptions = (
|
||||
mo.AverageWordEmbeddingClassifierModelOptions())
|
||||
model_options: mo.AverageWordEmbeddingModelOptions = (
|
||||
mo.AverageWordEmbeddingModelOptions())
|
||||
name: str = 'AverageWordEmbedding'
|
||||
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ class ModelSpecTest(tf.test.TestCase):
|
|||
})
|
||||
self.assertEqual(
|
||||
model_spec_obj.model_options,
|
||||
classifier_model_options.BertClassifierModelOptions(
|
||||
classifier_model_options.BertModelOptions(
|
||||
seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
|
||||
self.assertEqual(
|
||||
model_spec_obj.hparams,
|
||||
|
@ -57,7 +57,7 @@ class ModelSpecTest(tf.test.TestCase):
|
|||
self.assertEqual(model_spec_obj.name, 'AverageWordEmbedding')
|
||||
self.assertEqual(
|
||||
model_spec_obj.model_options,
|
||||
classifier_model_options.AverageWordEmbeddingClassifierModelOptions(
|
||||
classifier_model_options.AverageWordEmbeddingModelOptions(
|
||||
seq_len=256,
|
||||
wordvec_dim=16,
|
||||
do_lower_case=True,
|
||||
|
@ -77,7 +77,7 @@ class ModelSpecTest(tf.test.TestCase):
|
|||
|
||||
def test_custom_bert_spec(self):
|
||||
custom_bert_classifier_options = (
|
||||
classifier_model_options.BertClassifierModelOptions(
|
||||
classifier_model_options.BertModelOptions(
|
||||
seq_len=512, do_fine_tuning=False, dropout_rate=0.3))
|
||||
model_spec_obj = (
|
||||
ms.SupportedModels.MOBILEBERT_CLASSIFIER.value(
|
||||
|
@ -97,7 +97,7 @@ class ModelSpecTest(tf.test.TestCase):
|
|||
num_gpus=3,
|
||||
tpu='tpu/address')
|
||||
custom_average_word_embedding_model_options = (
|
||||
classifier_model_options.AverageWordEmbeddingClassifierModelOptions(
|
||||
classifier_model_options.AverageWordEmbeddingModelOptions(
|
||||
seq_len=512,
|
||||
wordvec_dim=32,
|
||||
do_lower_case=False,
|
||||
|
|
|
@ -50,13 +50,12 @@ def _validate(options: text_classifier_options.TextClassifierOptions):
|
|||
if options.model_options is None:
|
||||
return
|
||||
|
||||
if (isinstance(options.model_options,
|
||||
mo.AverageWordEmbeddingClassifierModelOptions) and
|
||||
if (isinstance(options.model_options, mo.AverageWordEmbeddingModelOptions) and
|
||||
(options.supported_model !=
|
||||
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
|
||||
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
|
||||
f" got {options.supported_model}")
|
||||
if (isinstance(options.model_options, mo.BertClassifierModelOptions) and
|
||||
if (isinstance(options.model_options, mo.BertModelOptions) and
|
||||
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
|
||||
raise ValueError(
|
||||
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
|
||||
|
@ -194,7 +193,7 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
|
|||
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
||||
|
||||
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
|
||||
model_options: mo.AverageWordEmbeddingClassifierModelOptions,
|
||||
model_options: mo.AverageWordEmbeddingModelOptions,
|
||||
hparams: hp.BaseHParams, label_names: Sequence[str]):
|
||||
super().__init__(model_spec, hparams, label_names)
|
||||
self._model_options = model_options
|
||||
|
@ -304,8 +303,8 @@ class _BertClassifier(TextClassifier):
|
|||
_INITIALIZER_RANGE = 0.02
|
||||
|
||||
def __init__(self, model_spec: ms.BertClassifierSpec,
|
||||
model_options: mo.BertClassifierModelOptions,
|
||||
hparams: hp.BaseHParams, label_names: Sequence[str]):
|
||||
model_options: mo.BertModelOptions, hparams: hp.BaseHParams,
|
||||
label_names: Sequence[str]):
|
||||
super().__init__(model_spec, hparams, label_names)
|
||||
self._model_options = model_options
|
||||
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
|
|
|
@ -79,7 +79,7 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
train_data, validation_data = self._get_data()
|
||||
options = text_classifier.TextClassifierOptions(
|
||||
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||
model_options=text_classifier.BertClassifierModelOptions(
|
||||
model_options=text_classifier.BertModelOptions(
|
||||
do_fine_tuning=False, seq_len=2),
|
||||
hparams=text_classifier.HParams(
|
||||
epochs=1,
|
||||
|
@ -116,8 +116,7 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
text_classifier.TextClassifierOptions(
|
||||
supported_model=(
|
||||
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER),
|
||||
model_options=(
|
||||
text_classifier.AverageWordEmbeddingClassifierModelOptions())))
|
||||
model_options=text_classifier.AverageWordEmbeddingModelOptions()))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
|
||||
' SupportedModels.MOBILEBERT_CLASSIFIER'):
|
||||
|
@ -128,7 +127,7 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
text_classifier.TextClassifierOptions(
|
||||
supported_model=(text_classifier.SupportedModels
|
||||
.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
|
||||
model_options=text_classifier.BertClassifierModelOptions()))
|
||||
model_options=text_classifier.BertModelOptions()))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got'
|
||||
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):
|
||||
|
|
Loading…
Reference in New Issue
Block a user