MCPcopy Index your code
hub / github.com/Standard-Intelligence/hertz-dev / GPTOutput

Class GPTOutput

transformer.py:306–320  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

304
305
306class GPTOutput(nn.Module):
307 def __init__(self, dim, vocab_size):
308 super().__init__()
309 self.dim = dim
310 self.norm = Norm(dim)
311 self.output = Linear(dim, vocab_size)
312
313 self.reset_parameters()
314
315 def reset_parameters(self):
316 std = 1.0 / math.sqrt(self.dim**2)
317 nn.init.trunc_normal_(self.output.weight, std=std, a=-3 * std, b=3 * std)
318
319 def forward(self, x):
320 return self.output(self.norm(x))
321
322@si_module
323class Stack(nn.Module):

Callers 1

__init__Method · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected