MCPcopy
hub / github.com/hzwer/ECCV2022-RIFE / update

Method update

model/RIFE.py:69–97  ·  view source on GitHub ↗
(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None)

Source from the content-addressed store, hash-verified

67 return (merged[2] + merged2[2].flip(2).flip(3)) / 2
68
69 def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
70 for param_group in self.optimG.param_groups:
71 param_group['lr'] = learning_rate
72 img0 = imgs[:, :3]
73 img1 = imgs[:, 3:]
74 if training:
75 self.train()
76 else:
77 self.eval()
78 flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(torch.cat((imgs, gt), 1), scale=[4, 2, 1])
79 loss_l1 = (self.lap(merged[2], gt)).mean()
80 loss_tea = (self.lap(merged_teacher, gt)).mean()
81 if training:
82 self.optimG.zero_grad()
83 loss_G = loss_l1 + loss_tea + loss_distill * 0.01 # when training RIFEm, the weight of loss_distill should be 0.005 or 0.002
84 loss_G.backward()
85 self.optimG.step()
86 else:
87 flow_teacher = flow[2]
88 return merged[2], {
89 'merged_tea': merged_teacher,
90 'mask': mask,
91 'mask_tea': mask,
92 'flow': flow[2][:, :2],
93 'flow_tea': flow_teacher,
94 'loss_l1': loss_l1,
95 'loss_tea': loss_tea,
96 'loss_distill': loss_distill,
97 }

Callers 3

trainFunction · 0.45
evaluateFunction · 0.45
inference_video.pyFile · 0.45

Calls 2

trainMethod · 0.95
evalMethod · 0.95

Tested by

no test coverage detected