(g, pred)
| 43 | |
| 44 | |
| 45 | def evaluate(g, pred): |
| 46 | label = g.ndata["label"] |
| 47 | val_mask = g.ndata["val_mask"] |
| 48 | test_mask = g.ndata["test_mask"] |
| 49 | |
| 50 | # Compute accuracy on validation/test set. |
| 51 | val_acc = (pred[val_mask] == label[val_mask]).float().mean() |
| 52 | test_acc = (pred[test_mask] == label[test_mask]).float().mean() |
| 53 | return val_acc, test_acc |
| 54 | |
| 55 | |
| 56 | def train(model, g, A_hat, X): |