Change object detector learning rate decay to cosine decay.
PiperOrigin-RevId: 527337105
This commit is contained in:
parent
507ed0d91d
commit
48aa88f39d
|
@ -273,7 +273,7 @@ defined in the enclosing protobuf in order to be traversed using
|
||||||
|
|
||||||
## Cycles
|
## Cycles
|
||||||
|
|
||||||
<!-- TODO -->
|
<!-- TODO: add discussion of PreviousLoopbackCalculator -->
|
||||||
|
|
||||||
By default, MediaPipe requires calculator graphs to be acyclic and treats cycles
|
By default, MediaPipe requires calculator graphs to be acyclic and treats cycles
|
||||||
in a graph as errors. If a graph is intended to have cycles, the cycles need to
|
in a graph as errors. If a graph is intended to have cycles, the cycles need to
|
||||||
|
|
|
@ -164,7 +164,7 @@ class Contract {
|
||||||
|
|
||||||
std::tuple<T...> items;
|
std::tuple<T...> items;
|
||||||
|
|
||||||
// TODO -, check for conflicts.
|
// TODO: when forwarding nested items (e.g. ports), check for conflicts.
|
||||||
decltype(ExtractNestedItems(items)) all_items{ExtractNestedItems(items)};
|
decltype(ExtractNestedItems(items)) all_items{ExtractNestedItems(items)};
|
||||||
|
|
||||||
constexpr auto inputs() const {
|
constexpr auto inputs() const {
|
||||||
|
|
|
@ -150,7 +150,7 @@ class CalculatorBase {
|
||||||
// Packets may be output during a call to Close(). However, output packets
|
// Packets may be output during a call to Close(). However, output packets
|
||||||
// are silently discarded if Close() is called after a graph run has ended.
|
// are silently discarded if Close() is called after a graph run has ended.
|
||||||
//
|
//
|
||||||
// NOTE - needs to perform an action only when processing is
|
// NOTE: If Close() needs to perform an action only when processing is
|
||||||
// complete, Close() must check if cc->GraphStatus() is OK.
|
// complete, Close() must check if cc->GraphStatus() is OK.
|
||||||
virtual absl::Status Close(CalculatorContext* cc) { return absl::OkStatus(); }
|
virtual absl::Status Close(CalculatorContext* cc) { return absl::OkStatus(); }
|
||||||
|
|
||||||
|
|
|
@ -111,7 +111,7 @@ class CalculatorContext {
|
||||||
|
|
||||||
// Returns the status of the graph run.
|
// Returns the status of the graph run.
|
||||||
//
|
//
|
||||||
// NOTE -.
|
// NOTE: This method should only be called during CalculatorBase::Close().
|
||||||
absl::Status GraphStatus() const { return graph_status_; }
|
absl::Status GraphStatus() const { return graph_status_; }
|
||||||
|
|
||||||
ProfilingContext* GetProfilingContext() const {
|
ProfilingContext* GetProfilingContext() const {
|
||||||
|
|
|
@ -66,7 +66,7 @@ class CalculatorRunner {
|
||||||
explicit CalculatorRunner(const std::string& node_config_string);
|
explicit CalculatorRunner(const std::string& node_config_string);
|
||||||
// Convenience constructor to initialize a calculator which uses indexes
|
// Convenience constructor to initialize a calculator which uses indexes
|
||||||
// (not tags) for all its fields.
|
// (not tags) for all its fields.
|
||||||
// NOTE -, which
|
// NOTE: This constructor calls proto_ns::TextFormat::ParseFromString(), which
|
||||||
// is not available when using lite protos.
|
// is not available when using lite protos.
|
||||||
CalculatorRunner(const std::string& calculator_type,
|
CalculatorRunner(const std::string& calculator_type,
|
||||||
const std::string& options_string, int num_inputs,
|
const std::string& options_string, int num_inputs,
|
||||||
|
|
|
@ -30,7 +30,7 @@ class OutputSidePacket {
|
||||||
|
|
||||||
// Sets the output side packet. The Packet must contain the data.
|
// Sets the output side packet. The Packet must contain the data.
|
||||||
//
|
//
|
||||||
// NOTE - cannot report errors via the return value. It uses an error
|
// NOTE: Set() cannot report errors via the return value. It uses an error
|
||||||
// callback function to report errors.
|
// callback function to report errors.
|
||||||
virtual void Set(const Packet& packet) = 0;
|
virtual void Set(const Packet& packet) = 0;
|
||||||
};
|
};
|
||||||
|
|
|
@ -48,7 +48,7 @@ class OutputSidePacketImpl : public OutputSidePacket {
|
||||||
|
|
||||||
// Sets the output side packet. The Packet must contain the data.
|
// Sets the output side packet. The Packet must contain the data.
|
||||||
//
|
//
|
||||||
// NOTE - cannot report errors via the return value. It uses an error
|
// NOTE: Set() cannot report errors via the return value. It uses an error
|
||||||
// callback function to report errors.
|
// callback function to report errors.
|
||||||
void Set(const Packet& packet) override;
|
void Set(const Packet& packet) override;
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,7 @@ class OutputStream {
|
||||||
// the only packet in the stream.
|
// the only packet in the stream.
|
||||||
// Violation of any of these conditions causes a CHECK-failure.
|
// Violation of any of these conditions causes a CHECK-failure.
|
||||||
//
|
//
|
||||||
// NOTE - cannot report errors via the return value. Instead of a
|
// NOTE: AddPacket() cannot report errors via the return value. Instead of a
|
||||||
// CHECK-failure, a subclass of OutputStream should use a callback function
|
// CHECK-failure, a subclass of OutputStream should use a callback function
|
||||||
// to report errors.
|
// to report errors.
|
||||||
virtual void AddPacket(const Packet& packet) = 0;
|
virtual void AddPacket(const Packet& packet) = 0;
|
||||||
|
|
|
@ -102,7 +102,7 @@ class SchedulerQueue : public TaskQueue {
|
||||||
// Implements the TaskQueue interface.
|
// Implements the TaskQueue interface.
|
||||||
void RunNextTask() override;
|
void RunNextTask() override;
|
||||||
|
|
||||||
// NOTE -, the caller must call
|
// NOTE: After calling SetRunning(true), the caller must call
|
||||||
// SubmitWaitingTasksToExecutor since tasks may have been added while the
|
// SubmitWaitingTasksToExecutor since tasks may have been added while the
|
||||||
// queue was not running.
|
// queue was not running.
|
||||||
void SetRunning(bool running) ABSL_LOCKS_EXCLUDED(mutex_);
|
void SetRunning(bool running) ABSL_LOCKS_EXCLUDED(mutex_);
|
||||||
|
|
|
@ -25,7 +25,7 @@ constexpr char kAnchorsTag[] = "ANCHORS";
|
||||||
constexpr char kBoxesInputTag[] = "BOXES";
|
constexpr char kBoxesInputTag[] = "BOXES";
|
||||||
constexpr char kBoxesOutputTag[] = "START_POS";
|
constexpr char kBoxesOutputTag[] = "START_POS";
|
||||||
constexpr char kCancelTag[] = "CANCEL_ID";
|
constexpr char kCancelTag[] = "CANCEL_ID";
|
||||||
// TODO -
|
// TODO: Find optimal Height/Width (0.1-0.3)
|
||||||
constexpr float kBoxEdgeSize =
|
constexpr float kBoxEdgeSize =
|
||||||
0.2f; // Used to establish tracking box dimensions
|
0.2f; // Used to establish tracking box dimensions
|
||||||
constexpr float kUsToMs =
|
constexpr float kUsToMs =
|
||||||
|
|
|
@ -106,7 +106,7 @@ public class ExternalTextureRenderer {
|
||||||
*
|
*
|
||||||
* <p>Before calling this, {@link #setup} must have been called.
|
* <p>Before calling this, {@link #setup} must have been called.
|
||||||
*
|
*
|
||||||
* <p>NOTE -} on passed surface texture.
|
* <p>NOTE: Calls {@link SurfaceTexture#updateTexImage()} on passed surface texture.
|
||||||
*/
|
*/
|
||||||
public void render(SurfaceTexture surfaceTexture) {
|
public void render(SurfaceTexture surfaceTexture) {
|
||||||
GLES20.glClear(GLES20.GL_COLOR_BUFFER_BIT);
|
GLES20.glClear(GLES20.GL_COLOR_BUFFER_BIT);
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
"""Hyperparameters for training object detection models."""
|
"""Hyperparameters for training object detection models."""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import List
|
from typing import Optional
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
|
||||||
|
@ -29,12 +29,13 @@ class HParams(hp.BaseHParams):
|
||||||
epochs: Number of training iterations over the dataset.
|
epochs: Number of training iterations over the dataset.
|
||||||
do_fine_tuning: If true, the base module is trained together with the
|
do_fine_tuning: If true, the base module is trained together with the
|
||||||
classification layer on top.
|
classification layer on top.
|
||||||
learning_rate_epoch_boundaries: List of epoch boundaries where
|
cosine_decay_epochs: The number of epochs for cosine decay learning rate.
|
||||||
learning_rate_epoch_boundaries[i] is the epoch where the learning rate
|
See
|
||||||
will decay to learning_rate * learning_rate_decay_multipliers[i].
|
https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/CosineDecay
|
||||||
learning_rate_decay_multipliers: List of learning rate multipliers which
|
for more info.
|
||||||
calculates the learning rate at the ith boundary as learning_rate *
|
cosine_decay_alpha: The alpha value for cosine decay learning rate. See
|
||||||
learning_rate_decay_multipliers[i].
|
https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/CosineDecay
|
||||||
|
for more info.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Parameters from BaseHParams class.
|
# Parameters from BaseHParams class.
|
||||||
|
@ -42,41 +43,9 @@ class HParams(hp.BaseHParams):
|
||||||
batch_size: int = 32
|
batch_size: int = 32
|
||||||
epochs: int = 10
|
epochs: int = 10
|
||||||
|
|
||||||
# Parameters for learning rate decay
|
# Parameters for cosine learning rate decay
|
||||||
learning_rate_epoch_boundaries: List[int] = dataclasses.field(
|
cosine_decay_epochs: Optional[int] = None
|
||||||
default_factory=lambda: []
|
cosine_decay_alpha: float = 0.0
|
||||||
)
|
|
||||||
learning_rate_decay_multipliers: List[float] = dataclasses.field(
|
|
||||||
default_factory=lambda: []
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# Validate stepwise learning rate parameters
|
|
||||||
lr_boundary_len = len(self.learning_rate_epoch_boundaries)
|
|
||||||
lr_decay_multipliers_len = len(self.learning_rate_decay_multipliers)
|
|
||||||
if lr_boundary_len != lr_decay_multipliers_len:
|
|
||||||
raise ValueError(
|
|
||||||
"Length of learning_rate_epoch_boundaries and ",
|
|
||||||
"learning_rate_decay_multipliers do not match: ",
|
|
||||||
f"{lr_boundary_len}!={lr_decay_multipliers_len}",
|
|
||||||
)
|
|
||||||
# Validate learning_rate_epoch_boundaries
|
|
||||||
if (
|
|
||||||
sorted(self.learning_rate_epoch_boundaries)
|
|
||||||
!= self.learning_rate_epoch_boundaries
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"learning_rate_epoch_boundaries is not in ascending order: ",
|
|
||||||
self.learning_rate_epoch_boundaries,
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
self.learning_rate_epoch_boundaries
|
|
||||||
and self.learning_rate_epoch_boundaries[-1] > self.epochs
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Values in learning_rate_epoch_boundaries cannot be greater ",
|
|
||||||
"than epochs",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
|
|
@ -354,19 +354,16 @@ class ObjectDetector(classifier.Classifier):
|
||||||
A tf.keras.optimizer.Optimizer for model training.
|
A tf.keras.optimizer.Optimizer for model training.
|
||||||
"""
|
"""
|
||||||
init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256
|
init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256
|
||||||
if self._hparams.learning_rate_epoch_boundaries:
|
decay_epochs = (
|
||||||
lr_values = [init_lr] + [
|
self._hparams.cosine_decay_epochs
|
||||||
init_lr * m for m in self._hparams.learning_rate_decay_multipliers
|
if self._hparams.cosine_decay_epochs
|
||||||
]
|
else self._hparams.epochs
|
||||||
lr_step_boundaries = [
|
)
|
||||||
steps_per_epoch * epoch_boundary
|
learning_rate = tf.keras.optimizers.schedules.CosineDecay(
|
||||||
for epoch_boundary in self._hparams.learning_rate_epoch_boundaries
|
init_lr,
|
||||||
]
|
steps_per_epoch * decay_epochs,
|
||||||
learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
|
self._hparams.cosine_decay_alpha,
|
||||||
lr_step_boundaries, lr_values
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
learning_rate = init_lr
|
|
||||||
return tf.keras.optimizers.experimental.SGD(
|
return tf.keras.optimizers.experimental.SGD(
|
||||||
learning_rate=learning_rate, momentum=0.9
|
learning_rate=learning_rate, momentum=0.9
|
||||||
)
|
)
|
||||||
|
|
|
@ -581,7 +581,7 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(
|
||||||
// Step 2.5: For SOFTMAX, apply softmax shaders (max, transformAndSum, and
|
// Step 2.5: For SOFTMAX, apply softmax shaders (max, transformAndSum, and
|
||||||
// normalization) to create softmax-transformed chunks before channel
|
// normalization) to create softmax-transformed chunks before channel
|
||||||
// extraction.
|
// extraction.
|
||||||
// NOTE - / sum_over_x(exp(x-C)) = exp(x) / sum_over_x(exp(x)). So
|
// NOTE: exp(x-C) / sum_over_x(exp(x-C)) = exp(x) / sum_over_x(exp(x)). So
|
||||||
// theoretically we can skip the max shader step entirely. However,
|
// theoretically we can skip the max shader step entirely. However,
|
||||||
// applying it does bring all our values into a nice (0, 1] range, so it
|
// applying it does bring all our values into a nice (0, 1] range, so it
|
||||||
// will likely be better for precision, especially when dealing with an
|
// will likely be better for precision, especially when dealing with an
|
||||||
|
|
|
@ -200,7 +200,7 @@ class BoxTracker {
|
||||||
|
|
||||||
// Cancels all ongoing tracks. To avoid race conditions all NewBoxTrack's in
|
// Cancels all ongoing tracks. To avoid race conditions all NewBoxTrack's in
|
||||||
// flight will also be canceled. Future NewBoxTrack's will be canceled.
|
// flight will also be canceled. Future NewBoxTrack's will be canceled.
|
||||||
// NOTE - before
|
// NOTE: To resume execution, you have to call ResumeTracking() before
|
||||||
// issuing more NewBoxTrack calls.
|
// issuing more NewBoxTrack calls.
|
||||||
void CancelAllOngoingTracks() ABSL_LOCKS_EXCLUDED(status_mutex_);
|
void CancelAllOngoingTracks() ABSL_LOCKS_EXCLUDED(status_mutex_);
|
||||||
void ResumeTracking() ABSL_LOCKS_EXCLUDED(status_mutex_);
|
void ResumeTracking() ABSL_LOCKS_EXCLUDED(status_mutex_);
|
||||||
|
@ -208,7 +208,7 @@ class BoxTracker {
|
||||||
// Waits for all ongoing tracks to complete.
|
// Waits for all ongoing tracks to complete.
|
||||||
// Optionally accepts a timeout in microseconds (== 0 for infinite wait).
|
// Optionally accepts a timeout in microseconds (== 0 for infinite wait).
|
||||||
// Returns true on success, false if timeout is reached.
|
// Returns true on success, false if timeout is reached.
|
||||||
// NOTE - must
|
// NOTE: If WaitForAllOngoingTracks timed out, CancelAllOngoingTracks() must
|
||||||
// be called before destructing the BoxTracker object or dangeling running
|
// be called before destructing the BoxTracker object or dangeling running
|
||||||
// threads might try to access invalid data.
|
// threads might try to access invalid data.
|
||||||
bool WaitForAllOngoingTracks(int timeout_us = 0)
|
bool WaitForAllOngoingTracks(int timeout_us = 0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user