diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py index 3258a18a6..8ed6de7ad 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -66,8 +66,13 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): hparams=image_classifier.HParams( train_epochs=1, batch_size=1, shuffle=True)), dict( - testcase_name='efficientnet_lite1', - model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE1, + testcase_name='efficientnet_lite2', + model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE2, + hparams=image_classifier.HParams( + train_epochs=1, batch_size=1, shuffle=True)), + dict( + testcase_name='efficientnet_lite4', + model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE4, hparams=image_classifier.HParams( train_epochs=1, batch_size=1, shuffle=True)), ) diff --git a/mediapipe/model_maker/python/vision/image_classifier/model_spec.py b/mediapipe/model_maker/python/vision/image_classifier/model_spec.py index a38b77f86..ef44f86e6 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/model_spec.py +++ b/mediapipe/model_maker/python/vision/image_classifier/model_spec.py @@ -53,11 +53,17 @@ efficientnet_lite0_spec = functools.partial( uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2', name='efficientnet_lite0') -efficientnet_lite1_spec = functools.partial( +efficientnet_lite2_spec = functools.partial( ModelSpec, - uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2', - input_image_shape=[240, 240], - name='efficientnet_lite1') + uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2', + input_image_shape=[260, 260], + name='efficientnet_lite2') + +efficientnet_lite4_spec = functools.partial( + ModelSpec, + uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2', + input_image_shape=[300, 300], + name='efficientnet_lite4') # TODO: Document the exposed models. @@ -66,7 +72,8 @@ class SupportedModels(enum.Enum): """Image classifier model supported by model maker.""" MOBILENET_V2 = mobilenet_v2_spec EFFICIENTNET_LITE0 = efficientnet_lite0_spec - EFFICIENTNET_LITE1 = efficientnet_lite1_spec + EFFICIENTNET_LITE2 = efficientnet_lite2_spec + EFFICIENTNET_LITE4 = efficientnet_lite4_spec @classmethod def get(cls, spec: 'SupportedModels') -> 'ModelSpec': diff --git a/mediapipe/model_maker/python/vision/image_classifier/model_spec_test.py b/mediapipe/model_maker/python/vision/image_classifier/model_spec_test.py index cf6aa8f5b..63f360ab9 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/model_spec_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/model_spec_test.py @@ -37,11 +37,17 @@ class ModelSpecTest(tf.test.TestCase, parameterized.TestCase): expected_name='efficientnet_lite0', expected_input_image_shape=[224, 224]), dict( - testcase_name='efficientnet_lite1_spec_test', - model_spec=ms.efficientnet_lite1_spec, - expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2', - expected_name='efficientnet_lite1', - expected_input_image_shape=[240, 240]), + testcase_name='efficientnet_lite2_spec_test', + model_spec=ms.efficientnet_lite2_spec, + expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2', + expected_name='efficientnet_lite2', + expected_input_image_shape=[260, 260]), + dict( + testcase_name='efficientnet_lite4_spec_test', + model_spec=ms.efficientnet_lite4_spec, + expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2', + expected_name='efficientnet_lite4', + expected_input_image_shape=[300, 300]), ) def test_predefiend_spec(self, model_spec: Callable[..., ms.ModelSpec], expected_uri: str, expected_name: str,