| 304 | |
| 305 | |
| 306 | class GPTOutput(nn.Module): |
| 307 | def __init__(self, dim, vocab_size): |
| 308 | super().__init__() |
| 309 | self.dim = dim |
| 310 | self.norm = Norm(dim) |
| 311 | self.output = Linear(dim, vocab_size) |
| 312 | |
| 313 | self.reset_parameters() |
| 314 | |
| 315 | def reset_parameters(self): |
| 316 | std = 1.0 / math.sqrt(self.dim**2) |
| 317 | nn.init.trunc_normal_(self.output.weight, std=std, a=-3 * std, b=3 * std) |
| 318 | |
| 319 | def forward(self, x): |
| 320 | return self.output(self.norm(x)) |
| 321 | |
| 322 | @si_module |
| 323 | class Stack(nn.Module): |