| 83 | |
| 84 | |
| 85 | def train(model, g, A_hat, X): |
| 86 | label = g.ndata["label"] |
| 87 | train_mask = g.ndata["train_mask"] |
| 88 | optimizer = Adam(model.parameters(), lr=1e-2, weight_decay=5e-4) |
| 89 | |
| 90 | for epoch in range(50): |
| 91 | # Forward. |
| 92 | model.train() |
| 93 | logits = model(A_hat, X) |
| 94 | |
| 95 | # Compute loss with nodes in training set. |
| 96 | loss = F.cross_entropy(logits[train_mask], label[train_mask]) |
| 97 | |
| 98 | # Backward. |
| 99 | optimizer.zero_grad() |
| 100 | loss.backward() |
| 101 | optimizer.step() |
| 102 | |
| 103 | # Compute prediction. |
| 104 | model.eval() |
| 105 | logits = model(A_hat, X) |
| 106 | pred = logits.argmax(dim=1) |
| 107 | |
| 108 | # Evaluate the prediction. |
| 109 | val_acc, test_acc = evaluate(g, pred) |
| 110 | print( |
| 111 | f"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}, test" |
| 112 | f" acc: {test_acc:.3f}" |
| 113 | ) |
| 114 | |
| 115 | |
| 116 | if __name__ == "__main__": |