| 13 | |
| 14 | |
| 15 | class SimpleModel(torch.nn.Module): |
| 16 | |
| 17 | def __init__(self, hidden_dim, empty_grad=False): |
| 18 | super(SimpleModel, self).__init__() |
| 19 | self.linear = torch.nn.Linear(hidden_dim, hidden_dim, bias=True) |
| 20 | self.linear = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) |
| 21 | if empty_grad: |
| 22 | self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, |
| 23 | hidden_dim)]) #QuantizeLinear(hidden_dim, hidden_dim) |
| 24 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss() |
| 25 | |
| 26 | def forward(self, x, y): |
| 27 | hidden = x |
| 28 | hidden1 = self.linear(hidden) |
| 29 | hidden2 = self.linear(hidden1) |
| 30 | return self.cross_entropy_loss(hidden2, y) |
| 31 | |
| 32 | |
| 33 | def create_config_from_dict(tmpdir, config_dict): |
no outgoing calls
no test coverage detected
searching dependent graphs…