(g, pred)
| 72 | |
| 73 | |
| 74 | def evaluate(g, pred): |
| 75 | label = g.ndata["label"] |
| 76 | val_mask = g.ndata["val_mask"] |
| 77 | test_mask = g.ndata["test_mask"] |
| 78 | |
| 79 | # Compute accuracy on validation/test set. |
| 80 | val_acc = (pred[val_mask] == label[val_mask]).float().mean() |
| 81 | test_acc = (pred[test_mask] == label[test_mask]).float().mean() |
| 82 | return val_acc, test_acc |
| 83 | |
| 84 | |
| 85 | def train(model, g, A_hat, X): |