| 57 | return [f1, f2, f3, f4] |
| 58 | |
| 59 | class Unet(nn.Module): |
| 60 | def __init__(self): |
| 61 | super(Unet, self).__init__() |
| 62 | self.down0 = Conv2(17, 2*c) |
| 63 | self.down1 = Conv2(4*c, 4*c) |
| 64 | self.down2 = Conv2(8*c, 8*c) |
| 65 | self.down3 = Conv2(16*c, 16*c) |
| 66 | self.up0 = deconv(32*c, 8*c) |
| 67 | self.up1 = deconv(16*c, 4*c) |
| 68 | self.up2 = deconv(8*c, 2*c) |
| 69 | self.up3 = deconv(4*c, c) |
| 70 | self.conv = nn.Conv2d(c, 3, 3, 1, 1) |
| 71 | |
| 72 | def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1): |
| 73 | s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1)) |
| 74 | s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) |
| 75 | s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) |
| 76 | s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) |
| 77 | x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) |
| 78 | x = self.up1(torch.cat((x, s2), 1)) |
| 79 | x = self.up2(torch.cat((x, s1), 1)) |
| 80 | x = self.up3(torch.cat((x, s0), 1)) |
| 81 | x = self.conv(x) |
| 82 | return torch.sigmoid(x) |