(g, features, labels, mask, model)
| 51 | |
| 52 | |
| 53 | def evaluate(g, features, labels, mask, model): |
| 54 | model.eval() |
| 55 | with torch.no_grad(): |
| 56 | logits = model(g, features) |
| 57 | logits = logits[mask] |
| 58 | labels = labels[mask] |
| 59 | _, indices = torch.max(logits, dim=1) |
| 60 | correct = torch.sum(indices == labels) |
| 61 | return correct.item() * 1.0 / len(labels) |
| 62 | |
| 63 | |
| 64 | def train(g, features, labels, masks, model): |
no outgoing calls
no test coverage detected