| 157 | class LinearStack(torch.nn.Module): |
| 158 | |
| 159 | def __init__(self, input_dim=128, hidden_dim=128, output_dim=128, num_layers=4): |
| 160 | super().__init__() |
| 161 | self.input_dim = input_dim |
| 162 | self.output_dim = output_dim |
| 163 | self.hidden_dim = hidden_dim |
| 164 | |
| 165 | self.input_layer = torch.nn.Linear(in_features=self.input_dim, out_features=self.hidden_dim) |
| 166 | self.layers = torch.nn.ModuleList([ |
| 167 | torch.nn.Linear(in_features=self.hidden_dim, out_features=self.hidden_dim, bias=False) |
| 168 | for x in range(num_layers) |
| 169 | ]) |
| 170 | self.output_layer = torch.nn.Linear(in_features=self.hidden_dim, out_features=self.output_dim) |
| 171 | |
| 172 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss() |
| 173 | |
| 174 | def forward(self, x, y): |
| 175 | x = self.input_layer(x) |