(model, local_rank)
| 37 | return rgb_map.clip(0, 1) |
| 38 | |
| 39 | def train(model, local_rank): |
| 40 | if local_rank == 0: |
| 41 | writer = SummaryWriter('train') |
| 42 | writer_val = SummaryWriter('validate') |
| 43 | else: |
| 44 | writer = None |
| 45 | writer_val = None |
| 46 | step = 0 |
| 47 | nr_eval = 0 |
| 48 | dataset = VimeoDataset('train') |
| 49 | sampler = DistributedSampler(dataset) |
| 50 | train_data = DataLoader(dataset, batch_size=args.batch_size, num_workers=8, pin_memory=True, drop_last=True, sampler=sampler) |
| 51 | args.step_per_epoch = train_data.__len__() |
| 52 | dataset_val = VimeoDataset('validation') |
| 53 | val_data = DataLoader(dataset_val, batch_size=16, pin_memory=True, num_workers=8) |
| 54 | print('training...') |
| 55 | time_stamp = time.time() |
| 56 | for epoch in range(args.epoch): |
| 57 | sampler.set_epoch(epoch) |
| 58 | for i, data in enumerate(train_data): |
| 59 | data_time_interval = time.time() - time_stamp |
| 60 | time_stamp = time.time() |
| 61 | data_gpu, timestep = data |
| 62 | data_gpu = data_gpu.to(device, non_blocking=True) / 255. |
| 63 | timestep = timestep.to(device, non_blocking=True) |
| 64 | imgs = data_gpu[:, :6] |
| 65 | gt = data_gpu[:, 6:9] |
| 66 | learning_rate = get_learning_rate(step) * args.world_size / 4 |
| 67 | pred, info = model.update(imgs, gt, learning_rate, training=True) # pass timestep if you are training RIFEm |
| 68 | train_time_interval = time.time() - time_stamp |
| 69 | time_stamp = time.time() |
| 70 | if step % 200 == 1 and local_rank == 0: |
| 71 | writer.add_scalar('learning_rate', learning_rate, step) |
| 72 | writer.add_scalar('loss/l1', info['loss_l1'], step) |
| 73 | writer.add_scalar('loss/tea', info['loss_tea'], step) |
| 74 | writer.add_scalar('loss/distill', info['loss_distill'], step) |
| 75 | if step % 1000 == 1 and local_rank == 0: |
| 76 | gt = (gt.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8') |
| 77 | mask = (torch.cat((info['mask'], info['mask_tea']), 3).permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8') |
| 78 | pred = (pred.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8') |
| 79 | merged_img = (info['merged_tea'].permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8') |
| 80 | flow0 = info['flow'].permute(0, 2, 3, 1).detach().cpu().numpy() |
| 81 | flow1 = info['flow_tea'].permute(0, 2, 3, 1).detach().cpu().numpy() |
| 82 | for i in range(5): |
| 83 | imgs = np.concatenate((merged_img[i], pred[i], gt[i]), 1)[:, :, ::-1] |
| 84 | writer.add_image(str(i) + '/img', imgs, step, dataformats='HWC') |
| 85 | writer.add_image(str(i) + '/flow', np.concatenate((flow2rgb(flow0[i]), flow2rgb(flow1[i])), 1), step, dataformats='HWC') |
| 86 | writer.add_image(str(i) + '/mask', mask[i], step, dataformats='HWC') |
| 87 | writer.flush() |
| 88 | if local_rank == 0: |
| 89 | print('epoch:{} {}/{} time:{:.2f}+{:.2f} loss_l1:{:.4e}'.format(epoch, i, args.step_per_epoch, data_time_interval, train_time_interval, info['loss_l1'])) |
| 90 | step += 1 |
| 91 | nr_eval += 1 |
| 92 | if nr_eval % 5 == 0: |
| 93 | evaluate(model, val_data, step, local_rank, writer_val) |
| 94 | model.save_model(log_path, local_rank) |
| 95 | dist.barrier() |
| 96 |
no test coverage detected