Fix preprocess Callable typing
PiperOrigin-RevId: 515818356
This commit is contained in:
parent
3e8fd58400
commit
c94de4032d
|
@ -18,7 +18,7 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from typing import Callable, Optional, Tuple, TypeVar
|
from typing import Any, Callable, Optional, Tuple, TypeVar
|
||||||
|
|
||||||
# Dependency imports
|
# Dependency imports
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -66,12 +66,14 @@ class Dataset(object):
|
||||||
"""
|
"""
|
||||||
return self._size
|
return self._size
|
||||||
|
|
||||||
def gen_tf_dataset(self,
|
def gen_tf_dataset(
|
||||||
batch_size: int = 1,
|
self,
|
||||||
is_training: bool = False,
|
batch_size: int = 1,
|
||||||
shuffle: bool = False,
|
is_training: bool = False,
|
||||||
preprocess: Optional[Callable[..., bool]] = None,
|
shuffle: bool = False,
|
||||||
drop_remainder: bool = False) -> tf.data.Dataset:
|
preprocess: Optional[Callable[..., Any]] = None,
|
||||||
|
drop_remainder: bool = False,
|
||||||
|
) -> tf.data.Dataset:
|
||||||
"""Generates a batched tf.data.Dataset for training/evaluation.
|
"""Generates a batched tf.data.Dataset for training/evaluation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -48,11 +48,13 @@ class Classifier(custom_model.CustomModel):
|
||||||
self._hparams: hp.BaseHParams = None
|
self._hparams: hp.BaseHParams = None
|
||||||
self._history: tf.keras.callbacks.History = None
|
self._history: tf.keras.callbacks.History = None
|
||||||
|
|
||||||
def _train_model(self,
|
def _train_model(
|
||||||
train_data: classification_ds.ClassificationDataset,
|
self,
|
||||||
validation_data: classification_ds.ClassificationDataset,
|
train_data: classification_ds.ClassificationDataset,
|
||||||
preprocessor: Optional[Callable[..., bool]] = None,
|
validation_data: classification_ds.ClassificationDataset,
|
||||||
checkpoint_path: Optional[str] = None):
|
preprocessor: Optional[Callable[..., Any]] = None,
|
||||||
|
checkpoint_path: Optional[str] = None,
|
||||||
|
):
|
||||||
"""Trains the classifier model.
|
"""Trains the classifier model.
|
||||||
|
|
||||||
Compiles and fits the tf.keras `_model` and records the `_history`.
|
Compiles and fits the tf.keras `_model` and records the `_history`.
|
||||||
|
|
|
@ -115,9 +115,11 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
|
||||||
def convert_to_tflite(
|
def convert_to_tflite(
|
||||||
model: tf.keras.Model,
|
model: tf.keras.Model,
|
||||||
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||||
supported_ops: Tuple[tf.lite.OpsSet,
|
supported_ops: Tuple[tf.lite.OpsSet, ...] = (
|
||||||
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,),
|
tf.lite.OpsSet.TFLITE_BUILTINS,
|
||||||
preprocess: Optional[Callable[..., bool]] = None) -> bytearray:
|
),
|
||||||
|
preprocess: Optional[Callable[..., Any]] = None,
|
||||||
|
) -> bytearray:
|
||||||
"""Converts the input Keras model to TFLite format.
|
"""Converts the input Keras model to TFLite format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user