MCPcopy
hub / github.com/zai-org/CogVideo / forward

Method forward

inference/gradio_composite_demo/rife/IFNet.py:70–123  ·  view source on GitHub ↗
(self, x, scale=[4, 2, 1], timestep=0.5)

Source from the content-addressed store, hash-verified

68 self.unet = Unet()
69
70 def forward(self, x, scale=[4, 2, 1], timestep=0.5):
71 img0 = x[:, :3]
72 img1 = x[:, 3:6]
73 gt = x[:, 6:] # In inference time, gt is None
74 flow_list = []
75 merged = []
76 mask_list = []
77 warped_img0 = img0
78 warped_img1 = img1
79 flow = None
80 loss_distill = 0
81 stu = [self.block0, self.block1, self.block2]
82 for i in range(3):
83 if flow != None:
84 flow_d, mask_d = stu[i](
85 torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), flow, scale=scale[i]
86 )
87 flow = flow + flow_d
88 mask = mask + mask_d
89 else:
90 flow, mask = stu[i](torch.cat((img0, img1), 1), None, scale=scale[i])
91 mask_list.append(torch.sigmoid(mask))
92 flow_list.append(flow)
93 warped_img0 = warp(img0, flow[:, :2])
94 warped_img1 = warp(img1, flow[:, 2:4])
95 merged_student = (warped_img0, warped_img1)
96 merged.append(merged_student)
97 if gt.shape[1] == 3:
98 flow_d, mask_d = self.block_tea(
99 torch.cat((img0, img1, warped_img0, warped_img1, mask, gt), 1), flow, scale=1
100 )
101 flow_teacher = flow + flow_d
102 warped_img0_teacher = warp(img0, flow_teacher[:, :2])
103 warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
104 mask_teacher = torch.sigmoid(mask + mask_d)
105 merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
106 else:
107 flow_teacher = None
108 merged_teacher = None
109 for i in range(3):
110 merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
111 if gt.shape[1] == 3:
112 loss_mask = (
113 ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01)
114 .float()
115 .detach()
116 )
117 loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
118 c0 = self.contextnet(img0, flow[:, :2])
119 c1 = self.contextnet(img1, flow[:, 2:4])
120 tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
121 res = tmp[:, :3] * 2 - 1
122 merged[2] = torch.clamp(merged[2] + res, 0, 1)
123 return flow_list, mask_list[2], merged, flow_teacher, merged_teacher, loss_distill

Callers

nothing calls this directly

Calls 1

warpFunction · 0.85

Tested by

no test coverage detected