| 77 | |
| 78 | |
| 79 | class FusionNet(nn.Module): |
| 80 | def __init__(self): |
| 81 | super(FusionNet, self).__init__() |
| 82 | self.conv0 = Conv2(10, c) |
| 83 | self.down0 = Conv2(c, 2*c) |
| 84 | self.down1 = Conv2(4*c, 4*c) |
| 85 | self.down2 = Conv2(8*c, 8*c) |
| 86 | self.down3 = Conv2(16*c, 16*c) |
| 87 | self.up0 = deconv(32*c, 8*c) |
| 88 | self.up1 = deconv(16*c, 4*c) |
| 89 | self.up2 = deconv(8*c, 2*c) |
| 90 | self.up3 = deconv(4*c, c) |
| 91 | self.conv = nn.ConvTranspose2d(c, 4, 4, 2, 1) |
| 92 | |
| 93 | def forward(self, img0, img1, flow, c0, c1, flow_gt): |
| 94 | warped_img0 = warp(img0, flow[:, :2]) |
| 95 | warped_img1 = warp(img1, flow[:, 2:4]) |
| 96 | if flow_gt == None: |
| 97 | warped_img0_gt, warped_img1_gt = None, None |
| 98 | else: |
| 99 | warped_img0_gt = warp(img0, flow_gt[:, :2]) |
| 100 | warped_img1_gt = warp(img1, flow_gt[:, 2:4]) |
| 101 | x = self.conv0(torch.cat((warped_img0, warped_img1, flow), 1)) |
| 102 | s0 = self.down0(x) |
| 103 | s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) |
| 104 | s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) |
| 105 | s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) |
| 106 | x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) |
| 107 | x = self.up1(torch.cat((x, s2), 1)) |
| 108 | x = self.up2(torch.cat((x, s1), 1)) |
| 109 | x = self.up3(torch.cat((x, s0), 1)) |
| 110 | x = self.conv(x) |
| 111 | return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt |
| 112 | |
| 113 | |
| 114 | class Model: |