Change object detector learning rate decay to cosine decay.

PiperOrigin-RevId: 527337105
This commit is contained in:
MediaPipe Team 2023-04-26 12:10:26 -07:00 committed by Copybara-Service
parent 507ed0d91d
commit 48aa88f39d
15 changed files with 35 additions and 69 deletions

View File

@ -273,7 +273,7 @@ defined in the enclosing protobuf in order to be traversed using
## Cycles ## Cycles
<!-- TODO --> <!-- TODO: add discussion of PreviousLoopbackCalculator -->
By default, MediaPipe requires calculator graphs to be acyclic and treats cycles By default, MediaPipe requires calculator graphs to be acyclic and treats cycles
in a graph as errors. If a graph is intended to have cycles, the cycles need to in a graph as errors. If a graph is intended to have cycles, the cycles need to

View File

@ -164,7 +164,7 @@ class Contract {
std::tuple<T...> items; std::tuple<T...> items;
// TODO -, check for conflicts. // TODO: when forwarding nested items (e.g. ports), check for conflicts.
decltype(ExtractNestedItems(items)) all_items{ExtractNestedItems(items)}; decltype(ExtractNestedItems(items)) all_items{ExtractNestedItems(items)};
constexpr auto inputs() const { constexpr auto inputs() const {

View File

@ -150,7 +150,7 @@ class CalculatorBase {
// Packets may be output during a call to Close(). However, output packets // Packets may be output during a call to Close(). However, output packets
// are silently discarded if Close() is called after a graph run has ended. // are silently discarded if Close() is called after a graph run has ended.
// //
// NOTE - needs to perform an action only when processing is // NOTE: If Close() needs to perform an action only when processing is
// complete, Close() must check if cc->GraphStatus() is OK. // complete, Close() must check if cc->GraphStatus() is OK.
virtual absl::Status Close(CalculatorContext* cc) { return absl::OkStatus(); } virtual absl::Status Close(CalculatorContext* cc) { return absl::OkStatus(); }

View File

@ -111,7 +111,7 @@ class CalculatorContext {
// Returns the status of the graph run. // Returns the status of the graph run.
// //
// NOTE -. // NOTE: This method should only be called during CalculatorBase::Close().
absl::Status GraphStatus() const { return graph_status_; } absl::Status GraphStatus() const { return graph_status_; }
ProfilingContext* GetProfilingContext() const { ProfilingContext* GetProfilingContext() const {

View File

@ -66,7 +66,7 @@ class CalculatorRunner {
explicit CalculatorRunner(const std::string& node_config_string); explicit CalculatorRunner(const std::string& node_config_string);
// Convenience constructor to initialize a calculator which uses indexes // Convenience constructor to initialize a calculator which uses indexes
// (not tags) for all its fields. // (not tags) for all its fields.
// NOTE -, which // NOTE: This constructor calls proto_ns::TextFormat::ParseFromString(), which
// is not available when using lite protos. // is not available when using lite protos.
CalculatorRunner(const std::string& calculator_type, CalculatorRunner(const std::string& calculator_type,
const std::string& options_string, int num_inputs, const std::string& options_string, int num_inputs,

View File

@ -30,7 +30,7 @@ class OutputSidePacket {
// Sets the output side packet. The Packet must contain the data. // Sets the output side packet. The Packet must contain the data.
// //
// NOTE - cannot report errors via the return value. It uses an error // NOTE: Set() cannot report errors via the return value. It uses an error
// callback function to report errors. // callback function to report errors.
virtual void Set(const Packet& packet) = 0; virtual void Set(const Packet& packet) = 0;
}; };

View File

@ -48,7 +48,7 @@ class OutputSidePacketImpl : public OutputSidePacket {
// Sets the output side packet. The Packet must contain the data. // Sets the output side packet. The Packet must contain the data.
// //
// NOTE - cannot report errors via the return value. It uses an error // NOTE: Set() cannot report errors via the return value. It uses an error
// callback function to report errors. // callback function to report errors.
void Set(const Packet& packet) override; void Set(const Packet& packet) override;

View File

@ -50,7 +50,7 @@ class OutputStream {
// the only packet in the stream. // the only packet in the stream.
// Violation of any of these conditions causes a CHECK-failure. // Violation of any of these conditions causes a CHECK-failure.
// //
// NOTE - cannot report errors via the return value. Instead of a // NOTE: AddPacket() cannot report errors via the return value. Instead of a
// CHECK-failure, a subclass of OutputStream should use a callback function // CHECK-failure, a subclass of OutputStream should use a callback function
// to report errors. // to report errors.
virtual void AddPacket(const Packet& packet) = 0; virtual void AddPacket(const Packet& packet) = 0;

View File

@ -102,7 +102,7 @@ class SchedulerQueue : public TaskQueue {
// Implements the TaskQueue interface. // Implements the TaskQueue interface.
void RunNextTask() override; void RunNextTask() override;
// NOTE -, the caller must call // NOTE: After calling SetRunning(true), the caller must call
// SubmitWaitingTasksToExecutor since tasks may have been added while the // SubmitWaitingTasksToExecutor since tasks may have been added while the
// queue was not running. // queue was not running.
void SetRunning(bool running) ABSL_LOCKS_EXCLUDED(mutex_); void SetRunning(bool running) ABSL_LOCKS_EXCLUDED(mutex_);

View File

@ -25,7 +25,7 @@ constexpr char kAnchorsTag[] = "ANCHORS";
constexpr char kBoxesInputTag[] = "BOXES"; constexpr char kBoxesInputTag[] = "BOXES";
constexpr char kBoxesOutputTag[] = "START_POS"; constexpr char kBoxesOutputTag[] = "START_POS";
constexpr char kCancelTag[] = "CANCEL_ID"; constexpr char kCancelTag[] = "CANCEL_ID";
// TODO - // TODO: Find optimal Height/Width (0.1-0.3)
constexpr float kBoxEdgeSize = constexpr float kBoxEdgeSize =
0.2f; // Used to establish tracking box dimensions 0.2f; // Used to establish tracking box dimensions
constexpr float kUsToMs = constexpr float kUsToMs =

View File

@ -106,7 +106,7 @@ public class ExternalTextureRenderer {
* *
* <p>Before calling this, {@link #setup} must have been called. * <p>Before calling this, {@link #setup} must have been called.
* *
* <p>NOTE -} on passed surface texture. * <p>NOTE: Calls {@link SurfaceTexture#updateTexImage()} on passed surface texture.
*/ */
public void render(SurfaceTexture surfaceTexture) { public void render(SurfaceTexture surfaceTexture) {
GLES20.glClear(GLES20.GL_COLOR_BUFFER_BIT); GLES20.glClear(GLES20.GL_COLOR_BUFFER_BIT);

View File

@ -14,7 +14,7 @@
"""Hyperparameters for training object detection models.""" """Hyperparameters for training object detection models."""
import dataclasses import dataclasses
from typing import List from typing import Optional
from mediapipe.model_maker.python.core import hyperparameters as hp from mediapipe.model_maker.python.core import hyperparameters as hp
@ -29,12 +29,13 @@ class HParams(hp.BaseHParams):
epochs: Number of training iterations over the dataset. epochs: Number of training iterations over the dataset.
do_fine_tuning: If true, the base module is trained together with the do_fine_tuning: If true, the base module is trained together with the
classification layer on top. classification layer on top.
learning_rate_epoch_boundaries: List of epoch boundaries where cosine_decay_epochs: The number of epochs for cosine decay learning rate.
learning_rate_epoch_boundaries[i] is the epoch where the learning rate See
will decay to learning_rate * learning_rate_decay_multipliers[i]. https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/CosineDecay
learning_rate_decay_multipliers: List of learning rate multipliers which for more info.
calculates the learning rate at the ith boundary as learning_rate * cosine_decay_alpha: The alpha value for cosine decay learning rate. See
learning_rate_decay_multipliers[i]. https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/CosineDecay
for more info.
""" """
# Parameters from BaseHParams class. # Parameters from BaseHParams class.
@ -42,41 +43,9 @@ class HParams(hp.BaseHParams):
batch_size: int = 32 batch_size: int = 32
epochs: int = 10 epochs: int = 10
# Parameters for learning rate decay # Parameters for cosine learning rate decay
learning_rate_epoch_boundaries: List[int] = dataclasses.field( cosine_decay_epochs: Optional[int] = None
default_factory=lambda: [] cosine_decay_alpha: float = 0.0
)
learning_rate_decay_multipliers: List[float] = dataclasses.field(
default_factory=lambda: []
)
def __post_init__(self):
# Validate stepwise learning rate parameters
lr_boundary_len = len(self.learning_rate_epoch_boundaries)
lr_decay_multipliers_len = len(self.learning_rate_decay_multipliers)
if lr_boundary_len != lr_decay_multipliers_len:
raise ValueError(
"Length of learning_rate_epoch_boundaries and ",
"learning_rate_decay_multipliers do not match: ",
f"{lr_boundary_len}!={lr_decay_multipliers_len}",
)
# Validate learning_rate_epoch_boundaries
if (
sorted(self.learning_rate_epoch_boundaries)
!= self.learning_rate_epoch_boundaries
):
raise ValueError(
"learning_rate_epoch_boundaries is not in ascending order: ",
self.learning_rate_epoch_boundaries,
)
if (
self.learning_rate_epoch_boundaries
and self.learning_rate_epoch_boundaries[-1] > self.epochs
):
raise ValueError(
"Values in learning_rate_epoch_boundaries cannot be greater ",
"than epochs",
)
@dataclasses.dataclass @dataclasses.dataclass

View File

@ -354,19 +354,16 @@ class ObjectDetector(classifier.Classifier):
A tf.keras.optimizer.Optimizer for model training. A tf.keras.optimizer.Optimizer for model training.
""" """
init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256 init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256
if self._hparams.learning_rate_epoch_boundaries: decay_epochs = (
lr_values = [init_lr] + [ self._hparams.cosine_decay_epochs
init_lr * m for m in self._hparams.learning_rate_decay_multipliers if self._hparams.cosine_decay_epochs
] else self._hparams.epochs
lr_step_boundaries = [ )
steps_per_epoch * epoch_boundary learning_rate = tf.keras.optimizers.schedules.CosineDecay(
for epoch_boundary in self._hparams.learning_rate_epoch_boundaries init_lr,
] steps_per_epoch * decay_epochs,
learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay( self._hparams.cosine_decay_alpha,
lr_step_boundaries, lr_values )
)
else:
learning_rate = init_lr
return tf.keras.optimizers.experimental.SGD( return tf.keras.optimizers.experimental.SGD(
learning_rate=learning_rate, momentum=0.9 learning_rate=learning_rate, momentum=0.9
) )

View File

@ -581,7 +581,7 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(
// Step 2.5: For SOFTMAX, apply softmax shaders (max, transformAndSum, and // Step 2.5: For SOFTMAX, apply softmax shaders (max, transformAndSum, and
// normalization) to create softmax-transformed chunks before channel // normalization) to create softmax-transformed chunks before channel
// extraction. // extraction.
// NOTE - / sum_over_x(exp(x-C)) = exp(x) / sum_over_x(exp(x)). So // NOTE: exp(x-C) / sum_over_x(exp(x-C)) = exp(x) / sum_over_x(exp(x)). So
// theoretically we can skip the max shader step entirely. However, // theoretically we can skip the max shader step entirely. However,
// applying it does bring all our values into a nice (0, 1] range, so it // applying it does bring all our values into a nice (0, 1] range, so it
// will likely be better for precision, especially when dealing with an // will likely be better for precision, especially when dealing with an

View File

@ -200,7 +200,7 @@ class BoxTracker {
// Cancels all ongoing tracks. To avoid race conditions all NewBoxTrack's in // Cancels all ongoing tracks. To avoid race conditions all NewBoxTrack's in
// flight will also be canceled. Future NewBoxTrack's will be canceled. // flight will also be canceled. Future NewBoxTrack's will be canceled.
// NOTE - before // NOTE: To resume execution, you have to call ResumeTracking() before
// issuing more NewBoxTrack calls. // issuing more NewBoxTrack calls.
void CancelAllOngoingTracks() ABSL_LOCKS_EXCLUDED(status_mutex_); void CancelAllOngoingTracks() ABSL_LOCKS_EXCLUDED(status_mutex_);
void ResumeTracking() ABSL_LOCKS_EXCLUDED(status_mutex_); void ResumeTracking() ABSL_LOCKS_EXCLUDED(status_mutex_);
@ -208,7 +208,7 @@ class BoxTracker {
// Waits for all ongoing tracks to complete. // Waits for all ongoing tracks to complete.
// Optionally accepts a timeout in microseconds (== 0 for infinite wait). // Optionally accepts a timeout in microseconds (== 0 for infinite wait).
// Returns true on success, false if timeout is reached. // Returns true on success, false if timeout is reached.
// NOTE - must // NOTE: If WaitForAllOngoingTracks timed out, CancelAllOngoingTracks() must
// be called before destructing the BoxTracker object or dangeling running // be called before destructing the BoxTracker object or dangeling running
// threads might try to access invalid data. // threads might try to access invalid data.
bool WaitForAllOngoingTracks(int timeout_us = 0) bool WaitForAllOngoingTracks(int timeout_us = 0)