Fix preprocess Callable typing
PiperOrigin-RevId: 515818356
This commit is contained in:
parent
3e8fd58400
commit
c94de4032d
|
@ -18,7 +18,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
from typing import Callable, Optional, Tuple, TypeVar
|
||||
from typing import Any, Callable, Optional, Tuple, TypeVar
|
||||
|
||||
# Dependency imports
|
||||
import tensorflow as tf
|
||||
|
@ -66,12 +66,14 @@ class Dataset(object):
|
|||
"""
|
||||
return self._size
|
||||
|
||||
def gen_tf_dataset(self,
|
||||
batch_size: int = 1,
|
||||
is_training: bool = False,
|
||||
shuffle: bool = False,
|
||||
preprocess: Optional[Callable[..., bool]] = None,
|
||||
drop_remainder: bool = False) -> tf.data.Dataset:
|
||||
def gen_tf_dataset(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
is_training: bool = False,
|
||||
shuffle: bool = False,
|
||||
preprocess: Optional[Callable[..., Any]] = None,
|
||||
drop_remainder: bool = False,
|
||||
) -> tf.data.Dataset:
|
||||
"""Generates a batched tf.data.Dataset for training/evaluation.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -48,11 +48,13 @@ class Classifier(custom_model.CustomModel):
|
|||
self._hparams: hp.BaseHParams = None
|
||||
self._history: tf.keras.callbacks.History = None
|
||||
|
||||
def _train_model(self,
|
||||
train_data: classification_ds.ClassificationDataset,
|
||||
validation_data: classification_ds.ClassificationDataset,
|
||||
preprocessor: Optional[Callable[..., bool]] = None,
|
||||
checkpoint_path: Optional[str] = None):
|
||||
def _train_model(
|
||||
self,
|
||||
train_data: classification_ds.ClassificationDataset,
|
||||
validation_data: classification_ds.ClassificationDataset,
|
||||
preprocessor: Optional[Callable[..., Any]] = None,
|
||||
checkpoint_path: Optional[str] = None,
|
||||
):
|
||||
"""Trains the classifier model.
|
||||
|
||||
Compiles and fits the tf.keras `_model` and records the `_history`.
|
||||
|
|
|
@ -115,9 +115,11 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
|
|||
def convert_to_tflite(
|
||||
model: tf.keras.Model,
|
||||
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||
supported_ops: Tuple[tf.lite.OpsSet,
|
||||
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,),
|
||||
preprocess: Optional[Callable[..., bool]] = None) -> bytearray:
|
||||
supported_ops: Tuple[tf.lite.OpsSet, ...] = (
|
||||
tf.lite.OpsSet.TFLITE_BUILTINS,
|
||||
),
|
||||
preprocess: Optional[Callable[..., Any]] = None,
|
||||
) -> bytearray:
|
||||
"""Converts the input Keras model to TFLite format.
|
||||
|
||||
Args:
|
||||
|
|
Loading…
Reference in New Issue
Block a user