| 43 | self.conv4 = Conv2(4*c, 8*c) |
| 44 | |
| 45 | def forward(self, x, flow): |
| 46 | x = self.conv1(x) |
| 47 | # flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 |
| 48 | f1 = warp(x, flow) |
| 49 | x = self.conv2(x) |
| 50 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 |
| 51 | f2 = warp(x, flow) |
| 52 | x = self.conv3(x) |
| 53 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 |
| 54 | f3 = warp(x, flow) |
| 55 | x = self.conv4(x) |
| 56 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 |
| 57 | f4 = warp(x, flow) |
| 58 | return [f1, f2, f3, f4] |
| 59 | |
| 60 | class Unet(nn.Module): |
| 61 | def __init__(self): |