MCPcopy Index your code
hub / github.com/THUDM/GLM / AnnealingLR

Class AnnealingLR

learning_rates.py:22–92  ·  view source on GitHub ↗

Anneals the learning rate from start to zero along a cosine curve.

Source from the content-addressed store, hash-verified

20
21
22class AnnealingLR(_LRScheduler):
23 """Anneals the learning rate from start to zero along a cosine curve."""
24
25 DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None']
26
27 def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1, decay_ratio=0.5):
28 assert warmup_iter <= num_iters
29 self.optimizer = optimizer
30 self.start_lr = start_lr
31 self.warmup_iter = warmup_iter
32 self.num_iters = last_iter + 1
33 self.end_iter = num_iters
34 self.decay_style = decay_style.lower() if isinstance(decay_style, str) else None
35 self.decay_ratio = 1 / decay_ratio
36 self.step(self.num_iters)
37 if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
38 print(f'learning rate decaying style {self.decay_style}, ratio {self.decay_ratio}')
39
40 def get_lr(self):
41 # https://openreview.net/pdf?id=BJYwwY9ll pg. 4
42 if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter:
43 return float(self.start_lr) * self.num_iters / self.warmup_iter
44 else:
45 if self.decay_style == self.DECAY_STYLES[0]:
46 decay_step_ratio = (self.num_iters - self.warmup_iter) / self.end_iter
47 return self.start_lr - self.start_lr * (1 - 1 / self.decay_ratio) * decay_step_ratio
48 elif self.decay_style == self.DECAY_STYLES[1]:
49 decay_step_ratio = min(1.0, (self.num_iters - self.warmup_iter) / self.end_iter)
50 return self.start_lr / self.decay_ratio * (
51 (math.cos(math.pi * decay_step_ratio) + 1) * (self.decay_ratio - 1) / 2 + 1)
52 elif self.decay_style == self.DECAY_STYLES[2]:
53 # TODO: implement exponential decay
54 return self.start_lr
55 else:
56 return self.start_lr
57
58 def step(self, step_num=None):
59 if step_num is None:
60 step_num = self.num_iters + 1
61 self.num_iters = step_num
62 new_lr = self.get_lr()
63 for group in self.optimizer.param_groups:
64 group['lr'] = new_lr
65
66 def state_dict(self):
67 sd = {
68 # 'start_lr': self.start_lr,
69 'warmup_iter': self.warmup_iter,
70 'num_iters': self.num_iters,
71 'decay_style': self.decay_style,
72 'end_iter': self.end_iter,
73 'decay_ratio': self.decay_ratio
74 }
75 return sd
76
77 def load_state_dict(self, sd):
78 # self.start_lr = sd['start_lr']
79 self.warmup_iter = sd['warmup_iter']

Callers 2

mainFunction · 0.90

Calls

no outgoing calls

Tested by 1

mainFunction · 0.72