| 13 | |
| 14 | |
| 15 | class SimpleModel(torch.nn.Module): |
| 16 | |
| 17 | def __init__(self, hidden_dim, empty_grad=False): |
| 18 | super(SimpleModel, self).__init__() |
| 19 | self.linear = torch.nn.Linear(hidden_dim, hidden_dim) |
| 20 | self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) |
| 21 | self.linear3 = torch.nn.Linear(hidden_dim, hidden_dim) |
| 22 | self.linear4 = torch.nn.Linear(hidden_dim, hidden_dim) |
| 23 | if empty_grad: |
| 24 | self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)]) |
| 25 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss() |
| 26 | |
| 27 | def forward(self, x, y): |
| 28 | hidden = x |
| 29 | hidden = self.linear(hidden) |
| 30 | hidden = self.linear2(hidden) |
| 31 | hidden = self.linear3(hidden) |
| 32 | hidden = self.linear4(hidden) |
| 33 | return self.cross_entropy_loss(hidden, y) |
| 34 | |
| 35 | |
| 36 | def create_config_from_dict(tmpdir, config_dict): |
no outgoing calls
no test coverage detected
searching dependent graphs…