(model, A, dataloader, ndata, num_classes)
| 139 | |
| 140 | |
| 141 | def evaluate(model, A, dataloader, ndata, num_classes): |
| 142 | model.eval() |
| 143 | ys = [] |
| 144 | y_hats = [] |
| 145 | fanouts = [4000, 4000, 4000] |
| 146 | for seeds in dataloader: |
| 147 | with torch.no_grad(): |
| 148 | sampled_matrices, x, y = multilayer_sample(A, fanouts, seeds, ndata) |
| 149 | ys.append(y) |
| 150 | y_hats.append(model(sampled_matrices, x)) |
| 151 | |
| 152 | return MF.accuracy( |
| 153 | torch.cat(y_hats), |
| 154 | torch.cat(ys), |
| 155 | task="multiclass", |
| 156 | num_classes=num_classes, |
| 157 | ) |
| 158 | |
| 159 | |
| 160 | def validate(device, A, ndata, dataset, model, batch_size): |
no test coverage detected