| 72 | |
| 73 | |
| 74 | def test_epoch(model, test_loader, prefix='Test'): |
| 75 | model.eval() |
| 76 | total_loss = 0 |
| 77 | total_loss_1 = 0 |
| 78 | total_loss_r = 0 |
| 79 | correct = 0 |
| 80 | criterion = nn.MSELoss() |
| 81 | criterion_1 = nn.L1Loss() |
| 82 | for feature, label, label_reg in tqdm(test_loader, desc=prefix, total=len(test_loader)): |
| 83 | feature, label_reg = feature.cuda().float(), label_reg.cuda().float() |
| 84 | with torch.no_grad(): |
| 85 | pred,_ = model(feature) |
| 86 | pred = torch.mean(pred,dim=1).view(pred.shape[0]) |
| 87 | loss = criterion(pred, label_reg) |
| 88 | loss_r = torch.sqrt(loss) |
| 89 | loss_1 = criterion_1(pred, label_reg) |
| 90 | total_loss += loss.item() |
| 91 | total_loss_1 += loss_1.item() |
| 92 | total_loss_r += loss_r.item() |
| 93 | loss = total_loss / len(test_loader) |
| 94 | loss_1 = total_loss_1 / len(test_loader) |
| 95 | loss_r = loss_r / len(test_loader) |
| 96 | return loss, loss_1, loss_r |
| 97 | |
| 98 | |
| 99 | def test_epoch_inference(model, test_loader, prefix='Test'): |