Improve image classifier model maker documentation and replace the legacy code.
PiperOrigin-RevId: 481979922
This commit is contained in:
		
							parent
							
								
									51879ae81a
								
							
						
					
					
						commit
						f2821d840d
					
				|  | @ -24,12 +24,12 @@ from mediapipe.model_maker.python.core.data import classification_dataset | |||
| 
 | ||||
| 
 | ||||
| def _load_image(path: str) -> tf.Tensor: | ||||
|   """Loads image.""" | ||||
|   """Loads a jpeg/png image and returns an image tensor.""" | ||||
|   image_raw = tf.io.read_file(path) | ||||
|   image_tensor = tf.cond( | ||||
|       tf.image.is_jpeg(image_raw), | ||||
|       lambda: tf.image.decode_jpeg(image_raw, channels=3), | ||||
|       lambda: tf.image.decode_png(image_raw, channels=3)) | ||||
|       tf.io.is_jpeg(image_raw), | ||||
|       lambda: tf.io.decode_jpeg(image_raw, channels=3), | ||||
|       lambda: tf.io.decode_png(image_raw, channels=3)) | ||||
|   return image_tensor | ||||
| 
 | ||||
| 
 | ||||
|  | @ -60,11 +60,10 @@ class Dataset(classification_dataset.ClassificationDataset): | |||
| 
 | ||||
|     Args: | ||||
|       dirname: Name of the directory containing the data files. | ||||
|       shuffle: boolean, if shuffle, random shuffle data. | ||||
|       shuffle: boolean, if true, random shuffle data. | ||||
| 
 | ||||
|     Returns: | ||||
|       Dataset containing images and labels and other related info. | ||||
| 
 | ||||
|     Raises: | ||||
|       ValueError: if the input data directory is empty. | ||||
|     """ | ||||
|  | @ -94,20 +93,20 @@ class Dataset(classification_dataset.ClassificationDataset): | |||
| 
 | ||||
|     path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths) | ||||
| 
 | ||||
|     autotune = tf.data.AUTOTUNE | ||||
|     image_ds = path_ds.map(_load_image, num_parallel_calls=autotune) | ||||
|     image_ds = path_ds.map(_load_image, num_parallel_calls=tf.data.AUTOTUNE) | ||||
| 
 | ||||
|     # Loads label. | ||||
|     # Load label | ||||
|     label_ds = tf.data.Dataset.from_tensor_slices( | ||||
|         tf.cast(all_image_labels, tf.int64)) | ||||
| 
 | ||||
|     # Creates  a dataset if (image, label) pairs. | ||||
|     # Create a dataset if (image, label) pairs | ||||
|     image_label_ds = tf.data.Dataset.zip((image_ds, label_ds)) | ||||
| 
 | ||||
|     tf.compat.v1.logging.info( | ||||
|         'Load image with size: %d, num_label: %d, labels: %s.', all_image_size, | ||||
|         all_label_size, ', '.join(label_names)) | ||||
|     return Dataset(image_label_ds, all_image_size, label_names) | ||||
|     return Dataset( | ||||
|         dataset=image_label_ds, size=all_image_size, index_to_label=label_names) | ||||
| 
 | ||||
|   @classmethod | ||||
|   def load_tf_dataset( | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user