Remove duplicate and non-public api for model_maker

PiperOrigin-RevId: 497251246
This commit is contained in:
MediaPipe Team 2022-12-22 15:29:18 -08:00 committed by Copybara-Service
parent 36f054dfbe
commit 5a71b551e5
8 changed files with 52 additions and 14 deletions

View File

@ -17,3 +17,6 @@ from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.vision import image_classifier from mediapipe.model_maker.python.vision import image_classifier
from mediapipe.model_maker.python.vision import gesture_recognizer from mediapipe.model_maker.python.vision import gesture_recognizer
from mediapipe.model_maker.python.text import text_classifier from mediapipe.model_maker.python.text import text_classifier
# Remove duplicated and non-public API
del python

View File

@ -29,3 +29,12 @@ BertModelOptions = model_options.BertModelOptions
SupportedModels = model_spec.SupportedModels SupportedModels = model_spec.SupportedModels
TextClassifier = text_classifier.TextClassifier TextClassifier = text_classifier.TextClassifier
TextClassifierOptions = text_classifier_options.TextClassifierOptions TextClassifierOptions = text_classifier_options.TextClassifierOptions
# Remove duplicated and non-public API
del hyperparameters
del dataset
del model_options
del model_spec
del preprocessor # pylint: disable=undefined-variable
del text_classifier
del text_classifier_options

View File

