Update image_classifier demo with new ImageClassifierOption changes
PiperOrigin-RevId: 489031381
This commit is contained in:
		
							parent
							
								
									512a531b9e
								
							
						
					
					
						commit
						74474d859e
					
				| 
						 | 
				
			
			@ -61,12 +61,14 @@ def run(data_dir: str, export_dir: str,
 | 
			
		|||
  data = image_classifier.Dataset.from_folder(data_dir)
 | 
			
		||||
  train_data, rest_data = data.split(0.8)
 | 
			
		||||
  validation_data, test_data = rest_data.split(0.5)
 | 
			
		||||
 | 
			
		||||
  model_options = image_classifier.ImageClassifierOptions(
 | 
			
		||||
      supported_model=model_spec,
 | 
			
		||||
      hparams=image_classifier.HParams(export_dir=export_dir),
 | 
			
		||||
  )
 | 
			
		||||
  model = image_classifier.ImageClassifier.create(
 | 
			
		||||
      model_spec=model_spec,
 | 
			
		||||
      train_data=train_data,
 | 
			
		||||
      validation_data=validation_data,
 | 
			
		||||
      hparams=image_classifier.HParams(model_dir=export_dir))
 | 
			
		||||
      options=model_options)
 | 
			
		||||
 | 
			
		||||
  _, acc = model.evaluate(test_data)
 | 
			
		||||
  print('Test accuracy: %f' % acc)
 | 
			
		||||
| 
						 | 
				
			
			@ -83,7 +85,6 @@ def run(data_dir: str, export_dir: str,
 | 
			
		|||
    raise ValueError(f'Quantization: {quantization} is not recognized')
 | 
			
		||||
 | 
			
		||||
  model.export_model(quantization_config=quantization_config)
 | 
			
		||||
  model.export_labels(export_dir)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(_) -> None:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user