MCPcopy
hub / github.com/THUDM/CogDL / modification

Method modification

examples/GRB/attack/modification/fga.py:44–94  ·  view source on GitHub ↗
(
        self, model, adj_origin, features_origin, labels_origin, index_target, feat_norm=None, adj_norm_func=None
    )

Source from the content-addressed store, hash-verified

42 return getGraph(adj_attack, graph.x, graph.y, device=self.device)
43
44 def modification(
45 self, model, adj_origin, features_origin, labels_origin, index_target, feat_norm=None, adj_norm_func=None
46 ):
47 model.eval()
48 if type(adj_origin) == torch.Tensor:
49 adj_attack = adj_origin.clone().to_dense()
50 else:
51 adj_attack = adj_origin.todense()
52 adj_attack = torch.FloatTensor(adj_attack)
53 features_origin = feat_preprocess(features=features_origin, feat_norm=feat_norm, device=self.device)
54 adj_attack.requires_grad = True
55 n_edge_flip = 0
56 for _ in tqdm(range(adj_attack.shape[1])):
57 if n_edge_flip >= self.n_edge_mod:
58 break
59 adj_attack_tensor = adj_preprocess(adj=adj_attack, adj_norm_func=adj_norm_func, device=self.device)
60 # print(type(adj_attack_tensor), adj_attack_tensor.is_sparse)
61 # degs = torch.sparse.sum(adj_attack_tensor, dim=1)
62 degs = adj_attack_tensor.sum(dim=1)
63 # pred = model(getGraph(sp.csr_matrix(adj_attack.detach().cpu().numpy()), features_origin))
64 pred = model(getGraph(adj_attack_tensor, features_origin, device=self.device))
65 loss = self.loss(pred[index_target], labels_origin[index_target])
66 grad = torch.autograd.grad(loss, adj_attack)[0]
67 grad = (grad + grad.T) / torch.Tensor([2.0]).to(self.device)
68 grad_max = torch.max(grad[index_target], dim=1)
69 index_max_i = torch.argmax(grad_max.values)
70 index_max_j = grad_max.indices[index_max_i]
71 index_max_i = index_target[index_max_i]
72 if adj_attack[index_max_i][index_max_j] == 0:
73 adj_attack.data[index_max_i][index_max_j] = 1
74 adj_attack.data[index_max_j][index_max_i] = 1
75 n_edge_flip += 1
76 else:
77 if self.allow_isolate:
78 adj_attack.data[index_max_i][index_max_j] = 0
79 adj_attack.data[index_max_j][index_max_i] = 0
80 n_edge_flip += 1
81 else:
82 if degs[index_max_i] > 1 and degs[index_max_j] > 1:
83 adj_attack.data[index_max_i][index_max_j] = 0
84 adj_attack.data[index_max_j][index_max_i] = 0
85 degs[index_max_i] -= 1
86 degs[index_max_j] -= 1
87 n_edge_flip += 1
88
89 adj_attack = adj_attack.detach().cpu().numpy()
90 adj_attack = sp.csr_matrix(adj_attack)
91 if self.verbose:
92 print("FGA attack finished. {:d} edges were flipped.".format(n_edge_flip))
93
94 return adj_attack

Callers 1

attackMethod · 0.95

Calls 8

feat_preprocessFunction · 0.90
adj_preprocessFunction · 0.90
getGraphFunction · 0.90
gradMethod · 0.80
evalMethod · 0.45
cloneMethod · 0.45
lossMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected