Fix preprocess Callable typing

PiperOrigin-RevId: 515818356
This commit is contained in:
MediaPipe Team 2023-03-10 21:49:05 -08:00 committed by Copybara-Service
parent 3e8fd58400
commit c94de4032d
3 changed files with 21 additions and 15 deletions

View File

@ -18,7 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools import functools
from typing import Callable, Optional, Tuple, TypeVar from typing import Any, Callable, Optional, Tuple, TypeVar
# Dependency imports # Dependency imports
import tensorflow as tf import tensorflow as tf
@ -66,12 +66,14 @@ class Dataset(object):
""" """
return self._size return self._size
def gen_tf_dataset(self, def gen_tf_dataset(
batch_size: int = 1, self,
is_training: bool = False, batch_size: int = 1,
shuffle: bool = False, is_training: bool = False,
preprocess: Optional[Callable[..., bool]] = None, shuffle: bool = False,
drop_remainder: bool = False) -> tf.data.Dataset: preprocess: Optional[Callable[..., Any]] = None,
drop_remainder: bool = False,
) -> tf.data.Dataset:
"""Generates a batched tf.data.Dataset for training/evaluation. """Generates a batched tf.data.Dataset for training/evaluation.
Args: Args:

View File

@ -48,11 +48,13 @@ class Classifier(custom_model.CustomModel):
self._hparams: hp.BaseHParams = None self._hparams: hp.BaseHParams = None
self._history: tf.keras.callbacks.History = None self._history: tf.keras.callbacks.History = None
def _train_model(self, def _train_model(
train_data: classification_ds.ClassificationDataset, self,
validation_data: classification_ds.ClassificationDataset, train_data: classification_ds.ClassificationDataset,
preprocessor: Optional[Callable[..., bool]] = None, validation_data: classification_ds.ClassificationDataset,
checkpoint_path: Optional[str] = None): preprocessor: Optional[Callable[..., Any]] = None,
checkpoint_path: Optional[str] = None,
):
"""Trains the classifier model. """Trains the classifier model.
Compiles and fits the tf.keras `_model` and records the `_history`. Compiles and fits the tf.keras `_model` and records the `_history`.

View File

@ -115,9 +115,11 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
def convert_to_tflite( def convert_to_tflite(
model: tf.keras.Model, model: tf.keras.Model,
quantization_config: Optional[quantization.QuantizationConfig] = None, quantization_config: Optional[quantization.QuantizationConfig] = None,
supported_ops: Tuple[tf.lite.OpsSet, supported_ops: Tuple[tf.lite.OpsSet, ...] = (
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,), tf.lite.OpsSet.TFLITE_BUILTINS,
preprocess: Optional[Callable[..., bool]] = None) -> bytearray: ),
preprocess: Optional[Callable[..., Any]] = None,
) -> bytearray:
"""Converts the input Keras model to TFLite format. """Converts the input Keras model to TFLite format.
Args: Args: