Format improvement.
PiperOrigin-RevId: 563321343
This commit is contained in:
parent
3ce457006f
commit
5f4a6e313e
|
@ -34,7 +34,8 @@ ESTIMITED_STEPS_PER_EPOCH = 1000
|
||||||
|
|
||||||
|
|
||||||
def get_default_callbacks(
|
def get_default_callbacks(
|
||||||
export_dir: str) -> Sequence[tf.keras.callbacks.Callback]:
|
export_dir: str,
|
||||||
|
) -> Sequence[tf.keras.callbacks.Callback]:
|
||||||
"""Gets default callbacks."""
|
"""Gets default callbacks."""
|
||||||
summary_dir = os.path.join(export_dir, 'summaries')
|
summary_dir = os.path.join(export_dir, 'summaries')
|
||||||
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
|
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
|
||||||
|
@ -43,12 +44,14 @@ def get_default_callbacks(
|
||||||
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||||
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
|
os.path.join(checkpoint_path, 'model-{epoch:04d}'),
|
||||||
save_weights_only=True,
|
save_weights_only=True,
|
||||||
period=5)
|
period=5,
|
||||||
|
)
|
||||||
return [summary_callback, checkpoint_callback]
|
return [summary_callback, checkpoint_callback]
|
||||||
|
|
||||||
|
|
||||||
def load_keras_model(model_path: str,
|
def load_keras_model(
|
||||||
compile_on_load: bool = False) -> tf.keras.Model:
|
model_path: str, compile_on_load: bool = False
|
||||||
|
) -> tf.keras.Model:
|
||||||
"""Loads a tensorflow Keras model from file and returns the Keras model.
|
"""Loads a tensorflow Keras model from file and returns the Keras model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -82,9 +85,11 @@ def load_tflite_model_buffer(model_path: str) -> bytearray:
|
||||||
return tflite_model_buffer
|
return tflite_model_buffer
|
||||||
|
|
||||||
|
|
||||||
def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
|
def get_steps_per_epoch(
|
||||||
batch_size: Optional[int] = None,
|
steps_per_epoch: Optional[int] = None,
|
||||||
train_data: Optional[dataset.Dataset] = None) -> int:
|
batch_size: Optional[int] = None,
|
||||||
|
train_data: Optional[dataset.Dataset] = None,
|
||||||
|
) -> int:
|
||||||
"""Gets the estimated training steps per epoch.
|
"""Gets the estimated training steps per epoch.
|
||||||
|
|
||||||
1. If `steps_per_epoch` is set, returns `steps_per_epoch` directly.
|
1. If `steps_per_epoch` is set, returns `steps_per_epoch` directly.
|
||||||
|
@ -201,17 +206,20 @@ def save_tflite(tflite_model: bytearray, tflite_file: str) -> None:
|
||||||
with tf.io.gfile.GFile(tflite_file, 'wb') as f:
|
with tf.io.gfile.GFile(tflite_file, 'wb') as f:
|
||||||
f.write(tflite_model)
|
f.write(tflite_model)
|
||||||
tf.compat.v1.logging.info(
|
tf.compat.v1.logging.info(
|
||||||
'TensorFlow Lite model exported successfully to: %s' % tflite_file)
|
'TensorFlow Lite model exported successfully to: %s' % tflite_file
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
||||||
"""Applies a warmup schedule on a given learning rate decay schedule."""
|
"""Applies a warmup schedule on a given learning rate decay schedule."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
initial_learning_rate: float,
|
self,
|
||||||
decay_schedule_fn: Callable[[Any], Any],
|
initial_learning_rate: float,
|
||||||
warmup_steps: int,
|
decay_schedule_fn: Callable[[Any], Any],
|
||||||
name: Optional[str] = None):
|
warmup_steps: int,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
):
|
||||||
"""Initializes a new instance of the `WarmUp` class.
|
"""Initializes a new instance of the `WarmUp` class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -239,14 +247,15 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
||||||
global_step_float < warmup_steps_float,
|
global_step_float < warmup_steps_float,
|
||||||
lambda: warmup_learning_rate,
|
lambda: warmup_learning_rate,
|
||||||
lambda: self.decay_schedule_fn(step),
|
lambda: self.decay_schedule_fn(step),
|
||||||
name=name)
|
name=name,
|
||||||
|
)
|
||||||
|
|
||||||
def get_config(self) -> Dict[str, Any]:
|
def get_config(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
'initial_learning_rate': self.initial_learning_rate,
|
'initial_learning_rate': self.initial_learning_rate,
|
||||||
'decay_schedule_fn': self.decay_schedule_fn,
|
'decay_schedule_fn': self.decay_schedule_fn,
|
||||||
'warmup_steps': self.warmup_steps,
|
'warmup_steps': self.warmup_steps,
|
||||||
'name': self.name
|
'name': self.name,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -280,7 +289,8 @@ class LiteRunner(object):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(input_tensors, list) and not isinstance(
|
if not isinstance(input_tensors, list) and not isinstance(
|
||||||
input_tensors, dict):
|
input_tensors, dict
|
||||||
|
):
|
||||||
input_tensors = [input_tensors]
|
input_tensors = [input_tensors]
|
||||||
|
|
||||||
interpreter = self.interpreter
|
interpreter = self.interpreter
|
||||||
|
@ -288,19 +298,18 @@ class LiteRunner(object):
|
||||||
# Reshape inputs
|
# Reshape inputs
|
||||||
for i, input_detail in enumerate(self.input_details):
|
for i, input_detail in enumerate(self.input_details):
|
||||||
input_tensor = _get_input_tensor(
|
input_tensor = _get_input_tensor(
|
||||||
input_tensors=input_tensors,
|
input_tensors=input_tensors, input_details=self.input_details, index=i
|
||||||
input_details=self.input_details,
|
)
|
||||||
index=i)
|
|
||||||
interpreter.resize_tensor_input(
|
interpreter.resize_tensor_input(
|
||||||
input_index=input_detail['index'], tensor_size=input_tensor.shape)
|
input_index=input_detail['index'], tensor_size=input_tensor.shape
|
||||||
|
)
|
||||||
interpreter.allocate_tensors()
|
interpreter.allocate_tensors()
|
||||||
|
|
||||||
# Feed input to the interpreter
|
# Feed input to the interpreter
|
||||||
for i, input_detail in enumerate(self.input_details):
|
for i, input_detail in enumerate(self.input_details):
|
||||||
input_tensor = _get_input_tensor(
|
input_tensor = _get_input_tensor(
|
||||||
input_tensors=input_tensors,
|
input_tensors=input_tensors, input_details=self.input_details, index=i
|
||||||
input_details=self.input_details,
|
)
|
||||||
index=i)
|
|
||||||
if input_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT):
|
if input_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT):
|
||||||
# Quantize the input
|
# Quantize the input
|
||||||
scale, zero_point = input_detail['quantization']
|
scale, zero_point = input_detail['quantization']
|
||||||
|
@ -331,9 +340,11 @@ def get_lite_runner(tflite_buffer: bytearray) -> 'LiteRunner':
|
||||||
return lite_runner
|
return lite_runner
|
||||||
|
|
||||||
|
|
||||||
def _get_input_tensor(input_tensors: Union[List[tf.Tensor], Dict[str,
|
def _get_input_tensor(
|
||||||
tf.Tensor]],
|
input_tensors: Union[List[tf.Tensor], Dict[str, tf.Tensor]],
|
||||||
input_details: Dict[str, Any], index: int) -> tf.Tensor:
|
input_details: Dict[str, Any],
|
||||||
|
index: int,
|
||||||
|
) -> tf.Tensor:
|
||||||
"""Returns input tensor in `input_tensors` that maps `input_detail[i]`."""
|
"""Returns input tensor in `input_tensors` that maps `input_detail[i]`."""
|
||||||
if isinstance(input_tensors, dict):
|
if isinstance(input_tensors, dict):
|
||||||
# Gets the mapped input tensor.
|
# Gets the mapped input tensor.
|
||||||
|
@ -341,7 +352,9 @@ def _get_input_tensor(input_tensors: Union[List[tf.Tensor], Dict[str,
|
||||||
for input_tensor_name, input_tensor in input_tensors.items():
|
for input_tensor_name, input_tensor in input_tensors.items():
|
||||||
if input_tensor_name in input_detail['name']:
|
if input_tensor_name in input_detail['name']:
|
||||||
return input_tensor
|
return input_tensor
|
||||||
raise ValueError('Input tensors don\'t contains a tensor that mapped the '
|
raise ValueError(
|
||||||
'input detail %s' % str(input_detail))
|
"Input tensors don't contains a tensor that mapped the input detail %s"
|
||||||
|
% str(input_detail)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return input_tensors[index]
|
return input_tensors[index]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user