Update image_classifier demo with new ImageClassifierOption changes
PiperOrigin-RevId: 489031381
This commit is contained in:
parent
512a531b9e
commit
74474d859e
|
@ -61,12 +61,14 @@ def run(data_dir: str, export_dir: str,
|
|||
data = image_classifier.Dataset.from_folder(data_dir)
|
||||
train_data, rest_data = data.split(0.8)
|
||||
validation_data, test_data = rest_data.split(0.5)
|
||||
|
||||
model_options = image_classifier.ImageClassifierOptions(
|
||||
supported_model=model_spec,
|
||||
hparams=image_classifier.HParams(export_dir=export_dir),
|
||||
)
|
||||
model = image_classifier.ImageClassifier.create(
|
||||
model_spec=model_spec,
|
||||
train_data=train_data,
|
||||
validation_data=validation_data,
|
||||
hparams=image_classifier.HParams(model_dir=export_dir))
|
||||
options=model_options)
|
||||
|
||||
_, acc = model.evaluate(test_data)
|
||||
print('Test accuracy: %f' % acc)
|
||||
|
@ -83,7 +85,6 @@ def run(data_dir: str, export_dir: str,
|
|||
raise ValueError(f'Quantization: {quantization} is not recognized')
|
||||
|
||||
model.export_model(quantization_config=quantization_config)
|
||||
model.export_labels(export_dir)
|
||||
|
||||
|
||||
def main(_) -> None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user