| 42 | self.conv4 = Conv2(4*c, 8*c) |
| 43 | |
| 44 | def forward(self, x, flow): |
| 45 | x = self.conv1(x) |
| 46 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 |
| 47 | f1 = warp(x, flow) |
| 48 | x = self.conv2(x) |
| 49 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 |
| 50 | f2 = warp(x, flow) |
| 51 | x = self.conv3(x) |
| 52 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 |
| 53 | f3 = warp(x, flow) |
| 54 | x = self.conv4(x) |
| 55 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 |
| 56 | f4 = warp(x, flow) |
| 57 | return [f1, f2, f3, f4] |
| 58 | |
| 59 | class Unet(nn.Module): |
| 60 | def __init__(self): |