(args)
| 113 | |
| 114 | |
| 115 | def main(args): |
| 116 | H, X, Y, num_classes, train_mask, val_mask, test_mask = load_data() |
| 117 | model = Net(X.shape[1], num_classes) |
| 118 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) |
| 119 | |
| 120 | with tqdm.trange(args.epochs) as tq: |
| 121 | for epoch in tq: |
| 122 | loss = train(model, optimizer, H, X, Y, train_mask) |
| 123 | val_acc, test_acc = evaluate( |
| 124 | model, H, X, Y, val_mask, test_mask, num_classes |
| 125 | ) |
| 126 | tq.set_postfix( |
| 127 | { |
| 128 | "Loss": f"{loss:.5f}", |
| 129 | "Val acc": f"{val_acc:.5f}", |
| 130 | "Test acc": f"{test_acc:.5f}", |
| 131 | }, |
| 132 | refresh=False, |
| 133 | ) |
| 134 | |
| 135 | print(f"Test acc: {test_acc:.3f}") |
| 136 | |
| 137 | |
| 138 | if __name__ == "__main__": |
no test coverage detected