Adds a public import API for TextClassifier.

PiperOrigin-RevId: 487949023
This commit is contained in:
MediaPipe Team 2022-11-11 16:50:38 -08:00 committed by Copybara-Service
parent 8ec83d2aa0
commit c2ac040a6c
8 changed files with 97 additions and 71 deletions

View File

@ -21,6 +21,19 @@ package(
licenses(["notice"])
py_library(
name = "text_classifier_import",
srcs = ["__init__.py"],
deps = [
":dataset",
":model_options",
":model_spec",
":text_classifier",
":text_classifier_options",
"//mediapipe/model_maker/python/core:hyperparameters",
],
)
py_library(
name = "model_options",
srcs = ["model_options.py"],
@ -114,12 +127,7 @@ py_test(
],
tags = ["requires-net:external"],
deps = [
":dataset",
":model_options",
":model_spec",
":text_classifier",
":text_classifier_options",
"//mediapipe/model_maker/python/core:hyperparameters",
":text_classifier_import",
"//mediapipe/tasks/python/test:test_utils",
],
)
@ -128,11 +136,7 @@ py_library(
name = "text_classifier_demo_lib",
srcs = ["text_classifier_demo.py"],
deps = [
":dataset",
":model_spec",
":text_classifier",
":text_classifier_options",
"//mediapipe/model_maker/python/core:hyperparameters",
":text_classifier_import",
"//mediapipe/model_maker/python/core/utils:quantization",
],
)

View File

@ -11,3 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe Public Python API for Text Classifier."""
from mediapipe.model_maker.python.core import hyperparameters
from mediapipe.model_maker.python.text.text_classifier import dataset
from mediapipe.model_maker.python.text.text_classifier import model_options
from mediapipe.model_maker.python.text.text_classifier import model_spec
from mediapipe.model_maker.python.text.text_classifier import text_classifier
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
HParams = hyperparameters.BaseHParams
CSVParams = dataset.CSVParameters
Dataset = dataset.Dataset
AverageWordEmbeddingClassifierModelOptions = (
model_options.AverageWordEmbeddingClassifierModelOptions)
BertClassifierModelOptions = model_options.BertClassifierModelOptions
SupportedModels = model_spec.SupportedModels
TextClassifier = text_classifier.TextClassifier
TextClassifierOptions = text_classifier_options.TextClassifierOptions

View File

@ -18,12 +18,12 @@ from typing import Union
from mediapipe.model_maker.python.text.core import bert_model_options
# BERT text classifier options inherited from BertModelOptions.
BertClassifierOptions = bert_model_options.BertModelOptions
# BERT text classifier model options inherited from BertModelOptions.
BertClassifierModelOptions = bert_model_options.BertModelOptions
@dataclasses.dataclass
class AverageWordEmbeddingClassifierOptions:
class AverageWordEmbeddingClassifierModelOptions:
"""Configurable model options for an Average Word Embedding classifier.
Attributes:
@ -41,5 +41,5 @@ class AverageWordEmbeddingClassifierOptions:
dropout_rate: float = 0.2
TextClassifierModelOptions = Union[AverageWordEmbeddingClassifierOptions,
BertClassifierOptions]
TextClassifierModelOptions = Union[AverageWordEmbeddingClassifierModelOptions,
BertClassifierModelOptions]

View File

@ -38,8 +38,8 @@ class AverageWordEmbeddingClassifierSpec:
# `learning_rate` is unused for the average word embedding model
hparams: hp.BaseHParams = hp.BaseHParams(
epochs=10, batch_size=32, learning_rate=0)
model_options: mo.AverageWordEmbeddingClassifierOptions = (
mo.AverageWordEmbeddingClassifierOptions())
model_options: mo.AverageWordEmbeddingClassifierModelOptions = (
mo.AverageWordEmbeddingClassifierModelOptions())
name: str = 'AverageWordEmbedding'

View File

@ -40,7 +40,7 @@ class ModelSpecTest(tf.test.TestCase):
})
self.assertEqual(
model_spec_obj.model_options,
classifier_model_options.BertClassifierOptions(
classifier_model_options.BertClassifierModelOptions(
seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
self.assertEqual(
model_spec_obj.hparams,
@ -57,7 +57,7 @@ class ModelSpecTest(tf.test.TestCase):
self.assertEqual(model_spec_obj.name, 'AverageWordEmbedding')
self.assertEqual(
model_spec_obj.model_options,
classifier_model_options.AverageWordEmbeddingClassifierOptions(
classifier_model_options.AverageWordEmbeddingClassifierModelOptions(
seq_len=256,
wordvec_dim=16,
do_lower_case=True,
@ -77,7 +77,7 @@ class ModelSpecTest(tf.test.TestCase):
def test_custom_bert_spec(self):
custom_bert_classifier_options = (
classifier_model_options.BertClassifierOptions(
classifier_model_options.BertClassifierModelOptions(
seq_len=512, do_fine_tuning=False, dropout_rate=0.3))
model_spec_obj = (
ms.SupportedModels.MOBILEBERT_CLASSIFIER.value(
@ -97,7 +97,7 @@ class ModelSpecTest(tf.test.TestCase):
num_gpus=3,
tpu='tpu/address')
custom_average_word_embedding_model_options = (
classifier_model_options.AverageWordEmbeddingClassifierOptions(
classifier_model_options.AverageWordEmbeddingClassifierModelOptions(
seq_len=512,
wordvec_dim=32,
do_lower_case=False,

View File

@ -51,12 +51,12 @@ def _validate(options: text_classifier_options.TextClassifierOptions):
return
if (isinstance(options.model_options,
mo.AverageWordEmbeddingClassifierOptions) and
mo.AverageWordEmbeddingClassifierModelOptions) and
(options.supported_model !=
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
f" got {options.supported_model}")
if (isinstance(options.model_options, mo.BertClassifierOptions) and
if (isinstance(options.model_options, mo.BertClassifierModelOptions) and
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
raise ValueError(
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
@ -194,7 +194,7 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
_DELIM_REGEX_PATTERN = r"[^\w\']+"
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
model_options: mo.AverageWordEmbeddingClassifierOptions,
model_options: mo.AverageWordEmbeddingClassifierModelOptions,
hparams: hp.BaseHParams, label_names: Sequence[str]):
super().__init__(model_spec, hparams, label_names)
self._model_options = model_options
@ -304,8 +304,8 @@ class _BertClassifier(TextClassifier):
_INITIALIZER_RANGE = 0.02
def __init__(self, model_spec: ms.BertClassifierSpec,
model_options: mo.BertClassifierOptions, hparams: hp.BaseHParams,
label_names: Sequence[str]):
model_options: mo.BertClassifierModelOptions,
hparams: hp.BaseHParams, label_names: Sequence[str]):
super().__init__(model_spec, hparams, label_names)
self._model_options = model_options
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()

View File

@ -23,12 +23,8 @@ from absl import flags
from absl import logging
import tensorflow as tf
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
from mediapipe.model_maker.python.text.text_classifier import text_classifier
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
from mediapipe.model_maker.python.text import text_classifier
FLAGS = flags.FLAGS
@ -53,31 +49,34 @@ def download_demo_data():
def run(data_dir,
export_dir=tempfile.mkdtemp(),
supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
supported_model=(
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
"""Runs demo."""
# Gets training data and validation data.
csv_params = text_ds.CSVParameters(
csv_params = text_classifier.CSVParams(
text_column='sentence', label_column='label', delimiter='\t')
train_data = text_ds.Dataset.from_csv(
train_data = text_classifier.Dataset.from_csv(
filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
csv_params=csv_params)
validation_data = text_ds.Dataset.from_csv(
validation_data = text_classifier.Dataset.from_csv(
filename=os.path.join(os.path.join(data_dir, 'dev.tsv')),
csv_params=csv_params)
quantization_config = None
if supported_model == ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER:
hparams = hp.BaseHParams(
if (supported_model ==
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
hparams = text_classifier.HParams(
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir)
# Warning: This takes extremely long to run on CPU
elif supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
elif (
supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER):
quantization_config = quantization.QuantizationConfig.for_dynamic()
hparams = hp.BaseHParams(
hparams = text_classifier.HParams(
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir)
# Fine-tunes the model.
options = text_classifier_options.TextClassifierOptions(
options = text_classifier.TextClassifierOptions(
supported_model=supported_model, hparams=hparams)
model = text_classifier.TextClassifier.create(train_data, validation_data,
options)
@ -96,9 +95,10 @@ def main(_):
export_dir = os.path.expanduser(FLAGS.export_dir)
if FLAGS.supported_model == 'average_word_embedding':
supported_model = ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
supported_model = (
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)
elif FLAGS.supported_model == 'bert':
supported_model = ms.SupportedModels.MOBILEBERT_CLASSIFIER
supported_model = text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER
run(data_dir, export_dir, supported_model)

View File

@ -18,12 +18,7 @@ import os
import tensorflow as tf
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import dataset
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
from mediapipe.model_maker.python.text.text_classifier import text_classifier
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
from mediapipe.model_maker.python.text import text_classifier
from mediapipe.tasks.python.test import test_utils
@ -43,18 +38,23 @@ class TextClassifierTest(tf.test.TestCase):
writer.writeheader()
for label, text in labels_and_text:
writer.writerow({'text': text, 'label': label})
csv_params = dataset.CSVParameters(text_column='text', label_column='label')
all_data = dataset.Dataset.from_csv(
csv_params = text_classifier.CSVParams(
text_column='text', label_column='label')
all_data = text_classifier.Dataset.from_csv(
filename=csv_file, csv_params=csv_params)
return all_data.split(0.5)
def test_create_and_train_average_word_embedding_model(self):
train_data, validation_data = self._get_data()
options = text_classifier_options.TextClassifierOptions(
supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER,
hparams=hp.BaseHParams(epochs=1, batch_size=1, learning_rate=0))
average_word_embedding_classifier = text_classifier.TextClassifier.create(
train_data, validation_data, options)
options = (
text_classifier.TextClassifierOptions(
supported_model=(text_classifier.SupportedModels
.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
hparams=text_classifier.HParams(
epochs=1, batch_size=1, learning_rate=0)))
average_word_embedding_classifier = (
text_classifier.TextClassifier.create(train_data, validation_data,
options))
_, accuracy = average_word_embedding_classifier.evaluate(validation_data)
self.assertGreaterEqual(accuracy, 0.0)
@ -77,10 +77,11 @@ class TextClassifierTest(tf.test.TestCase):
def test_create_and_train_bert(self):
train_data, validation_data = self._get_data()
options = text_classifier_options.TextClassifierOptions(
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER,
model_options=mo.BertClassifierOptions(do_fine_tuning=False, seq_len=2),
hparams=hp.BaseHParams(
options = text_classifier.TextClassifierOptions(
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
model_options=text_classifier.BertClassifierModelOptions(
do_fine_tuning=False, seq_len=2),
hparams=text_classifier.HParams(
epochs=1,
batch_size=1,
learning_rate=3e-5,
@ -94,12 +95,13 @@ class TextClassifierTest(tf.test.TestCase):
def test_label_mismatch(self):
options = (
text_classifier_options.TextClassifierOptions(
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER))
text_classifier.TextClassifierOptions(
supported_model=(
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER)))
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
train_data = dataset.Dataset(train_tf_dataset, 1, ['foo'])
train_data = text_classifier.Dataset(train_tf_dataset, 1, ['foo'])
validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
validation_data = dataset.Dataset(validation_tf_dataset, 1, ['bar'])
validation_data = text_classifier.Dataset(validation_tf_dataset, 1, ['bar'])
with self.assertRaisesRegex(
ValueError,
'Training data label names .* not equal to validation data label names'
@ -111,9 +113,11 @@ class TextClassifierTest(tf.test.TestCase):
train_data, validation_data = self._get_data()
avg_options = (
text_classifier_options.TextClassifierOptions(
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER,
model_options=mo.AverageWordEmbeddingClassifierOptions()))
text_classifier.TextClassifierOptions(
supported_model=(
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER),
model_options=(
text_classifier.AverageWordEmbeddingClassifierModelOptions())))
with self.assertRaisesRegex(
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
' SupportedModels.MOBILEBERT_CLASSIFIER'):
@ -121,10 +125,10 @@ class TextClassifierTest(tf.test.TestCase):
avg_options)
bert_options = (
text_classifier_options.TextClassifierOptions(
supported_model=(
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
model_options=mo.BertClassifierOptions()))
text_classifier.TextClassifierOptions(
supported_model=(text_classifier.SupportedModels
.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
model_options=text_classifier.BertClassifierModelOptions()))
with self.assertRaisesRegex(
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got'
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):