| 8 | |
| 9 | |
| 10 | class Conv1D(nn.Module): |
| 11 | def __init__(self, nf, nx): |
| 12 | super().__init__() |
| 13 | self.nf = nf |
| 14 | w = torch.empty(nx, nf) |
| 15 | nn.init.normal_(w, std=0.02) |
| 16 | self.weight = nn.Parameter(w) |
| 17 | self.bias = nn.Parameter(torch.zeros(nf)) |
| 18 | |
| 19 | def forward(self, x): |
| 20 | size_out = x.shape[:-1] + (self.nf,) |
| 21 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) |
| 22 | x = x.view(size_out) |
| 23 | return x |
| 24 | |
| 25 | |
| 26 | @clear_cache_before_run() |
no outgoing calls
searching dependent graphs…