MCPcopy
hub / github.com/microsoft/Swin-Transformer / MultiStepLRScheduler

Class MultiStepLRScheduler

lr_scheduler.py:118–152  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

116
117
118class MultiStepLRScheduler(Scheduler):
119 def __init__(self, optimizer: torch.optim.Optimizer, milestones, gamma=0.1, warmup_t=0, warmup_lr_init=0, t_in_epochs=True) -> None:
120 super().__init__(optimizer, param_group_field="lr")
121
122 self.milestones = milestones
123 self.gamma = gamma
124 self.warmup_t = warmup_t
125 self.warmup_lr_init = warmup_lr_init
126 self.t_in_epochs = t_in_epochs
127 if self.warmup_t:
128 self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
129 super().update_groups(self.warmup_lr_init)
130 else:
131 self.warmup_steps = [1 for _ in self.base_values]
132
133 assert self.warmup_t <= min(self.milestones)
134
135 def _get_lr(self, t):
136 if t < self.warmup_t:
137 lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
138 else:
139 lrs = [v * (self.gamma ** bisect.bisect_right(self.milestones, t)) for v in self.base_values]
140 return lrs
141
142 def get_epoch_values(self, epoch: int):
143 if self.t_in_epochs:
144 return self._get_lr(epoch)
145 else:
146 return None
147
148 def get_update_values(self, num_updates: int):
149 if not self.t_in_epochs:
150 return self._get_lr(num_updates)
151 else:
152 return None

Callers 1

build_schedulerFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected