Fix preprocess Callable typing

PiperOrigin-RevId: 515818356
This commit is contained in:
MediaPipe Team 2023-03-10 21:49:05 -08:00 committed by Copybara-Service
parent 3e8fd58400
commit c94de4032d
3 changed files with 21 additions and 15 deletions

View File

@ -18,7 +18,7 @@ from __future__ import division
from __future__ import print_function
import functools
from typing import Callable, Optional, Tuple, TypeVar
from typing import Any, Callable, Optional, Tuple, TypeVar
# Dependency imports
import tensorflow as tf
@ -66,12 +66,14 @@ class Dataset(object):
"""
return self._size
def gen_tf_dataset(self,
batch_size: int = 1,
is_training: bool = False,
shuffle: bool = False,
preprocess: Optional[Callable[..., bool]] = None,
drop_remainder: bool = False) -> tf.data.Dataset:
def gen_tf_dataset(
self,
batch_size: int = 1,
is_training: bool = False,
shuffle: bool = False,
preprocess: Optional[Callable[..., Any]] = None,
drop_remainder: bool = False,
) -> tf.data.Dataset:
"""Generates a batched tf.data.Dataset for training/evaluation.
Args:

View File

@ -48,11 +48,13 @@ class Classifier(custom_model.CustomModel):
self._hparams: hp.BaseHParams = None
self._history: tf.keras.callbacks.History = None
def _train_model(self,
train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset,
preprocessor: Optional[Callable[..., bool]] = None,
checkpoint_path: Optional[str] = None):
def _train_model(
self,
train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset,
preprocessor: Optional[Callable[..., Any]] = None,
checkpoint_path: Optional[str] = None,
):
"""Trains the classifier model.
Compiles and fits the tf.keras `_model` and records the `_history`.

View File

@ -115,9 +115,11 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
def convert_to_tflite(
model: tf.keras.Model,
quantization_config: Optional[quantization.QuantizationConfig] = None,
supported_ops: Tuple[tf.lite.OpsSet,
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,),
preprocess: Optional[Callable[..., bool]] = None) -> bytearray:
supported_ops: Tuple[tf.lite.OpsSet, ...] = (
tf.lite.OpsSet.TFLITE_BUILTINS,
),
preprocess: Optional[Callable[..., Any]] = None,
) -> bytearray:
"""Converts the input Keras model to TFLite format.
Args: