Adds a public import API for TextClassifier.

PiperOrigin-RevId: 487949023
This commit is contained in:
MediaPipe Team 2022-11-11 16:50:38 -08:00 committed by Copybara-Service
parent 8ec83d2aa0
commit c2ac040a6c
8 changed files with 97 additions and 71 deletions

View File

@ -21,6 +21,19 @@ package(
licenses(["notice"]) licenses(["notice"])
py_library(
name = "text_classifier_import",
srcs = ["__init__.py"],
deps = [
":dataset",
":model_options",
":model_spec",
":text_classifier",
":text_classifier_options",
"//mediapipe/model_maker/python/core:hyperparameters",
],
)
py_library( py_library(
name = "model_options", name = "model_options",
srcs = ["model_options.py"], srcs = ["model_options.py"],
@ -114,12 +127,7 @@ py_test(
], ],
tags = ["requires-net:external"], tags = ["requires-net:external"],
deps = [ deps = [
":dataset", ":text_classifier_import",
":model_options",
":model_spec",
":text_classifier",
":text_classifier_options",
"//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
], ],
) )
@ -128,11 +136,7 @@ py_library(
name = "text_classifier_demo_lib", name = "text_classifier_demo_lib",
srcs = ["text_classifier_demo.py"], srcs = ["text_classifier_demo.py"],
deps = [ deps = [
":dataset", ":text_classifier_import",
":model_spec",
":text_classifier",
":text_classifier_options",
"//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/model_maker/python/core/utils:quantization", "//mediapipe/model_maker/python/core/utils:quantization",
], ],
) )

View File

