()
| 533 | |
| 534 | |
| 535 | def test_gcnguard_defense(): |
| 536 | graph, dataset, test_mask, device, device_ids = init_dataset() |
| 537 | model_sur = init_surrogate_model(graph, dataset, test_mask, device, device_ids) |
| 538 | print("GATGuard defense model...") |
| 539 | model_target = GCNGuard(in_feats=graph.num_features, |
| 540 | hidden_size=16, |
| 541 | out_feats=graph.num_classes, |
| 542 | num_layers=2, |
| 543 | activation="relu", |
| 544 | norm="layernorm", |
| 545 | drop=True |
| 546 | ) |
| 547 | test_score = train_model(model_target, graph, dataset, test_mask, device, device_ids) |
| 548 | print("Test score before attack for target model: {:.4f}.".format(test_score)) |
| 549 | apply_pgd_modification_attack(model_sur, model_target, graph, test_mask, device) |
| 550 | |
| 551 | |
| 552 | def test_robustgcn_defense(): |
no test coverage detected