| 235 | return output |
| 236 | |
| 237 | class Inpaint_Edge_Net(BaseNetwork): |
| 238 | def __init__(self, residual_blocks=8, init_weights=True): |
| 239 | super(Inpaint_Edge_Net, self).__init__() |
| 240 | in_channels = 7 |
| 241 | out_channels = 1 |
| 242 | self.encoder = [] |
| 243 | # 0 |
| 244 | self.encoder_0 = nn.Sequential( |
| 245 | nn.ReflectionPad2d(3), |
| 246 | spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, padding=0), True), |
| 247 | nn.InstanceNorm2d(64, track_running_stats=False), |
| 248 | nn.ReLU(True)) |
| 249 | # 1 |
| 250 | self.encoder_1 = nn.Sequential( |
| 251 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), True), |
| 252 | nn.InstanceNorm2d(128, track_running_stats=False), |
| 253 | nn.ReLU(True)) |
| 254 | # 2 |
| 255 | self.encoder_2 = nn.Sequential( |
| 256 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), True), |
| 257 | nn.InstanceNorm2d(256, track_running_stats=False), |
| 258 | nn.ReLU(True)) |
| 259 | # 3 |
| 260 | blocks = [] |
| 261 | for _ in range(residual_blocks): |
| 262 | block = ResnetBlock(256, 2) |
| 263 | blocks.append(block) |
| 264 | |
| 265 | self.middle = nn.Sequential(*blocks) |
| 266 | # + 3 |
| 267 | self.decoder_0 = nn.Sequential( |
| 268 | spectral_norm(nn.ConvTranspose2d(in_channels=256+256, out_channels=128, kernel_size=4, stride=2, padding=1), True), |
| 269 | nn.InstanceNorm2d(128, track_running_stats=False), |
| 270 | nn.ReLU(True)) |
| 271 | # + 2 |
| 272 | self.decoder_1 = nn.Sequential( |
| 273 | spectral_norm(nn.ConvTranspose2d(in_channels=128+128, out_channels=64, kernel_size=4, stride=2, padding=1), True), |
| 274 | nn.InstanceNorm2d(64, track_running_stats=False), |
| 275 | nn.ReLU(True)) |
| 276 | # + 1 |
| 277 | self.decoder_2 = nn.Sequential( |
| 278 | nn.ReflectionPad2d(3), |
| 279 | nn.Conv2d(in_channels=64+64, out_channels=out_channels, kernel_size=7, padding=0), |
| 280 | ) |
| 281 | |
| 282 | if init_weights: |
| 283 | self.init_weights() |
| 284 | |
| 285 | def add_border(self, input, channel_pad_1=None): |
| 286 | h = input.shape[-2] |
| 287 | w = input.shape[-1] |
| 288 | require_len_unit = 16 |
| 289 | residual_h = int(np.ceil(h / float(require_len_unit)) * require_len_unit - h) # + 2*require_len_unit |
| 290 | residual_w = int(np.ceil(w / float(require_len_unit)) * require_len_unit - w) # + 2*require_len_unit |
| 291 | enlarge_input = torch.zeros((input.shape[0], input.shape[1], h + residual_h, w + residual_w)).to(input.device) |
| 292 | if channel_pad_1 is not None: |
| 293 | for channel in channel_pad_1: |
| 294 | enlarge_input[:, channel] = 1 |
no outgoing calls
no test coverage detected