| 107 | self.updates_created=False |
| 108 | |
| 109 | def get_lr(self): |
| 110 | lr = [] |
| 111 | for group in self.param_groups: |
| 112 | for p in group['params']: |
| 113 | state = self.state[p] |
| 114 | if len(state) == 0: |
| 115 | return [0] |
| 116 | if group['t_total'] != -1: |
| 117 | schedule_fct = SCHEDULES[group['schedule']] |
| 118 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) |
| 119 | else: |
| 120 | lr_scheduled = group['lr'] |
| 121 | lr.append(lr_scheduled) |
| 122 | return lr |
| 123 | |
| 124 | def apply_gradients(self, dummy_overflow_buf, lr_scheduled, per_param_decay, grad_list, param_list, momentum, velocity, update): |
| 125 | # Compute global gradient norm |