| 106 | self.up4 = nn.PixelShuffle(2) |
| 107 | |
| 108 | def forward(self, img0, img1, flow, c0, c1, flow_gt): |
| 109 | warped_img0 = warp(img0, flow) |
| 110 | warped_img1 = warp(img1, -flow) |
| 111 | if flow_gt == None: |
| 112 | warped_img0_gt, warped_img1_gt = None, None |
| 113 | else: |
| 114 | warped_img0_gt = warp(img0, flow_gt[:, :2]) |
| 115 | warped_img1_gt = warp(img1, flow_gt[:, 2:4]) |
| 116 | x = self.conv0(torch.cat((warped_img0, warped_img1, flow), 1)) |
| 117 | s0 = self.down0(x) |
| 118 | s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) |
| 119 | s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) |
| 120 | s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) |
| 121 | x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) |
| 122 | x = self.up1(torch.cat((x, s2), 1)) |
| 123 | x = self.up2(torch.cat((x, s1), 1)) |
| 124 | x = self.up3(torch.cat((x, s0), 1)) |
| 125 | x = self.up4(self.conv(x)) |
| 126 | return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt |
| 127 | |
| 128 | |
| 129 | class Model: |