Anneals the learning rate from start to zero along a cosine curve.
| 20 | |
| 21 | |
| 22 | class AnnealingLR(_LRScheduler): |
| 23 | """Anneals the learning rate from start to zero along a cosine curve.""" |
| 24 | |
| 25 | DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None'] |
| 26 | |
| 27 | def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1, decay_ratio=0.5): |
| 28 | assert warmup_iter <= num_iters |
| 29 | self.optimizer = optimizer |
| 30 | self.start_lr = start_lr |
| 31 | self.warmup_iter = warmup_iter |
| 32 | self.num_iters = last_iter + 1 |
| 33 | self.end_iter = num_iters |
| 34 | self.decay_style = decay_style.lower() if isinstance(decay_style, str) else None |
| 35 | self.decay_ratio = 1 / decay_ratio |
| 36 | self.step(self.num_iters) |
| 37 | if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: |
| 38 | print(f'learning rate decaying style {self.decay_style}, ratio {self.decay_ratio}') |
| 39 | |
| 40 | def get_lr(self): |
| 41 | # https://openreview.net/pdf?id=BJYwwY9ll pg. 4 |
| 42 | if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: |
| 43 | return float(self.start_lr) * self.num_iters / self.warmup_iter |
| 44 | else: |
| 45 | if self.decay_style == self.DECAY_STYLES[0]: |
| 46 | decay_step_ratio = (self.num_iters - self.warmup_iter) / self.end_iter |
| 47 | return self.start_lr - self.start_lr * (1 - 1 / self.decay_ratio) * decay_step_ratio |
| 48 | elif self.decay_style == self.DECAY_STYLES[1]: |
| 49 | decay_step_ratio = min(1.0, (self.num_iters - self.warmup_iter) / self.end_iter) |
| 50 | return self.start_lr / self.decay_ratio * ( |
| 51 | (math.cos(math.pi * decay_step_ratio) + 1) * (self.decay_ratio - 1) / 2 + 1) |
| 52 | elif self.decay_style == self.DECAY_STYLES[2]: |
| 53 | # TODO: implement exponential decay |
| 54 | return self.start_lr |
| 55 | else: |
| 56 | return self.start_lr |
| 57 | |
| 58 | def step(self, step_num=None): |
| 59 | if step_num is None: |
| 60 | step_num = self.num_iters + 1 |
| 61 | self.num_iters = step_num |
| 62 | new_lr = self.get_lr() |
| 63 | for group in self.optimizer.param_groups: |
| 64 | group['lr'] = new_lr |
| 65 | |
| 66 | def state_dict(self): |
| 67 | sd = { |
| 68 | # 'start_lr': self.start_lr, |
| 69 | 'warmup_iter': self.warmup_iter, |
| 70 | 'num_iters': self.num_iters, |
| 71 | 'decay_style': self.decay_style, |
| 72 | 'end_iter': self.end_iter, |
| 73 | 'decay_ratio': self.decay_ratio |
| 74 | } |
| 75 | return sd |
| 76 | |
| 77 | def load_state_dict(self, sd): |
| 78 | # self.start_lr = sd['start_lr'] |
| 79 | self.warmup_iter = sd['warmup_iter'] |
no outgoing calls