From 48aa88f39df68f8dbd4c98741b243be8ae6803aa Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 26 Apr 2023 12:10:26 -0700 Subject: [PATCH] Change object detector learning rate decay to cosine decay. PiperOrigin-RevId: 527337105 --- docs/framework_concepts/graphs.md | 2 +- mediapipe/framework/api2/contract.h | 2 +- mediapipe/framework/calculator_base.h | 2 +- mediapipe/framework/calculator_context.h | 2 +- mediapipe/framework/calculator_runner.h | 2 +- mediapipe/framework/output_side_packet.h | 2 +- mediapipe/framework/output_side_packet_impl.h | 2 +- mediapipe/framework/output_stream.h | 2 +- mediapipe/framework/scheduler_queue.h | 2 +- .../tracked_anchor_manager_calculator.cc | 2 +- .../glutil/ExternalTextureRenderer.java | 2 +- .../vision/object_detector/hyperparameters.py | 53 ++++--------------- .../vision/object_detector/object_detector.py | 23 ++++---- .../segmentation_postprocessor_gl.cc | 2 +- mediapipe/util/tracking/box_tracker.h | 4 +- 15 files changed, 35 insertions(+), 69 deletions(-) diff --git a/docs/framework_concepts/graphs.md b/docs/framework_concepts/graphs.md index 23e20d052..5f9c68e08 100644 --- a/docs/framework_concepts/graphs.md +++ b/docs/framework_concepts/graphs.md @@ -273,7 +273,7 @@ defined in the enclosing protobuf in order to be traversed using ## Cycles - + By default, MediaPipe requires calculator graphs to be acyclic and treats cycles in a graph as errors. If a graph is intended to have cycles, the cycles need to diff --git a/mediapipe/framework/api2/contract.h b/mediapipe/framework/api2/contract.h index 9b92212cb..90e4c38cd 100644 --- a/mediapipe/framework/api2/contract.h +++ b/mediapipe/framework/api2/contract.h @@ -164,7 +164,7 @@ class Contract { std::tuple items; - // TODO -, check for conflicts. + // TODO: when forwarding nested items (e.g. ports), check for conflicts. decltype(ExtractNestedItems(items)) all_items{ExtractNestedItems(items)}; constexpr auto inputs() const { diff --git a/mediapipe/framework/calculator_base.h b/mediapipe/framework/calculator_base.h index 4bd3d7398..19f37f9de 100644 --- a/mediapipe/framework/calculator_base.h +++ b/mediapipe/framework/calculator_base.h @@ -150,7 +150,7 @@ class CalculatorBase { // Packets may be output during a call to Close(). However, output packets // are silently discarded if Close() is called after a graph run has ended. // - // NOTE - needs to perform an action only when processing is + // NOTE: If Close() needs to perform an action only when processing is // complete, Close() must check if cc->GraphStatus() is OK. virtual absl::Status Close(CalculatorContext* cc) { return absl::OkStatus(); } diff --git a/mediapipe/framework/calculator_context.h b/mediapipe/framework/calculator_context.h index 4eafea79f..284226d92 100644 --- a/mediapipe/framework/calculator_context.h +++ b/mediapipe/framework/calculator_context.h @@ -111,7 +111,7 @@ class CalculatorContext { // Returns the status of the graph run. // - // NOTE -. + // NOTE: This method should only be called during CalculatorBase::Close(). absl::Status GraphStatus() const { return graph_status_; } ProfilingContext* GetProfilingContext() const { diff --git a/mediapipe/framework/calculator_runner.h b/mediapipe/framework/calculator_runner.h index 350fb535c..fb1020de1 100644 --- a/mediapipe/framework/calculator_runner.h +++ b/mediapipe/framework/calculator_runner.h @@ -66,7 +66,7 @@ class CalculatorRunner { explicit CalculatorRunner(const std::string& node_config_string); // Convenience constructor to initialize a calculator which uses indexes // (not tags) for all its fields. - // NOTE -, which + // NOTE: This constructor calls proto_ns::TextFormat::ParseFromString(), which // is not available when using lite protos. CalculatorRunner(const std::string& calculator_type, const std::string& options_string, int num_inputs, diff --git a/mediapipe/framework/output_side_packet.h b/mediapipe/framework/output_side_packet.h index 44eb07085..9a0c8cbd2 100644 --- a/mediapipe/framework/output_side_packet.h +++ b/mediapipe/framework/output_side_packet.h @@ -30,7 +30,7 @@ class OutputSidePacket { // Sets the output side packet. The Packet must contain the data. // - // NOTE - cannot report errors via the return value. It uses an error + // NOTE: Set() cannot report errors via the return value. It uses an error // callback function to report errors. virtual void Set(const Packet& packet) = 0; }; diff --git a/mediapipe/framework/output_side_packet_impl.h b/mediapipe/framework/output_side_packet_impl.h index 7b16eb32b..7e7d639cd 100644 --- a/mediapipe/framework/output_side_packet_impl.h +++ b/mediapipe/framework/output_side_packet_impl.h @@ -48,7 +48,7 @@ class OutputSidePacketImpl : public OutputSidePacket { // Sets the output side packet. The Packet must contain the data. // - // NOTE - cannot report errors via the return value. It uses an error + // NOTE: Set() cannot report errors via the return value. It uses an error // callback function to report errors. void Set(const Packet& packet) override; diff --git a/mediapipe/framework/output_stream.h b/mediapipe/framework/output_stream.h index 55679066d..191c26fd7 100644 --- a/mediapipe/framework/output_stream.h +++ b/mediapipe/framework/output_stream.h @@ -50,7 +50,7 @@ class OutputStream { // the only packet in the stream. // Violation of any of these conditions causes a CHECK-failure. // - // NOTE - cannot report errors via the return value. Instead of a + // NOTE: AddPacket() cannot report errors via the return value. Instead of a // CHECK-failure, a subclass of OutputStream should use a callback function // to report errors. virtual void AddPacket(const Packet& packet) = 0; diff --git a/mediapipe/framework/scheduler_queue.h b/mediapipe/framework/scheduler_queue.h index 345da7dc2..f6777f42b 100644 --- a/mediapipe/framework/scheduler_queue.h +++ b/mediapipe/framework/scheduler_queue.h @@ -102,7 +102,7 @@ class SchedulerQueue : public TaskQueue { // Implements the TaskQueue interface. void RunNextTask() override; - // NOTE -, the caller must call + // NOTE: After calling SetRunning(true), the caller must call // SubmitWaitingTasksToExecutor since tasks may have been added while the // queue was not running. void SetRunning(bool running) ABSL_LOCKS_EXCLUDED(mutex_); diff --git a/mediapipe/graphs/instant_motion_tracking/calculators/tracked_anchor_manager_calculator.cc b/mediapipe/graphs/instant_motion_tracking/calculators/tracked_anchor_manager_calculator.cc index 9468b901e..446aee781 100644 --- a/mediapipe/graphs/instant_motion_tracking/calculators/tracked_anchor_manager_calculator.cc +++ b/mediapipe/graphs/instant_motion_tracking/calculators/tracked_anchor_manager_calculator.cc @@ -25,7 +25,7 @@ constexpr char kAnchorsTag[] = "ANCHORS"; constexpr char kBoxesInputTag[] = "BOXES"; constexpr char kBoxesOutputTag[] = "START_POS"; constexpr char kCancelTag[] = "CANCEL_ID"; -// TODO - +// TODO: Find optimal Height/Width (0.1-0.3) constexpr float kBoxEdgeSize = 0.2f; // Used to establish tracking box dimensions constexpr float kUsToMs = diff --git a/mediapipe/java/com/google/mediapipe/glutil/ExternalTextureRenderer.java b/mediapipe/java/com/google/mediapipe/glutil/ExternalTextureRenderer.java index 2dc6112a3..4dd35f865 100644 --- a/mediapipe/java/com/google/mediapipe/glutil/ExternalTextureRenderer.java +++ b/mediapipe/java/com/google/mediapipe/glutil/ExternalTextureRenderer.java @@ -106,7 +106,7 @@ public class ExternalTextureRenderer { * *

