| 95 | dist.barrier() |
| 96 | |
| 97 | def evaluate(model, val_data, nr_eval, local_rank, writer_val): |
| 98 | loss_l1_list = [] |
| 99 | loss_distill_list = [] |
| 100 | loss_tea_list = [] |
| 101 | psnr_list = [] |
| 102 | psnr_list_teacher = [] |
| 103 | time_stamp = time.time() |
| 104 | for i, data in enumerate(val_data): |
| 105 | data_gpu, timestep = data |
| 106 | data_gpu = data_gpu.to(device, non_blocking=True) / 255. |
| 107 | imgs = data_gpu[:, :6] |
| 108 | gt = data_gpu[:, 6:9] |
| 109 | with torch.no_grad(): |
| 110 | pred, info = model.update(imgs, gt, training=False) |
| 111 | merged_img = info['merged_tea'] |
| 112 | loss_l1_list.append(info['loss_l1'].cpu().numpy()) |
| 113 | loss_tea_list.append(info['loss_tea'].cpu().numpy()) |
| 114 | loss_distill_list.append(info['loss_distill'].cpu().numpy()) |
| 115 | for j in range(gt.shape[0]): |
| 116 | psnr = -10 * math.log10(torch.mean((gt[j] - pred[j]) * (gt[j] - pred[j])).cpu().data) |
| 117 | psnr_list.append(psnr) |
| 118 | psnr = -10 * math.log10(torch.mean((merged_img[j] - gt[j]) * (merged_img[j] - gt[j])).cpu().data) |
| 119 | psnr_list_teacher.append(psnr) |
| 120 | gt = (gt.permute(0, 2, 3, 1).cpu().numpy() * 255).astype('uint8') |
| 121 | pred = (pred.permute(0, 2, 3, 1).cpu().numpy() * 255).astype('uint8') |
| 122 | merged_img = (merged_img.permute(0, 2, 3, 1).cpu().numpy() * 255).astype('uint8') |
| 123 | flow0 = info['flow'].permute(0, 2, 3, 1).cpu().numpy() |
| 124 | flow1 = info['flow_tea'].permute(0, 2, 3, 1).cpu().numpy() |
| 125 | if i == 0 and local_rank == 0: |
| 126 | for j in range(10): |
| 127 | imgs = np.concatenate((merged_img[j], pred[j], gt[j]), 1)[:, :, ::-1] |
| 128 | writer_val.add_image(str(j) + '/img', imgs.copy(), nr_eval, dataformats='HWC') |
| 129 | writer_val.add_image(str(j) + '/flow', flow2rgb(flow0[j][:, :, ::-1]), nr_eval, dataformats='HWC') |
| 130 | |
| 131 | eval_time_interval = time.time() - time_stamp |
| 132 | |
| 133 | if local_rank != 0: |
| 134 | return |
| 135 | writer_val.add_scalar('psnr', np.array(psnr_list).mean(), nr_eval) |
| 136 | writer_val.add_scalar('psnr_teacher', np.array(psnr_list_teacher).mean(), nr_eval) |
| 137 | |
| 138 | if __name__ == "__main__": |
| 139 | parser = argparse.ArgumentParser() |