(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None)
| 67 | return (merged[2] + merged2[2].flip(2).flip(3)) / 2 |
| 68 | |
| 69 | def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): |
| 70 | for param_group in self.optimG.param_groups: |
| 71 | param_group['lr'] = learning_rate |
| 72 | img0 = imgs[:, :3] |
| 73 | img1 = imgs[:, 3:] |
| 74 | if training: |
| 75 | self.train() |
| 76 | else: |
| 77 | self.eval() |
| 78 | flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(torch.cat((imgs, gt), 1), scale=[4, 2, 1]) |
| 79 | loss_l1 = (self.lap(merged[2], gt)).mean() |
| 80 | loss_tea = (self.lap(merged_teacher, gt)).mean() |
| 81 | if training: |
| 82 | self.optimG.zero_grad() |
| 83 | loss_G = loss_l1 + loss_tea + loss_distill * 0.01 # when training RIFEm, the weight of loss_distill should be 0.005 or 0.002 |
| 84 | loss_G.backward() |
| 85 | self.optimG.step() |
| 86 | else: |
| 87 | flow_teacher = flow[2] |
| 88 | return merged[2], { |
| 89 | 'merged_tea': merged_teacher, |
| 90 | 'mask': mask, |
| 91 | 'mask_tea': mask, |
| 92 | 'flow': flow[2][:, :2], |
| 93 | 'flow_tea': flow_teacher, |
| 94 | 'loss_l1': loss_l1, |
| 95 | 'loss_tea': loss_tea, |
| 96 | 'loss_distill': loss_distill, |
| 97 | } |
no test coverage detected