| 39 | class SimpleFrozenModel(torch.nn.Module): |
| 40 | |
| 41 | def __init__(self, hidden_dim, empty_grad=False): |
| 42 | super(SimpleFrozenModel, self).__init__() |
| 43 | self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for i in range(2)]) |
| 44 | if empty_grad: |
| 45 | self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) |
| 46 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss() |
| 47 | self.empty_grad = empty_grad |
| 48 | # Freeze first layer |
| 49 | self.linears[0].weight.requires_grad = False |
| 50 | self.linears[0].bias.requires_grad = False |
| 51 | |
| 52 | def custom_state_dict(self, *args, **kwargs): |
| 53 | state_dict = super(SimpleFrozenModel, self).state_dict(*args, **kwargs) |