MCPcopy Index your code
hub / github.com/pytorch/examples / VAE

Class VAE

vae/main.py:46–72  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

44
45
46class VAE(nn.Module):
47 def __init__(self):
48 super(VAE, self).__init__()
49
50 self.fc1 = nn.Linear(784, 400)
51 self.fc21 = nn.Linear(400, 20)
52 self.fc22 = nn.Linear(400, 20)
53 self.fc3 = nn.Linear(20, 400)
54 self.fc4 = nn.Linear(400, 784)
55
56 def encode(self, x):
57 h1 = F.relu(self.fc1(x))
58 return self.fc21(h1), self.fc22(h1)
59
60 def reparameterize(self, mu, logvar):
61 std = torch.exp(0.5*logvar)
62 eps = torch.randn_like(std)
63 return mu + eps*std
64
65 def decode(self, z):
66 h3 = F.relu(self.fc3(z))
67 return torch.sigmoid(self.fc4(h3))
68
69 def forward(self, x):
70 mu, logvar = self.encode(x.view(-1, 784))
71 z = self.reparameterize(mu, logvar)
72 return self.decode(z), mu, logvar
73
74
75model = VAE().to(device)

Callers 1

main.pyFile · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected