Implements BERT version of Adam algorithm with weight decay fix. Params: lr: learning rate warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 t_total: total number of training steps for the learning rate schedule, -1 means constant le
| 61 | |
| 62 | |
| 63 | class BertAdam(Optimizer): |
| 64 | """Implements BERT version of Adam algorithm with weight decay fix. |
| 65 | Params: |
| 66 | lr: learning rate |
| 67 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 |
| 68 | t_total: total number of training steps for the learning |
| 69 | rate schedule, -1 means constant learning rate. Default: -1 |
| 70 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' |
| 71 | b1: Adams b1. Default: 0.9 |
| 72 | b2: Adams b2. Default: 0.999 |
| 73 | e: Adams epsilon. Default: 1e-6 |
| 74 | weight_decay: Weight decay. Default: 0.01 |
| 75 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 |
| 76 | """ |
| 77 | |
| 78 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, |
| 79 | schedule='warmup_linear', |
| 80 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, |
| 81 | max_grad_norm=1.0): |
| 82 | if lr is not required and lr < 0.0: |
| 83 | raise ValueError( |
| 84 | "Invalid learning rate: {} - should be >= 0.0".format(lr)) |
| 85 | if schedule not in SCHEDULES: |
| 86 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) |
| 87 | if not 0.0 <= warmup < 1.0 and not warmup == -1: |
| 88 | raise ValueError( |
| 89 | "Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format( |
| 90 | warmup)) |
| 91 | if not 0.0 <= b1 < 1.0: |
| 92 | raise ValueError( |
| 93 | "Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) |
| 94 | if not 0.0 <= b2 < 1.0: |
| 95 | raise ValueError( |
| 96 | "Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) |
| 97 | if not e >= 0.0: |
| 98 | raise ValueError( |
| 99 | "Invalid epsilon value: {} - should be >= 0.0".format(e)) |
| 100 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, |
| 101 | t_total=t_total, |
| 102 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, |
| 103 | max_grad_norm=max_grad_norm) |
| 104 | super(BertAdam, self).__init__(params, defaults) |
| 105 | |
| 106 | def get_lr(self): |
| 107 | lr = [] |
| 108 | for group in self.param_groups: |
| 109 | for p in group['params']: |
| 110 | state = self.state[p] |
| 111 | if len(state) == 0: |
| 112 | return [0] |
| 113 | if group['t_total'] != -1: |
| 114 | schedule_fct = SCHEDULES[group['schedule']] |
| 115 | lr_scheduled = group['lr'] * schedule_fct( |
| 116 | state['step'] / group['t_total'], group['warmup']) |
| 117 | else: |
| 118 | lr_scheduled = group['lr'] |
| 119 | lr.append(lr_scheduled) |
| 120 | return lr |