Before calling this, {@link #setup} must have been called. * - *

NOTE -} on passed surface texture. + *

NOTE: Calls {@link SurfaceTexture#updateTexImage()} on passed surface texture. */ public void render(SurfaceTexture surfaceTexture) { GLES20.glClear(GLES20.GL_COLOR_BUFFER_BIT); diff --git a/mediapipe/model_maker/python/vision/object_detector/hyperparameters.py b/mediapipe/model_maker/python/vision/object_detector/hyperparameters.py index 8b49564a0..1bc7514f2 100644 --- a/mediapipe/model_maker/python/vision/object_detector/hyperparameters.py +++ b/mediapipe/model_maker/python/vision/object_detector/hyperparameters.py @@ -14,7 +14,7 @@ """Hyperparameters for training object detection models.""" import dataclasses -from typing import List +from typing import Optional from mediapipe.model_maker.python.core import hyperparameters as hp @@ -29,12 +29,13 @@ class HParams(hp.BaseHParams): epochs: Number of training iterations over the dataset. do_fine_tuning: If true, the base module is trained together with the classification layer on top. - learning_rate_epoch_boundaries: List of epoch boundaries where - learning_rate_epoch_boundaries[i] is the epoch where the learning rate - will decay to learning_rate * learning_rate_decay_multipliers[i]. - learning_rate_decay_multipliers: List of learning rate multipliers which - calculates the learning rate at the ith boundary as learning_rate * - learning_rate_decay_multipliers[i]. + cosine_decay_epochs: The number of epochs for cosine decay learning rate. + See + https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/CosineDecay + for more info. + cosine_decay_alpha: The alpha value for cosine decay learning rate. See + https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/CosineDecay + for more info. """ # Parameters from BaseHParams class. @@ -42,41 +43,9 @@ class HParams(hp.BaseHParams): batch_size: int = 32 epochs: int = 10 - # Parameters for learning rate decay - learning_rate_epoch_boundaries: List[int] = dataclasses.field( - default_factory=lambda: [] - ) - learning_rate_decay_multipliers: List[float] = dataclasses.field( - default_factory=lambda: [] - ) - - def __post_init__(self): - # Validate stepwise learning rate parameters - lr_boundary_len = len(self.learning_rate_epoch_boundaries) - lr_decay_multipliers_len = len(self.learning_rate_decay_multipliers) - if lr_boundary_len != lr_decay_multipliers_len: - raise ValueError( - "Length of learning_rate_epoch_boundaries and ", - "learning_rate_decay_multipliers do not match: ", - f"{lr_boundary_len}!={lr_decay_multipliers_len}", - ) - # Validate learning_rate_epoch_boundaries - if ( - sorted(self.learning_rate_epoch_boundaries) - != self.learning_rate_epoch_boundaries - ): - raise ValueError( - "learning_rate_epoch_boundaries is not in ascending order: ", - self.learning_rate_epoch_boundaries, - ) - if ( - self.learning_rate_epoch_boundaries - and self.learning_rate_epoch_boundaries[-1] > self.epochs - ): - raise ValueError( - "Values in learning_rate_epoch_boundaries cannot be greater ", - "than epochs", - ) + # Parameters for cosine learning rate decay + cosine_decay_epochs: Optional[int] = None + cosine_decay_alpha: float = 0.0 @dataclasses.dataclass diff --git a/mediapipe/model_maker/python/vision/object_detector/object_detector.py b/mediapipe/model_maker/python/vision/object_detector/object_detector.py index d208a9b28..746eef1b3 100644 --- a/mediapipe/model_maker/python/vision/object_detector/object_detector.py +++ b/mediapipe/model_maker/python/vision/object_detector/object_detector.py @@ -354,19 +354,16 @@ class ObjectDetector(classifier.Classifier): A tf.keras.optimizer.Optimizer for model training. """ init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256 - if self._hparams.learning_rate_epoch_boundaries: - lr_values = [init_lr] + [ - init_lr * m for m in self._hparams.learning_rate_decay_multipliers - ] - lr_step_boundaries = [ - steps_per_epoch * epoch_boundary - for epoch_boundary in self._hparams.learning_rate_epoch_boundaries - ] - learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay( - lr_step_boundaries, lr_values - ) - else: - learning_rate = init_lr + decay_epochs = ( + self._hparams.cosine_decay_epochs + if self._hparams.cosine_decay_epochs + else self._hparams.epochs + ) + learning_rate = tf.keras.optimizers.schedules.CosineDecay( + init_lr, + steps_per_epoch * decay_epochs, + self._hparams.cosine_decay_alpha, + ) return tf.keras.optimizers.experimental.SGD( learning_rate=learning_rate, momentum=0.9 ) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.cc index 96451617f..5b212069f 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.cc @@ -581,7 +581,7 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu( // Step 2.5: For SOFTMAX, apply softmax shaders (max, transformAndSum, and // normalization) to create softmax-transformed chunks before channel // extraction. - // NOTE - / sum_over_x(exp(x-C)) = exp(x) / sum_over_x(exp(x)). So + // NOTE: exp(x-C) / sum_over_x(exp(x-C)) = exp(x) / sum_over_x(exp(x)). So // theoretically we can skip the max shader step entirely. However, // applying it does bring all our values into a nice (0, 1] range, so it // will likely be better for precision, especially when dealing with an diff --git a/mediapipe/util/tracking/box_tracker.h b/mediapipe/util/tracking/box_tracker.h index 31ac5c117..8654e97fd 100644 --- a/mediapipe/util/tracking/box_tracker.h +++ b/mediapipe/util/tracking/box_tracker.h @@ -200,7 +200,7 @@ class BoxTracker { // Cancels all ongoing tracks. To avoid race conditions all NewBoxTrack's in // flight will also be canceled. Future NewBoxTrack's will be canceled. - // NOTE - before + // NOTE: To resume execution, you have to call ResumeTracking() before // issuing more NewBoxTrack calls. void CancelAllOngoingTracks() ABSL_LOCKS_EXCLUDED(status_mutex_); void ResumeTracking() ABSL_LOCKS_EXCLUDED(status_mutex_); @@ -208,7 +208,7 @@ class BoxTracker { // Waits for all ongoing tracks to complete. // Optionally accepts a timeout in microseconds (== 0 for infinite wait). // Returns true on success, false if timeout is reached. - // NOTE - must + // NOTE: If WaitForAllOngoingTracks timed out, CancelAllOngoingTracks() must // be called before destructing the BoxTracker object or dangeling running // threads might try to access invalid data. bool WaitForAllOngoingTracks(int timeout_us = 0)