Renames model options in TextClassifier.

PiperOrigin-RevId: 488063797
This commit is contained in:
MediaPipe Team 2022-11-12 10:00:31 -08:00 committed by Copybara-Service
parent bf6c8a0b63
commit aafbf73c0a
6 changed files with 21 additions and 23 deletions

View File

@ -23,9 +23,9 @@ from mediapipe.model_maker.python.text.text_classifier import text_classifier_op
HParams = hyperparameters.BaseHParams HParams = hyperparameters.BaseHParams
CSVParams = dataset.CSVParameters CSVParams = dataset.CSVParameters
Dataset = dataset.Dataset Dataset = dataset.Dataset
AverageWordEmbeddingClassifierModelOptions = ( AverageWordEmbeddingModelOptions = (
model_options.AverageWordEmbeddingClassifierModelOptions) model_options.AverageWordEmbeddingModelOptions)
BertClassifierModelOptions = model_options.BertClassifierModelOptions BertModelOptions = model_options.BertModelOptions
SupportedModels = model_spec.SupportedModels SupportedModels = model_spec.SupportedModels
TextClassifier = text_classifier.TextClassifier TextClassifier = text_classifier.TextClassifier
TextClassifierOptions = text_classifier_options.TextClassifierOptions TextClassifierOptions = text_classifier_options.TextClassifierOptions

View File

@ -19,11 +19,11 @@ from typing import Union
from mediapipe.model_maker.python.text.core import bert_model_options from mediapipe.model_maker.python.text.core import bert_model_options
# BERT text classifier model options inherited from BertModelOptions. # BERT text classifier model options inherited from BertModelOptions.
BertClassifierModelOptions = bert_model_options.BertModelOptions BertModelOptions = bert_model_options.BertModelOptions
@dataclasses.dataclass @dataclasses.dataclass
class AverageWordEmbeddingClassifierModelOptions: class AverageWordEmbeddingModelOptions:
"""Configurable model options for an Average Word Embedding classifier. """Configurable model options for an Average Word Embedding classifier.
Attributes: Attributes:
@ -41,5 +41,5 @@ class AverageWordEmbeddingClassifierModelOptions:
dropout_rate: float = 0.2 dropout_rate: float = 0.2
TextClassifierModelOptions = Union[AverageWordEmbeddingClassifierModelOptions, TextClassifierModelOptions = Union[AverageWordEmbeddingModelOptions,
BertClassifierModelOptions] BertModelOptions]

View File

@ -38,8 +38,8 @@ class AverageWordEmbeddingClassifierSpec:
# `learning_rate` is unused for the average word embedding model # `learning_rate` is unused for the average word embedding model
hparams: hp.BaseHParams = hp.BaseHParams( hparams: hp.BaseHParams = hp.BaseHParams(
epochs=10, batch_size=32, learning_rate=0) epochs=10, batch_size=32, learning_rate=0)
model_options: mo.AverageWordEmbeddingClassifierModelOptions = ( model_options: mo.AverageWordEmbeddingModelOptions = (
mo.AverageWordEmbeddingClassifierModelOptions()) mo.AverageWordEmbeddingModelOptions())
name: str = 'AverageWordEmbedding' name: str = 'AverageWordEmbedding'

View File

