MCPcopy Index your code
hub / github.com/hzwer/ECCV2022-RIFE / IFNet

Class IFNet

model/IFNet.py:53–108  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

51 return flow, mask
52
53class IFNet(nn.Module):
54 def __init__(self):
55 super(IFNet, self).__init__()
56 self.block0 = IFBlock(6, c=240)
57 self.block1 = IFBlock(13+4, c=150)
58 self.block2 = IFBlock(13+4, c=90)
59 self.block_tea = IFBlock(16+4, c=90)
60 self.contextnet = Contextnet()
61 self.unet = Unet()
62
63 def forward(self, x, scale=[4,2,1], timestep=0.5):
64 img0 = x[:, :3]
65 img1 = x[:, 3:6]
66 gt = x[:, 6:] # In inference time, gt is None
67 flow_list = []
68 merged = []
69 mask_list = []
70 warped_img0 = img0
71 warped_img1 = img1
72 flow = None
73 loss_distill = 0
74 stu = [self.block0, self.block1, self.block2]
75 for i in range(3):
76 if flow != None:
77 flow_d, mask_d = stu[i](torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), flow, scale=scale[i])
78 flow = flow + flow_d
79 mask = mask + mask_d
80 else:
81 flow, mask = stu[i](torch.cat((img0, img1), 1), None, scale=scale[i])
82 mask_list.append(torch.sigmoid(mask))
83 flow_list.append(flow)
84 warped_img0 = warp(img0, flow[:, :2])
85 warped_img1 = warp(img1, flow[:, 2:4])
86 merged_student = (warped_img0, warped_img1)
87 merged.append(merged_student)
88 if gt.shape[1] == 3:
89 flow_d, mask_d = self.block_tea(torch.cat((img0, img1, warped_img0, warped_img1, mask, gt), 1), flow, scale=1)
90 flow_teacher = flow + flow_d
91 warped_img0_teacher = warp(img0, flow_teacher[:, :2])
92 warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
93 mask_teacher = torch.sigmoid(mask + mask_d)
94 merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
95 else:
96 flow_teacher = None
97 merged_teacher = None
98 for i in range(3):
99 merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
100 if gt.shape[1] == 3:
101 loss_mask = ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01).float().detach()
102 loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
103 c0 = self.contextnet(img0, flow[:, :2])
104 c1 = self.contextnet(img1, flow[:, 2:4])
105 tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
106 res = tmp[:, :3] * 2 - 1
107 merged[2] = torch.clamp(merged[2] + res, 0, 1)
108 return flow_list, mask_list[2], merged, flow_teacher, merged_teacher, loss_distill

Callers 1

__init__Method · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected