(self, x: torch.Tensor)
| 158 | """Subclass torch's LayerNorm to handle fp16.""" |
| 159 | |
| 160 | def forward(self, x: torch.Tensor): |
| 161 | orig_type = x.dtype |
| 162 | ret = super().forward(x.type(torch.float32)) |
| 163 | return ret.type(orig_type) |
| 164 | |
| 165 | |
| 166 | class QuickGELU(nn.Module): |