@ -40,7 +40,7 @@ class ModelSpecTest(tf.test.TestCase):
}) })
self.assertEqual( self.assertEqual(
model_spec_obj.model_options, model_spec_obj.model_options,
classifier_model_options.BertClassifierModelOptions( classifier_model_options.BertModelOptions(
seq_len=128, do_fine_tuning=True, dropout_rate=0.1)) seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
self.assertEqual( self.assertEqual(
model_spec_obj.hparams, model_spec_obj.hparams,
@ -57,7 +57,7 @@ class ModelSpecTest(tf.test.TestCase):
self.assertEqual(model_spec_obj.name, 'AverageWordEmbedding') self.assertEqual(model_spec_obj.name, 'AverageWordEmbedding')
self.assertEqual( self.assertEqual(
model_spec_obj.model_options, model_spec_obj.model_options,
classifier_model_options.AverageWordEmbeddingClassifierModelOptions( classifier_model_options.AverageWordEmbeddingModelOptions(
seq_len=256, seq_len=256,
wordvec_dim=16, wordvec_dim=16,
do_lower_case=True, do_lower_case=True,
@ -77,7 +77,7 @@ class ModelSpecTest(tf.test.TestCase):
def test_custom_bert_spec(self): def test_custom_bert_spec(self):
custom_bert_classifier_options = ( custom_bert_classifier_options = (
classifier_model_options.BertClassifierModelOptions( classifier_model_options.BertModelOptions(
seq_len=512, do_fine_tuning=False, dropout_rate=0.3)) seq_len=512, do_fine_tuning=False, dropout_rate=0.3))
model_spec_obj = ( model_spec_obj = (
ms.SupportedModels.MOBILEBERT_CLASSIFIER.value( ms.SupportedModels.MOBILEBERT_CLASSIFIER.value(
@ -97,7 +97,7 @@ class ModelSpecTest(tf.test.TestCase):
num_gpus=3, num_gpus=3,
tpu='tpu/address') tpu='tpu/address')
custom_average_word_embedding_model_options = ( custom_average_word_embedding_model_options = (
classifier_model_options.AverageWordEmbeddingClassifierModelOptions( classifier_model_options.AverageWordEmbeddingModelOptions(
seq_len=512, seq_len=512,
wordvec_dim=32, wordvec_dim=32,
do_lower_case=False, do_lower_case=False,

View File

@ -50,13 +50,12 @@ def _validate(options: text_classifier_options.TextClassifierOptions):
if options.model_options is None: if options.model_options is None:
return return
if (isinstance(options.model_options, if (isinstance(options.model_options, mo.AverageWordEmbeddingModelOptions) and
mo.AverageWordEmbeddingClassifierModelOptions) and
(options.supported_model != (options.supported_model !=
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)): ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER," raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
f" got {options.supported_model}") f" got {options.supported_model}")
if (isinstance(options.model_options, mo.BertClassifierModelOptions) and if (isinstance(options.model_options, mo.BertModelOptions) and
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)): (options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
raise ValueError( raise ValueError(
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}") f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
@ -194,7 +193,7 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
_DELIM_REGEX_PATTERN = r"[^\w\']+" _DELIM_REGEX_PATTERN = r"[^\w\']+"
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec, def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
model_options: mo.AverageWordEmbeddingClassifierModelOptions, model_options: mo.AverageWordEmbeddingModelOptions,
hparams: hp.BaseHParams, label_names: Sequence[str]): hparams: hp.BaseHParams, label_names: Sequence[str]):
super().__init__(model_spec, hparams, label_names) super().__init__(model_spec, hparams, label_names)
self._model_options = model_options self._model_options = model_options
@ -304,8 +303,8 @@ class _BertClassifier(TextClassifier):
_INITIALIZER_RANGE = 0.02 _INITIALIZER_RANGE = 0.02
def __init__(self, model_spec: ms.BertClassifierSpec, def __init__(self, model_spec: ms.BertClassifierSpec,
model_options: mo.BertClassifierModelOptions, model_options: mo.BertModelOptions, hparams: hp.BaseHParams,
hparams: hp.BaseHParams, label_names: Sequence[str]): label_names: Sequence[str]):
super().__init__(model_spec, hparams, label_names) super().__init__(model_spec, hparams, label_names)
self._model_options = model_options self._model_options = model_options
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy() self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()

View File

@ -79,7 +79,7 @@ class TextClassifierTest(tf.test.TestCase):
train_data, validation_data = self._get_data() train_data, validation_data = self._get_data()
options = text_classifier.TextClassifierOptions( options = text_classifier.TextClassifierOptions(
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER, supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
model_options=text_classifier.BertClassifierModelOptions( model_options=text_classifier.BertModelOptions(
do_fine_tuning=False, seq_len=2), do_fine_tuning=False, seq_len=2),
hparams=text_classifier.HParams( hparams=text_classifier.HParams(
epochs=1, epochs=1,
@ -116,8 +116,7 @@ class TextClassifierTest(tf.test.TestCase):
text_classifier.TextClassifierOptions( text_classifier.TextClassifierOptions(
supported_model=( supported_model=(
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER), text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER),
model_options=( model_options=text_classifier.AverageWordEmbeddingModelOptions()))
text_classifier.AverageWordEmbeddingClassifierModelOptions())))
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got' ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
' SupportedModels.MOBILEBERT_CLASSIFIER'): ' SupportedModels.MOBILEBERT_CLASSIFIER'):
@ -128,7 +127,7 @@ class TextClassifierTest(tf.test.TestCase):
text_classifier.TextClassifierOptions( text_classifier.TextClassifierOptions(
supported_model=(text_classifier.SupportedModels supported_model=(text_classifier.SupportedModels
.AVERAGE_WORD_EMBEDDING_CLASSIFIER), .AVERAGE_WORD_EMBEDDING_CLASSIFIER),
model_options=text_classifier.BertClassifierModelOptions())) model_options=text_classifier.BertModelOptions()))
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got' ValueError, 'Expected MOBILEBERT_CLASSIFIER, got'
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'): ' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):