()
| 6 | import os |
| 7 | |
| 8 | def test_set_best_config(): |
| 9 | args = get_default_args(task="node_classification", dataset="citeseer", model="gat") |
| 10 | args.model = args.model[0] |
| 11 | args.dataset = args.dataset[0] |
| 12 | args = set_best_config(args) |
| 13 | |
| 14 | assert args.lr == 0.005 |
| 15 | assert args.epochs == 1000 |
| 16 | assert args.weight_decay == 0.001 |
| 17 | |
| 18 | |
| 19 | def test_train(): |
no test coverage detected