| 18 | ) |
| 19 | |
| 20 | class IFBlock(nn.Module): |
| 21 | def __init__(self, in_planes, c=64): |
| 22 | super(IFBlock, self).__init__() |
| 23 | self.conv0 = nn.Sequential( |
| 24 | conv(in_planes, c//2, 3, 1, 1), |
| 25 | conv(c//2, c, 3, 2, 1), |
| 26 | ) |
| 27 | self.convblock = nn.Sequential( |
| 28 | conv(c, c), |
| 29 | conv(c, c), |
| 30 | conv(c, c), |
| 31 | conv(c, c), |
| 32 | conv(c, c), |
| 33 | conv(c, c), |
| 34 | conv(c, c), |
| 35 | conv(c, c), |
| 36 | ) |
| 37 | self.lastconv = nn.ConvTranspose2d(c, 5, 4, 2, 1) |
| 38 | |
| 39 | def forward(self, x, flow, scale): |
| 40 | if scale != 1: |
| 41 | x = F.interpolate(x, scale_factor = 1. / scale, mode="bilinear", align_corners=False) |
| 42 | if flow != None: |
| 43 | flow = F.interpolate(flow, scale_factor = 1. / scale, mode="bilinear", align_corners=False) * 1. / scale |
| 44 | x = torch.cat((x, flow), 1) |
| 45 | x = self.conv0(x) |
| 46 | x = self.convblock(x) + x |
| 47 | tmp = self.lastconv(x) |
| 48 | tmp = F.interpolate(tmp, scale_factor = scale, mode="bilinear", align_corners=False) |
| 49 | flow = tmp[:, :4] * scale |
| 50 | mask = tmp[:, 4:5] |
| 51 | return flow, mask |
| 52 | |
| 53 | class IFNet(nn.Module): |
| 54 | def __init__(self): |