Change object detector learning rate decay to cosine decay.
PiperOrigin-RevId: 527337105
This commit is contained in:
		
							parent
							
								
									507ed0d91d
								
							
						
					
					
						commit
						48aa88f39d
					
				|  | @ -273,7 +273,7 @@ defined in the enclosing protobuf in order to be traversed using | |||
| 
 | ||||
| ## Cycles | ||||
| 
 | ||||
| <!-- TODO --> | ||||
| <!-- TODO: add discussion of PreviousLoopbackCalculator --> | ||||
| 
 | ||||
| By default, MediaPipe requires calculator graphs to be acyclic and treats cycles | ||||
| in a graph as errors. If a graph is intended to have cycles, the cycles need to | ||||
|  |  | |||
|  | @ -164,7 +164,7 @@ class Contract { | |||
| 
 | ||||
|   std::tuple<T...> items; | ||||
| 
 | ||||
|   // TODO -, check for conflicts.
 | ||||
|   // TODO: when forwarding nested items (e.g. ports), check for conflicts.
 | ||||
|   decltype(ExtractNestedItems(items)) all_items{ExtractNestedItems(items)}; | ||||
| 
 | ||||
|   constexpr auto inputs() const { | ||||
|  |  | |||
|  | @ -150,7 +150,7 @@ class CalculatorBase { | |||
|   // Packets may be output during a call to Close().  However, output packets
 | ||||
|   // are silently discarded if Close() is called after a graph run has ended.
 | ||||
|   //
 | ||||
|   // NOTE - needs to perform an action only when processing is
 | ||||
|   // NOTE: If Close() needs to perform an action only when processing is
 | ||||
|   // complete, Close() must check if cc->GraphStatus() is OK.
 | ||||
|   virtual absl::Status Close(CalculatorContext* cc) { return absl::OkStatus(); } | ||||
| 
 | ||||
|  |  | |||
|  | @ -111,7 +111,7 @@ class CalculatorContext { | |||
| 
 | ||||
|   // Returns the status of the graph run.
 | ||||
|   //
 | ||||
|   // NOTE -.
 | ||||
|   // NOTE: This method should only be called during CalculatorBase::Close().
 | ||||
|   absl::Status GraphStatus() const { return graph_status_; } | ||||
| 
 | ||||
|   ProfilingContext* GetProfilingContext() const { | ||||
|  |  | |||
|  | @ -66,7 +66,7 @@ class CalculatorRunner { | |||
|   explicit CalculatorRunner(const std::string& node_config_string); | ||||
|   // Convenience constructor to initialize a calculator which uses indexes
 | ||||
|   // (not tags) for all its fields.
 | ||||
|   // NOTE -, which
 | ||||
|   // NOTE: This constructor calls proto_ns::TextFormat::ParseFromString(), which
 | ||||
|   // is not available when using lite protos.
 | ||||
|   CalculatorRunner(const std::string& calculator_type, | ||||
|                    const std::string& options_string, int num_inputs, | ||||
|  |  | |||
|  | @ -30,7 +30,7 @@ class OutputSidePacket { | |||
| 
 | ||||
|   // Sets the output side packet. The Packet must contain the data.
 | ||||
|   //
 | ||||
|   // NOTE - cannot report errors via the return value. It uses an error
 | ||||
|   // NOTE: Set() cannot report errors via the return value. It uses an error
 | ||||
|   // callback function to report errors.
 | ||||
|   virtual void Set(const Packet& packet) = 0; | ||||
| }; | ||||
|  |  | |||
|  | @ -48,7 +48,7 @@ class OutputSidePacketImpl : public OutputSidePacket { | |||
| 
 | ||||
|   // Sets the output side packet. The Packet must contain the data.
 | ||||
|   //
 | ||||
|   // NOTE - cannot report errors via the return value. It uses an error
 | ||||
|   // NOTE: Set() cannot report errors via the return value. It uses an error
 | ||||
|   // callback function to report errors.
 | ||||
|   void Set(const Packet& packet) override; | ||||
| 
 | ||||
|  |  | |||
|  | @ -50,7 +50,7 @@ class OutputStream { | |||
|   //   the only packet in the stream.
 | ||||
|   // Violation of any of these conditions causes a CHECK-failure.
 | ||||
|   //
 | ||||
|   // NOTE - cannot report errors via the return value. Instead of a
 | ||||
|   // NOTE: AddPacket() cannot report errors via the return value. Instead of a
 | ||||
|   // CHECK-failure, a subclass of OutputStream should use a callback function
 | ||||
|   // to report errors.
 | ||||
|   virtual void AddPacket(const Packet& packet) = 0; | ||||
|  |  | |||
|  | @ -102,7 +102,7 @@ class SchedulerQueue : public TaskQueue { | |||
|   // Implements the TaskQueue interface.
 | ||||
|   void RunNextTask() override; | ||||
| 
 | ||||
|   // NOTE -, the caller must call
 | ||||
|   // NOTE: After calling SetRunning(true), the caller must call
 | ||||
|   // SubmitWaitingTasksToExecutor since tasks may have been added while the
 | ||||
|   // queue was not running.
 | ||||
|   void SetRunning(bool running) ABSL_LOCKS_EXCLUDED(mutex_); | ||||
|  |  | |||
|  | @ -25,7 +25,7 @@ constexpr char kAnchorsTag[] = "ANCHORS"; | |||
| constexpr char kBoxesInputTag[] = "BOXES"; | ||||
| constexpr char kBoxesOutputTag[] = "START_POS"; | ||||
| constexpr char kCancelTag[] = "CANCEL_ID"; | ||||
| // TODO -
 | ||||
| // TODO: Find optimal Height/Width (0.1-0.3)
 | ||||
| constexpr float kBoxEdgeSize = | ||||
|     0.2f;  // Used to establish tracking box dimensions
 | ||||
| constexpr float kUsToMs = | ||||
|  |  | |||
|  | @ -106,7 +106,7 @@ public class ExternalTextureRenderer { | |||
|    * | ||||
|    * <p>Before calling this, {@link #setup} must have been called. | ||||
|    * | ||||
|    * <p>NOTE -} on passed surface texture. | ||||
|    * <p>NOTE: Calls {@link SurfaceTexture#updateTexImage()} on passed surface texture. | ||||
|    */ | ||||
|   public void render(SurfaceTexture surfaceTexture) { | ||||
|     GLES20.glClear(GLES20.GL_COLOR_BUFFER_BIT); | ||||
|  |  | |||
|  | @ -14,7 +14,7 @@ | |||
| """Hyperparameters for training object detection models.""" | ||||
| 
 | ||||
| import dataclasses | ||||
| from typing import List | ||||
| from typing import Optional | ||||
| 
 | ||||
| from mediapipe.model_maker.python.core import hyperparameters as hp | ||||
| 
 | ||||
|  | @ -29,12 +29,13 @@ class HParams(hp.BaseHParams): | |||
|     epochs: Number of training iterations over the dataset. | ||||
|     do_fine_tuning: If true, the base module is trained together with the | ||||
|       classification layer on top. | ||||
|     learning_rate_epoch_boundaries: List of epoch boundaries where | ||||
|       learning_rate_epoch_boundaries[i] is the epoch where the learning rate | ||||
|       will decay to learning_rate * learning_rate_decay_multipliers[i]. | ||||
|     learning_rate_decay_multipliers: List of learning rate multipliers which | ||||
|       calculates the learning rate at the ith boundary as learning_rate * | ||||
|       learning_rate_decay_multipliers[i]. | ||||
|     cosine_decay_epochs: The number of epochs for cosine decay learning rate. | ||||
|       See | ||||
|       https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/CosineDecay | ||||
|         for more info. | ||||
|     cosine_decay_alpha: The alpha value for cosine decay learning rate. See | ||||
|       https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/CosineDecay | ||||
|         for more info. | ||||
|   """ | ||||
| 
 | ||||
|   # Parameters from BaseHParams class. | ||||
|  | @ -42,41 +43,9 @@ class HParams(hp.BaseHParams): | |||
|   batch_size: int = 32 | ||||
|   epochs: int = 10 | ||||
| 
 | ||||
|   # Parameters for learning rate decay | ||||
|   learning_rate_epoch_boundaries: List[int] = dataclasses.field( | ||||
|       default_factory=lambda: [] | ||||
|   ) | ||||
|   learning_rate_decay_multipliers: List[float] = dataclasses.field( | ||||
|       default_factory=lambda: [] | ||||
|   ) | ||||
| 
 | ||||
|   def __post_init__(self): | ||||
|     # Validate stepwise learning rate parameters | ||||
|     lr_boundary_len = len(self.learning_rate_epoch_boundaries) | ||||
|     lr_decay_multipliers_len = len(self.learning_rate_decay_multipliers) | ||||
|     if lr_boundary_len != lr_decay_multipliers_len: | ||||
|       raise ValueError( | ||||
|           "Length of learning_rate_epoch_boundaries and ", | ||||
|           "learning_rate_decay_multipliers do not match: ", | ||||
|           f"{lr_boundary_len}!={lr_decay_multipliers_len}", | ||||
|       ) | ||||
|     # Validate learning_rate_epoch_boundaries | ||||
|     if ( | ||||
|         sorted(self.learning_rate_epoch_boundaries) | ||||
|         != self.learning_rate_epoch_boundaries | ||||
|     ): | ||||
|       raise ValueError( | ||||
|           "learning_rate_epoch_boundaries is not in ascending order: ", | ||||
|           self.learning_rate_epoch_boundaries, | ||||
|       ) | ||||
|     if ( | ||||
|         self.learning_rate_epoch_boundaries | ||||
|         and self.learning_rate_epoch_boundaries[-1] > self.epochs | ||||
|     ): | ||||
|       raise ValueError( | ||||
|           "Values in learning_rate_epoch_boundaries cannot be greater ", | ||||
|           "than epochs", | ||||
|       ) | ||||
|   # Parameters for cosine learning rate decay | ||||
|   cosine_decay_epochs: Optional[int] = None | ||||
|   cosine_decay_alpha: float = 0.0 | ||||
| 
 | ||||
| 
 | ||||
| @dataclasses.dataclass | ||||
|  |  | |||
|  | @ -354,19 +354,16 @@ class ObjectDetector(classifier.Classifier): | |||
|       A tf.keras.optimizer.Optimizer for model training. | ||||
|     """ | ||||
|     init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256 | ||||
|     if self._hparams.learning_rate_epoch_boundaries: | ||||
|       lr_values = [init_lr] + [ | ||||
|           init_lr * m for m in self._hparams.learning_rate_decay_multipliers | ||||
|       ] | ||||
|       lr_step_boundaries = [ | ||||
|           steps_per_epoch * epoch_boundary | ||||
|           for epoch_boundary in self._hparams.learning_rate_epoch_boundaries | ||||
|       ] | ||||
|       learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay( | ||||
|           lr_step_boundaries, lr_values | ||||
|       ) | ||||
|     else: | ||||
|       learning_rate = init_lr | ||||
|     decay_epochs = ( | ||||
|         self._hparams.cosine_decay_epochs | ||||
|         if self._hparams.cosine_decay_epochs | ||||
|         else self._hparams.epochs | ||||
|     ) | ||||
|     learning_rate = tf.keras.optimizers.schedules.CosineDecay( | ||||
|         init_lr, | ||||
|         steps_per_epoch * decay_epochs, | ||||
|         self._hparams.cosine_decay_alpha, | ||||
|     ) | ||||
|     return tf.keras.optimizers.experimental.SGD( | ||||
|         learning_rate=learning_rate, momentum=0.9 | ||||
|     ) | ||||
|  |  | |||
|  | @ -581,7 +581,7 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu( | |||
|       // Step 2.5: For SOFTMAX, apply softmax shaders (max, transformAndSum, and
 | ||||
|       // normalization) to create softmax-transformed chunks before channel
 | ||||
|       // extraction.
 | ||||
|       // NOTE - / sum_over_x(exp(x-C)) = exp(x) / sum_over_x(exp(x)). So
 | ||||
|       // NOTE: exp(x-C) / sum_over_x(exp(x-C)) = exp(x) / sum_over_x(exp(x)). So
 | ||||
|       //   theoretically we can skip the max shader step entirely. However,
 | ||||
|       //   applying it does bring all our values into a nice (0, 1] range, so it
 | ||||
|       //   will likely be better for precision, especially when dealing with an
 | ||||
|  |  | |||
|  | @ -200,7 +200,7 @@ class BoxTracker { | |||
| 
 | ||||
|   // Cancels all ongoing tracks. To avoid race conditions all NewBoxTrack's in
 | ||||
|   // flight will also be canceled. Future NewBoxTrack's will be canceled.
 | ||||
|   // NOTE - before
 | ||||
|   // NOTE: To resume execution, you have to call ResumeTracking() before
 | ||||
|   //       issuing more NewBoxTrack calls.
 | ||||
|   void CancelAllOngoingTracks() ABSL_LOCKS_EXCLUDED(status_mutex_); | ||||
|   void ResumeTracking() ABSL_LOCKS_EXCLUDED(status_mutex_); | ||||
|  | @ -208,7 +208,7 @@ class BoxTracker { | |||
|   // Waits for all ongoing tracks to complete.
 | ||||
|   // Optionally accepts a timeout in microseconds (== 0 for infinite wait).
 | ||||
|   // Returns true on success, false if timeout is reached.
 | ||||
|   // NOTE - must
 | ||||
|   // NOTE: If WaitForAllOngoingTracks timed out, CancelAllOngoingTracks() must
 | ||||
|   // be called before destructing the BoxTracker object or dangeling running
 | ||||
|   // threads might try to access invalid data.
 | ||||
|   bool WaitForAllOngoingTracks(int timeout_us = 0) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user