| 182 | class LinearStackPipe(PipelineModule): |
| 183 | |
| 184 | def __init__(self, input_dim=128, hidden_dim=128, output_dim=128, num_layers=4, **kwargs): |
| 185 | self.input_dim = input_dim |
| 186 | self.output_dim = output_dim |
| 187 | self.hidden_dim = hidden_dim |
| 188 | self.num_layers = num_layers |
| 189 | |
| 190 | layers = [] |
| 191 | layers.append(LayerSpec(torch.nn.Linear, self.input_dim, self.hidden_dim)) |
| 192 | for x in range(self.num_layers): |
| 193 | layers.append(LayerSpec(torch.nn.Linear, self.hidden_dim, self.hidden_dim, bias=False)) |
| 194 | layers.append(lambda x: x) |
| 195 | layers.append(LayerSpec(torch.nn.Linear, self.hidden_dim, self.output_dim)) |
| 196 | |
| 197 | super().__init__(layers=layers, loss_fn=torch.nn.CrossEntropyLoss(), **kwargs) |
| 198 | |
| 199 | |
| 200 | class SimpleOptimizer(torch.optim.Optimizer): |