Format improvement.
PiperOrigin-RevId: 563321343
This commit is contained in:
		
							parent
							
								
									3ce457006f
								
							
						
					
					
						commit
						5f4a6e313e
					
				| 
						 | 
					@ -34,7 +34,8 @@ ESTIMITED_STEPS_PER_EPOCH = 1000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_default_callbacks(
 | 
					def get_default_callbacks(
 | 
				
			||||||
    export_dir: str) -> Sequence[tf.keras.callbacks.Callback]:
 | 
					    export_dir: str,
 | 
				
			||||||
 | 
					) -> Sequence[tf.keras.callbacks.Callback]:
 | 
				
			||||||
  """Gets default callbacks."""
 | 
					  """Gets default callbacks."""
 | 
				
			||||||
  summary_dir = os.path.join(export_dir, 'summaries')
 | 
					  summary_dir = os.path.join(export_dir, 'summaries')
 | 
				
			||||||
  summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
 | 
					  summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
 | 
				
			||||||
| 
						 | 
					@ -43,12 +44,14 @@ def get_default_callbacks(
 | 
				
			||||||
  checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
 | 
					  checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
 | 
				
			||||||
      os.path.join(checkpoint_path, 'model-{epoch:04d}'),
 | 
					      os.path.join(checkpoint_path, 'model-{epoch:04d}'),
 | 
				
			||||||
      save_weights_only=True,
 | 
					      save_weights_only=True,
 | 
				
			||||||
      period=5)
 | 
					      period=5,
 | 
				
			||||||
 | 
					  )
 | 
				
			||||||
  return [summary_callback, checkpoint_callback]
 | 
					  return [summary_callback, checkpoint_callback]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def load_keras_model(model_path: str,
 | 
					def load_keras_model(
 | 
				
			||||||
                     compile_on_load: bool = False) -> tf.keras.Model:
 | 
					    model_path: str, compile_on_load: bool = False
 | 
				
			||||||
 | 
					) -> tf.keras.Model:
 | 
				
			||||||
  """Loads a tensorflow Keras model from file and returns the Keras model.
 | 
					  """Loads a tensorflow Keras model from file and returns the Keras model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  Args:
 | 
					  Args:
 | 
				
			||||||
| 
						 | 
					@ -82,9 +85,11 @@ def load_tflite_model_buffer(model_path: str) -> bytearray:
 | 
				
			||||||
  return tflite_model_buffer
 | 
					  return tflite_model_buffer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
 | 
					def get_steps_per_epoch(
 | 
				
			||||||
 | 
					    steps_per_epoch: Optional[int] = None,
 | 
				
			||||||
    batch_size: Optional[int] = None,
 | 
					    batch_size: Optional[int] = None,
 | 
				
			||||||
                        train_data: Optional[dataset.Dataset] = None) -> int:
 | 
					    train_data: Optional[dataset.Dataset] = None,
 | 
				
			||||||
 | 
					) -> int:
 | 
				
			||||||
  """Gets the estimated training steps per epoch.
 | 
					  """Gets the estimated training steps per epoch.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  1. If `steps_per_epoch` is set, returns `steps_per_epoch` directly.
 | 
					  1. If `steps_per_epoch` is set, returns `steps_per_epoch` directly.
 | 
				
			||||||
