| 54 | |
| 55 | |
| 56 | def train(model, g, A_hat, X): |
| 57 | label = g.ndata["label"] |
| 58 | train_mask = g.ndata["train_mask"] |
| 59 | optimizer = Adam(model.parameters(), lr=1e-2, weight_decay=5e-4) |
| 60 | |
| 61 | for epoch in range(50): |
| 62 | # Forward. |
| 63 | model.train() |
| 64 | logits = model(A_hat, X) |
| 65 | |
| 66 | # Compute loss with nodes in training set. |
| 67 | loss = F.cross_entropy(logits[train_mask], label[train_mask]) |
| 68 | |
| 69 | # Backward. |
| 70 | optimizer.zero_grad() |
| 71 | loss.backward() |
| 72 | optimizer.step() |
| 73 | |
| 74 | # Compute prediction. |
| 75 | model.eval() |
| 76 | logits = model(A_hat, X) |
| 77 | pred = logits.argmax(dim=1) |
| 78 | |
| 79 | # Evaluate the prediction. |
| 80 | val_acc, test_acc = evaluate(g, pred) |
| 81 | print( |
| 82 | f"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}, test" |
| 83 | f" acc: {test_acc:.3f}" |
| 84 | ) |
| 85 | |
| 86 | |
| 87 | if __name__ == "__main__": |