@ -146,6 +146,8 @@ py_test(
tags = ["notsan"], tags = ["notsan"],
deps = [ deps = [
":gesture_recognizer_import", ":gesture_recognizer_import",
":hyperparameters",
":model_options",
"//mediapipe/model_maker/python/core/utils:test_util", "//mediapipe/model_maker/python/core/utils:test_util",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
], ],

View File

@ -25,3 +25,12 @@ HParams = hyperparameters.HParams
Dataset = dataset.Dataset Dataset = dataset.Dataset
HandDataPreprocessingParams = dataset.HandDataPreprocessingParams HandDataPreprocessingParams = dataset.HandDataPreprocessingParams
GestureRecognizerOptions = gesture_recognizer_options.GestureRecognizerOptions GestureRecognizerOptions = gesture_recognizer_options.GestureRecognizerOptions
# Remove duplicated and non-public API
del constants # pylint: disable=undefined-variable
del dataset
del gesture_recognizer
del gesture_recognizer_options
del hyperparameters
del metadata_writer # pylint: disable=undefined-variable
del model_options

View File

@ -23,6 +23,8 @@ import tensorflow as tf
from mediapipe.model_maker.python.core.utils import test_util from mediapipe.model_maker.python.core.utils import test_util
from mediapipe.model_maker.python.vision import gesture_recognizer from mediapipe.model_maker.python.vision import gesture_recognizer
from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters
from mediapipe.model_maker.python.vision.gesture_recognizer import model_options
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
_TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdata' _TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdata'
@ -48,11 +50,11 @@ class GestureRecognizerTest(tf.test.TestCase):
self._train_data, self._validation_data = all_data.split(0.9) self._train_data, self._validation_data = all_data.split(0.9)
def test_gesture_recognizer_model(self): def test_gesture_recognizer_model(self):
model_options = gesture_recognizer.ModelOptions() mo = gesture_recognizer.ModelOptions()
hparams = gesture_recognizer.HParams( hparams = gesture_recognizer.HParams(
export_dir=tempfile.mkdtemp(), epochs=2) export_dir=tempfile.mkdtemp(), epochs=2)
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
model_options=model_options, hparams=hparams) model_options=mo, hparams=hparams)
model = gesture_recognizer.GestureRecognizer.create( model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data, train_data=self._train_data,
validation_data=self._validation_data, validation_data=self._validation_data,
@ -64,11 +66,11 @@ class GestureRecognizerTest(tf.test.TestCase):
tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense) tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense)
def test_gesture_recognizer_model_layer_widths(self, mock_dense): def test_gesture_recognizer_model_layer_widths(self, mock_dense):
layer_widths = [64, 32] layer_widths = [64, 32]
model_options = gesture_recognizer.ModelOptions(layer_widths=layer_widths) mo = gesture_recognizer.ModelOptions(layer_widths=layer_widths)
hparams = gesture_recognizer.HParams( hparams = gesture_recognizer.HParams(
export_dir=tempfile.mkdtemp(), epochs=2) export_dir=tempfile.mkdtemp(), epochs=2)
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
model_options=model_options, hparams=hparams) model_options=mo, hparams=hparams)
model = gesture_recognizer.GestureRecognizer.create( model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data, train_data=self._train_data,
validation_data=self._validation_data, validation_data=self._validation_data,
@ -87,11 +89,11 @@ class GestureRecognizerTest(tf.test.TestCase):
self._test_accuracy(model) self._test_accuracy(model)
def test_export_gesture_recognizer_model(self): def test_export_gesture_recognizer_model(self):
model_options = gesture_recognizer.ModelOptions() mo = gesture_recognizer.ModelOptions()
hparams = gesture_recognizer.HParams( hparams = gesture_recognizer.HParams(
export_dir=tempfile.mkdtemp(), epochs=2) export_dir=tempfile.mkdtemp(), epochs=2)
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
model_options=model_options, hparams=hparams) model_options=mo, hparams=hparams)
model = gesture_recognizer.GestureRecognizer.create( model = gesture_recognizer.GestureRecognizer.create(
train_data=self._train_data, train_data=self._train_data,
validation_data=self._validation_data, validation_data=self._validation_data,
@ -128,12 +130,12 @@ class GestureRecognizerTest(tf.test.TestCase):
self.assertGreater(accuracy, threshold) self.assertGreater(accuracy, threshold)
@unittest_mock.patch.object( @unittest_mock.patch.object(
gesture_recognizer.hyperparameters, hyperparameters,
'HParams', 'HParams',
autospec=True, autospec=True,
return_value=gesture_recognizer.HParams(epochs=1)) return_value=gesture_recognizer.HParams(epochs=1))
@unittest_mock.patch.object( @unittest_mock.patch.object(
gesture_recognizer.model_options, model_options,
'GestureRecognizerModelOptions', 'GestureRecognizerModelOptions',
autospec=True, autospec=True,
return_value=gesture_recognizer.ModelOptions()) return_value=gesture_recognizer.ModelOptions())
@ -148,11 +150,11 @@ class GestureRecognizerTest(tf.test.TestCase):
mock_model_options.assert_called_once() mock_model_options.assert_called_once()
def test_continual_training_by_loading_checkpoint(self): def test_continual_training_by_loading_checkpoint(self):
model_options = gesture_recognizer.ModelOptions() mo = gesture_recognizer.ModelOptions()
hparams = gesture_recognizer.HParams( hparams = gesture_recognizer.HParams(
export_dir=tempfile.mkdtemp(), epochs=2) export_dir=tempfile.mkdtemp(), epochs=2)
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
model_options=model_options, hparams=hparams) model_options=mo, hparams=hparams)
mock_stdout = io.StringIO() mock_stdout = io.StringIO()
with mock.patch('sys.stdout', mock_stdout): with mock.patch('sys.stdout', mock_stdout):
model = gesture_recognizer.GestureRecognizer.create( model = gesture_recognizer.GestureRecognizer.create(

View File

@ -121,7 +121,9 @@ py_library(
srcs = ["image_classifier_test.py"], srcs = ["image_classifier_test.py"],
data = ["//mediapipe/model_maker/python/vision/image_classifier/testdata"], data = ["//mediapipe/model_maker/python/vision/image_classifier/testdata"],
deps = [ deps = [
":hyperparameters",
":image_classifier_import", ":image_classifier_import",
":model_options",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
], ],
) )

View File

@ -27,3 +27,12 @@ ModelOptions = model_options.ImageClassifierModelOptions
ModelSpec = model_spec.ModelSpec ModelSpec = model_spec.ModelSpec
SupportedModels = model_spec.SupportedModels SupportedModels = model_spec.SupportedModels
ImageClassifierOptions = image_classifier_options.ImageClassifierOptions ImageClassifierOptions = image_classifier_options.ImageClassifierOptions
# Remove duplicated and non-public API
del dataset
del hyperparameters
del image_classifier
del image_classifier_options
del model_options
del model_spec
del train_image_classifier_lib # pylint: disable=undefined-variable

View File

@ -24,6 +24,8 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.vision import image_classifier from mediapipe.model_maker.python.vision import image_classifier
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
from mediapipe.model_maker.python.vision.image_classifier import model_options
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
@ -159,15 +161,15 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
self.assertGreaterEqual(accuracy, threshold) self.assertGreaterEqual(accuracy, threshold)
@unittest_mock.patch.object( @unittest_mock.patch.object(
image_classifier.hyperparameters, hyperparameters,
'HParams', 'HParams',
autospec=True, autospec=True,
return_value=image_classifier.HParams(epochs=1)) return_value=hyperparameters.HParams(epochs=1))
@unittest_mock.patch.object( @unittest_mock.patch.object(
image_classifier.model_options, model_options,
'ImageClassifierModelOptions', 'ImageClassifierModelOptions',
autospec=True, autospec=True,
return_value=image_classifier.ModelOptions()) return_value=model_options.ImageClassifierModelOptions())
def test_create_hparams_and_model_options_if_none_in_image_classifier_options( def test_create_hparams_and_model_options_if_none_in_image_classifier_options(
self, mock_hparams, mock_model_options): self, mock_hparams, mock_model_options):
options = image_classifier.ImageClassifierOptions( options = image_classifier.ImageClassifierOptions(