Cleans up unused function from image_classifier's Dataset.

PiperOrigin-RevId: 482703775
This commit is contained in:
MediaPipe Team 2022-10-20 23:57:21 -07:00 committed by Copybara-Service
parent 348c4e6652
commit 55ba23ce9a

View File

@ -16,7 +16,7 @@
import os import os
import random import random
from typing import List, Optional, Tuple from typing import List, Optional
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
@ -107,32 +107,3 @@ class Dataset(classification_dataset.ClassificationDataset):
all_label_size, ', '.join(label_names)) all_label_size, ', '.join(label_names))
return Dataset( return Dataset(
dataset=image_label_ds, size=all_image_size, index_by_label=label_names) dataset=image_label_ds, size=all_image_size, index_by_label=label_names)
@classmethod
def load_tf_dataset(
cls, name: str
) -> Tuple[Optional[classification_dataset.ClassificationDataset],
Optional[classification_dataset.ClassificationDataset],
Optional[classification_dataset.ClassificationDataset]]:
"""Loads data from tensorflow_datasets.
Args:
name: the registered name of the tfds.core.DatasetBuilder. Refer to the
documentation of tfds.load for more details.
Returns:
A tuple of Datasets for the train/validation/test.
Raises:
ValueError: if the input tf dataset does not have train/validation/test
labels.
"""
data, info = tfds.load(name, with_info=True)
if 'label' not in info.features:
raise ValueError('info.features need to contain \'label\' key.')
label_names = info.features['label'].names
train_data = _create_data('train', data, info, label_names)
validation_data = _create_data('validation', data, info, label_names)
test_data = _create_data('test', data, info, label_names)
return train_data, validation_data, test_data