@ -11,3 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""MediaPipe Public Python API for Text Classifier."""
from mediapipe.model_maker.python.core import hyperparameters
from mediapipe.model_maker.python.text.text_classifier import dataset
from mediapipe.model_maker.python.text.text_classifier import model_options
from mediapipe.model_maker.python.text.text_classifier import model_spec
from mediapipe.model_maker.python.text.text_classifier import text_classifier
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
HParams = hyperparameters.BaseHParams
CSVParams = dataset.CSVParameters
Dataset = dataset.Dataset
AverageWordEmbeddingClassifierModelOptions = (
model_options.AverageWordEmbeddingClassifierModelOptions)
BertClassifierModelOptions = model_options.BertClassifierModelOptions
SupportedModels = model_spec.SupportedModels
TextClassifier = text_classifier.TextClassifier
TextClassifierOptions = text_classifier_options.TextClassifierOptions

View File

@ -18,12 +18,12 @@ from typing import Union
from mediapipe.model_maker.python.text.core import bert_model_options from mediapipe.model_maker.python.text.core import bert_model_options
# BERT text classifier options inherited from BertModelOptions. # BERT text classifier model options inherited from BertModelOptions.
BertClassifierOptions = bert_model_options.BertModelOptions BertClassifierModelOptions = bert_model_options.BertModelOptions
@dataclasses.dataclass @dataclasses.dataclass
class AverageWordEmbeddingClassifierOptions: class AverageWordEmbeddingClassifierModelOptions:
"""Configurable model options for an Average Word Embedding classifier. """Configurable model options for an Average Word Embedding classifier.
Attributes: Attributes:
@ -41,5 +41,5 @@ class AverageWordEmbeddingClassifierOptions:
dropout_rate: float = 0.2 dropout_rate: float = 0.2
TextClassifierModelOptions = Union[AverageWordEmbeddingClassifierOptions, TextClassifierModelOptions = Union[AverageWordEmbeddingClassifierModelOptions,
BertClassifierOptions] BertClassifierModelOptions]

View File

@ -38,8 +38,8 @@ class AverageWordEmbeddingClassifierSpec:
# `learning_rate` is unused for the average word embedding model # `learning_rate` is unused for the average word embedding model
hparams: hp.BaseHParams = hp.BaseHParams( hparams: hp.BaseHParams = hp.BaseHParams(
epochs=10, batch_size=32, learning_rate=0) epochs=10, batch_size=32, learning_rate=0)
model_options: mo.AverageWordEmbeddingClassifierOptions = ( model_options: mo.AverageWordEmbeddingClassifierModelOptions = (
mo.AverageWordEmbeddingClassifierOptions()) mo.AverageWordEmbeddingClassifierModelOptions())
name: str = 'AverageWordEmbedding' name: str = 'AverageWordEmbedding'

View File

@ -40,7 +40,7 @@ class ModelSpecTest(tf.test.TestCase):
}) })
self.assertEqual( self.assertEqual(
model_spec_obj.model_options, model_spec_obj.model_options,
classifier_model_options.BertClassifierOptions( classifier_model_options.BertClassifierModelOptions(
seq_len=128, do_fine_tuning=True, dropout_rate=0.1)) seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
self.assertEqual( self.assertEqual(
model_spec_obj.hparams, model_spec_obj.hparams,
@ -57,7 +57,7 @@ class ModelSpecTest(tf.test.TestCase):
self.assertEqual(model_spec_obj.name, 'AverageWordEmbedding') self.assertEqual(model_spec_obj.name, 'AverageWordEmbedding')
self.assertEqual( self.assertEqual(
model_spec_obj.model_options, model_spec_obj.model_options,
classifier_model_options.AverageWordEmbeddingClassifierOptions( classifier_model_options.AverageWordEmbeddingClassifierModelOptions(
seq_len=256, seq_len=256,
wordvec_dim=16, wordvec_dim=16,
do_lower_case=True, do_lower_case=True,
@ -77,7 +77,7 @@ class ModelSpecTest(tf.test.TestCase):
def test_custom_bert_spec(self): def test_custom_bert_spec(self):
custom_bert_classifier_options = ( custom_bert_classifier_options = (
classifier_model_options.BertClassifierOptions( classifier_model_options.BertClassifierModelOptions(
seq_len=512, do_fine_tuning=False, dropout_rate=0.3)) seq_len=512, do_fine_tuning=False, dropout_rate=0.3))
model_spec_obj = ( model_spec_obj = (
ms.SupportedModels.MOBILEBERT_CLASSIFIER.value( ms.SupportedModels.MOBILEBERT_CLASSIFIER.value(
@ -97,7 +97,7 @@ class ModelSpecTest(tf.test.TestCase):
num_gpus=3, num_gpus=3,
tpu='tpu/address') tpu='tpu/address')
custom_average_word_embedding_model_options = ( custom_average_word_embedding_model_options = (
classifier_model_options.AverageWordEmbeddingClassifierOptions( classifier_model_options.AverageWordEmbeddingClassifierModelOptions(
seq_len=512, seq_len=512,
wordvec_dim=32, wordvec_dim=32,
do_lower_case=False, do_lower_case=False,

View File

@ -51,12 +51,12 @@ def _validate(options: text_classifier_options.TextClassifierOptions):
return return
if (isinstance(options.model_options, if (isinstance(options.model_options,
mo.AverageWordEmbeddingClassifierOptions) and mo.AverageWordEmbeddingClassifierModelOptions) and
(options.supported_model != (options.supported_model !=
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)): ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER," raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
f" got {options.supported_model}") f" got {options.supported_model}")
if (isinstance(options.model_options, mo.BertClassifierOptions) and if (isinstance(options.model_options, mo.BertClassifierModelOptions) and
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)): (options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
raise ValueError( raise ValueError(
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}") f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
@ -194,7 +194,7 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
_DELIM_REGEX_PATTERN = r"[^\w\']+" _DELIM_REGEX_PATTERN = r"[^\w\']+"
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec, def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
model_options: mo.AverageWordEmbeddingClassifierOptions, model_options: mo.AverageWordEmbeddingClassifierModelOptions,
hparams: hp.BaseHParams, label_names: Sequence[str]): hparams: hp.BaseHParams, label_names: Sequence[str]):
super().__init__(model_spec, hparams, label_names) super().__init__(model_spec, hparams, label_names)
self._model_options = model_options self._model_options = model_options
@ -304,8 +304,8 @@ class _BertClassifier(TextClassifier):
_INITIALIZER_RANGE = 0.02 _INITIALIZER_RANGE = 0.02
def __init__(self, model_spec: ms.BertClassifierSpec, def __init__(self, model_spec: ms.BertClassifierSpec,
model_options: mo.BertClassifierOptions, hparams: hp.BaseHParams, model_options: mo.BertClassifierModelOptions,
label_names: Sequence[str]): hparams: hp.BaseHParams, label_names: Sequence[str]):
super().__init__(model_spec, hparams, label_names) super().__init__(model_spec, hparams, label_names)
self._model_options = model_options self._model_options = model_options
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy() self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()

View File

@ -23,12 +23,8 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds from mediapipe.model_maker.python.text import text_classifier
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
from mediapipe.model_maker.python.text.text_classifier import text_classifier
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
@ -53,31 +49,34 @@ def download_demo_data():
def run(data_dir, def run(data_dir,
export_dir=tempfile.mkdtemp(), export_dir=tempfile.mkdtemp(),
supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER): supported_model=(
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
"""Runs demo.""" """Runs demo."""
# Gets training data and validation data. # Gets training data and validation data.
csv_params = text_ds.CSVParameters( csv_params = text_classifier.CSVParams(
text_column='sentence', label_column='label', delimiter='\t') text_column='sentence', label_column='label', delimiter='\t')
train_data = text_ds.Dataset.from_csv( train_data = text_classifier.Dataset.from_csv(
filename=os.path.join(os.path.join(data_dir, 'train.tsv')), filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
csv_params=csv_params) csv_params=csv_params)
validation_data = text_ds.Dataset.from_csv( validation_data = text_classifier.Dataset.from_csv(
filename=os.path.join(os.path.join(data_dir, 'dev.tsv')), filename=os.path.join(os.path.join(data_dir, 'dev.tsv')),
csv_params=csv_params) csv_params=csv_params)
quantization_config = None quantization_config = None
if supported_model == ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER: if (supported_model ==
hparams = hp.BaseHParams( text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
hparams = text_classifier.HParams(
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir) epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir)
# Warning: This takes extremely long to run on CPU # Warning: This takes extremely long to run on CPU
elif supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER: elif (
supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER):
quantization_config = quantization.QuantizationConfig.for_dynamic() quantization_config = quantization.QuantizationConfig.for_dynamic()
hparams = hp.BaseHParams( hparams = text_classifier.HParams(
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir) epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir)
# Fine-tunes the model. # Fine-tunes the model.
options = text_classifier_options.TextClassifierOptions( options = text_classifier.TextClassifierOptions(
supported_model=supported_model, hparams=hparams) supported_model=supported_model, hparams=hparams)
model = text_classifier.TextClassifier.create(train_data, validation_data, model = text_classifier.TextClassifier.create(train_data, validation_data,
options) options)
@ -96,9 +95,10 @@ def main(_):
export_dir = os.path.expanduser(FLAGS.export_dir) export_dir = os.path.expanduser(FLAGS.export_dir)
if FLAGS.supported_model == 'average_word_embedding': if FLAGS.supported_model == 'average_word_embedding':
supported_model = ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER supported_model = (
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)
elif FLAGS.supported_model == 'bert': elif FLAGS.supported_model == 'bert':
supported_model = ms.SupportedModels.MOBILEBERT_CLASSIFIER supported_model = text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER
run(data_dir, export_dir, supported_model) run(data_dir, export_dir, supported_model)

View File

@ -18,12 +18,7 @@ import os
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core import hyperparameters as hp from mediapipe.model_maker.python.text import text_classifier
from mediapipe.model_maker.python.text.text_classifier import dataset
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
from mediapipe.model_maker.python.text.text_classifier import text_classifier
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
@ -43,18 +38,23 @@ class TextClassifierTest(tf.test.TestCase):
writer.writeheader() writer.writeheader()
for label, text in labels_and_text: for label, text in labels_and_text:
writer.writerow({'text': text, 'label': label}) writer.writerow({'text': text, 'label': label})
csv_params = dataset.CSVParameters(text_column='text', label_column='label') csv_params = text_classifier.CSVParams(
all_data = dataset.Dataset.from_csv( text_column='text', label_column='label')
all_data = text_classifier.Dataset.from_csv(
filename=csv_file, csv_params=csv_params) filename=csv_file, csv_params=csv_params)
return all_data.split(0.5) return all_data.split(0.5)
def test_create_and_train_average_word_embedding_model(self): def test_create_and_train_average_word_embedding_model(self):
train_data, validation_data = self._get_data() train_data, validation_data = self._get_data()
options = text_classifier_options.TextClassifierOptions( options = (
supported_model=ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER, text_classifier.TextClassifierOptions(
hparams=hp.BaseHParams(epochs=1, batch_size=1, learning_rate=0)) supported_model=(text_classifier.SupportedModels
average_word_embedding_classifier = text_classifier.TextClassifier.create( .AVERAGE_WORD_EMBEDDING_CLASSIFIER),
train_data, validation_data, options) hparams=text_classifier.HParams(
epochs=1, batch_size=1, learning_rate=0)))
average_word_embedding_classifier = (
text_classifier.TextClassifier.create(train_data, validation_data,
options))
_, accuracy = average_word_embedding_classifier.evaluate(validation_data) _, accuracy = average_word_embedding_classifier.evaluate(validation_data)
self.assertGreaterEqual(accuracy, 0.0) self.assertGreaterEqual(accuracy, 0.0)
@ -77,10 +77,11 @@ class TextClassifierTest(tf.test.TestCase):
def test_create_and_train_bert(self): def test_create_and_train_bert(self):
train_data, validation_data = self._get_data() train_data, validation_data = self._get_data()
options = text_classifier_options.TextClassifierOptions( options = text_classifier.TextClassifierOptions(
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER, supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
model_options=mo.BertClassifierOptions(do_fine_tuning=False, seq_len=2), model_options=text_classifier.BertClassifierModelOptions(
hparams=hp.BaseHParams( do_fine_tuning=False, seq_len=2),
hparams=text_classifier.HParams(
epochs=1, epochs=1,
batch_size=1, batch_size=1,
learning_rate=3e-5, learning_rate=3e-5,
@ -94,12 +95,13 @@ class TextClassifierTest(tf.test.TestCase):
def test_label_mismatch(self): def test_label_mismatch(self):
options = ( options = (
text_classifier_options.TextClassifierOptions( text_classifier.TextClassifierOptions(
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER)) supported_model=(
text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER)))
train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
train_data = dataset.Dataset(train_tf_dataset, 1, ['foo']) train_data = text_classifier.Dataset(train_tf_dataset, 1, ['foo'])
validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]])
validation_data = dataset.Dataset(validation_tf_dataset, 1, ['bar']) validation_data = text_classifier.Dataset(validation_tf_dataset, 1, ['bar'])
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
'Training data label names .* not equal to validation data label names' 'Training data label names .* not equal to validation data label names'
@ -111,9 +113,11 @@ class TextClassifierTest(tf.test.TestCase):
train_data, validation_data = self._get_data() train_data, validation_data = self._get_data()
avg_options = ( avg_options = (
text_classifier_options.TextClassifierOptions( text_classifier.TextClassifierOptions(
supported_model=ms.SupportedModels.MOBILEBERT_CLASSIFIER, supported_model=(
model_options=mo.AverageWordEmbeddingClassifierOptions())) text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER),
model_options=(
text_classifier.AverageWordEmbeddingClassifierModelOptions())))
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got' ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got'
' SupportedModels.MOBILEBERT_CLASSIFIER'): ' SupportedModels.MOBILEBERT_CLASSIFIER'):
@ -121,10 +125,10 @@ class TextClassifierTest(tf.test.TestCase):
avg_options) avg_options)
bert_options = ( bert_options = (
text_classifier_options.TextClassifierOptions( text_classifier.TextClassifierOptions(
supported_model=( supported_model=(text_classifier.SupportedModels
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER), .AVERAGE_WORD_EMBEDDING_CLASSIFIER),
model_options=mo.BertClassifierOptions())) model_options=text_classifier.BertClassifierModelOptions()))
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, 'Expected MOBILEBERT_CLASSIFIER, got' ValueError, 'Expected MOBILEBERT_CLASSIFIER, got'
' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'): ' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'):