Adds a public import API for TextClassifier
.
PiperOrigin-RevId: 487949023
This commit is contained in:
parent
8ec83d2aa0
commit
c2ac040a6c
|
@ -21,6 +21,19 @@ package(
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "text_classifier_import",
|
||||||
|
srcs = ["__init__.py"],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
":model_options",
|
||||||
|
":model_spec",
|
||||||
|
":text_classifier",
|
||||||
|
":text_classifier_options",
|
||||||
|
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "model_options",
|
name = "model_options",
|
||||||
srcs = ["model_options.py"],
|
srcs = ["model_options.py"],
|
||||||
|
@ -114,12 +127,7 @@ py_test(
|
||||||
],
|
],
|
||||||
tags = ["requires-net:external"],
|
tags = ["requires-net:external"],
|
||||||
deps = [
|
deps = [
|
||||||
":dataset",
|
":text_classifier_import",
|
||||||
":model_options",
|
|
||||||
":model_spec",
|
|
||||||
":text_classifier",
|
|
||||||
":text_classifier_options",
|
|
||||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -128,11 +136,7 @@ py_library(
|
||||||
name = "text_classifier_demo_lib",
|
name = "text_classifier_demo_lib",
|
||||||
srcs = ["text_classifier_demo.py"],
|
srcs = ["text_classifier_demo.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":dataset",
|
":text_classifier_import",
|
||||||
":model_spec",
|
|
||||||
":text_classifier",
|
|
||||||
":text_classifier_options",
|
|
||||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
|
||||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,3 +11,21 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""MediaPipe Public Python API for Text Classifier."""
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import dataset
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_options
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import text_classifier
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||||
|
|
||||||
|
HParams = hyperparameters.BaseHParams
|
||||||
|
CSVParams = dataset.CSVParameters
|
||||||
|
Dataset = dataset.Dataset
|
||||||
|
AverageWordEmbeddingClassifierModelOptions = (
|
||||||
|
model_options.AverageWordEmbeddingClassifierModelOptions)
|
||||||
|
BertClassifierModelOptions = model_options.BertClassifierModelOptions
|
||||||
|
SupportedModels = model_spec.SupportedModels
|
||||||
|
TextClassifier = text_classifier.TextClassifier
|
||||||
|
TextClassifierOptions = text_classifier_options.TextClassifierOptions
|
||||||
|
|
|
@ -18,12 +18,12 @@ from typing import Union
|
||||||
|
|
||||||
from mediapipe.model_maker.python.text.core import bert_model_options
|
from mediapipe.model_maker.python.text.core import bert_model_options
|
||||||
|
|
||||||
# BERT text classifier options inherited from BertModelOptions.
|
# BERT text classifier model options inherited from BertModelOptions.
|
||||||
BertClassifierOptions = bert_model_options.BertModelOptions
|
BertClassifierModelOptions = bert_model_options.BertModelOptions
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class AverageWordEmbeddingClassifierOptions:
|
class AverageWordEmbeddingClassifierModelOptions:
|
||||||
"""Configurable model options for an Average Word Embedding classifier.
|
"""Configurable model options for an Average Word Embedding classifier.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
@ -41,5 +41,5 @@ class AverageWordEmbeddingClassifierOptions:
|
||||||
dropout_rate: float = 0.2
|
dropout_rate: float = 0.2
|
||||||
|
|
||||||
|
|
||||||
TextClassifierModelOptions = Union[AverageWordEmbeddingClassifierOptions,
|
TextClassifierModelOptions = Union[AverageWordEmbeddingClassifierModelOptions,
|
||||||
BertClassifierOptions]
|
BertClassifierModelOptions]
|
||||||
|
|
|
@ -38,8 +38,8 @@ class AverageWordEmbeddingClassifierSpec:
|
||||||
# `learning_rate` is unused for the average word embedding model
|
# `learning_rate` is unused for the average word embedding model
|
||||||
hparams: hp.BaseHParams = hp.BaseHParams(
|
hparams: hp.BaseHParams = hp.BaseHParams(
|
||||||
epochs=10, batch_size=32, learning_rate=0)
|
epochs=10, batch_size=32, learning_rate=0)
|
||||||
model_options: mo.AverageWordEmbeddingClassifierOptions = (
|
model_options: mo.AverageWordEmbeddingClassifierModelOptions = (
|
||||||
mo.AverageWordEmbeddingClassifierOptions())
|
mo.AverageWordEmbeddingClassifierModelOptions())
|
||||||
name: str = 'AverageWordEmbedding'
|
name: str = 'AverageWordEmbedding'
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ class ModelSpecTest(tf.test.TestCase):
|
||||||
})
|
})
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
model_spec_obj.model_options,
|
model_spec_obj.model_options,
|
||||||
classifier_model_options.BertClassifierOptions(
|
classifier_model_options.BertClassifierModelOptions(
|
||||||
seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
|
seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
model_spec_obj.hparams,
|
model_spec_obj.hparams,
|
||||||
|
@ -57,7 +57,7 @@ class ModelSpecTest(tf.test.TestCase):
|
||||||
self.assertEqual(model_spec_obj.name, 'AverageWordEmbedding')
|
self.assertEqual(model_spec_obj.name, 'AverageWordEmbedding')
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
model_spec_obj.model_options,
|
model_spec_obj.model_options,
|
||||||
classifier_model_options.AverageWordEmbeddingClassifierOptions(
|
classifier_model_options.AverageWordEmbeddingClassifierModelOptions(
|
||||||
seq_len=256,
|
seq_len=256,
|
||||||
wordvec_dim=16,
|
wordvec_dim=16,
|
||||||
do_lower_case=True,
|
do_lower_case=True,
|
||||||
|
@ -77,7 +77,7 @@ class ModelSpecTest(tf.test.TestCase):
|
||||||
|
|
||||||
def test_custom_bert_spec(self):
|
def test_custom_bert_spec(self):
|
||||||
custom_bert_classifier_options = (
|
custom_bert_classifier_options = (
|
||||||
classifier_model_options.BertClassifierOptions(
|
classifier_model_options.BertClassifierModelOptions(
|
||||||
seq_len=512, do_fine_tuning=False, dropout_rate=0.3))
|
seq_len=512, do_fine_tuning=False, dropout_rate=0.3))
|
||||||
model_spec_obj = (
|
model_spec_obj = (
|
||||||
ms.SupportedModels.MOBILEBERT_CLASSIFIER.value(
|
ms.SupportedModels.MOBILEBERT_CLASSIFIER.value(
|
||||||
|
@ -97,7 +97,7 @@ class ModelSpecTest(tf.test.TestCase):
|
||||||
num_gpus=3,
|
num_gpus=3,
|
||||||
tpu='tpu/address')
|
tpu='tpu/address')
|
||||||
custom_average_word_embedding_model_options = (
|
custom_average_word_embedding_model_options = (
|
||||||
classifier_model_options.AverageWordEmbeddingClassifierOptions(
|
classifier_model_options.AverageWordEmbeddingClassifierModelOptions(
|
||||||
seq_len=512,
|
seq_len=512,
|
||||||
wordvec_dim=32,
|
wordvec_dim=32,
|
||||||
do_lower_case=False,
|
do_lower_case=False,
|
||||||
|
|
|
@ -51,12 +51,12 @@ def _validate(options: text_classifier_options.TextClassifierOptions):
|
||||||
return
|
return
|
||||||
|
|
||||||
if (isinstance(options.model_options,
|
if (isinstance(options.model_options,
|
||||||
mo.AverageWordEmbeddingClassifierOptions) and
|
mo.AverageWordEmbeddingClassifierModelOptions) and
|
||||||
(options.supported_model !=
|
(options.supported_model !=
|
||||||
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
|
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
|
||||||
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
|
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
|
||||||
f" got {options.supported_model}")
|
f" got {options.supported_model}")
|
||||||
if (isinstance(options.model_options, mo.BertClassifierOptions) and
|
if (isinstance(options.model_options, mo.BertClassifierModelOptions) and
|
||||||
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
|
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
|
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
|
||||||
|
@ -194,7 +194,7 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
|
||||||
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
||||||
|
|
||||||
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
|
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
|
||||||
model_options: mo.AverageWordEmbeddingClassifierOptions,
|
model_options: mo.AverageWordEmbeddingClassifierModelOptions,
|
||||||
hparams: hp.BaseHParams, label_names: Sequence[str]):
|
hparams: hp.BaseHParams, label_names: Sequence[str]):
|
||||||
super().__init__(model_spec, hparams, label_names)
|
super().__init__(model_spec, hparams, label_names)
|
||||||
self._model_options = model_options
|
self._model_options = model_options
|
||||||
|
@ -304,8 +304,8 @@ class _BertClassifier(TextClassifier):
|
||||||
_INITIALIZER_RANGE = 0.02
|
_INITIALIZER_RANGE = 0.02
|
||||||
|
|
||||||
def __init__(self, model_spec: ms.BertClassifierSpec,
|
def __init__(self, model_spec: ms.BertClassifierSpec,
|
||||||
model_options: mo.BertClassifierOptions, hparams: hp.BaseHParams,
|
model_options: mo.BertClassifierModelOptions,
|
||||||
label_names: Sequence[str]):
|
hparams: hp.BaseHParams, label_names: Sequence[str]):
|
||||||
super().__init__(model_spec, hparams, label_names)
|
super().__init__(model_spec, hparams, label_names)
|
||||||
self._model_options = model_options
|
self._model_options = model_options
|
||||||
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||||
|
|
|
@ -23,12 +23,8 @@ from absl import flags
|
||||||
from absl import logging
|
from absl import logging
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
|
||||||
from mediapipe.model_maker.python.core.utils import quantization
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
|
from mediapipe.model_maker.python.text import text_classifier
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
|
||||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier
|
|
||||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
@ -53,31 +49,34 @@ def download_demo_data():
|
||||||
|
|
||||||
def run(data_dir,
|
def run(data_dir,
|
||||||
export_dir=tempfile.mkdtemp(),
|
export_dir=tempfile.mkdtemp(),
|
||||||
supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
supported_model=(
|
||||||
|
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
|
||||||
"""Runs demo."""
|
"""Runs demo."""
|
||||||
|
|
||||||
# Gets training data and validation data.
|
# Gets training data and validation data.
|
||||||
csv_params = text_ds.CSVParameters(
|
csv_params = text_classifier.CSVParams(
|
||||||
text_column='sentence', label_column='label', delimiter='\t')
|
text_column='sentence', label_column='label', delimiter='\t')
|
||||||
train_data = text_ds.Dataset.from_csv(
|
train_data = text_classifier.Dataset.from_csv(
|
||||||
filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
|
filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
|
||||||
csv_params=csv_params)
|
csv_params=csv_params)
|
||||||
validation_data = text_ds.Dataset.from_csv(
|
validation_data = text_classifier.Dataset.from_csv(
|
||||||
filename=os.path.join(os.path.join(data_dir, 'dev.tsv')),
|
filename=os.path.join(os.path.join(data_dir, 'dev.tsv')),
|
||||||
csv_params=csv_params)
|
csv_params=csv_params)
|
||||||
|
|
||||||
quantization_config = None
|
quantization_config = None
|
||||||
if supported_model == ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER:
|
if (supported_model ==
|
||||||
hparams = hp.BaseHParams(
|
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
||||||
|
hparams = text_classifier.HParams(
|
||||||
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir)
|
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir)
|
||||||
# Warning: This takes extremely long to run on CPU
|
# Warning: This takes extremely long to run on CPU
|
||||||
elif supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
|
elif (
|
||||||
|
supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER):
|
||||||
quantization_config = quantization.QuantizationConfig.for_dynamic()
|
quantization_config = quantization.QuantizationConfig.for_dynamic()
|
||||||
hparams = hp.BaseHParams(
|
hparams = text_classifier.HParams(
|
||||||
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir)
|
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir)
|
||||||
|
|
||||||
# Fine-tunes the model.
|
# Fine-tunes the model.
|
||||||
options = text_classifier_options.TextClassifierOptions(
|
options = text_classifier.TextClassifierOptions(
|
||||||
supported_model=supported_model, hparams=hparams)
|
supported_model=supported_model, hparams=hparams)
|
||||||
model = text_classifier.TextClassifier.create(train_data, validation_data,
|
model = text_classifier.TextClassifier.create(train_data, validation_data,
|
||||||
options)
|
options)
|
||||||
|
@ -96,9 +95,10 @@ def main(_):
|
||||||
export_dir = os.path.expanduser(FLAGS.export_dir)
|
export_dir = os.path.expanduser(FLAGS.export_dir)
|
||||||
|
|
||||||
if FLAGS.supported_model == 'average_word_embedding':
|
if FLAGS.supported_model == 'average_word_embedding':
|
||||||
supported_model = ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
|
supported_model = (
|
||||||
|
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)
|
||||||
elif FLAGS.supported_model == 'bert':
|
elif FLAGS.supported_model == 'bert':
|
||||||
supported_model = ms.SupportedModels.MOBILEBERT_CLASSIFIER
|
supported_model = text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER
|
||||||
|
|
||||||
run(data_dir, export_dir, supported_model)
|
run(data_dir, export_dir, supported_model)
|
||||||
|
|
||||||
|
|
|
@ -18,12 +18,7 @@ import os
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
from mediapipe.model_maker.python.text import text_classifier
|
||||||
from mediapipe.model_maker.python.text.text_classifier import dataset
|
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
|
||||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier
|
|
||||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,18 +38,23 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
writer.writeheader()
|
writer.writeheader()
|
||||||
for label, text in labels_and_text:
|
for label, text in labels_and_text:
|
||||||
writer.writerow({'text': text, 'label': label})
|
writer.writerow({'text': text, 'label': label})
|
||||||
csv_params = dataset.CSVParameters(text_column='text', label_column='label')
|
csv_params = text_classifier.CSVParams(
|
||||||
all_data = dataset.Dataset.from_csv(
|
text_column='text', label_column='label')
|
||||||
|
all_data = text_classifier.Dataset.from_csv(
|
||||||
filename=csv_file, csv_params=csv_params)
|
filename=csv_file, csv_params=csv_params)
|
||||||
return all_data.split(0.5)
|
return all_data.split(0.5)
|
||||||
|
|
||||||
def test_create_and_train_average_word_embedding_model(self):
|
def test_create_and_train_average_word_embedding_model(self):
|
||||||
train_data, validation_data = self._get_data()
|
train_data, validation_data = self._get_data()
|
||||||
options = text_classifier_options.TextClassifierOptions(
|
options = (
|
||||||
supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER,
|
text_classifier.TextClassifierOptions(
|
||||||
hparams=hp.BaseHParams(epochs=1, batch_size=1, learning_rate=0))
|
supported_model=(text_classifier.SupportedModels
|
||||||
average_word_embedding_classifier = text_classifier.TextClassifier.create(
|
.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
|
||||||
train_data, validation_data, options)
|
hparams=text_classifier.HParams(
|
||||||
|
epochs=1, batch_size=1, learning_rate=0)))
|
||||||
|
average_word_embedding_classifier = (
|
||||||
|
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||||
|
options))
|
||||||
|
|
||||||
_, accuracy = average_word_embedding_classifier.evaluate(validation_data)
|
_, accuracy = average_word_embedding_classifier.evaluate(validation_data)
|
||||||
self.assertGreaterEqual(accuracy, 0.0)
|
self.assertGreaterEqual(accuracy, 0.0)
|
||||||
|
@ -77,10 +77,11 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
|
|
||||||
def test_create_and_train_bert(self):
|
def test_create_and_train_bert(self):
|
||||||
train_data, validation_data = self._get_data()
|
train_data, validation_data = self._get_data()
|
||||||
options = text_classifier_options.TextClassifierOptions(
|
options = text_classifier.TextClassifierOptions(
|
||||||
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER,
|
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||||
model_options=mo.BertClassifierOptions(do_fine_tuning=False, seq_len=2),
|
model_options=text_classifier.BertClassifierModelOptions(
|
||||||
hparams=hp.BaseHParams(
|
do_fine_tuning=False, seq_len=2),
|
||||||
|
hparams=text_classifier.HParams(
|
||||||
epochs=1,
|
epochs=1,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
learning_rate=3e-5,
|
learning_rate=3e-5,
|
||||||
|
@ -94,12 +95,13 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
|
|
||||||
def test_label_mismatch(self):
|
def test_label_mismatch(self):
|
||||||
options = (
|
options = (
|
||||||
text_classifier_options.TextClassifierOptions(
|
text_classifier.TextClassifierOptions(
|
||||||
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER))
|
supported_model=(
|
||||||
|
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER)))
|
||||||
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||||
train_data = dataset.Dataset(train_tf_dataset, 1, ['foo'])
|
train_data = text_classifier.Dataset(train_tf_dataset, 1, ['foo'])
|
||||||
validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||||
validation_data = dataset.Dataset(validation_tf_dataset, 1, ['bar'])
|
validation_data = text_classifier.Dataset(validation_tf_dataset, 1, ['bar'])
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError,
|
ValueError,
|
||||||
'Training data label names .* not equal to validation data label names'
|
'Training data label names .* not equal to validation data label names'
|
||||||
|
@ -111,9 +113,11 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
train_data, validation_data = self._get_data()
|
train_data, validation_data = self._get_data()
|
||||||
|
|
||||||
avg_options = (
|
avg_options = (
|
||||||
text_classifier_options.TextClassifierOptions(
|
text_classifier.TextClassifierOptions(
|
||||||
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER,
|
supported_model=(
|
||||||
model_options=mo.AverageWordEmbeddingClassifierOptions()))
|
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER),
|
||||||
|
model_options=(
|
||||||
|
text_classifier.AverageWordEmbeddingClassifierModelOptions())))
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
|
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
|
||||||
' SupportedModels.MOBILEBERT_CLASSIFIER'):
|
' SupportedModels.MOBILEBERT_CLASSIFIER'):
|
||||||
|
@ -121,10 +125,10 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
avg_options)
|
avg_options)
|
||||||
|
|
||||||
bert_options = (
|
bert_options = (
|
||||||
text_classifier_options.TextClassifierOptions(
|
text_classifier.TextClassifierOptions(
|
||||||
supported_model=(
|
supported_model=(text_classifier.SupportedModels
|
||||||
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
|
.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
|
||||||
model_options=mo.BertClassifierOptions()))
|
model_options=text_classifier.BertClassifierModelOptions()))
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got'
|
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got'
|
||||||
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):
|
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user