r"""Fused Linear and activation Module.
| 149 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} |
| 150 | |
| 151 | class LinearActivation(Module): |
| 152 | r"""Fused Linear and activation Module. |
| 153 | """ |
| 154 | __constants__ = ['bias'] |
| 155 | |
| 156 | def __init__(self, in_features, out_features, act='gelu', bias=True): |
| 157 | super(LinearActivation, self).__init__() |
| 158 | self.in_features = in_features |
| 159 | self.out_features = out_features |
| 160 | self.fused_gelu = False |
| 161 | self.fused_tanh = False |
| 162 | if isinstance(act, str) or (sys.version_info[0] == 2 and isinstance(act, unicode)): |
| 163 | if bias and act == 'gelu': |
| 164 | self.fused_gelu = True |
| 165 | elif bias and act == 'tanh': |
| 166 | self.fused_tanh = True |
| 167 | else: |
| 168 | self.act_fn = ACT2FN[act] |
| 169 | else: |
| 170 | self.act_fn = act |
| 171 | self.weight = Parameter(torch.Tensor(out_features, in_features)) |
| 172 | if bias: |
| 173 | self.bias = Parameter(torch.Tensor(out_features)) |
| 174 | else: |
| 175 | self.register_parameter('bias', None) |
| 176 | self.reset_parameters() |
| 177 | |
| 178 | def reset_parameters(self): |
| 179 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) |
| 180 | if self.bias is not None: |
| 181 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) |
| 182 | bound = 1 / math.sqrt(fan_in) |
| 183 | init.uniform_(self.bias, -bound, bound) |
| 184 | |
| 185 | def forward(self, input): |
| 186 | if self.fused_gelu: |
| 187 | return bias_gelu(self.bias, F.linear(input, self.weight, None)) |
| 188 | elif self.fused_tanh: |
| 189 | return bias_tanh(self.bias, F.linear(input, self.weight, None)) |
| 190 | else: |
| 191 | return self.act_fn(F.linear(input, self.weight, self.bias)) |
| 192 | |
| 193 | def extra_repr(self): |
| 194 | return 'in_features={}, out_features={}, bias={}'.format( |
| 195 | self.in_features, self.out_features, self.bias is not None |
| 196 | ) |
| 197 | |
| 198 | |
| 199 | class BertConfig(object): |