| 44 | |
| 45 | |
| 46 | class VAE(nn.Module): |
| 47 | def __init__(self): |
| 48 | super(VAE, self).__init__() |
| 49 | |
| 50 | self.fc1 = nn.Linear(784, 400) |
| 51 | self.fc21 = nn.Linear(400, 20) |
| 52 | self.fc22 = nn.Linear(400, 20) |
| 53 | self.fc3 = nn.Linear(20, 400) |
| 54 | self.fc4 = nn.Linear(400, 784) |
| 55 | |
| 56 | def encode(self, x): |
| 57 | h1 = F.relu(self.fc1(x)) |
| 58 | return self.fc21(h1), self.fc22(h1) |
| 59 | |
| 60 | def reparameterize(self, mu, logvar): |
| 61 | std = torch.exp(0.5*logvar) |
| 62 | eps = torch.randn_like(std) |
| 63 | return mu + eps*std |
| 64 | |
| 65 | def decode(self, z): |
| 66 | h3 = F.relu(self.fc3(z)) |
| 67 | return torch.sigmoid(self.fc4(h3)) |
| 68 | |
| 69 | def forward(self, x): |
| 70 | mu, logvar = self.encode(x.view(-1, 784)) |
| 71 | z = self.reparameterize(mu, logvar) |
| 72 | return self.decode(z), mu, logvar |
| 73 | |
| 74 | |
| 75 | model = VAE().to(device) |