Update image_classifier demo with new ImageClassifierOption changes

PiperOrigin-RevId: 489031381
This commit is contained in:
MediaPipe Team 2022-11-16 13:58:21 -08:00 committed by Copybara-Service
parent 512a531b9e
commit 74474d859e

View File

@ -61,12 +61,14 @@ def run(data_dir: str, export_dir: str,
data = image_classifier.Dataset.from_folder(data_dir) data = image_classifier.Dataset.from_folder(data_dir)
train_data, rest_data = data.split(0.8) train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5) validation_data, test_data = rest_data.split(0.5)
model_options = image_classifier.ImageClassifierOptions(
supported_model=model_spec,
hparams=image_classifier.HParams(export_dir=export_dir),
)
model = image_classifier.ImageClassifier.create( model = image_classifier.ImageClassifier.create(
model_spec=model_spec,
train_data=train_data, train_data=train_data,
validation_data=validation_data, validation_data=validation_data,
hparams=image_classifier.HParams(model_dir=export_dir)) options=model_options)
_, acc = model.evaluate(test_data) _, acc = model.evaluate(test_data)
print('Test accuracy: %f' % acc) print('Test accuracy: %f' % acc)
@ -83,7 +85,6 @@ def run(data_dir: str, export_dir: str,
raise ValueError(f'Quantization: {quantization} is not recognized') raise ValueError(f'Quantization: {quantization} is not recognized')
model.export_model(quantization_config=quantization_config) model.export_model(quantization_config=quantization_config)
model.export_labels(export_dir)
def main(_) -> None: def main(_) -> None: