| 116 | |
| 117 | |
| 118 | class MultiStepLRScheduler(Scheduler): |
| 119 | def __init__(self, optimizer: torch.optim.Optimizer, milestones, gamma=0.1, warmup_t=0, warmup_lr_init=0, t_in_epochs=True) -> None: |
| 120 | super().__init__(optimizer, param_group_field="lr") |
| 121 | |
| 122 | self.milestones = milestones |
| 123 | self.gamma = gamma |
| 124 | self.warmup_t = warmup_t |
| 125 | self.warmup_lr_init = warmup_lr_init |
| 126 | self.t_in_epochs = t_in_epochs |
| 127 | if self.warmup_t: |
| 128 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] |
| 129 | super().update_groups(self.warmup_lr_init) |
| 130 | else: |
| 131 | self.warmup_steps = [1 for _ in self.base_values] |
| 132 | |
| 133 | assert self.warmup_t <= min(self.milestones) |
| 134 | |
| 135 | def _get_lr(self, t): |
| 136 | if t < self.warmup_t: |
| 137 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] |
| 138 | else: |
| 139 | lrs = [v * (self.gamma ** bisect.bisect_right(self.milestones, t)) for v in self.base_values] |
| 140 | return lrs |
| 141 | |
| 142 | def get_epoch_values(self, epoch: int): |
| 143 | if self.t_in_epochs: |
| 144 | return self._get_lr(epoch) |
| 145 | else: |
| 146 | return None |
| 147 | |
| 148 | def get_update_values(self, num_updates: int): |
| 149 | if not self.t_in_epochs: |
| 150 | return self._get_lr(num_updates) |
| 151 | else: |
| 152 | return None |