Cleans up unused function from image_classifier's Dataset.
PiperOrigin-RevId: 482703775
This commit is contained in:
parent
348c4e6652
commit
55ba23ce9a
|
@ -16,7 +16,7 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
|
@ -107,32 +107,3 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
all_label_size, ', '.join(label_names))
|
||||
return Dataset(
|
||||
dataset=image_label_ds, size=all_image_size, index_by_label=label_names)
|
||||
|
||||
@classmethod
|
||||
def load_tf_dataset(
|
||||
cls, name: str
|
||||
) -> Tuple[Optional[classification_dataset.ClassificationDataset],
|
||||
Optional[classification_dataset.ClassificationDataset],
|
||||
Optional[classification_dataset.ClassificationDataset]]:
|
||||
"""Loads data from tensorflow_datasets.
|
||||
|
||||
Args:
|
||||
name: the registered name of the tfds.core.DatasetBuilder. Refer to the
|
||||
documentation of tfds.load for more details.
|
||||
|
||||
Returns:
|
||||
A tuple of Datasets for the train/validation/test.
|
||||
|
||||
Raises:
|
||||
ValueError: if the input tf dataset does not have train/validation/test
|
||||
labels.
|
||||
"""
|
||||
data, info = tfds.load(name, with_info=True)
|
||||
if 'label' not in info.features:
|
||||
raise ValueError('info.features need to contain \'label\' key.')
|
||||
label_names = info.features['label'].names
|
||||
|
||||
train_data = _create_data('train', data, info, label_names)
|
||||
validation_data = _create_data('validation', data, info, label_names)
|
||||
test_data = _create_data('test', data, info, label_names)
|
||||
return train_data, validation_data, test_data
|
||||
|
|
Loading…
Reference in New Issue
Block a user