Support ExBert training and option to select between AdamW and LAMB optimizers for BertClassifier
PiperOrigin-RevId: 543905014
This commit is contained in:
		
							parent
							
								
									bed624f3b6
								
							
						
					
					
						commit
						1ee55d1f1b
					
				| 
						 | 
					@ -31,11 +31,11 @@ py_library(
 | 
				
			||||||
    visibility = ["//visibility:public"],
 | 
					    visibility = ["//visibility:public"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
        ":dataset",
 | 
					        ":dataset",
 | 
				
			||||||
 | 
					        ":hyperparameters",
 | 
				
			||||||
        ":model_options",
 | 
					        ":model_options",
 | 
				
			||||||
        ":model_spec",
 | 
					        ":model_spec",
 | 
				
			||||||
        ":text_classifier",
 | 
					        ":text_classifier",
 | 
				
			||||||
        ":text_classifier_options",
 | 
					        ":text_classifier_options",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
					 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -45,12 +45,18 @@ py_library(
 | 
				
			||||||
    deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"],
 | 
					    deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_library(
 | 
				
			||||||
 | 
					    name = "hyperparameters",
 | 
				
			||||||
 | 
					    srcs = ["hyperparameters.py"],
 | 
				
			||||||
 | 
					    deps = ["//mediapipe/model_maker/python/core:hyperparameters"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
py_library(
 | 
					py_library(
 | 
				
			||||||
    name = "model_spec",
 | 
					    name = "model_spec",
 | 
				
			||||||
    srcs = ["model_spec.py"],
 | 
					    srcs = ["model_spec.py"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":hyperparameters",
 | 
				
			||||||
        ":model_options",
 | 
					        ":model_options",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
					 | 
				
			||||||
        "//mediapipe/model_maker/python/core/utils:file_util",
 | 
					        "//mediapipe/model_maker/python/core/utils:file_util",
 | 
				
			||||||
        "//mediapipe/model_maker/python/text/core:bert_model_spec",
 | 
					        "//mediapipe/model_maker/python/text/core:bert_model_spec",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
| 
						 | 
					@ -61,9 +67,9 @@ py_test(
 | 
				
			||||||
    srcs = ["model_spec_test.py"],
 | 
					    srcs = ["model_spec_test.py"],
 | 
				
			||||||
    tags = ["requires-net:external"],
 | 
					    tags = ["requires-net:external"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":hyperparameters",
 | 
				
			||||||
        ":model_options",
 | 
					        ":model_options",
 | 
				
			||||||
        ":model_spec",
 | 
					        ":model_spec",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
					 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -100,9 +106,9 @@ py_library(
 | 
				
			||||||
    name = "text_classifier_options",
 | 
					    name = "text_classifier_options",
 | 
				
			||||||
    srcs = ["text_classifier_options.py"],
 | 
					    srcs = ["text_classifier_options.py"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":hyperparameters",
 | 
				
			||||||
        ":model_options",
 | 
					        ":model_options",
 | 
				
			||||||
        ":model_spec",
 | 
					        ":model_spec",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
					 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -111,11 +117,11 @@ py_library(
 | 
				
			||||||
    srcs = ["text_classifier.py"],
 | 
					    srcs = ["text_classifier.py"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
        ":dataset",
 | 
					        ":dataset",
 | 
				
			||||||
 | 
					        ":hyperparameters",
 | 
				
			||||||
        ":model_options",
 | 
					        ":model_options",
 | 
				
			||||||
        ":model_spec",
 | 
					        ":model_spec",
 | 
				
			||||||
        ":preprocessor",
 | 
					        ":preprocessor",
 | 
				
			||||||
        ":text_classifier_options",
 | 
					        ":text_classifier_options",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
					 | 
				
			||||||
        "//mediapipe/model_maker/python/core/data:dataset",
 | 
					        "//mediapipe/model_maker/python/core/data:dataset",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core/tasks:classifier",
 | 
					        "//mediapipe/model_maker/python/core/tasks:classifier",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core/utils:metrics",
 | 
					        "//mediapipe/model_maker/python/core/utils:metrics",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,19 +13,23 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
"""MediaPipe Public Python API for Text Classifier."""
 | 
					"""MediaPipe Public Python API for Text Classifier."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from mediapipe.model_maker.python.core import hyperparameters
 | 
					 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import dataset
 | 
					from mediapipe.model_maker.python.text.text_classifier import dataset
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.text.text_classifier import hyperparameters
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import model_options
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_options
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_spec
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier
 | 
					from mediapipe.model_maker.python.text.text_classifier import text_classifier
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
 | 
					from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
 | 
				
			||||||
 | 
					
 | 
				
			||||||
HParams = hyperparameters.BaseHParams
 | 
					
 | 
				
			||||||
 | 
					AverageWordEmbeddingHParams = hyperparameters.AverageWordEmbeddingHParams
 | 
				
			||||||
 | 
					AverageWordEmbeddingModelOptions = (
 | 
				
			||||||
 | 
					    model_options.AverageWordEmbeddingModelOptions
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					BertOptimizer = hyperparameters.BertOptimizer
 | 
				
			||||||
 | 
					BertHParams = hyperparameters.BertHParams
 | 
				
			||||||
 | 
					BertModelOptions = model_options.BertModelOptions
 | 
				
			||||||
CSVParams = dataset.CSVParameters
 | 
					CSVParams = dataset.CSVParameters
 | 
				
			||||||
Dataset = dataset.Dataset
 | 
					Dataset = dataset.Dataset
 | 
				
			||||||
AverageWordEmbeddingModelOptions = (
 | 
					 | 
				
			||||||
    model_options.AverageWordEmbeddingModelOptions)
 | 
					 | 
				
			||||||
BertModelOptions = model_options.BertModelOptions
 | 
					 | 
				
			||||||
SupportedModels = model_spec.SupportedModels
 | 
					SupportedModels = model_spec.SupportedModels
 | 
				
			||||||
TextClassifier = text_classifier.TextClassifier
 | 
					TextClassifier = text_classifier.TextClassifier
 | 
				
			||||||
TextClassifierOptions = text_classifier_options.TextClassifierOptions
 | 
					TextClassifierOptions = text_classifier_options.TextClassifierOptions
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,54 @@
 | 
				
			||||||
 | 
					# Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""Hyperparameters for training object detection models."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import dataclasses
 | 
				
			||||||
 | 
					import enum
 | 
				
			||||||
 | 
					from typing import Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclasses.dataclass
 | 
				
			||||||
 | 
					class AverageWordEmbeddingHParams(hp.BaseHParams):
 | 
				
			||||||
 | 
					  """The hyperparameters for an AverageWordEmbeddingClassifier."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@enum.unique
 | 
				
			||||||
 | 
					class BertOptimizer(enum.Enum):
 | 
				
			||||||
 | 
					  """Supported Optimizers for Bert Text Classifier."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  ADAMW = "adamw"
 | 
				
			||||||
 | 
					  LAMB = "lamb"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclasses.dataclass
 | 
				
			||||||
 | 
					class BertHParams(hp.BaseHParams):
 | 
				
			||||||
 | 
					  """The hyperparameters for a Bert Classifier.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Attributes:
 | 
				
			||||||
 | 
					    learning_rate: Learning rate to use for gradient descent training.
 | 
				
			||||||
 | 
					    batch_size: Batch size for training.
 | 
				
			||||||
 | 
					    epochs: Number of training iterations over the dataset.
 | 
				
			||||||
 | 
					    optimizer: Optimizer to use for training. Only supported values are "adamw"
 | 
				
			||||||
 | 
					      and "lamb".
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  learning_rate: float = 3e-5
 | 
				
			||||||
 | 
					  batch_size: int = 48
 | 
				
			||||||
 | 
					  epochs: int = 2
 | 
				
			||||||
 | 
					  optimizer: BertOptimizer = BertOptimizer.ADAMW
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					HParams = Union[BertHParams, AverageWordEmbeddingHParams]
 | 
				
			||||||
| 
						 | 
					@ -17,13 +17,11 @@ import dataclasses
 | 
				
			||||||
import enum
 | 
					import enum
 | 
				
			||||||
import functools
 | 
					import functools
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
					 | 
				
			||||||
from mediapipe.model_maker.python.core.utils import file_util
 | 
					from mediapipe.model_maker.python.core.utils import file_util
 | 
				
			||||||
from mediapipe.model_maker.python.text.core import bert_model_spec
 | 
					from mediapipe.model_maker.python.text.core import bert_model_spec
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_options as mo
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# BERT-based text classifier spec inherited from BertModelSpec
 | 
					 | 
				
			||||||
BertClassifierSpec = bert_model_spec.BertModelSpec
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
 | 
					MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
 | 
				
			||||||
    'text_classifier/mobilebert_tiny',
 | 
					    'text_classifier/mobilebert_tiny',
 | 
				
			||||||
| 
						 | 
					@ -31,6 +29,12 @@ MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
 | 
				
			||||||
    is_folder=True,
 | 
					    is_folder=True,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					EXBERT_FILES = file_util.DownloadedFiles(
 | 
				
			||||||
 | 
					    'text_classifier/exbert',
 | 
				
			||||||
 | 
					    'https://storage.googleapis.com/mediapipe-assets/exbert.tar.gz',
 | 
				
			||||||
 | 
					    is_folder=True,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclasses.dataclass
 | 
					@dataclasses.dataclass
 | 
				
			||||||
class AverageWordEmbeddingClassifierSpec:
 | 
					class AverageWordEmbeddingClassifierSpec:
 | 
				
			||||||
| 
						 | 
					@ -43,27 +47,53 @@ class AverageWordEmbeddingClassifierSpec:
 | 
				
			||||||
  """
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # `learning_rate` is unused for the average word embedding model
 | 
					  # `learning_rate` is unused for the average word embedding model
 | 
				
			||||||
  hparams: hp.BaseHParams = hp.BaseHParams(
 | 
					  hparams: hp.AverageWordEmbeddingHParams = hp.AverageWordEmbeddingHParams(
 | 
				
			||||||
      epochs=10, batch_size=32, learning_rate=0)
 | 
					      epochs=10, batch_size=32, learning_rate=0
 | 
				
			||||||
 | 
					  )
 | 
				
			||||||
  model_options: mo.AverageWordEmbeddingModelOptions = (
 | 
					  model_options: mo.AverageWordEmbeddingModelOptions = (
 | 
				
			||||||
      mo.AverageWordEmbeddingModelOptions())
 | 
					      mo.AverageWordEmbeddingModelOptions())
 | 
				
			||||||
  name: str = 'AverageWordEmbedding'
 | 
					  name: str = 'AverageWordEmbedding'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
average_word_embedding_classifier_spec = functools.partial(
 | 
					average_word_embedding_classifier_spec = functools.partial(
 | 
				
			||||||
    AverageWordEmbeddingClassifierSpec)
 | 
					    AverageWordEmbeddingClassifierSpec)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclasses.dataclass
 | 
				
			||||||
 | 
					class BertClassifierSpec(bert_model_spec.BertModelSpec):
 | 
				
			||||||
 | 
					  """Specification for a Bert classifier model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Only overrides the hparams attribute since the rest of the attributes are
 | 
				
			||||||
 | 
					  inherited from the BertModelSpec.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  hparams: hp.BertHParams = hp.BertHParams()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
mobilebert_classifier_spec = functools.partial(
 | 
					mobilebert_classifier_spec = functools.partial(
 | 
				
			||||||
    BertClassifierSpec,
 | 
					    BertClassifierSpec,
 | 
				
			||||||
    downloaded_files=MOBILEBERT_TINY_FILES,
 | 
					    downloaded_files=MOBILEBERT_TINY_FILES,
 | 
				
			||||||
    hparams=hp.BaseHParams(
 | 
					    hparams=hp.BertHParams(
 | 
				
			||||||
        epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
 | 
					        epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
 | 
				
			||||||
    ),
 | 
					    ),
 | 
				
			||||||
    name='MobileBert',
 | 
					    name='MobileBert',
 | 
				
			||||||
    tflite_input_name={
 | 
					    tflite_input_name={
 | 
				
			||||||
        'ids': 'serving_default_input_1:0',
 | 
					        'ids': 'serving_default_input_1:0',
 | 
				
			||||||
        'mask': 'serving_default_input_3:0',
 | 
					 | 
				
			||||||
        'segment_ids': 'serving_default_input_2:0',
 | 
					        'segment_ids': 'serving_default_input_2:0',
 | 
				
			||||||
 | 
					        'mask': 'serving_default_input_3:0',
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					exbert_classifier_spec = functools.partial(
 | 
				
			||||||
 | 
					    BertClassifierSpec,
 | 
				
			||||||
 | 
					    downloaded_files=EXBERT_FILES,
 | 
				
			||||||
 | 
					    hparams=hp.BertHParams(
 | 
				
			||||||
 | 
					        epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
 | 
					    name='ExBert',
 | 
				
			||||||
 | 
					    tflite_input_name={
 | 
				
			||||||
 | 
					        'ids': 'serving_default_input_1:0',
 | 
				
			||||||
 | 
					        'segment_ids': 'serving_default_input_2:0',
 | 
				
			||||||
 | 
					        'mask': 'serving_default_input_3:0',
 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -73,3 +103,4 @@ class SupportedModels(enum.Enum):
 | 
				
			||||||
  """Predefined text classifier model specs supported by Model Maker."""
 | 
					  """Predefined text classifier model specs supported by Model Maker."""
 | 
				
			||||||
  AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
 | 
					  AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
 | 
				
			||||||
  MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec
 | 
					  MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec
 | 
				
			||||||
 | 
					  EXBERT_CLASSIFIER = exbert_classifier_spec
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,7 +19,7 @@ from unittest import mock as unittest_mock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import tensorflow as tf
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
					from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -57,11 +57,13 @@ class ModelSpecTest(tf.test.TestCase):
 | 
				
			||||||
            seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
 | 
					            seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
 | 
				
			||||||
    self.assertEqual(
 | 
					    self.assertEqual(
 | 
				
			||||||
        model_spec_obj.hparams,
 | 
					        model_spec_obj.hparams,
 | 
				
			||||||
        hp.BaseHParams(
 | 
					        hp.BertHParams(
 | 
				
			||||||
            epochs=3,
 | 
					            epochs=3,
 | 
				
			||||||
            batch_size=48,
 | 
					            batch_size=48,
 | 
				
			||||||
            learning_rate=3e-5,
 | 
					            learning_rate=3e-5,
 | 
				
			||||||
            distribution_strategy='off'))
 | 
					            distribution_strategy='off',
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_predefined_average_word_embedding_spec(self):
 | 
					  def test_predefined_average_word_embedding_spec(self):
 | 
				
			||||||
    model_spec_obj = (
 | 
					    model_spec_obj = (
 | 
				
			||||||
| 
						 | 
					@ -78,7 +80,7 @@ class ModelSpecTest(tf.test.TestCase):
 | 
				
			||||||
            dropout_rate=0.2))
 | 
					            dropout_rate=0.2))
 | 
				
			||||||
    self.assertEqual(
 | 
					    self.assertEqual(
 | 
				
			||||||
        model_spec_obj.hparams,
 | 
					        model_spec_obj.hparams,
 | 
				
			||||||
        hp.BaseHParams(
 | 
					        hp.AverageWordEmbeddingHParams(
 | 
				
			||||||
            epochs=10,
 | 
					            epochs=10,
 | 
				
			||||||
            batch_size=32,
 | 
					            batch_size=32,
 | 
				
			||||||
            learning_rate=0,
 | 
					            learning_rate=0,
 | 
				
			||||||
| 
						 | 
					@ -101,7 +103,7 @@ class ModelSpecTest(tf.test.TestCase):
 | 
				
			||||||
                     custom_bert_classifier_options)
 | 
					                     custom_bert_classifier_options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_custom_average_word_embedding_spec(self):
 | 
					  def test_custom_average_word_embedding_spec(self):
 | 
				
			||||||
    custom_hparams = hp.BaseHParams(
 | 
					    custom_hparams = hp.AverageWordEmbeddingHParams(
 | 
				
			||||||
        learning_rate=0.4,
 | 
					        learning_rate=0.4,
 | 
				
			||||||
        batch_size=64,
 | 
					        batch_size=64,
 | 
				
			||||||
        epochs=10,
 | 
					        epochs=10,
 | 
				
			||||||
| 
						 | 
					@ -110,7 +112,8 @@ class ModelSpecTest(tf.test.TestCase):
 | 
				
			||||||
        export_dir='foo/bar',
 | 
					        export_dir='foo/bar',
 | 
				
			||||||
        distribution_strategy='mirrored',
 | 
					        distribution_strategy='mirrored',
 | 
				
			||||||
        num_gpus=3,
 | 
					        num_gpus=3,
 | 
				
			||||||
        tpu='tpu/address')
 | 
					        tpu='tpu/address',
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    custom_average_word_embedding_model_options = (
 | 
					    custom_average_word_embedding_model_options = (
 | 
				
			||||||
        classifier_model_options.AverageWordEmbeddingModelOptions(
 | 
					        classifier_model_options.AverageWordEmbeddingModelOptions(
 | 
				
			||||||
            seq_len=512,
 | 
					            seq_len=512,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,15 +19,16 @@ import tempfile
 | 
				
			||||||
from typing import Any, Optional, Sequence, Tuple
 | 
					from typing import Any, Optional, Sequence, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import tensorflow as tf
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					from tensorflow_addons import optimizers as tfa_optimizers
 | 
				
			||||||
import tensorflow_hub as hub
 | 
					import tensorflow_hub as hub
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
					 | 
				
			||||||
from mediapipe.model_maker.python.core.data import dataset as ds
 | 
					from mediapipe.model_maker.python.core.data import dataset as ds
 | 
				
			||||||
from mediapipe.model_maker.python.core.tasks import classifier
 | 
					from mediapipe.model_maker.python.core.tasks import classifier
 | 
				
			||||||
from mediapipe.model_maker.python.core.utils import metrics
 | 
					from mediapipe.model_maker.python.core.utils import metrics
 | 
				
			||||||
from mediapipe.model_maker.python.core.utils import model_util
 | 
					from mediapipe.model_maker.python.core.utils import model_util
 | 
				
			||||||
from mediapipe.model_maker.python.core.utils import quantization
 | 
					from mediapipe.model_maker.python.core.utils import quantization
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
 | 
					from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_options as mo
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import preprocessor
 | 
					from mediapipe.model_maker.python.text.text_classifier import preprocessor
 | 
				
			||||||
| 
						 | 
					@ -55,22 +56,26 @@ def _validate(options: text_classifier_options.TextClassifierOptions):
 | 
				
			||||||
       ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
 | 
					       ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
 | 
				
			||||||
    raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
 | 
					    raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
 | 
				
			||||||
                     f" got {options.supported_model}")
 | 
					                     f" got {options.supported_model}")
 | 
				
			||||||
  if (isinstance(options.model_options, mo.BertModelOptions) and
 | 
					  if isinstance(options.model_options, mo.BertModelOptions) and (
 | 
				
			||||||
      (options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
 | 
					      options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER
 | 
				
			||||||
 | 
					      and options.supported_model != ms.SupportedModels.EXBERT_CLASSIFIER
 | 
				
			||||||
 | 
					  ):
 | 
				
			||||||
    raise ValueError(
 | 
					    raise ValueError(
 | 
				
			||||||
        f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
 | 
					        "Expected a Bert Classifier(MobileBERT or EXBERT), got "
 | 
				
			||||||
 | 
					        f"{options.supported_model}"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TextClassifier(classifier.Classifier):
 | 
					class TextClassifier(classifier.Classifier):
 | 
				
			||||||
  """API for creating and training a text classification model."""
 | 
					  """API for creating and training a text classification model."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def __init__(self, model_spec: Any, hparams: hp.BaseHParams,
 | 
					  def __init__(
 | 
				
			||||||
               label_names: Sequence[str]):
 | 
					      self, model_spec: Any, label_names: Sequence[str], shuffle: bool
 | 
				
			||||||
 | 
					  ):
 | 
				
			||||||
    super().__init__(
 | 
					    super().__init__(
 | 
				
			||||||
        model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
 | 
					        model_spec=model_spec, label_names=label_names, shuffle=shuffle
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    self._model_spec = model_spec
 | 
					    self._model_spec = model_spec
 | 
				
			||||||
    self._hparams = hparams
 | 
					 | 
				
			||||||
    self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
 | 
					 | 
				
			||||||
    self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None
 | 
					    self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @classmethod
 | 
					  @classmethod
 | 
				
			||||||
| 
						 | 
					@ -107,7 +112,10 @@ class TextClassifier(classifier.Classifier):
 | 
				
			||||||
    if options.hparams is None:
 | 
					    if options.hparams is None:
 | 
				
			||||||
      options.hparams = options.supported_model.value().hparams
 | 
					      options.hparams = options.supported_model.value().hparams
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
 | 
					    if (
 | 
				
			||||||
 | 
					        options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
 | 
				
			||||||
 | 
					        or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
      text_classifier = (
 | 
					      text_classifier = (
 | 
				
			||||||
          _BertClassifier.create_bert_classifier(train_data, validation_data,
 | 
					          _BertClassifier.create_bert_classifier(train_data, validation_data,
 | 
				
			||||||
                                                 options,
 | 
					                                                 options,
 | 
				
			||||||
| 
						 | 
					@ -225,11 +233,17 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  _DELIM_REGEX_PATTERN = r"[^\w\']+"
 | 
					  _DELIM_REGEX_PATTERN = r"[^\w\']+"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
 | 
					  def __init__(
 | 
				
			||||||
 | 
					      self,
 | 
				
			||||||
 | 
					      model_spec: ms.AverageWordEmbeddingClassifierSpec,
 | 
				
			||||||
      model_options: mo.AverageWordEmbeddingModelOptions,
 | 
					      model_options: mo.AverageWordEmbeddingModelOptions,
 | 
				
			||||||
               hparams: hp.BaseHParams, label_names: Sequence[str]):
 | 
					      hparams: hp.AverageWordEmbeddingHParams,
 | 
				
			||||||
    super().__init__(model_spec, hparams, label_names)
 | 
					      label_names: Sequence[str],
 | 
				
			||||||
 | 
					  ):
 | 
				
			||||||
 | 
					    super().__init__(model_spec, label_names, hparams.shuffle)
 | 
				
			||||||
    self._model_options = model_options
 | 
					    self._model_options = model_options
 | 
				
			||||||
 | 
					    self._hparams = hparams
 | 
				
			||||||
 | 
					    self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
 | 
				
			||||||
    self._loss_function = "sparse_categorical_crossentropy"
 | 
					    self._loss_function = "sparse_categorical_crossentropy"
 | 
				
			||||||
    self._metric_functions = [
 | 
					    self._metric_functions = [
 | 
				
			||||||
        "accuracy",
 | 
					        "accuracy",
 | 
				
			||||||
| 
						 | 
					@ -344,10 +358,16 @@ class _BertClassifier(TextClassifier):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  _INITIALIZER_RANGE = 0.02
 | 
					  _INITIALIZER_RANGE = 0.02
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def __init__(self, model_spec: ms.BertClassifierSpec,
 | 
					  def __init__(
 | 
				
			||||||
               model_options: mo.BertModelOptions, hparams: hp.BaseHParams,
 | 
					      self,
 | 
				
			||||||
               label_names: Sequence[str]):
 | 
					      model_spec: ms.BertClassifierSpec,
 | 
				
			||||||
    super().__init__(model_spec, hparams, label_names)
 | 
					      model_options: mo.BertModelOptions,
 | 
				
			||||||
 | 
					      hparams: hp.BertHParams,
 | 
				
			||||||
 | 
					      label_names: Sequence[str],
 | 
				
			||||||
 | 
					  ):
 | 
				
			||||||
 | 
					    super().__init__(model_spec, label_names, hparams.shuffle)
 | 
				
			||||||
 | 
					    self._hparams = hparams
 | 
				
			||||||
 | 
					    self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
 | 
				
			||||||
    self._model_options = model_options
 | 
					    self._model_options = model_options
 | 
				
			||||||
    with self._hparams.get_strategy().scope():
 | 
					    with self._hparams.get_strategy().scope():
 | 
				
			||||||
      self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
 | 
					      self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
 | 
				
			||||||
| 
						 | 
					@ -480,11 +500,26 @@ class _BertClassifier(TextClassifier):
 | 
				
			||||||
          initial_learning_rate=initial_lr,
 | 
					          initial_learning_rate=initial_lr,
 | 
				
			||||||
          decay_schedule_fn=lr_schedule,
 | 
					          decay_schedule_fn=lr_schedule,
 | 
				
			||||||
          warmup_steps=warmup_steps)
 | 
					          warmup_steps=warmup_steps)
 | 
				
			||||||
 | 
					    if self._hparams.optimizer == hp.BertOptimizer.ADAMW:
 | 
				
			||||||
      self._optimizer = tf.keras.optimizers.experimental.AdamW(
 | 
					      self._optimizer = tf.keras.optimizers.experimental.AdamW(
 | 
				
			||||||
        lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0)
 | 
					          lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
      self._optimizer.exclude_from_weight_decay(
 | 
					      self._optimizer.exclude_from_weight_decay(
 | 
				
			||||||
        var_names=["LayerNorm", "layer_norm", "bias"])
 | 
					          var_names=["LayerNorm", "layer_norm", "bias"]
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					    elif self._hparams.optimizer == hp.BertOptimizer.LAMB:
 | 
				
			||||||
 | 
					      self._optimizer = tfa_optimizers.LAMB(
 | 
				
			||||||
 | 
					          lr_schedule,
 | 
				
			||||||
 | 
					          weight_decay_rate=0.01,
 | 
				
			||||||
 | 
					          epsilon=1e-6,
 | 
				
			||||||
 | 
					          exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
 | 
				
			||||||
 | 
					          global_clipnorm=1.0,
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					      raise ValueError(
 | 
				
			||||||
 | 
					          "BertHParams.optimizer must be set to ADAM or "
 | 
				
			||||||
 | 
					          f"LAMB. Got {self._hparams.optimizer}."
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def _save_vocab(self, vocab_filepath: str):
 | 
					  def _save_vocab(self, vocab_filepath: str):
 | 
				
			||||||
    tf.io.gfile.copy(
 | 
					    tf.io.gfile.copy(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -66,14 +66,16 @@ def run(data_dir,
 | 
				
			||||||
  quantization_config = None
 | 
					  quantization_config = None
 | 
				
			||||||
  if (supported_model ==
 | 
					  if (supported_model ==
 | 
				
			||||||
      text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
 | 
					      text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
 | 
				
			||||||
    hparams = text_classifier.HParams(
 | 
					    hparams = text_classifier.AverageWordEmbeddingHParams(
 | 
				
			||||||
        epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir)
 | 
					        epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
  # Warning: This takes extremely long to run on CPU
 | 
					  # Warning: This takes extremely long to run on CPU
 | 
				
			||||||
  elif (
 | 
					  elif (
 | 
				
			||||||
      supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER):
 | 
					      supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER):
 | 
				
			||||||
    quantization_config = quantization.QuantizationConfig.for_dynamic()
 | 
					    quantization_config = quantization.QuantizationConfig.for_dynamic()
 | 
				
			||||||
    hparams = text_classifier.HParams(
 | 
					    hparams = text_classifier.BertHParams(
 | 
				
			||||||
        epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir)
 | 
					        epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # Fine-tunes the model.
 | 
					  # Fine-tunes the model.
 | 
				
			||||||
  options = text_classifier.TextClassifierOptions(
 | 
					  options = text_classifier.TextClassifierOptions(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -16,7 +16,7 @@
 | 
				
			||||||
import dataclasses
 | 
					import dataclasses
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
					from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_options as mo
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -34,5 +34,5 @@ class TextClassifierOptions:
 | 
				
			||||||
      architecture of the `supported_model`.
 | 
					      architecture of the `supported_model`.
 | 
				
			||||||
  """
 | 
					  """
 | 
				
			||||||
  supported_model: ms.SupportedModels
 | 
					  supported_model: ms.SupportedModels
 | 
				
			||||||
  hparams: Optional[hp.BaseHParams] = None
 | 
					  hparams: Optional[hp.HParams] = None
 | 
				
			||||||
  model_options: Optional[mo.TextClassifierModelOptions] = None
 | 
					  model_options: Optional[mo.TextClassifierModelOptions] = None
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -66,12 +66,14 @@ class TextClassifierTest(tf.test.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_create_and_train_average_word_embedding_model(self):
 | 
					  def test_create_and_train_average_word_embedding_model(self):
 | 
				
			||||||
    train_data, validation_data = self._get_data()
 | 
					    train_data, validation_data = self._get_data()
 | 
				
			||||||
    options = (
 | 
					    options = text_classifier.TextClassifierOptions(
 | 
				
			||||||
        text_classifier.TextClassifierOptions(
 | 
					        supported_model=(
 | 
				
			||||||
            supported_model=(text_classifier.SupportedModels
 | 
					            text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
 | 
				
			||||||
                             .AVERAGE_WORD_EMBEDDING_CLASSIFIER),
 | 
					        ),
 | 
				
			||||||
            hparams=text_classifier.HParams(
 | 
					        hparams=text_classifier.AverageWordEmbeddingHParams(
 | 
				
			||||||
                epochs=1, batch_size=1, learning_rate=0)))
 | 
					            epochs=1, batch_size=1, learning_rate=0
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    average_word_embedding_classifier = (
 | 
					    average_word_embedding_classifier = (
 | 
				
			||||||
        text_classifier.TextClassifier.create(train_data, validation_data,
 | 
					        text_classifier.TextClassifier.create(train_data, validation_data,
 | 
				
			||||||
                                              options))
 | 
					                                              options))
 | 
				
			||||||
| 
						 | 
					@ -103,12 +105,15 @@ class TextClassifierTest(tf.test.TestCase):
 | 
				
			||||||
    options = text_classifier.TextClassifierOptions(
 | 
					    options = text_classifier.TextClassifierOptions(
 | 
				
			||||||
        supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
 | 
					        supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
 | 
				
			||||||
        model_options=text_classifier.BertModelOptions(
 | 
					        model_options=text_classifier.BertModelOptions(
 | 
				
			||||||
            do_fine_tuning=False, seq_len=2),
 | 
					            do_fine_tuning=False, seq_len=2
 | 
				
			||||||
        hparams=text_classifier.HParams(
 | 
					        ),
 | 
				
			||||||
 | 
					        hparams=text_classifier.BertHParams(
 | 
				
			||||||
            epochs=1,
 | 
					            epochs=1,
 | 
				
			||||||
            batch_size=1,
 | 
					            batch_size=1,
 | 
				
			||||||
            learning_rate=3e-5,
 | 
					            learning_rate=3e-5,
 | 
				
			||||||
            distribution_strategy='off'))
 | 
					            distribution_strategy='off',
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    bert_classifier = text_classifier.TextClassifier.create(
 | 
					    bert_classifier = text_classifier.TextClassifier.create(
 | 
				
			||||||
        train_data, validation_data, options)
 | 
					        train_data, validation_data, options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user