(model, graph, dataset, test_mask, device, device_ids)
| 16 | |
| 17 | |
| 18 | def train_model(model, graph, dataset, test_mask, device, device_ids): |
| 19 | print(model) |
| 20 | print("training...") |
| 21 | mw_class = fetch_model_wrapper("node_classification_mw") |
| 22 | dw_class = fetch_data_wrapper("node_classification_dw") |
| 23 | optimizer_cfg = dict( |
| 24 | lr=0.01, |
| 25 | weight_decay=0 |
| 26 | ) |
| 27 | model_wrapper1 = mw_class(model, optimizer_cfg) |
| 28 | dataset_wrapper = dw_class(dataset) |
| 29 | trainer = Trainer(epochs=3, |
| 30 | early_stopping=True, |
| 31 | patience=1, |
| 32 | cpu=device=="cpu", |
| 33 | device_ids=device_ids) |
| 34 | trainer.run(model_wrapper1, dataset_wrapper) |
| 35 | model.load_state_dict(torch.load("./checkpoints/model.pt"), False) |
| 36 | model.to(device) |
| 37 | test_score = evaluate(model, |
| 38 | graph, |
| 39 | mask=test_mask, |
| 40 | device=device) |
| 41 | return test_score |
| 42 | |
| 43 | |
| 44 | def init_dataset(): |
no test coverage detected