(device, A, ndata, dataset, model)
| 165 | |
| 166 | |
| 167 | def train(device, A, ndata, dataset, model): |
| 168 | # Create sampler & dataloader. |
| 169 | train_idx = dataset.train_idx.to(device) |
| 170 | val_idx = dataset.val_idx.to(device) |
| 171 | |
| 172 | train_dataloader = torch.utils.data.DataLoader( |
| 173 | train_idx, batch_size=1024, shuffle=True |
| 174 | ) |
| 175 | val_dataloader = torch.utils.data.DataLoader(val_idx, batch_size=1024) |
| 176 | |
| 177 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4) |
| 178 | |
| 179 | fanouts = [4000, 4000, 4000] |
| 180 | for epoch in range(20): |
| 181 | model.train() |
| 182 | total_loss = 0 |
| 183 | for it, seeds in enumerate(train_dataloader): |
| 184 | sampled_matrices, x, y = multilayer_sample(A, fanouts, seeds, ndata) |
| 185 | y_hat = model(sampled_matrices, x) |
| 186 | loss = F.cross_entropy(y_hat, y) |
| 187 | optimizer.zero_grad() |
| 188 | loss.backward() |
| 189 | optimizer.step() |
| 190 | total_loss += loss.item() |
| 191 | |
| 192 | acc = evaluate(model, A, val_dataloader, ndata, dataset.num_classes) |
| 193 | print( |
| 194 | "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format( |
| 195 | epoch, total_loss / (it + 1), acc.item() |
| 196 | ) |
| 197 | ) |
| 198 | |
| 199 | |
| 200 | if __name__ == "__main__": |
no test coverage detected