Adds a public import API for TextClassifier
.
PiperOrigin-RevId: 487949023
This commit is contained in:
parent
8ec83d2aa0
commit
c2ac040a6c
|
@ -21,6 +21,19 @@ package(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
py_library(
|
||||
name = "text_classifier_import",
|
||||
srcs = ["__init__.py"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
":text_classifier",
|
||||
":text_classifier_options",
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_options",
|
||||
srcs = ["model_options.py"],
|
||||
|
@ -114,12 +127,7 @@ py_test(
|
|||
],
|
||||
tags = ["requires-net:external"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":model_options",
|
||||
":model_spec",
|
||||
":text_classifier",
|
||||
":text_classifier_options",
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
":text_classifier_import",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
@ -128,11 +136,7 @@ py_library(
|
|||
name = "text_classifier_demo_lib",
|
||||
srcs = ["text_classifier_demo.py"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":model_spec",
|
||||
":text_classifier",
|
||||
":text_classifier_options",
|
||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
||||
":text_classifier_import",
|
||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -11,3 +11,21 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MediaPipe Public Python API for Text Classifier."""
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters
|
||||
from mediapipe.model_maker.python.text.text_classifier import dataset
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_options
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier
|
||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||
|
||||
HParams = hyperparameters.BaseHParams
|
||||
CSVParams = dataset.CSVParameters
|
||||
Dataset = dataset.Dataset
|
||||
AverageWordEmbeddingClassifierModelOptions = (
|
||||
model_options.AverageWordEmbeddingClassifierModelOptions)
|
||||
BertClassifierModelOptions = model_options.BertClassifierModelOptions
|
||||
SupportedModels = model_spec.SupportedModels
|
||||
TextClassifier = text_classifier.TextClassifier
|
||||
TextClassifierOptions = text_classifier_options.TextClassifierOptions
|
||||
|
|
|
@ -18,12 +18,12 @@ from typing import Union
|
|||
|
||||
from mediapipe.model_maker.python.text.core import bert_model_options
|
||||
|
||||
# BERT text classifier options inherited from BertModelOptions.
|
||||
BertClassifierOptions = bert_model_options.BertModelOptions
|
||||
# BERT text classifier model options inherited from BertModelOptions.
|
||||
BertClassifierModelOptions = bert_model_options.BertModelOptions
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AverageWordEmbeddingClassifierOptions:
|
||||
class AverageWordEmbeddingClassifierModelOptions:
|
||||
"""Configurable model options for an Average Word Embedding classifier.
|
||||
|
||||
Attributes:
|
||||
|
@ -41,5 +41,5 @@ class AverageWordEmbeddingClassifierOptions:
|
|||
dropout_rate: float = 0.2
|
||||
|
||||
|
||||
TextClassifierModelOptions = Union[AverageWordEmbeddingClassifierOptions,
|
||||
BertClassifierOptions]
|
||||
TextClassifierModelOptions = Union[AverageWordEmbeddingClassifierModelOptions,
|
||||
BertClassifierModelOptions]
|
||||
|
|
|
@ -38,8 +38,8 @@ class AverageWordEmbeddingClassifierSpec:
|
|||
# `learning_rate` is unused for the average word embedding model
|
||||
hparams: hp.BaseHParams = hp.BaseHParams(
|
||||
epochs=10, batch_size=32, learning_rate=0)
|
||||
model_options: mo.AverageWordEmbeddingClassifierOptions = (
|
||||
mo.AverageWordEmbeddingClassifierOptions())
|
||||
model_options: mo.AverageWordEmbeddingClassifierModelOptions = (
|
||||
mo.AverageWordEmbeddingClassifierModelOptions())
|
||||
name: str = 'AverageWordEmbedding'
|
||||
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ class ModelSpecTest(tf.test.TestCase):
|
|||
})
|
||||
self.assertEqual(
|
||||
model_spec_obj.model_options,
|
||||
classifier_model_options.BertClassifierOptions(
|
||||
classifier_model_options.BertClassifierModelOptions(
|
||||
seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
|
||||
self.assertEqual(
|
||||
model_spec_obj.hparams,
|
||||
|
@ -57,7 +57,7 @@ class ModelSpecTest(tf.test.TestCase):
|
|||
self.assertEqual(model_spec_obj.name, 'AverageWordEmbedding')
|
||||
self.assertEqual(
|
||||
model_spec_obj.model_options,
|
||||
classifier_model_options.AverageWordEmbeddingClassifierOptions(
|
||||
classifier_model_options.AverageWordEmbeddingClassifierModelOptions(
|
||||
seq_len=256,
|
||||
wordvec_dim=16,
|
||||
do_lower_case=True,
|
||||
|
@ -77,7 +77,7 @@ class ModelSpecTest(tf.test.TestCase):
|
|||
|
||||
def test_custom_bert_spec(self):
|
||||
custom_bert_classifier_options = (
|
||||
classifier_model_options.BertClassifierOptions(
|
||||
classifier_model_options.BertClassifierModelOptions(
|
||||
seq_len=512, do_fine_tuning=False, dropout_rate=0.3))
|
||||
model_spec_obj = (
|
||||
ms.SupportedModels.MOBILEBERT_CLASSIFIER.value(
|
||||
|
@ -97,7 +97,7 @@ class ModelSpecTest(tf.test.TestCase):
|
|||
num_gpus=3,
|
||||
tpu='tpu/address')
|
||||
custom_average_word_embedding_model_options = (
|
||||
classifier_model_options.AverageWordEmbeddingClassifierOptions(
|
||||
classifier_model_options.AverageWordEmbeddingClassifierModelOptions(
|
||||
seq_len=512,
|
||||
wordvec_dim=32,
|
||||
do_lower_case=False,
|
||||
|
|
|
@ -51,12 +51,12 @@ def _validate(options: text_classifier_options.TextClassifierOptions):
|
|||
return
|
||||
|
||||
if (isinstance(options.model_options,
|
||||
mo.AverageWordEmbeddingClassifierOptions) and
|
||||
mo.AverageWordEmbeddingClassifierModelOptions) and
|
||||
(options.supported_model !=
|
||||
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
|
||||
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
|
||||
f" got {options.supported_model}")
|
||||
if (isinstance(options.model_options, mo.BertClassifierOptions) and
|
||||
if (isinstance(options.model_options, mo.BertClassifierModelOptions) and
|
||||
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
|
||||
raise ValueError(
|
||||
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
|
||||
|
@ -194,7 +194,7 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
|
|||
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
||||
|
||||
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
|
||||
model_options: mo.AverageWordEmbeddingClassifierOptions,
|
||||
model_options: mo.AverageWordEmbeddingClassifierModelOptions,
|
||||
hparams: hp.BaseHParams, label_names: Sequence[str]):
|
||||
super().__init__(model_spec, hparams, label_names)
|
||||
self._model_options = model_options
|
||||
|
@ -304,8 +304,8 @@ class _BertClassifier(TextClassifier):
|
|||
_INITIALIZER_RANGE = 0.02
|
||||
|
||||
def __init__(self, model_spec: ms.BertClassifierSpec,
|
||||
model_options: mo.BertClassifierOptions, hparams: hp.BaseHParams,
|
||||
label_names: Sequence[str]):
|
||||
model_options: mo.BertClassifierModelOptions,
|
||||
hparams: hp.BaseHParams, label_names: Sequence[str]):
|
||||
super().__init__(model_spec, hparams, label_names)
|
||||
self._model_options = model_options
|
||||
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
|
|
|
@ -23,12 +23,8 @@ from absl import flags
|
|||
from absl import logging
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier
|
||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||
from mediapipe.model_maker.python.text import text_classifier
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
@ -53,31 +49,34 @@ def download_demo_data():
|
|||
|
||||
def run(data_dir,
|
||||
export_dir=tempfile.mkdtemp(),
|
||||
supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
||||
supported_model=(
|
||||
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
|
||||
"""Runs demo."""
|
||||
|
||||
# Gets training data and validation data.
|
||||
csv_params = text_ds.CSVParameters(
|
||||
csv_params = text_classifier.CSVParams(
|
||||
text_column='sentence', label_column='label', delimiter='\t')
|
||||
train_data = text_ds.Dataset.from_csv(
|
||||
train_data = text_classifier.Dataset.from_csv(
|
||||
filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
|
||||
csv_params=csv_params)
|
||||
validation_data = text_ds.Dataset.from_csv(
|
||||
validation_data = text_classifier.Dataset.from_csv(
|
||||
filename=os.path.join(os.path.join(data_dir, 'dev.tsv')),
|
||||
csv_params=csv_params)
|
||||
|
||||
quantization_config = None
|
||||
if supported_model == ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER:
|
||||
hparams = hp.BaseHParams(
|
||||
if (supported_model ==
|
||||
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
||||
hparams = text_classifier.HParams(
|
||||
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir)
|
||||
# Warning: This takes extremely long to run on CPU
|
||||
elif supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
|
||||
elif (
|
||||
supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER):
|
||||
quantization_config = quantization.QuantizationConfig.for_dynamic()
|
||||
hparams = hp.BaseHParams(
|
||||
hparams = text_classifier.HParams(
|
||||
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir)
|
||||
|
||||
# Fine-tunes the model.
|
||||
options = text_classifier_options.TextClassifierOptions(
|
||||
options = text_classifier.TextClassifierOptions(
|
||||
supported_model=supported_model, hparams=hparams)
|
||||
model = text_classifier.TextClassifier.create(train_data, validation_data,
|
||||
options)
|
||||
|
@ -96,9 +95,10 @@ def main(_):
|
|||
export_dir = os.path.expanduser(FLAGS.export_dir)
|
||||
|
||||
if FLAGS.supported_model == 'average_word_embedding':
|
||||
supported_model = ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
|
||||
supported_model = (
|
||||
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)
|
||||
elif FLAGS.supported_model == 'bert':
|
||||
supported_model = ms.SupportedModels.MOBILEBERT_CLASSIFIER
|
||||
supported_model = text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER
|
||||
|
||||
run(data_dir, export_dir, supported_model)
|
||||
|
||||
|
|
|
@ -18,12 +18,7 @@ import os
|
|||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.text.text_classifier import dataset
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier
|
||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||
from mediapipe.model_maker.python.text import text_classifier
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
|
||||
|
@ -43,18 +38,23 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
writer.writeheader()
|
||||
for label, text in labels_and_text:
|
||||
writer.writerow({'text': text, 'label': label})
|
||||
csv_params = dataset.CSVParameters(text_column='text', label_column='label')
|
||||
all_data = dataset.Dataset.from_csv(
|
||||
csv_params = text_classifier.CSVParams(
|
||||
text_column='text', label_column='label')
|
||||
all_data = text_classifier.Dataset.from_csv(
|
||||
filename=csv_file, csv_params=csv_params)
|
||||
return all_data.split(0.5)
|
||||
|
||||
def test_create_and_train_average_word_embedding_model(self):
|
||||
train_data, validation_data = self._get_data()
|
||||
options = text_classifier_options.TextClassifierOptions(
|
||||
supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER,
|
||||
hparams=hp.BaseHParams(epochs=1, batch_size=1, learning_rate=0))
|
||||
average_word_embedding_classifier = text_classifier.TextClassifier.create(
|
||||
train_data, validation_data, options)
|
||||
options = (
|
||||
text_classifier.TextClassifierOptions(
|
||||
supported_model=(text_classifier.SupportedModels
|
||||
.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
|
||||
hparams=text_classifier.HParams(
|
||||
epochs=1, batch_size=1, learning_rate=0)))
|
||||
average_word_embedding_classifier = (
|
||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||
options))
|
||||
|
||||
_, accuracy = average_word_embedding_classifier.evaluate(validation_data)
|
||||
self.assertGreaterEqual(accuracy, 0.0)
|
||||
|
@ -77,10 +77,11 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
|
||||
def test_create_and_train_bert(self):
|
||||
train_data, validation_data = self._get_data()
|
||||
options = text_classifier_options.TextClassifierOptions(
|
||||
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||
model_options=mo.BertClassifierOptions(do_fine_tuning=False, seq_len=2),
|
||||
hparams=hp.BaseHParams(
|
||||
options = text_classifier.TextClassifierOptions(
|
||||
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||
model_options=text_classifier.BertClassifierModelOptions(
|
||||
do_fine_tuning=False, seq_len=2),
|
||||
hparams=text_classifier.HParams(
|
||||
epochs=1,
|
||||
batch_size=1,
|
||||
learning_rate=3e-5,
|
||||
|
@ -94,12 +95,13 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
|
||||
def test_label_mismatch(self):
|
||||
options = (
|
||||
text_classifier_options.TextClassifierOptions(
|
||||
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER))
|
||||
text_classifier.TextClassifierOptions(
|
||||
supported_model=(
|
||||
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER)))
|
||||
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||
train_data = dataset.Dataset(train_tf_dataset, 1, ['foo'])
|
||||
train_data = text_classifier.Dataset(train_tf_dataset, 1, ['foo'])
|
||||
validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
|
||||
validation_data = dataset.Dataset(validation_tf_dataset, 1, ['bar'])
|
||||
validation_data = text_classifier.Dataset(validation_tf_dataset, 1, ['bar'])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Training data label names .* not equal to validation data label names'
|
||||
|
@ -111,9 +113,11 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
train_data, validation_data = self._get_data()
|
||||
|
||||
avg_options = (
|
||||
text_classifier_options.TextClassifierOptions(
|
||||
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||
model_options=mo.AverageWordEmbeddingClassifierOptions()))
|
||||
text_classifier.TextClassifierOptions(
|
||||
supported_model=(
|
||||
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER),
|
||||
model_options=(
|
||||
text_classifier.AverageWordEmbeddingClassifierModelOptions())))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
|
||||
' SupportedModels.MOBILEBERT_CLASSIFIER'):
|
||||
|
@ -121,10 +125,10 @@ class TextClassifierTest(tf.test.TestCase):
|
|||
avg_options)
|
||||
|
||||
bert_options = (
|
||||
text_classifier_options.TextClassifierOptions(
|
||||
supported_model=(
|
||||
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
|
||||
model_options=mo.BertClassifierOptions()))
|
||||
text_classifier.TextClassifierOptions(
|
||||
supported_model=(text_classifier.SupportedModels
|
||||
.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
|
||||
model_options=text_classifier.BertClassifierModelOptions()))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got'
|
||||
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):
|
||||
|
|
Loading…
Reference in New Issue
Block a user