MCPcopy Index your code
hub / github.com/hzwer/ECCV2022-RIFE / evaluate

Function evaluate

train.py:97–136  ·  view source on GitHub ↗
(model, val_data, nr_eval, local_rank, writer_val)

Source from the content-addressed store, hash-verified

95 dist.barrier()
96
97def evaluate(model, val_data, nr_eval, local_rank, writer_val):
98 loss_l1_list = []
99 loss_distill_list = []
100 loss_tea_list = []
101 psnr_list = []
102 psnr_list_teacher = []
103 time_stamp = time.time()
104 for i, data in enumerate(val_data):
105 data_gpu, timestep = data
106 data_gpu = data_gpu.to(device, non_blocking=True) / 255.
107 imgs = data_gpu[:, :6]
108 gt = data_gpu[:, 6:9]
109 with torch.no_grad():
110 pred, info = model.update(imgs, gt, training=False)
111 merged_img = info['merged_tea']
112 loss_l1_list.append(info['loss_l1'].cpu().numpy())
113 loss_tea_list.append(info['loss_tea'].cpu().numpy())
114 loss_distill_list.append(info['loss_distill'].cpu().numpy())
115 for j in range(gt.shape[0]):
116 psnr = -10 * math.log10(torch.mean((gt[j] - pred[j]) * (gt[j] - pred[j])).cpu().data)
117 psnr_list.append(psnr)
118 psnr = -10 * math.log10(torch.mean((merged_img[j] - gt[j]) * (merged_img[j] - gt[j])).cpu().data)
119 psnr_list_teacher.append(psnr)
120 gt = (gt.permute(0, 2, 3, 1).cpu().numpy() * 255).astype('uint8')
121 pred = (pred.permute(0, 2, 3, 1).cpu().numpy() * 255).astype('uint8')
122 merged_img = (merged_img.permute(0, 2, 3, 1).cpu().numpy() * 255).astype('uint8')
123 flow0 = info['flow'].permute(0, 2, 3, 1).cpu().numpy()
124 flow1 = info['flow_tea'].permute(0, 2, 3, 1).cpu().numpy()
125 if i == 0 and local_rank == 0:
126 for j in range(10):
127 imgs = np.concatenate((merged_img[j], pred[j], gt[j]), 1)[:, :, ::-1]
128 writer_val.add_image(str(j) + '/img', imgs.copy(), nr_eval, dataformats='HWC')
129 writer_val.add_image(str(j) + '/flow', flow2rgb(flow0[j][:, :, ::-1]), nr_eval, dataformats='HWC')
130
131 eval_time_interval = time.time() - time_stamp
132
133 if local_rank != 0:
134 return
135 writer_val.add_scalar('psnr', np.array(psnr_list).mean(), nr_eval)
136 writer_val.add_scalar('psnr_teacher', np.array(psnr_list_teacher).mean(), nr_eval)
137
138if __name__ == "__main__":
139 parser = argparse.ArgumentParser()

Callers 1

trainFunction · 0.85

Calls 2

flow2rgbFunction · 0.85
updateMethod · 0.45

Tested by

no test coverage detected