| 64 | |
| 65 | |
| 66 | class LinearLRScheduler(Scheduler): |
| 67 | def __init__(self, |
| 68 | optimizer: torch.optim.Optimizer, |
| 69 | t_initial: int, |
| 70 | lr_min_rate: float, |
| 71 | warmup_t=0, |
| 72 | warmup_lr_init=0., |
| 73 | t_in_epochs=True, |
| 74 | noise_range_t=None, |
| 75 | noise_pct=0.67, |
| 76 | noise_std=1.0, |
| 77 | noise_seed=42, |
| 78 | initialize=True, |
| 79 | ) -> None: |
| 80 | super().__init__( |
| 81 | optimizer, param_group_field="lr", |
| 82 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, |
| 83 | initialize=initialize) |
| 84 | |
| 85 | self.t_initial = t_initial |
| 86 | self.lr_min_rate = lr_min_rate |
| 87 | self.warmup_t = warmup_t |
| 88 | self.warmup_lr_init = warmup_lr_init |
| 89 | self.t_in_epochs = t_in_epochs |
| 90 | if self.warmup_t: |
| 91 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] |
| 92 | super().update_groups(self.warmup_lr_init) |
| 93 | else: |
| 94 | self.warmup_steps = [1 for _ in self.base_values] |
| 95 | |
| 96 | def _get_lr(self, t): |
| 97 | if t < self.warmup_t: |
| 98 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] |
| 99 | else: |
| 100 | t = t - self.warmup_t |
| 101 | total_t = self.t_initial - self.warmup_t |
| 102 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] |
| 103 | return lrs |
| 104 | |
| 105 | def get_epoch_values(self, epoch: int): |
| 106 | if self.t_in_epochs: |
| 107 | return self._get_lr(epoch) |
| 108 | else: |
| 109 | return None |
| 110 | |
| 111 | def get_update_values(self, num_updates: int): |
| 112 | if not self.t_in_epochs: |
| 113 | return self._get_lr(num_updates) |
| 114 | else: |
| 115 | return None |
| 116 | |
| 117 | |
| 118 | class MultiStepLRScheduler(Scheduler): |