| 
						 | 
					@ -201,17 +206,20 @@ def save_tflite(tflite_model: bytearray, tflite_file: str) -> None:
 | 
				
			||||||
  with tf.io.gfile.GFile(tflite_file, 'wb') as f:
 | 
					  with tf.io.gfile.GFile(tflite_file, 'wb') as f:
 | 
				
			||||||
    f.write(tflite_model)
 | 
					    f.write(tflite_model)
 | 
				
			||||||
  tf.compat.v1.logging.info(
 | 
					  tf.compat.v1.logging.info(
 | 
				
			||||||
      'TensorFlow Lite model exported successfully to: %s' % tflite_file)
 | 
					      'TensorFlow Lite model exported successfully to: %s' % tflite_file
 | 
				
			||||||
 | 
					  )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
 | 
					class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
 | 
				
			||||||
  """Applies a warmup schedule on a given learning rate decay schedule."""
 | 
					  """Applies a warmup schedule on a given learning rate decay schedule."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def __init__(self,
 | 
					  def __init__(
 | 
				
			||||||
 | 
					      self,
 | 
				
			||||||
      initial_learning_rate: float,
 | 
					      initial_learning_rate: float,
 | 
				
			||||||
      decay_schedule_fn: Callable[[Any], Any],
 | 
					      decay_schedule_fn: Callable[[Any], Any],
 | 
				
			||||||
      warmup_steps: int,
 | 
					      warmup_steps: int,
 | 
				
			||||||
               name: Optional[str] = None):
 | 
					      name: Optional[str] = None,
 | 
				
			||||||
 | 
					  ):
 | 
				
			||||||
    """Initializes a new instance of the `WarmUp` class.
 | 
					    """Initializes a new instance of the `WarmUp` class.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Args:
 | 
					    Args:
 | 
				
			||||||
| 
						 | 
					@ -239,14 +247,15 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
 | 
				
			||||||
          global_step_float < warmup_steps_float,
 | 
					          global_step_float < warmup_steps_float,
 | 
				
			||||||
          lambda: warmup_learning_rate,
 | 
					          lambda: warmup_learning_rate,
 | 
				
			||||||
          lambda: self.decay_schedule_fn(step),
 | 
					          lambda: self.decay_schedule_fn(step),
 | 
				
			||||||
          name=name)
 | 
					          name=name,
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def get_config(self) -> Dict[str, Any]:
 | 
					  def get_config(self) -> Dict[str, Any]:
 | 
				
			||||||
    return {
 | 
					    return {
 | 
				
			||||||
        'initial_learning_rate': self.initial_learning_rate,
 | 
					        'initial_learning_rate': self.initial_learning_rate,
 | 
				
			||||||
        'decay_schedule_fn': self.decay_schedule_fn,
 | 
					        'decay_schedule_fn': self.decay_schedule_fn,
 | 
				
			||||||
        'warmup_steps': self.warmup_steps,
 | 
					        'warmup_steps': self.warmup_steps,
 | 
				
			||||||
        'name': self.name
 | 
					        'name': self.name,
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -280,7 +289,8 @@ class LiteRunner(object):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if not isinstance(input_tensors, list) and not isinstance(
 | 
					    if not isinstance(input_tensors, list) and not isinstance(
 | 
				
			||||||
        input_tensors, dict):
 | 
					        input_tensors, dict
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
      input_tensors = [input_tensors]
 | 
					      input_tensors = [input_tensors]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    interpreter = self.interpreter
 | 
					    interpreter = self.interpreter
 | 
				
			||||||
| 
						 | 
					@ -288,19 +298,18 @@ class LiteRunner(object):
 | 
				
			||||||
    # Reshape inputs
 | 
					    # Reshape inputs
 | 
				
			||||||
    for i, input_detail in enumerate(self.input_details):
 | 
					    for i, input_detail in enumerate(self.input_details):
 | 
				
			||||||
      input_tensor = _get_input_tensor(
 | 
					      input_tensor = _get_input_tensor(
 | 
				
			||||||
          input_tensors=input_tensors,
 | 
					          input_tensors=input_tensors, input_details=self.input_details, index=i
 | 
				
			||||||
          input_details=self.input_details,
 | 
					      )
 | 
				
			||||||
          index=i)
 | 
					 | 
				
			||||||
      interpreter.resize_tensor_input(
 | 
					      interpreter.resize_tensor_input(
 | 
				
			||||||
          input_index=input_detail['index'], tensor_size=input_tensor.shape)
 | 
					          input_index=input_detail['index'], tensor_size=input_tensor.shape
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
    interpreter.allocate_tensors()
 | 
					    interpreter.allocate_tensors()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Feed input to the interpreter
 | 
					    # Feed input to the interpreter
 | 
				
			||||||
    for i, input_detail in enumerate(self.input_details):
 | 
					    for i, input_detail in enumerate(self.input_details):
 | 
				
			||||||
      input_tensor = _get_input_tensor(
 | 
					      input_tensor = _get_input_tensor(
 | 
				
			||||||
          input_tensors=input_tensors,
 | 
					          input_tensors=input_tensors, input_details=self.input_details, index=i
 | 
				
			||||||
          input_details=self.input_details,
 | 
					      )
 | 
				
			||||||
          index=i)
 | 
					 | 
				
			||||||
      if input_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT):
 | 
					      if input_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT):
 | 
				
			||||||
        # Quantize the input
 | 
					        # Quantize the input
 | 
				
			||||||
        scale, zero_point = input_detail['quantization']
 | 
					        scale, zero_point = input_detail['quantization']
 | 
				
			||||||
| 
						 | 
					@ -331,9 +340,11 @@ def get_lite_runner(tflite_buffer: bytearray) -> 'LiteRunner':
 | 
				
			||||||
  return lite_runner
 | 
					  return lite_runner
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _get_input_tensor(input_tensors: Union[List[tf.Tensor], Dict[str,
 | 
					def _get_input_tensor(
 | 
				
			||||||
                                                                 tf.Tensor]],
 | 
					    input_tensors: Union[List[tf.Tensor], Dict[str, tf.Tensor]],
 | 
				
			||||||
                      input_details: Dict[str, Any], index: int) -> tf.Tensor:
 | 
					    input_details: Dict[str, Any],
 | 
				
			||||||
 | 
					    index: int,
 | 
				
			||||||
 | 
					) -> tf.Tensor:
 | 
				
			||||||
  """Returns input tensor in `input_tensors` that maps `input_detail[i]`."""
 | 
					  """Returns input tensor in `input_tensors` that maps `input_detail[i]`."""
 | 
				
			||||||
  if isinstance(input_tensors, dict):
 | 
					  if isinstance(input_tensors, dict):
 | 
				
			||||||
    # Gets the mapped input tensor.
 | 
					    # Gets the mapped input tensor.
 | 
				
			||||||
| 
						 | 
					@ -341,7 +352,9 @@ def _get_input_tensor(input_tensors: Union[List[tf.Tensor], Dict[str,
 | 
				
			||||||
    for input_tensor_name, input_tensor in input_tensors.items():
 | 
					    for input_tensor_name, input_tensor in input_tensors.items():
 | 
				
			||||||
      if input_tensor_name in input_detail['name']:
 | 
					      if input_tensor_name in input_detail['name']:
 | 
				
			||||||
        return input_tensor
 | 
					        return input_tensor
 | 
				
			||||||
    raise ValueError('Input tensors don\'t contains a tensor that mapped the '
 | 
					    raise ValueError(
 | 
				
			||||||
                     'input detail %s' % str(input_detail))
 | 
					        "Input tensors don't contains a tensor that mapped the input detail %s"
 | 
				
			||||||
 | 
					        % str(input_detail)
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
  else:
 | 
					  else:
 | 
				
			||||||
    return input_tensors[index]
 | 
					    return input_tensors[index]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user