| 92 | self.block3 = IFBlock(8, scale=1, c=48) |
| 93 | |
| 94 | def forward(self, x, scale=1.0): |
| 95 | x = F.interpolate(x, scale_factor=0.5 * scale, mode="bilinear", |
| 96 | align_corners=False) |
| 97 | flow0 = self.block0(x) |
| 98 | F1 = flow0 |
| 99 | warped_img0 = warp(x[:, :3], F1) |
| 100 | warped_img1 = warp(x[:, 3:], -F1) |
| 101 | flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1), 1)) |
| 102 | F2 = (flow0 + flow1) |
| 103 | warped_img0 = warp(x[:, :3], F2) |
| 104 | warped_img1 = warp(x[:, 3:], -F2) |
| 105 | flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2), 1)) |
| 106 | F3 = (flow0 + flow1 + flow2) |
| 107 | warped_img0 = warp(x[:, :3], F3) |
| 108 | warped_img1 = warp(x[:, 3:], -F3) |
| 109 | flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3), 1)) |
| 110 | F4 = (flow0 + flow1 + flow2 + flow3) |
| 111 | F4 = F.interpolate(F4, scale_factor=1 / scale, mode="bilinear", |
| 112 | align_corners=False) / scale |
| 113 | return F4, [F1, F2, F3, F4] |
| 114 | |
| 115 | if __name__ == '__main__': |
| 116 | img0 = torch.zeros(3, 3, 256, 256).float().to(device) |