Cleans up unused function from image_classifier's Dataset.
PiperOrigin-RevId: 482703775
This commit is contained in:
parent
348c4e6652
commit
55ba23ce9a
|
@ -16,7 +16,7 @@
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow_datasets as tfds
|
import tensorflow_datasets as tfds
|
||||||
|
|
||||||
|
@ -107,32 +107,3 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
all_label_size, ', '.join(label_names))
|
all_label_size, ', '.join(label_names))
|
||||||
return Dataset(
|
return Dataset(
|
||||||
dataset=image_label_ds, size=all_image_size, index_by_label=label_names)
|
dataset=image_label_ds, size=all_image_size, index_by_label=label_names)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load_tf_dataset(
|
|
||||||
cls, name: str
|
|
||||||
) -> Tuple[Optional[classification_dataset.ClassificationDataset],
|
|
||||||
Optional[classification_dataset.ClassificationDataset],
|
|
||||||
Optional[classification_dataset.ClassificationDataset]]:
|
|
||||||
"""Loads data from tensorflow_datasets.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: the registered name of the tfds.core.DatasetBuilder. Refer to the
|
|
||||||
documentation of tfds.load for more details.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple of Datasets for the train/validation/test.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if the input tf dataset does not have train/validation/test
|
|
||||||
labels.
|
|
||||||
"""
|
|
||||||
data, info = tfds.load(name, with_info=True)
|
|
||||||
if 'label' not in info.features:
|
|
||||||
raise ValueError('info.features need to contain \'label\' key.')
|
|
||||||
label_names = info.features['label'].names
|
|
||||||
|
|
||||||
train_data = _create_data('train', data, info, label_names)
|
|
||||||
validation_data = _create_data('validation', data, info, label_names)
|
|
||||||
test_data = _create_data('test', data, info, label_names)
|
|
||||||
return train_data, validation_data, test_data
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user