| 48 | c = 32 |
| 49 | |
| 50 | class ContextNet(nn.Module): |
| 51 | def __init__(self): |
| 52 | super(ContextNet, self).__init__() |
| 53 | self.conv0 = Conv2(3, c) |
| 54 | self.conv1 = Conv2(c, c) |
| 55 | self.conv2 = Conv2(c, 2*c) |
| 56 | self.conv3 = Conv2(2*c, 4*c) |
| 57 | self.conv4 = Conv2(4*c, 8*c) |
| 58 | |
| 59 | def forward(self, x, flow): |
| 60 | x = self.conv0(x) |
| 61 | x = self.conv1(x) |
| 62 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 |
| 63 | f1 = warp(x, flow) |
| 64 | x = self.conv2(x) |
| 65 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", |
| 66 | align_corners=False) * 0.5 |
| 67 | f2 = warp(x, flow) |
| 68 | x = self.conv3(x) |
| 69 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", |
| 70 | align_corners=False) * 0.5 |
| 71 | f3 = warp(x, flow) |
| 72 | x = self.conv4(x) |
| 73 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", |
| 74 | align_corners=False) * 0.5 |
| 75 | f4 = warp(x, flow) |
| 76 | return [f1, f2, f3, f4] |
| 77 | |
| 78 | |
| 79 | class FusionNet(nn.Module): |