MCPcopy
hub / github.com/hzwer/ECCV2022-RIFE / Model

Class Model

model/RIFE.py:18–97  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

16device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
18class Model:
19 def __init__(self, local_rank=-1, arbitrary=False):
20 if arbitrary == True:
21 self.flownet = IFNet_m()
22 else:
23 self.flownet = IFNet()
24 self.device()
25 self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-3) # use large weight decay may avoid NaN loss
26 self.epe = EPE()
27 self.lap = LapLoss()
28 self.sobel = SOBEL()
29 if local_rank != -1:
30 self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)
31
32 def train(self):
33 self.flownet.train()
34
35 def eval(self):
36 self.flownet.eval()
37
38 def device(self):
39 self.flownet.to(device)
40
41 def load_model(self, path, rank=0):
42 def convert(param):
43 return {
44 k.replace("module.", ""): v
45 for k, v in param.items()
46 if "module." in k
47 }
48
49 if rank <= 0:
50 self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path))))
51
52 def save_model(self, path, rank=0):
53 if rank == 0:
54 torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))
55
56 def inference(self, img0, img1, scale=1, scale_list=None, TTA=False, timestep=0.5):
57 if scale_list is None:
58 scale_list = [4, 2, 1]
59 for i in range(3):
60 scale_list[i] = scale_list[i] * 1.0 / scale
61 imgs = torch.cat((img0, img1), 1)
62 flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(imgs, scale_list, timestep=timestep)
63 if TTA == False:
64 return merged[2]
65 else:
66 flow2, mask2, merged2, flow_teacher2, merged_teacher2, loss_distill2 = self.flownet(imgs.flip(2).flip(3), scale_list, timestep=timestep)
67 return (merged[2] + merged2[2].flip(2).flip(3)) / 2
68
69 def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
70 for param_group in self.optimG.param_groups:
71 param_group['lr'] = learning_rate
72 img0 = imgs[:, :3]
73 img1 = imgs[:, 3:]
74 if training:
75 self.train()

Callers 10

train.pyFile · 0.90
inference_img.pyFile · 0.90
inference_video.pyFile · 0.90
HD_multi_4X.pyFile · 0.90
testtime.pyFile · 0.90
HD.pyFile · 0.90
ATD12K.pyFile · 0.90
UCF101.pyFile · 0.90
Vimeo90K.pyFile · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected