(device, A, ndata, dataset, model, batch_size)
| 158 | |
| 159 | |
| 160 | def validate(device, A, ndata, dataset, model, batch_size): |
| 161 | inf_id = dataset.test_idx.to(device) |
| 162 | inf_dataloader = torch.utils.data.DataLoader(inf_id, batch_size=batch_size) |
| 163 | acc = evaluate(model, A, inf_dataloader, ndata, dataset.num_classes) |
| 164 | return acc |
| 165 | |
| 166 | |
| 167 | def train(device, A, ndata, dataset, model): |