Update image_classifier demo with new ImageClassifierOption changes
PiperOrigin-RevId: 489031381
This commit is contained in:
parent
512a531b9e
commit
74474d859e
|
@ -61,12 +61,14 @@ def run(data_dir: str, export_dir: str,
|
||||||
data = image_classifier.Dataset.from_folder(data_dir)
|
data = image_classifier.Dataset.from_folder(data_dir)
|
||||||
train_data, rest_data = data.split(0.8)
|
train_data, rest_data = data.split(0.8)
|
||||||
validation_data, test_data = rest_data.split(0.5)
|
validation_data, test_data = rest_data.split(0.5)
|
||||||
|
model_options = image_classifier.ImageClassifierOptions(
|
||||||
|
supported_model=model_spec,
|
||||||
|
hparams=image_classifier.HParams(export_dir=export_dir),
|
||||||
|
)
|
||||||
model = image_classifier.ImageClassifier.create(
|
model = image_classifier.ImageClassifier.create(
|
||||||
model_spec=model_spec,
|
|
||||||
train_data=train_data,
|
train_data=train_data,
|
||||||
validation_data=validation_data,
|
validation_data=validation_data,
|
||||||
hparams=image_classifier.HParams(model_dir=export_dir))
|
options=model_options)
|
||||||
|
|
||||||
_, acc = model.evaluate(test_data)
|
_, acc = model.evaluate(test_data)
|
||||||
print('Test accuracy: %f' % acc)
|
print('Test accuracy: %f' % acc)
|
||||||
|
@ -83,7 +85,6 @@ def run(data_dir: str, export_dir: str,
|
||||||
raise ValueError(f'Quantization: {quantization} is not recognized')
|
raise ValueError(f'Quantization: {quantization} is not recognized')
|
||||||
|
|
||||||
model.export_model(quantization_config=quantization_config)
|
model.export_model(quantization_config=quantization_config)
|
||||||
model.export_labels(export_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def main(_) -> None:
|
def main(_) -> None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user