| 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 17 | |
| 18 | class Model: |
| 19 | def __init__(self, local_rank=-1, arbitrary=False): |
| 20 | if arbitrary == True: |
| 21 | self.flownet = IFNet_m() |
| 22 | else: |
| 23 | self.flownet = IFNet() |
| 24 | self.device() |
| 25 | self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-3) # use large weight decay may avoid NaN loss |
| 26 | self.epe = EPE() |
| 27 | self.lap = LapLoss() |
| 28 | self.sobel = SOBEL() |
| 29 | if local_rank != -1: |
| 30 | self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank) |
| 31 | |
| 32 | def train(self): |
| 33 | self.flownet.train() |
| 34 | |
| 35 | def eval(self): |
| 36 | self.flownet.eval() |
| 37 | |
| 38 | def device(self): |
| 39 | self.flownet.to(device) |
| 40 | |
| 41 | def load_model(self, path, rank=0): |
| 42 | def convert(param): |
| 43 | return { |
| 44 | k.replace("module.", ""): v |
| 45 | for k, v in param.items() |
| 46 | if "module." in k |
| 47 | } |
| 48 | |
| 49 | if rank <= 0: |
| 50 | self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path)))) |
| 51 | |
| 52 | def save_model(self, path, rank=0): |
| 53 | if rank == 0: |
| 54 | torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path)) |
| 55 | |
| 56 | def inference(self, img0, img1, scale=1, scale_list=None, TTA=False, timestep=0.5): |
| 57 | if scale_list is None: |
| 58 | scale_list = [4, 2, 1] |
| 59 | for i in range(3): |
| 60 | scale_list[i] = scale_list[i] * 1.0 / scale |
| 61 | imgs = torch.cat((img0, img1), 1) |
| 62 | flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(imgs, scale_list, timestep=timestep) |
| 63 | if TTA == False: |
| 64 | return merged[2] |
| 65 | else: |
| 66 | flow2, mask2, merged2, flow_teacher2, merged_teacher2, loss_distill2 = self.flownet(imgs.flip(2).flip(3), scale_list, timestep=timestep) |
| 67 | return (merged[2] + merged2[2].flip(2).flip(3)) / 2 |
| 68 | |
| 69 | def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): |
| 70 | for param_group in self.optimG.param_groups: |
| 71 | param_group['lr'] = learning_rate |
| 72 | img0 = imgs[:, :3] |
| 73 | img1 = imgs[:, 3:] |
| 74 | if training: |
| 75 | self.train() |
no outgoing calls
no test coverage detected