(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None)
| 61 | return (merged[2] + merged2[2].flip(2).flip(3)) / 2 |
| 62 | |
| 63 | def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): |
| 64 | for param_group in self.optimG.param_groups: |
| 65 | param_group["lr"] = learning_rate |
| 66 | img0 = imgs[:, :3] |
| 67 | img1 = imgs[:, 3:] |
| 68 | if training: |
| 69 | self.train() |
| 70 | else: |
| 71 | self.eval() |
| 72 | flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet( |
| 73 | torch.cat((imgs, gt), 1), scale=[4, 2, 1] |
| 74 | ) |
| 75 | loss_l1 = (self.lap(merged[2], gt)).mean() |
| 76 | loss_tea = (self.lap(merged_teacher, gt)).mean() |
| 77 | if training: |
| 78 | self.optimG.zero_grad() |
| 79 | loss_G = ( |
| 80 | loss_l1 + loss_tea + loss_distill * 0.01 |
| 81 | ) # when training RIFEm, the weight of loss_distill should be 0.005 or 0.002 |
| 82 | loss_G.backward() |
| 83 | self.optimG.step() |
| 84 | else: |
| 85 | flow_teacher = flow[2] |
| 86 | return merged[2], { |
| 87 | "merged_tea": merged_teacher, |
| 88 | "mask": mask, |
| 89 | "mask_tea": mask, |
| 90 | "flow": flow[2][:, :2], |
| 91 | "flow_tea": flow_teacher, |
| 92 | "loss_l1": loss_l1, |
| 93 | "loss_tea": loss_tea, |
| 94 | "loss_distill": loss_distill, |
| 95 | } |
no test coverage detected