(self, x)
| 17 | self.bias = nn.Parameter(torch.zeros(nf)) |
| 18 | |
| 19 | def forward(self, x): |
| 20 | size_out = x.shape[:-1] + (self.nf,) |
| 21 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) |
| 22 | x = x.view(size_out) |
| 23 | return x |
| 24 | |
| 25 | |
| 26 | @clear_cache_before_run() |