| 194 | |
| 195 | |
| 196 | class VAEApprox(torch.nn.Module): |
| 197 | def __init__(self): |
| 198 | super(VAEApprox, self).__init__() |
| 199 | self.conv1 = torch.nn.Conv2d(4, 8, (7, 7)) |
| 200 | self.conv2 = torch.nn.Conv2d(8, 16, (5, 5)) |
| 201 | self.conv3 = torch.nn.Conv2d(16, 32, (3, 3)) |
| 202 | self.conv4 = torch.nn.Conv2d(32, 64, (3, 3)) |
| 203 | self.conv5 = torch.nn.Conv2d(64, 32, (3, 3)) |
| 204 | self.conv6 = torch.nn.Conv2d(32, 16, (3, 3)) |
| 205 | self.conv7 = torch.nn.Conv2d(16, 8, (3, 3)) |
| 206 | self.conv8 = torch.nn.Conv2d(8, 3, (3, 3)) |
| 207 | self.current_type = None |
| 208 | |
| 209 | def forward(self, x): |
| 210 | extra = 11 |
| 211 | x = torch.nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2)) |
| 212 | x = torch.nn.functional.pad(x, (extra, extra, extra, extra)) |
| 213 | for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8]: |
| 214 | x = layer(x) |
| 215 | x = torch.nn.functional.leaky_relu(x, 0.1) |
| 216 | return x |
| 217 | |
| 218 | |
| 219 | VAE_approx_models = {} |