(self, x, y)
| 28 | self.empty_grad = empty_grad |
| 29 | |
| 30 | def forward(self, x, y): |
| 31 | if len(self.linears) == 1: |
| 32 | x = self.linears[0](x) |
| 33 | else: |
| 34 | for i, l in enumerate(self.linears): |
| 35 | x = self.linears[i // 2](x) + l(x) |
| 36 | return self.cross_entropy_loss(x, y) |
| 37 | |
| 38 | |
| 39 | class SimpleFrozenModel(torch.nn.Module): |