MCPcopy Index your code
hub / github.com/hzwer/ECCV2022-RIFE / FusionNet

Class FusionNet

model/oldmodel/RIFE_HDv2.py:79–111  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

77
78
79class FusionNet(nn.Module):
80 def __init__(self):
81 super(FusionNet, self).__init__()
82 self.conv0 = Conv2(10, c)
83 self.down0 = Conv2(c, 2*c)
84 self.down1 = Conv2(4*c, 4*c)
85 self.down2 = Conv2(8*c, 8*c)
86 self.down3 = Conv2(16*c, 16*c)
87 self.up0 = deconv(32*c, 8*c)
88 self.up1 = deconv(16*c, 4*c)
89 self.up2 = deconv(8*c, 2*c)
90 self.up3 = deconv(4*c, c)
91 self.conv = nn.ConvTranspose2d(c, 4, 4, 2, 1)
92
93 def forward(self, img0, img1, flow, c0, c1, flow_gt):
94 warped_img0 = warp(img0, flow[:, :2])
95 warped_img1 = warp(img1, flow[:, 2:4])
96 if flow_gt == None:
97 warped_img0_gt, warped_img1_gt = None, None
98 else:
99 warped_img0_gt = warp(img0, flow_gt[:, :2])
100 warped_img1_gt = warp(img1, flow_gt[:, 2:4])
101 x = self.conv0(torch.cat((warped_img0, warped_img1, flow), 1))
102 s0 = self.down0(x)
103 s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
104 s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
105 s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
106 x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
107 x = self.up1(torch.cat((x, s2), 1))
108 x = self.up2(torch.cat((x, s1), 1))
109 x = self.up3(torch.cat((x, s0), 1))
110 x = self.conv(x)
111 return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
112
113
114class Model:

Callers 1

__init__Method · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected