MCPcopy
hub / github.com/zai-org/CogVideo / update

Method update

inference/gradio_composite_demo/rife/RIFE.py:63–95  ·  view source on GitHub ↗
(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None)

Source from the content-addressed store, hash-verified

61 return (merged[2] + merged2[2].flip(2).flip(3)) / 2
62
63 def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
64 for param_group in self.optimG.param_groups:
65 param_group["lr"] = learning_rate
66 img0 = imgs[:, :3]
67 img1 = imgs[:, 3:]
68 if training:
69 self.train()
70 else:
71 self.eval()
72 flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(
73 torch.cat((imgs, gt), 1), scale=[4, 2, 1]
74 )
75 loss_l1 = (self.lap(merged[2], gt)).mean()
76 loss_tea = (self.lap(merged_teacher, gt)).mean()
77 if training:
78 self.optimG.zero_grad()
79 loss_G = (
80 loss_l1 + loss_tea + loss_distill * 0.01
81 ) # when training RIFEm, the weight of loss_distill should be 0.005 or 0.002
82 loss_G.backward()
83 self.optimG.step()
84 else:
85 flow_teacher = flow[2]
86 return merged[2], {
87 "merged_tea": merged_teacher,
88 "mask": mask,
89 "mask_tea": mask,
90 "flow": flow[2][:, :2],
91 "flow_tea": flow_teacher,
92 "loss_l1": loss_l1,
93 "loss_tea": loss_tea,
94 "loss_distill": loss_distill,
95 }

Callers 15

generateFunction · 0.45
generateFunction · 0.45
_preprocess_dataMethod · 0.45
encode_videoFunction · 0.45
mainFunction · 0.45
mainFunction · 0.45
log_videoMethod · 0.45
url_openerFunction · 0.45
__init__Method · 0.45
forwardMethod · 0.45
forwardMethod · 0.45
downloadFunction · 0.45

Calls 3

trainMethod · 0.95
evalMethod · 0.95
backwardMethod · 0.45

Tested by

no test coverage detected