Format improvement.

PiperOrigin-RevId: 563321343
This commit is contained in:
MediaPipe Team 2023-09-06 22:59:32 -07:00 committed by Copybara-Service
parent 3ce457006f
commit 5f4a6e313e

View File

@ -34,7 +34,8 @@ ESTIMITED_STEPS_PER_EPOCH = 1000
def get_default_callbacks( def get_default_callbacks(
export_dir: str) -> Sequence[tf.keras.callbacks.Callback]: export_dir: str,
) -> Sequence[tf.keras.callbacks.Callback]:
"""Gets default callbacks.""" """Gets default callbacks."""
summary_dir = os.path.join(export_dir, 'summaries') summary_dir = os.path.join(export_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
@ -43,12 +44,14 @@ def get_default_callbacks(
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
os.path.join(checkpoint_path, 'model-{epoch:04d}'), os.path.join(checkpoint_path, 'model-{epoch:04d}'),
save_weights_only=True, save_weights_only=True,
period=5) period=5,
)
return [summary_callback, checkpoint_callback] return [summary_callback, checkpoint_callback]
def load_keras_model(model_path: str, def load_keras_model(
compile_on_load: bool = False) -> tf.keras.Model: model_path: str, compile_on_load: bool = False
) -> tf.keras.Model:
"""Loads a tensorflow Keras model from file and returns the Keras model. """Loads a tensorflow Keras model from file and returns the Keras model.
Args: Args:
@ -82,9 +85,11 @@ def load_tflite_model_buffer(model_path: str) -> bytearray:
return tflite_model_buffer return tflite_model_buffer
def get_steps_per_epoch(steps_per_epoch: Optional[int] = None, def get_steps_per_epoch(
batch_size: Optional[int] = None, steps_per_epoch: Optional[int] = None,
train_data: Optional[dataset.Dataset] = None) -> int: batch_size: Optional[int] = None,
train_data: Optional[dataset.Dataset] = None,
) -> int:
"""Gets the estimated training steps per epoch. """Gets the estimated training steps per epoch.
1. If `steps_per_epoch` is set, returns `steps_per_epoch` directly. 1. If `steps_per_epoch` is set, returns `steps_per_epoch` directly.
@ -201,17 +206,20 @@ def save_tflite(tflite_model: bytearray, tflite_file: str) -> None:
with tf.io.gfile.GFile(tflite_file, 'wb') as f: with tf.io.gfile.GFile(tflite_file, 'wb') as f:
f.write(tflite_model) f.write(tflite_model)
tf.compat.v1.logging.info( tf.compat.v1.logging.info(
'TensorFlow Lite model exported successfully to: %s' % tflite_file) 'TensorFlow Lite model exported successfully to: %s' % tflite_file
)
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Applies a warmup schedule on a given learning rate decay schedule.""" """Applies a warmup schedule on a given learning rate decay schedule."""
def __init__(self, def __init__(
initial_learning_rate: float, self,
decay_schedule_fn: Callable[[Any], Any], initial_learning_rate: float,
warmup_steps: int, decay_schedule_fn: Callable[[Any], Any],
name: Optional[str] = None): warmup_steps: int,
name: Optional[str] = None,
):
"""Initializes a new instance of the `WarmUp` class. """Initializes a new instance of the `WarmUp` class.
Args: Args:
@ -239,14 +247,15 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
global_step_float < warmup_steps_float, global_step_float < warmup_steps_float,
lambda: warmup_learning_rate, lambda: warmup_learning_rate,
lambda: self.decay_schedule_fn(step), lambda: self.decay_schedule_fn(step),
name=name) name=name,
)
def get_config(self) -> Dict[str, Any]: def get_config(self) -> Dict[str, Any]:
return { return {
'initial_learning_rate': self.initial_learning_rate, 'initial_learning_rate': self.initial_learning_rate,
'decay_schedule_fn': self.decay_schedule_fn, 'decay_schedule_fn': self.decay_schedule_fn,
'warmup_steps': self.warmup_steps, 'warmup_steps': self.warmup_steps,
'name': self.name 'name': self.name,
} }
@ -280,7 +289,8 @@ class LiteRunner(object):
""" """
if not isinstance(input_tensors, list) and not isinstance( if not isinstance(input_tensors, list) and not isinstance(
input_tensors, dict): input_tensors, dict
):
input_tensors = [input_tensors] input_tensors = [input_tensors]
interpreter = self.interpreter interpreter = self.interpreter
@ -288,19 +298,18 @@ class LiteRunner(object):
# Reshape inputs # Reshape inputs
for i, input_detail in enumerate(self.input_details): for i, input_detail in enumerate(self.input_details):
input_tensor = _get_input_tensor( input_tensor = _get_input_tensor(
input_tensors=input_tensors, input_tensors=input_tensors, input_details=self.input_details, index=i
input_details=self.input_details, )
index=i)
interpreter.resize_tensor_input( interpreter.resize_tensor_input(
input_index=input_detail['index'], tensor_size=input_tensor.shape) input_index=input_detail['index'], tensor_size=input_tensor.shape
)
interpreter.allocate_tensors() interpreter.allocate_tensors()
# Feed input to the interpreter # Feed input to the interpreter
for i, input_detail in enumerate(self.input_details): for i, input_detail in enumerate(self.input_details):
input_tensor = _get_input_tensor( input_tensor = _get_input_tensor(
input_tensors=input_tensors, input_tensors=input_tensors, input_details=self.input_details, index=i
input_details=self.input_details, )
index=i)
if input_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT): if input_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT):
# Quantize the input # Quantize the input
scale, zero_point = input_detail['quantization'] scale, zero_point = input_detail['quantization']
@ -331,9 +340,11 @@ def get_lite_runner(tflite_buffer: bytearray) -> 'LiteRunner':
return lite_runner return lite_runner
def _get_input_tensor(input_tensors: Union[List[tf.Tensor], Dict[str, def _get_input_tensor(
tf.Tensor]], input_tensors: Union[List[tf.Tensor], Dict[str, tf.Tensor]],
input_details: Dict[str, Any], index: int) -> tf.Tensor: input_details: Dict[str, Any],
index: int,
) -> tf.Tensor:
"""Returns input tensor in `input_tensors` that maps `input_detail[i]`.""" """Returns input tensor in `input_tensors` that maps `input_detail[i]`."""
if isinstance(input_tensors, dict): if isinstance(input_tensors, dict):
# Gets the mapped input tensor. # Gets the mapped input tensor.
@ -341,7 +352,9 @@ def _get_input_tensor(input_tensors: Union[List[tf.Tensor], Dict[str,
for input_tensor_name, input_tensor in input_tensors.items(): for input_tensor_name, input_tensor in input_tensors.items():
if input_tensor_name in input_detail['name']: if input_tensor_name in input_detail['name']:
return input_tensor return input_tensor
raise ValueError('Input tensors don\'t contains a tensor that mapped the ' raise ValueError(
'input detail %s' % str(input_detail)) "Input tensors don't contains a tensor that mapped the input detail %s"
% str(input_detail)
)
else: else:
return input_tensors[index] return input_tensors[index]