| 71 | self.conv4 = ResBlock(4*c, 8*c) |
| 72 | |
| 73 | def forward(self, x, flow): |
| 74 | x = self.conv0(x) |
| 75 | x = self.conv1(x) |
| 76 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 |
| 77 | f1 = warp(x, flow) |
| 78 | x = self.conv2(x) |
| 79 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", |
| 80 | align_corners=False) * 0.5 |
| 81 | f2 = warp(x, flow) |
| 82 | x = self.conv3(x) |
| 83 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", |
| 84 | align_corners=False) * 0.5 |
| 85 | f3 = warp(x, flow) |
| 86 | x = self.conv4(x) |
| 87 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", |
| 88 | align_corners=False) * 0.5 |
| 89 | f4 = warp(x, flow) |
| 90 | return [f1, f2, f3, f4] |
| 91 | |
| 92 | |
| 93 | class FusionNet(nn.Module): |