MCPcopy
hub / github.com/zai-org/CogVideo / forward

Method forward

inference/gradio_composite_demo/rife/IFNet_HDv3.py:100–138  ·  view source on GitHub ↗
(self, x, scale_list=[4, 2, 1], training=False)

Source from the content-addressed store, hash-verified

98 # self.unet = Unet()
99
100 def forward(self, x, scale_list=[4, 2, 1], training=False):
101 if training == False:
102 channel = x.shape[1] // 2
103 img0 = x[:, :channel]
104 img1 = x[:, channel:]
105 flow_list = []
106 merged = []
107 mask_list = []
108 warped_img0 = img0
109 warped_img1 = img1
110 flow = (x[:, :4]).detach() * 0
111 mask = (x[:, :1]).detach() * 0
112 loss_cons = 0
113 block = [self.block0, self.block1, self.block2]
114 for i in range(3):
115 f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
116 f1, m1 = block[i](
117 torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1),
118 torch.cat((flow[:, 2:4], flow[:, :2]), 1),
119 scale=scale_list[i],
120 )
121 flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
122 mask = mask + (m0 + (-m1)) / 2
123 mask_list.append(mask)
124 flow_list.append(flow)
125 warped_img0 = warp(img0, flow[:, :2])
126 warped_img1 = warp(img1, flow[:, 2:4])
127 merged.append((warped_img0, warped_img1))
128 """
129 c0 = self.contextnet(img0, flow[:, :2])
130 c1 = self.contextnet(img1, flow[:, 2:4])
131 tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
132 res = tmp[:, 1:4] * 2 - 1
133 """
134 for i in range(3):
135 mask_list[i] = torch.sigmoid(mask_list[i])
136 merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
137 # merged[i] = torch.clamp(merged[i] + res, 0, 1)
138 return flow_list, mask_list[2], merged

Callers

nothing calls this directly

Calls 1

warpFunction · 0.90

Tested by

no test coverage detected