| 188 | |
| 189 | |
| 190 | class VGGDecoder(nn.Module): |
| 191 | def __init__(self, level): |
| 192 | super(VGGDecoder, self).__init__() |
| 193 | self.level = level |
| 194 | |
| 195 | if level > 3: |
| 196 | self.pad4_1 = nn.ReflectionPad2d((1, 1, 1, 1)) |
| 197 | self.conv4_1 = nn.Conv2d(512, 256, 3, 1, 0) |
| 198 | self.relu4_1 = nn.ReLU(inplace=True) |
| 199 | # 28 x 28 |
| 200 | |
| 201 | self.unpool3 = nn.MaxUnpool2d(kernel_size=2, stride=2) |
| 202 | # 56 x 56 |
| 203 | |
| 204 | self.pad3_4 = nn.ReflectionPad2d((1, 1, 1, 1)) |
| 205 | self.conv3_4 = nn.Conv2d(256, 256, 3, 1, 0) |
| 206 | self.relu3_4 = nn.ReLU(inplace=True) |
| 207 | # 56 x 56 |
| 208 | |
| 209 | self.pad3_3 = nn.ReflectionPad2d((1, 1, 1, 1)) |
| 210 | self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 0) |
| 211 | self.relu3_3 = nn.ReLU(inplace=True) |
| 212 | # 56 x 56 |
| 213 | |
| 214 | self.pad3_2 = nn.ReflectionPad2d((1, 1, 1, 1)) |
| 215 | self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 0) |
| 216 | self.relu3_2 = nn.ReLU(inplace=True) |
| 217 | # 56 x 56 |
| 218 | |
| 219 | if level > 2: |
| 220 | self.pad3_1 = nn.ReflectionPad2d((1, 1, 1, 1)) |
| 221 | self.conv3_1 = nn.Conv2d(256, 128, 3, 1, 0) |
| 222 | self.relu3_1 = nn.ReLU(inplace=True) |
| 223 | # 56 x 56 |
| 224 | |
| 225 | self.unpool2 = nn.MaxUnpool2d(kernel_size=2, stride=2) |
| 226 | # 112 x 112 |
| 227 | |
| 228 | self.pad2_2 = nn.ReflectionPad2d((1, 1, 1, 1)) |
| 229 | self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 0) |
| 230 | self.relu2_2 = nn.ReLU(inplace=True) |
| 231 | # 112 x 112 |
| 232 | |
| 233 | if level > 1: |
| 234 | self.pad2_1 = nn.ReflectionPad2d((1, 1, 1, 1)) |
| 235 | self.conv2_1 = nn.Conv2d(128, 64, 3, 1, 0) |
| 236 | self.relu2_1 = nn.ReLU(inplace=True) |
| 237 | # 112 x 112 |
| 238 | |
| 239 | self.unpool1 = nn.MaxUnpool2d(kernel_size=2, stride=2) |
| 240 | # 224 x 224 |
| 241 | |
| 242 | self.pad1_2 = nn.ReflectionPad2d((1, 1, 1, 1)) |
| 243 | self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 0) |
| 244 | self.relu1_2 = nn.ReLU(inplace=True) |
| 245 | # 224 x 224 |
| 246 | |
| 247 | if level > 0: |
no outgoing calls
no test coverage detected