MCPcopy Index your code
hub / github.com/hzwer/ECCV2022-RIFE / Unet

Class Unet

model/refine.py:59–82  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

57 return [f1, f2, f3, f4]
58
59class Unet(nn.Module):
60 def __init__(self):
61 super(Unet, self).__init__()
62 self.down0 = Conv2(17, 2*c)
63 self.down1 = Conv2(4*c, 4*c)
64 self.down2 = Conv2(8*c, 8*c)
65 self.down3 = Conv2(16*c, 16*c)
66 self.up0 = deconv(32*c, 8*c)
67 self.up1 = deconv(16*c, 4*c)
68 self.up2 = deconv(8*c, 2*c)
69 self.up3 = deconv(4*c, c)
70 self.conv = nn.Conv2d(c, 3, 3, 1, 1)
71
72 def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1):
73 s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1))
74 s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
75 s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
76 s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
77 x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
78 x = self.up1(torch.cat((x, s2), 1))
79 x = self.up2(torch.cat((x, s1), 1))
80 x = self.up3(torch.cat((x, s0), 1))
81 x = self.conv(x)
82 return torch.sigmoid(x)

Callers 3

__init__Method · 0.70
__init__Method · 0.70
__init__Method · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected