Improve image classifier model maker documentation and replace the legacy code.
PiperOrigin-RevId: 481979922
This commit is contained in:
parent
51879ae81a
commit
f2821d840d
|
@ -24,12 +24,12 @@ from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
|
|
||||||
|
|
||||||
def _load_image(path: str) -> tf.Tensor:
|
def _load_image(path: str) -> tf.Tensor:
|
||||||
"""Loads image."""
|
"""Loads a jpeg/png image and returns an image tensor."""
|
||||||
image_raw = tf.io.read_file(path)
|
image_raw = tf.io.read_file(path)
|
||||||
image_tensor = tf.cond(
|
image_tensor = tf.cond(
|
||||||
tf.image.is_jpeg(image_raw),
|
tf.io.is_jpeg(image_raw),
|
||||||
lambda: tf.image.decode_jpeg(image_raw, channels=3),
|
lambda: tf.io.decode_jpeg(image_raw, channels=3),
|
||||||
lambda: tf.image.decode_png(image_raw, channels=3))
|
lambda: tf.io.decode_png(image_raw, channels=3))
|
||||||
return image_tensor
|
return image_tensor
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,11 +60,10 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dirname: Name of the directory containing the data files.
|
dirname: Name of the directory containing the data files.
|
||||||
shuffle: boolean, if shuffle, random shuffle data.
|
shuffle: boolean, if true, random shuffle data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dataset containing images and labels and other related info.
|
Dataset containing images and labels and other related info.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the input data directory is empty.
|
ValueError: if the input data directory is empty.
|
||||||
"""
|
"""
|
||||||
|
@ -94,20 +93,20 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
|
|
||||||
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
|
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
|
||||||
|
|
||||||
autotune = tf.data.AUTOTUNE
|
image_ds = path_ds.map(_load_image, num_parallel_calls=tf.data.AUTOTUNE)
|
||||||
image_ds = path_ds.map(_load_image, num_parallel_calls=autotune)
|
|
||||||
|
|
||||||
# Loads label.
|
# Load label
|
||||||
label_ds = tf.data.Dataset.from_tensor_slices(
|
label_ds = tf.data.Dataset.from_tensor_slices(
|
||||||
tf.cast(all_image_labels, tf.int64))
|
tf.cast(all_image_labels, tf.int64))
|
||||||
|
|
||||||
# Creates a dataset if (image, label) pairs.
|
# Create a dataset if (image, label) pairs
|
||||||
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
|
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
|
||||||
|
|
||||||
tf.compat.v1.logging.info(
|
tf.compat.v1.logging.info(
|
||||||
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
|
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
|
||||||
all_label_size, ', '.join(label_names))
|
all_label_size, ', '.join(label_names))
|
||||||
return Dataset(image_label_ds, all_image_size, label_names)
|
return Dataset(
|
||||||
|
dataset=image_label_ds, size=all_image_size, index_to_label=label_names)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_tf_dataset(
|
def load_tf_dataset(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user