| 69 | |
| 70 | |
| 71 | def target_test(model, loader): |
| 72 | with torch.no_grad(): |
| 73 | model.eval() |
| 74 | |
| 75 | total = 0 |
| 76 | top1 = 0 |
| 77 | for i, (raw_batch, batch, raw_label, target_label) in enumerate(loader): |
| 78 | raw_batch = raw_batch.to('cuda') |
| 79 | batch = batch.to('cuda') |
| 80 | raw_label = raw_label.to('cuda') |
| 81 | target_label = target_label.to('cuda') |
| 82 | |
| 83 | out = model(raw_batch) |
| 84 | _, raw_pred = out.max(dim=1) |
| 85 | out = model(batch) |
| 86 | _, pred = out.max(dim=1) |
| 87 | |
| 88 | raw_correct = raw_pred.eq(raw_label) |
| 89 | total += int(raw_correct.sum().item()) |
| 90 | valid_target_correct = pred.eq(target_label) * raw_correct |
| 91 | top1 += int(valid_target_correct.sum().item()) |
| 92 | return float(top1) / total * 100 |
| 93 | |
| 94 | |
| 95 | def clean_test(model, loader): |