Function
evaluate
(model, H, X, Y, val_mask, test_mask, num_classes)
Source from the content-addressed store, hash-verified
| 75 | |
| 76 | |
| 77 | def evaluate(model, H, X, Y, val_mask, test_mask, num_classes): |
| 78 | model.eval() |
| 79 | Y_hat = model(H, X) |
| 80 | val_acc = accuracy( |
| 81 | Y_hat[val_mask], Y[val_mask], task="multiclass", num_classes=num_classes |
| 82 | ) |
| 83 | test_acc = accuracy( |
| 84 | Y_hat[test_mask], |
| 85 | Y[test_mask], |
| 86 | task="multiclass", |
| 87 | num_classes=num_classes, |
| 88 | ) |
| 89 | return val_acc, test_acc |
| 90 | |
| 91 | |
| 92 | def load_data(): |
Tested by
no test coverage detected