(self, hidden_dim, empty_grad=False, nlayers=1)
| 20 | class SimpleModel(torch.nn.Module): |
| 21 | |
| 22 | def __init__(self, hidden_dim, empty_grad=False, nlayers=1): |
| 23 | super(SimpleModel, self).__init__() |
| 24 | self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for i in range(nlayers)]) |
| 25 | if empty_grad: |
| 26 | self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) |
| 27 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss() |
| 28 | self.empty_grad = empty_grad |
| 29 | |
| 30 | def forward(self, x, y): |
| 31 | if len(self.linears) == 1: |