| 62 | self.block3 = IFBlock(10, scale=1, c=48) |
| 63 | |
| 64 | def forward(self, x, scale=1.0): |
| 65 | if scale != 1.0: |
| 66 | x = F.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=False) |
| 67 | flow0 = self.block0(x) |
| 68 | F1 = flow0 |
| 69 | F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0 |
| 70 | warped_img0 = warp(x[:, :3], F1_large[:, :2]) |
| 71 | warped_img1 = warp(x[:, 3:], F1_large[:, 2:4]) |
| 72 | flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1)) |
| 73 | F2 = (flow0 + flow1) |
| 74 | F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0 |
| 75 | warped_img0 = warp(x[:, :3], F2_large[:, :2]) |
| 76 | warped_img1 = warp(x[:, 3:], F2_large[:, 2:4]) |
| 77 | flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1)) |
| 78 | F3 = (flow0 + flow1 + flow2) |
| 79 | F3_large = F.interpolate(F3, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0 |
| 80 | warped_img0 = warp(x[:, :3], F3_large[:, :2]) |
| 81 | warped_img1 = warp(x[:, 3:], F3_large[:, 2:4]) |
| 82 | flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3_large), 1)) |
| 83 | F4 = (flow0 + flow1 + flow2 + flow3) |
| 84 | if scale != 1.0: |
| 85 | F4 = F.interpolate(F4, scale_factor=1 / scale, mode="bilinear", align_corners=False) / scale |
| 86 | return F4, [F1, F2, F3, F4] |
| 87 | |
| 88 | if __name__ == '__main__': |
| 89 | img0 = torch.zeros(3, 3, 256, 256).float().to(device) |