Remove unnecessary architectures from image classifier ModelSpec
PiperOrigin-RevId: 481974529
This commit is contained in:
parent
bc47589c9b
commit
51879ae81a
|
@ -60,11 +60,6 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
model_spec=image_classifier.SupportedModels.MOBILENET_V2,
|
model_spec=image_classifier.SupportedModels.MOBILENET_V2,
|
||||||
hparams=image_classifier.HParams(
|
hparams=image_classifier.HParams(
|
||||||
train_epochs=1, batch_size=1, shuffle=True)),
|
train_epochs=1, batch_size=1, shuffle=True)),
|
||||||
dict(
|
|
||||||
testcase_name='resnet_50',
|
|
||||||
model_spec=image_classifier.SupportedModels.RESNET_50,
|
|
||||||
hparams=image_classifier.HParams(
|
|
||||||
train_epochs=1, batch_size=1, shuffle=True)),
|
|
||||||
dict(
|
dict(
|
||||||
testcase_name='efficientnet_lite0',
|
testcase_name='efficientnet_lite0',
|
||||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
|
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
|
||||||
|
@ -75,21 +70,6 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE1,
|
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE1,
|
||||||
hparams=image_classifier.HParams(
|
hparams=image_classifier.HParams(
|
||||||
train_epochs=1, batch_size=1, shuffle=True)),
|
train_epochs=1, batch_size=1, shuffle=True)),
|
||||||
dict(
|
|
||||||
testcase_name='efficientnet_lite2',
|
|
||||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE2,
|
|
||||||
hparams=image_classifier.HParams(
|
|
||||||
train_epochs=1, batch_size=1, shuffle=True)),
|
|
||||||
dict(
|
|
||||||
testcase_name='efficientnet_lite3',
|
|
||||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE3,
|
|
||||||
hparams=image_classifier.HParams(
|
|
||||||
train_epochs=1, batch_size=1, shuffle=True)),
|
|
||||||
dict(
|
|
||||||
testcase_name='efficientnet_lite4',
|
|
||||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE4,
|
|
||||||
hparams=image_classifier.HParams(
|
|
||||||
train_epochs=1, batch_size=1, shuffle=True)),
|
|
||||||
)
|
)
|
||||||
def test_create_and_train_model(self,
|
def test_create_and_train_model(self,
|
||||||
model_spec: image_classifier.SupportedModels,
|
model_spec: image_classifier.SupportedModels,
|
||||||
|
|
|
@ -48,11 +48,6 @@ mobilenet_v2_spec = functools.partial(
|
||||||
uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4',
|
uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4',
|
||||||
name='mobilenet_v2')
|
name='mobilenet_v2')
|
||||||
|
|
||||||
resnet_50_spec = functools.partial(
|
|
||||||
ModelSpec,
|
|
||||||
uri='https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4',
|
|
||||||
name='resnet_50')
|
|
||||||
|
|
||||||
efficientnet_lite0_spec = functools.partial(
|
efficientnet_lite0_spec = functools.partial(
|
||||||
ModelSpec,
|
ModelSpec,
|
||||||
uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2',
|
uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2',
|
||||||
|
@ -64,36 +59,14 @@ efficientnet_lite1_spec = functools.partial(
|
||||||
input_image_shape=[240, 240],
|
input_image_shape=[240, 240],
|
||||||
name='efficientnet_lite1')
|
name='efficientnet_lite1')
|
||||||
|
|
||||||
efficientnet_lite2_spec = functools.partial(
|
|
||||||
ModelSpec,
|
|
||||||
uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2',
|
|
||||||
input_image_shape=[260, 260],
|
|
||||||
name='efficientnet_lite2')
|
|
||||||
|
|
||||||
efficientnet_lite3_spec = functools.partial(
|
|
||||||
ModelSpec,
|
|
||||||
uri='https://tfhub.dev/tensorflow/efficientnet/lite3/feature-vector/2',
|
|
||||||
input_image_shape=[280, 280],
|
|
||||||
name='efficientnet_lite3')
|
|
||||||
|
|
||||||
efficientnet_lite4_spec = functools.partial(
|
|
||||||
ModelSpec,
|
|
||||||
uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2',
|
|
||||||
input_image_shape=[300, 300],
|
|
||||||
name='efficientnet_lite4')
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Document the exposed models.
|
# TODO: Document the exposed models.
|
||||||
@enum.unique
|
@enum.unique
|
||||||
class SupportedModels(enum.Enum):
|
class SupportedModels(enum.Enum):
|
||||||
"""Image classifier model supported by model maker."""
|
"""Image classifier model supported by model maker."""
|
||||||
MOBILENET_V2 = mobilenet_v2_spec
|
MOBILENET_V2 = mobilenet_v2_spec
|
||||||
RESNET_50 = resnet_50_spec
|
|
||||||
EFFICIENTNET_LITE0 = efficientnet_lite0_spec
|
EFFICIENTNET_LITE0 = efficientnet_lite0_spec
|
||||||
EFFICIENTNET_LITE1 = efficientnet_lite1_spec
|
EFFICIENTNET_LITE1 = efficientnet_lite1_spec
|
||||||
EFFICIENTNET_LITE2 = efficientnet_lite2_spec
|
|
||||||
EFFICIENTNET_LITE3 = efficientnet_lite3_spec
|
|
||||||
EFFICIENTNET_LITE4 = efficientnet_lite4_spec
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
|
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
|
||||||
|
|
|
@ -30,12 +30,6 @@ class ModelSpecTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
expected_uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4',
|
expected_uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4',
|
||||||
expected_name='mobilenet_v2',
|
expected_name='mobilenet_v2',
|
||||||
expected_input_image_shape=[224, 224]),
|
expected_input_image_shape=[224, 224]),
|
||||||
dict(
|
|
||||||
testcase_name='resnet_50_spec_test',
|
|
||||||
model_spec=ms.resnet_50_spec,
|
|
||||||
expected_uri='https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4',
|
|
||||||
expected_name='resnet_50',
|
|
||||||
expected_input_image_shape=[224, 224]),
|
|
||||||
dict(
|
dict(
|
||||||
testcase_name='efficientnet_lite0_spec_test',
|
testcase_name='efficientnet_lite0_spec_test',
|
||||||
model_spec=ms.efficientnet_lite0_spec,
|
model_spec=ms.efficientnet_lite0_spec,
|
||||||
|
@ -48,24 +42,6 @@ class ModelSpecTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2',
|
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2',
|
||||||
expected_name='efficientnet_lite1',
|
expected_name='efficientnet_lite1',
|
||||||
expected_input_image_shape=[240, 240]),
|
expected_input_image_shape=[240, 240]),
|
||||||
dict(
|
|
||||||
testcase_name='efficientnet_lite2_spec_test',
|
|
||||||
model_spec=ms.efficientnet_lite2_spec,
|
|
||||||
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2',
|
|
||||||
expected_name='efficientnet_lite2',
|
|
||||||
expected_input_image_shape=[260, 260]),
|
|
||||||
dict(
|
|
||||||
testcase_name='efficientnet_lite3_spec_test',
|
|
||||||
model_spec=ms.efficientnet_lite3_spec,
|
|
||||||
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite3/feature-vector/2',
|
|
||||||
expected_name='efficientnet_lite3',
|
|
||||||
expected_input_image_shape=[280, 280]),
|
|
||||||
dict(
|
|
||||||
testcase_name='efficientnet_lite4_spec_test',
|
|
||||||
model_spec=ms.efficientnet_lite4_spec,
|
|
||||||
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2',
|
|
||||||
expected_name='efficientnet_lite4',
|
|
||||||
expected_input_image_shape=[300, 300]),
|
|
||||||
)
|
)
|
||||||
def test_predefiend_spec(self, model_spec: Callable[..., ms.ModelSpec],
|
def test_predefiend_spec(self, model_spec: Callable[..., ms.ModelSpec],
|
||||||
expected_uri: str, expected_name: str,
|
expected_uri: str, expected_name: str,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user