Improve image classifier model maker documentation and replace the legacy code.
PiperOrigin-RevId: 481979922
This commit is contained in:
parent
51879ae81a
commit
f2821d840d
|
@ -24,12 +24,12 @@ from mediapipe.model_maker.python.core.data import classification_dataset
|
|||
|
||||
|
||||
def _load_image(path: str) -> tf.Tensor:
|
||||
"""Loads image."""
|
||||
"""Loads a jpeg/png image and returns an image tensor."""
|
||||
image_raw = tf.io.read_file(path)
|
||||
image_tensor = tf.cond(
|
||||
tf.image.is_jpeg(image_raw),
|
||||
lambda: tf.image.decode_jpeg(image_raw, channels=3),
|
||||
lambda: tf.image.decode_png(image_raw, channels=3))
|
||||
tf.io.is_jpeg(image_raw),
|
||||
lambda: tf.io.decode_jpeg(image_raw, channels=3),
|
||||
lambda: tf.io.decode_png(image_raw, channels=3))
|
||||
return image_tensor
|
||||
|
||||
|
||||
|
@ -60,11 +60,10 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
|
||||
Args:
|
||||
dirname: Name of the directory containing the data files.
|
||||
shuffle: boolean, if shuffle, random shuffle data.
|
||||
shuffle: boolean, if true, random shuffle data.
|
||||
|
||||
Returns:
|
||||
Dataset containing images and labels and other related info.
|
||||
|
||||
Raises:
|
||||
ValueError: if the input data directory is empty.
|
||||
"""
|
||||
|
@ -94,20 +93,20 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
|
||||
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
|
||||
|
||||
autotune = tf.data.AUTOTUNE
|
||||
image_ds = path_ds.map(_load_image, num_parallel_calls=autotune)
|
||||
image_ds = path_ds.map(_load_image, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
|
||||
# Loads label.
|
||||
# Load label
|
||||
label_ds = tf.data.Dataset.from_tensor_slices(
|
||||
tf.cast(all_image_labels, tf.int64))
|
||||
|
||||
# Creates a dataset if (image, label) pairs.
|
||||
# Create a dataset if (image, label) pairs
|
||||
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
|
||||
|
||||
tf.compat.v1.logging.info(
|
||||
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
|
||||
all_label_size, ', '.join(label_names))
|
||||
return Dataset(image_label_ds, all_image_size, label_names)
|
||||
return Dataset(
|
||||
dataset=image_label_ds, size=all_image_size, index_to_label=label_names)
|
||||
|
||||
@classmethod
|
||||
def load_tf_dataset(
|
||||
|
|
Loading…
Reference in New Issue
Block a user