MCPcopy
hub / github.com/Tencent/NeuralNLP-NeuralClassifier / BertAdam

Class BertAdam

model/optimizer.py:63–219  ·  view source on GitHub ↗

Implements BERT version of Adam algorithm with weight decay fix. Params: lr: learning rate warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 t_total: total number of training steps for the learning rate schedule, -1 means constant le

Source from the content-addressed store, hash-verified

61
62
63class BertAdam(Optimizer):
64 """Implements BERT version of Adam algorithm with weight decay fix.
65 Params:
66 lr: learning rate
67 warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
68 t_total: total number of training steps for the learning
69 rate schedule, -1 means constant learning rate. Default: -1
70 schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
71 b1: Adams b1. Default: 0.9
72 b2: Adams b2. Default: 0.999
73 e: Adams epsilon. Default: 1e-6
74 weight_decay: Weight decay. Default: 0.01
75 max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
76 """
77
78 def __init__(self, params, lr=required, warmup=-1, t_total=-1,
79 schedule='warmup_linear',
80 b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
81 max_grad_norm=1.0):
82 if lr is not required and lr < 0.0:
83 raise ValueError(
84 "Invalid learning rate: {} - should be >= 0.0".format(lr))
85 if schedule not in SCHEDULES:
86 raise ValueError("Invalid schedule parameter: {}".format(schedule))
87 if not 0.0 <= warmup < 1.0 and not warmup == -1:
88 raise ValueError(
89 "Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(
90 warmup))
91 if not 0.0 <= b1 < 1.0:
92 raise ValueError(
93 "Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
94 if not 0.0 <= b2 < 1.0:
95 raise ValueError(
96 "Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
97 if not e >= 0.0:
98 raise ValueError(
99 "Invalid epsilon value: {} - should be >= 0.0".format(e))
100 defaults = dict(lr=lr, schedule=schedule, warmup=warmup,
101 t_total=t_total,
102 b1=b1, b2=b2, e=e, weight_decay=weight_decay,
103 max_grad_norm=max_grad_norm)
104 super(BertAdam, self).__init__(params, defaults)
105
106 def get_lr(self):
107 lr = []
108 for group in self.param_groups:
109 for p in group['params']:
110 state = self.state[p]
111 if len(state) == 0:
112 return [0]
113 if group['t_total'] != -1:
114 schedule_fct = SCHEDULES[group['schedule']]
115 lr_scheduled = group['lr'] * schedule_fct(
116 state['step'] / group['t_total'], group['warmup'])
117 else:
118 lr_scheduled = group['lr']
119 lr.append(lr_scheduled)
120 return lr

Callers 1

get_optimizerFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected