Renames model options in TextClassifier.

PiperOrigin-RevId: 488063797
This commit is contained in:
MediaPipe Team 2022-11-12 10:00:31 -08:00 committed by Copybara-Service
parent bf6c8a0b63
commit aafbf73c0a
6 changed files with 21 additions and 23 deletions

View File

@ -23,9 +23,9 @@ from mediapipe.model_maker.python.text.text_classifier import text_classifier_op
HParams = hyperparameters.BaseHParams
CSVParams = dataset.CSVParameters
Dataset = dataset.Dataset
AverageWordEmbeddingClassifierModelOptions = (
model_options.AverageWordEmbeddingClassifierModelOptions)
BertClassifierModelOptions = model_options.BertClassifierModelOptions
AverageWordEmbeddingModelOptions = (
model_options.AverageWordEmbeddingModelOptions)
BertModelOptions = model_options.BertModelOptions
SupportedModels = model_spec.SupportedModels
TextClassifier = text_classifier.TextClassifier
TextClassifierOptions = text_classifier_options.TextClassifierOptions

View File

@ -19,11 +19,11 @@ from typing import Union
from mediapipe.model_maker.python.text.core import bert_model_options
# BERT text classifier model options inherited from BertModelOptions.
BertClassifierModelOptions = bert_model_options.BertModelOptions
BertModelOptions = bert_model_options.BertModelOptions
@dataclasses.dataclass
class AverageWordEmbeddingClassifierModelOptions:
class AverageWordEmbeddingModelOptions:
"""Configurable model options for an Average Word Embedding classifier.
Attributes:
@ -41,5 +41,5 @@ class AverageWordEmbeddingClassifierModelOptions:
dropout_rate: float = 0.2
TextClassifierModelOptions = Union[AverageWordEmbeddingClassifierModelOptions,
BertClassifierModelOptions]
TextClassifierModelOptions = Union[AverageWordEmbeddingModelOptions,
BertModelOptions]

View File

@ -38,8 +38,8 @@ class AverageWordEmbeddingClassifierSpec:
# `learning_rate` is unused for the average word embedding model
hparams: hp.BaseHParams = hp.BaseHParams(
epochs=10, batch_size=32, learning_rate=0)
model_options: mo.AverageWordEmbeddingClassifierModelOptions = (
mo.AverageWordEmbeddingClassifierModelOptions())
model_options: mo.AverageWordEmbeddingModelOptions = (
mo.AverageWordEmbeddingModelOptions())
name: str = 'AverageWordEmbedding'

View File

@ -40,7 +40,7 @@ class ModelSpecTest(tf.test.TestCase):
})
self.assertEqual(
model_spec_obj.model_options,
classifier_model_options.BertClassifierModelOptions(
classifier_model_options.BertModelOptions(
seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
self.assertEqual(
model_spec_obj.hparams,
@ -57,7 +57,7 @@ class ModelSpecTest(tf.test.TestCase):
self.assertEqual(model_spec_obj.name, 'AverageWordEmbedding')
self.assertEqual(
model_spec_obj.model_options,
classifier_model_options.AverageWordEmbeddingClassifierModelOptions(
classifier_model_options.AverageWordEmbeddingModelOptions(
seq_len=256,
wordvec_dim=16,
do_lower_case=True,
@ -77,7 +77,7 @@ class ModelSpecTest(tf.test.TestCase):
def test_custom_bert_spec(self):
custom_bert_classifier_options = (
classifier_model_options.BertClassifierModelOptions(
classifier_model_options.BertModelOptions(
seq_len=512, do_fine_tuning=False, dropout_rate=0.3))
model_spec_obj = (
ms.SupportedModels.MOBILEBERT_CLASSIFIER.value(
@ -97,7 +97,7 @@ class ModelSpecTest(tf.test.TestCase):
num_gpus=3,
tpu='tpu/address')
custom_average_word_embedding_model_options = (
classifier_model_options.AverageWordEmbeddingClassifierModelOptions(
classifier_model_options.AverageWordEmbeddingModelOptions(
seq_len=512,
wordvec_dim=32,
do_lower_case=False,

View File

@ -50,13 +50,12 @@ def _validate(options: text_classifier_options.TextClassifierOptions):
if options.model_options is None:
return
if (isinstance(options.model_options,
mo.AverageWordEmbeddingClassifierModelOptions) and
if (isinstance(options.model_options, mo.AverageWordEmbeddingModelOptions) and
(options.supported_model !=
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
f" got {options.supported_model}")
if (isinstance(options.model_options, mo.BertClassifierModelOptions) and
if (isinstance(options.model_options, mo.BertModelOptions) and
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
raise ValueError(
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
@ -194,7 +193,7 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
_DELIM_REGEX_PATTERN = r"[^\w\']+"
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
model_options: mo.AverageWordEmbeddingClassifierModelOptions,
model_options: mo.AverageWordEmbeddingModelOptions,
hparams: hp.BaseHParams, label_names: Sequence[str]):
super().__init__(model_spec, hparams, label_names)
self._model_options = model_options
@ -304,8 +303,8 @@ class _BertClassifier(TextClassifier):
_INITIALIZER_RANGE = 0.02
def __init__(self, model_spec: ms.BertClassifierSpec,
model_options: mo.BertClassifierModelOptions,
hparams: hp.BaseHParams, label_names: Sequence[str]):
model_options: mo.BertModelOptions, hparams: hp.BaseHParams,
label_names: Sequence[str]):
super().__init__(model_spec, hparams, label_names)
self._model_options = model_options
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()

View File

@ -79,7 +79,7 @@ class TextClassifierTest(tf.test.TestCase):
train_data, validation_data = self._get_data()
options = text_classifier.TextClassifierOptions(
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
model_options=text_classifier.BertClassifierModelOptions(
model_options=text_classifier.BertModelOptions(
do_fine_tuning=False, seq_len=2),
hparams=text_classifier.HParams(
epochs=1,
@ -116,8 +116,7 @@ class TextClassifierTest(tf.test.TestCase):
text_classifier.TextClassifierOptions(
supported_model=(
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER),
model_options=(
text_classifier.AverageWordEmbeddingClassifierModelOptions())))
model_options=text_classifier.AverageWordEmbeddingModelOptions()))
with self.assertRaisesRegex(
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
' SupportedModels.MOBILEBERT_CLASSIFIER'):
@ -128,7 +127,7 @@ class TextClassifierTest(tf.test.TestCase):
text_classifier.TextClassifierOptions(
supported_model=(text_classifier.SupportedModels
.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
model_options=text_classifier.BertClassifierModelOptions()))
model_options=text_classifier.BertModelOptions()))
with self.assertRaisesRegex(
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got'
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):