| 54 | self.conv4 = Conv2(4 * c, 8 * c) |
| 55 | |
| 56 | def forward(self, x, flow): |
| 57 | x = self.conv1(x) |
| 58 | flow = ( |
| 59 | F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) |
| 60 | * 0.5 |
| 61 | ) |
| 62 | f1 = warp(x, flow) |
| 63 | x = self.conv2(x) |
| 64 | flow = ( |
| 65 | F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) |
| 66 | * 0.5 |
| 67 | ) |
| 68 | f2 = warp(x, flow) |
| 69 | x = self.conv3(x) |
| 70 | flow = ( |
| 71 | F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) |
| 72 | * 0.5 |
| 73 | ) |
| 74 | f3 = warp(x, flow) |
| 75 | x = self.conv4(x) |
| 76 | flow = ( |
| 77 | F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) |
| 78 | * 0.5 |
| 79 | ) |
| 80 | f4 = warp(x, flow) |
| 81 | return [f1, f2, f3, f4] |
| 82 | |
| 83 | |
| 84 | class Unet(nn.Module): |