| 68 | |
| 69 | |
| 70 | class Decoder(nn.Module): |
| 71 | def __init__(self): |
| 72 | super(Decoder, self).__init__() |
| 73 | |
| 74 | self.model = nn.Sequential( |
| 75 | nn.Linear(opt.latent_dim, 512), |
| 76 | nn.LeakyReLU(0.2, inplace=True), |
| 77 | nn.Linear(512, 512), |
| 78 | nn.BatchNorm1d(512), |
| 79 | nn.LeakyReLU(0.2, inplace=True), |
| 80 | nn.Linear(512, int(np.prod(img_shape))), |
| 81 | nn.Tanh(), |
| 82 | ) |
| 83 | |
| 84 | def forward(self, z): |
| 85 | img_flat = self.model(z) |
| 86 | img = img_flat.view(img_flat.shape[0], *img_shape) |
| 87 | return img |
| 88 | |
| 89 | |
| 90 | class Discriminator(nn.Module): |