Improve image classifier model maker documentation and replace the legacy code.

PiperOrigin-RevId: 481979922
This commit is contained in:
MediaPipe Team 2022-10-18 11:45:36 -07:00 committed by Copybara-Service
parent 51879ae81a
commit f2821d840d

View File

@ -24,12 +24,12 @@ from mediapipe.model_maker.python.core.data import classification_dataset
def _load_image(path: str) -> tf.Tensor:
"""Loads image."""
"""Loads a jpeg/png image and returns an image tensor."""
image_raw = tf.io.read_file(path)
image_tensor = tf.cond(
tf.image.is_jpeg(image_raw),
lambda: tf.image.decode_jpeg(image_raw, channels=3),
lambda: tf.image.decode_png(image_raw, channels=3))
tf.io.is_jpeg(image_raw),
lambda: tf.io.decode_jpeg(image_raw, channels=3),
lambda: tf.io.decode_png(image_raw, channels=3))
return image_tensor
@ -60,11 +60,10 @@ class Dataset(classification_dataset.ClassificationDataset):
Args:
dirname: Name of the directory containing the data files.
shuffle: boolean, if shuffle, random shuffle data.
shuffle: boolean, if true, random shuffle data.
Returns:
Dataset containing images and labels and other related info.
Raises:
ValueError: if the input data directory is empty.
"""
@ -94,20 +93,20 @@ class Dataset(classification_dataset.ClassificationDataset):
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
autotune = tf.data.AUTOTUNE
image_ds = path_ds.map(_load_image, num_parallel_calls=autotune)
image_ds = path_ds.map(_load_image, num_parallel_calls=tf.data.AUTOTUNE)
# Loads label.
# Load label
label_ds = tf.data.Dataset.from_tensor_slices(
tf.cast(all_image_labels, tf.int64))
# Creates a dataset if (image, label) pairs.
# Create a dataset if (image, label) pairs
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
tf.compat.v1.logging.info(
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
all_label_size, ', '.join(label_names))
return Dataset(image_label_ds, all_image_size, label_names)
return Dataset(
dataset=image_label_ds, size=all_image_size, index_to_label=label_names)
@classmethod
def load_tf_dataset(