| 51 | |
| 52 | |
| 53 | class Encoder(nn.Module): |
| 54 | def __init__(self, in_channels=3, dim=64, n_downsample=2, shared_block=None): |
| 55 | super(Encoder, self).__init__() |
| 56 | |
| 57 | # Initial convolution block |
| 58 | layers = [ |
| 59 | nn.ReflectionPad2d(3), |
| 60 | nn.Conv2d(in_channels, dim, 7), |
| 61 | nn.InstanceNorm2d(64), |
| 62 | nn.LeakyReLU(0.2, inplace=True), |
| 63 | ] |
| 64 | |
| 65 | # Downsampling |
| 66 | for _ in range(n_downsample): |
| 67 | layers += [ |
| 68 | nn.Conv2d(dim, dim * 2, 4, stride=2, padding=1), |
| 69 | nn.InstanceNorm2d(dim * 2), |
| 70 | nn.ReLU(inplace=True), |
| 71 | ] |
| 72 | dim *= 2 |
| 73 | |
| 74 | # Residual blocks |
| 75 | for _ in range(3): |
| 76 | layers += [ResidualBlock(dim)] |
| 77 | |
| 78 | self.model_blocks = nn.Sequential(*layers) |
| 79 | self.shared_block = shared_block |
| 80 | |
| 81 | def reparameterization(self, mu): |
| 82 | Tensor = torch.cuda.FloatTensor if mu.is_cuda else torch.FloatTensor |
| 83 | z = Variable(Tensor(np.random.normal(0, 1, mu.shape))) |
| 84 | return z + mu |
| 85 | |
| 86 | def forward(self, x): |
| 87 | x = self.model_blocks(x) |
| 88 | mu = self.shared_block(x) |
| 89 | z = self.reparameterization(mu) |
| 90 | return mu, z |
| 91 | |
| 92 | |
| 93 | class Generator(nn.Module): |