Internal change

PiperOrigin-RevId: 482259130
This commit is contained in:
MediaPipe Team 2022-10-19 11:32:42 -07:00 committed by Copybara-Service
parent c260074abb
commit 3d588bae8b
3 changed files with 30 additions and 12 deletions

View File

@ -66,8 +66,13 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
dict(
testcase_name='efficientnet_lite1',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE1,
testcase_name='efficientnet_lite2',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE2,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
dict(
testcase_name='efficientnet_lite4',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE4,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
)

View File

@ -53,11 +53,17 @@ efficientnet_lite0_spec = functools.partial(
uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2',
name='efficientnet_lite0')
efficientnet_lite1_spec = functools.partial(
efficientnet_lite2_spec = functools.partial(
ModelSpec,
uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2',
input_image_shape=[240, 240],
name='efficientnet_lite1')
uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2',
input_image_shape=[260, 260],
name='efficientnet_lite2')
efficientnet_lite4_spec = functools.partial(
ModelSpec,
uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2',
input_image_shape=[300, 300],
name='efficientnet_lite4')
# TODO: Document the exposed models.
@ -66,7 +72,8 @@ class SupportedModels(enum.Enum):
"""Image classifier model supported by model maker."""
MOBILENET_V2 = mobilenet_v2_spec
EFFICIENTNET_LITE0 = efficientnet_lite0_spec
EFFICIENTNET_LITE1 = efficientnet_lite1_spec
EFFICIENTNET_LITE2 = efficientnet_lite2_spec
EFFICIENTNET_LITE4 = efficientnet_lite4_spec
@classmethod
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':

View File

@ -37,11 +37,17 @@ class ModelSpecTest(tf.test.TestCase, parameterized.TestCase):
expected_name='efficientnet_lite0',
expected_input_image_shape=[224, 224]),
dict(
testcase_name='efficientnet_lite1_spec_test',
model_spec=ms.efficientnet_lite1_spec,
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2',
expected_name='efficientnet_lite1',
expected_input_image_shape=[240, 240]),
testcase_name='efficientnet_lite2_spec_test',
model_spec=ms.efficientnet_lite2_spec,
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2',
expected_name='efficientnet_lite2',
expected_input_image_shape=[260, 260]),
dict(
testcase_name='efficientnet_lite4_spec_test',
model_spec=ms.efficientnet_lite4_spec,
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2',
expected_name='efficientnet_lite4',
expected_input_image_shape=[300, 300]),
)
def test_predefiend_spec(self, model_spec: Callable[..., ms.ModelSpec],
expected_uri: str, expected_name: str,