MCPcopy
hub / github.com/hzwer/ECCV2022-RIFE / train

Function train

train.py:39–95  ·  view source on GitHub ↗
(model, local_rank)

Source from the content-addressed store, hash-verified

37 return rgb_map.clip(0, 1)
38
39def train(model, local_rank):
40 if local_rank == 0:
41 writer = SummaryWriter('train')
42 writer_val = SummaryWriter('validate')
43 else:
44 writer = None
45 writer_val = None
46 step = 0
47 nr_eval = 0
48 dataset = VimeoDataset('train')
49 sampler = DistributedSampler(dataset)
50 train_data = DataLoader(dataset, batch_size=args.batch_size, num_workers=8, pin_memory=True, drop_last=True, sampler=sampler)
51 args.step_per_epoch = train_data.__len__()
52 dataset_val = VimeoDataset('validation')
53 val_data = DataLoader(dataset_val, batch_size=16, pin_memory=True, num_workers=8)
54 print('training...')
55 time_stamp = time.time()
56 for epoch in range(args.epoch):
57 sampler.set_epoch(epoch)
58 for i, data in enumerate(train_data):
59 data_time_interval = time.time() - time_stamp
60 time_stamp = time.time()
61 data_gpu, timestep = data
62 data_gpu = data_gpu.to(device, non_blocking=True) / 255.
63 timestep = timestep.to(device, non_blocking=True)
64 imgs = data_gpu[:, :6]
65 gt = data_gpu[:, 6:9]
66 learning_rate = get_learning_rate(step) * args.world_size / 4
67 pred, info = model.update(imgs, gt, learning_rate, training=True) # pass timestep if you are training RIFEm
68 train_time_interval = time.time() - time_stamp
69 time_stamp = time.time()
70 if step % 200 == 1 and local_rank == 0:
71 writer.add_scalar('learning_rate', learning_rate, step)
72 writer.add_scalar('loss/l1', info['loss_l1'], step)
73 writer.add_scalar('loss/tea', info['loss_tea'], step)
74 writer.add_scalar('loss/distill', info['loss_distill'], step)
75 if step % 1000 == 1 and local_rank == 0:
76 gt = (gt.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
77 mask = (torch.cat((info['mask'], info['mask_tea']), 3).permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
78 pred = (pred.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
79 merged_img = (info['merged_tea'].permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
80 flow0 = info['flow'].permute(0, 2, 3, 1).detach().cpu().numpy()
81 flow1 = info['flow_tea'].permute(0, 2, 3, 1).detach().cpu().numpy()
82 for i in range(5):
83 imgs = np.concatenate((merged_img[i], pred[i], gt[i]), 1)[:, :, ::-1]
84 writer.add_image(str(i) + '/img', imgs, step, dataformats='HWC')
85 writer.add_image(str(i) + '/flow', np.concatenate((flow2rgb(flow0[i]), flow2rgb(flow1[i])), 1), step, dataformats='HWC')
86 writer.add_image(str(i) + '/mask', mask[i], step, dataformats='HWC')
87 writer.flush()
88 if local_rank == 0:
89 print('epoch:{} {}/{} time:{:.2f}+{:.2f} loss_l1:{:.4e}'.format(epoch, i, args.step_per_epoch, data_time_interval, train_time_interval, info['loss_l1']))
90 step += 1
91 nr_eval += 1
92 if nr_eval % 5 == 0:
93 evaluate(model, val_data, step, local_rank, writer_val)
94 model.save_model(log_path, local_rank)
95 dist.barrier()
96

Callers 1

train.pyFile · 0.85

Calls 7

VimeoDatasetClass · 0.85
get_learning_rateFunction · 0.85
flow2rgbFunction · 0.85
evaluateFunction · 0.85
__len__Method · 0.80
updateMethod · 0.45
save_modelMethod · 0.45

Tested by

no test coverage detected