(graph: BaseGraph, dataloader: Iterable, method: str='psnr')
| 63 | |
| 64 | |
| 65 | def evaluation(graph: BaseGraph, dataloader: Iterable, method: str='psnr'): |
| 66 | if method not in {'psnr', 'ssim'}: raise Exception('Evaluation method not understood.') |
| 67 | executor = TorchExecutor(graph) |
| 68 | ret_collector = [] |
| 69 | |
| 70 | for lr_img, hr_img in tqdm(dataloader): |
| 71 | pred = executor.forward(lr_img.cuda())[0] |
| 72 | real = hr_img |
| 73 | |
| 74 | # post processing |
| 75 | pred = convert_any_to_numpy((pred.squeeze(0) * 255).round()) |
| 76 | real = convert_any_to_numpy((real.squeeze(0) * 255).round()) |
| 77 | |
| 78 | if method == 'psnr': sample_ret = psnr(img1=real, img2=pred, input_order='CHW') |
| 79 | else: sample_ret = ssim(img1=real, img2=pred, input_order='CHW') |
| 80 | ret_collector.append(sample_ret) |
| 81 | |
| 82 | return sum(ret_collector) / len(ret_collector) |
| 83 | |
| 84 | calib_loader = load_div2k_dataset( |
| 85 | lr_folder = TRAIN_LR_DIR, |
no test coverage detected