(self, closure=None)
| 72 | super().__init__(params, defaults) |
| 73 | |
| 74 | def step(self, closure=None): |
| 75 | loss = None |
| 76 | if closure is not None: |
| 77 | loss = closure() |
| 78 | |
| 79 | for group in self.param_groups: |
| 80 | for p in group['params']: |
| 81 | if p.grad is None: |
| 82 | continue |
| 83 | |
| 84 | grad = p.grad.data |
| 85 | state = self.state[p] |
| 86 | |
| 87 | if len(state) == 0: |
| 88 | state['step'] = 0 |
| 89 | state['exp_avg'] = torch.zeros_like(p.data) |
| 90 | state['exp_avg_sq'] = torch.zeros_like(p.data) |
| 91 | |
| 92 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] |
| 93 | beta1, beta2 = group['betas'] |
| 94 | state['step'] += 1 |
| 95 | |
| 96 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) |
| 97 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) |
| 98 | |
| 99 | bias_correction1 = 1 - beta1 ** state['step'] |
| 100 | bias_correction2 = 1 - beta2 ** state['step'] |
| 101 | |
| 102 | exp_avg_hat = exp_avg / bias_correction1 |
| 103 | exp_avg_sq_hat = exp_avg_sq / bias_correction2 |
| 104 | |
| 105 | update = exp_avg_hat / (exp_avg_sq_hat.sqrt() + group['eps']) |
| 106 | update.add_(p.data, alpha=group['weight_decay']) |
| 107 | |
| 108 | weight_norm = p.data.norm() |
| 109 | update_norm = update.norm() |
| 110 | |
| 111 | if weight_norm > 0 and update_norm > 0: |
| 112 | trust_ratio = weight_norm / update_norm |
| 113 | else: |
| 114 | trust_ratio = 1.0 |
| 115 | |
| 116 | p.data.add_(update, alpha=-group['lr'] * trust_ratio) |
| 117 | |
| 118 | return loss |
| 119 | |
| 120 | class RAdam(Optimizer): |
| 121 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): |
no test coverage detected