From 74474d859e0891fc97b4038b7b8ecb9420c4b522 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 16 Nov 2022 13:58:21 -0800 Subject: [PATCH] Update image_classifier demo with new ImageClassifierOption changes PiperOrigin-RevId: 489031381 --- .../vision/image_classifier/image_classifier_demo.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py index 5832ea53a..f382e28aa 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py @@ -61,12 +61,14 @@ def run(data_dir: str, export_dir: str, data = image_classifier.Dataset.from_folder(data_dir) train_data, rest_data = data.split(0.8) validation_data, test_data = rest_data.split(0.5) - + model_options = image_classifier.ImageClassifierOptions( + supported_model=model_spec, + hparams=image_classifier.HParams(export_dir=export_dir), + ) model = image_classifier.ImageClassifier.create( - model_spec=model_spec, train_data=train_data, validation_data=validation_data, - hparams=image_classifier.HParams(model_dir=export_dir)) + options=model_options) _, acc = model.evaluate(test_data) print('Test accuracy: %f' % acc) @@ -83,7 +85,6 @@ def run(data_dir: str, export_dir: str, raise ValueError(f'Quantization: {quantization} is not recognized') model.export_model(quantization_config=quantization_config) - model.export_labels(export_dir) def main(_) -> None: