MCPcopy
hub / github.com/THUDM/CogDL / train_model

Function train_model

examples/GRB/test_attack_defense.py:18–41  ·  view source on GitHub ↗
(model, graph, dataset, test_mask, device, device_ids)

Source from the content-addressed store, hash-verified

16
17
18def train_model(model, graph, dataset, test_mask, device, device_ids):
19 print(model)
20 print("training...")
21 mw_class = fetch_model_wrapper("node_classification_mw")
22 dw_class = fetch_data_wrapper("node_classification_dw")
23 optimizer_cfg = dict(
24 lr=0.01,
25 weight_decay=0
26 )
27 model_wrapper1 = mw_class(model, optimizer_cfg)
28 dataset_wrapper = dw_class(dataset)
29 trainer = Trainer(epochs=3,
30 early_stopping=True,
31 patience=1,
32 cpu=device=="cpu",
33 device_ids=device_ids)
34 trainer.run(model_wrapper1, dataset_wrapper)
35 model.load_state_dict(torch.load("./checkpoints/model.pt"), False)
36 model.to(device)
37 test_score = evaluate(model,
38 graph,
39 mask=test_mask,
40 device=device)
41 return test_score
42
43
44def init_dataset():

Callers 6

init_surrogate_modelFunction · 0.85
init_target_modelFunction · 0.85
test_gcnsvd_defenseFunction · 0.85
test_gatguard_defenseFunction · 0.85
test_gcnguard_defenseFunction · 0.85
test_robustgcn_defenseFunction · 0.85

Calls 6

runMethod · 0.95
fetch_model_wrapperFunction · 0.90
fetch_data_wrapperFunction · 0.90
TrainerClass · 0.90
evaluateFunction · 0.90
toMethod · 0.45

Tested by

no test coverage detected