MCPcopy
hub / github.com/THUDM/CogDL / STACK

Class STACK

examples/GRB/attack/modification/stack.py:11–75  ·  view source on GitHub ↗

STACK.

Source from the content-addressed store, hash-verified

9
10
11class STACK(ModificationAttack):
12 """
13 STACK.
14 """
15
16 def __init__(self, n_edge_mod, allow_isolate=True, device="cpu", verbose=True):
17 self.n_edge_mod = n_edge_mod
18 self.allow_isolate = allow_isolate
19 self.device = device
20 self.verbose = verbose
21
22 def attack(self, graph: Graph, **kwargs):
23 adj_attack = self.modification(graph.to_scipy_csr(), graph.test_nid.cpu())
24
25 return getGraph(adj_attack, graph.x, graph.y, device=self.device)
26
27 def modification(self, adj, index_target):
28 adj_attack = adj.copy()
29 degs = adj_attack.getnnz(axis=1)
30 adj_ = adj + sp.eye(adj.shape[0])
31 eigen_vals, eigen_vecs = spl.eigh(adj_.toarray(), np.diag(adj_.getnnz(axis=1)))
32 index_i, index_j = index_target[adj[index_target].nonzero()[0]], adj[index_target].nonzero()[1]
33 edges_target = np.column_stack([index_i, index_j])
34
35 flip_indicator = 1 - 2 * np.array(adj[tuple(edges_target.T)])[0]
36 eigen_scores = np.zeros(len(edges_target))
37 sub_org = np.sqrt(np.sum(eigen_vals ** 2))
38 for x in range(len(edges_target)):
39 i, j = edges_target[x]
40 vals_est = eigen_vals + flip_indicator[x] * (
41 2 * eigen_vecs[i] * eigen_vecs[j] - eigen_vals * (eigen_vecs[i] ** 2 + eigen_vecs[j] ** 2)
42 )
43 loss_ij = np.abs(sub_org - np.sqrt(np.sum(vals_est ** 2)))
44 eigen_scores[x] = loss_ij
45 struct_scores = np.expand_dims(eigen_scores, 1)
46 flip_edges_idx = np.argsort(struct_scores, axis=0)[::-1]
47 flip_edges = edges_target[flip_edges_idx].squeeze()
48
49 n_edge_flip = 0
50 for index in tqdm(flip_edges):
51 if n_edge_flip >= self.n_edge_mod:
52 break
53 if adj_attack[index[0], index[1]] == 0:
54 adj_attack[index[0], index[1]] = 1
55 adj_attack[index[1], index[0]] = 1
56 degs[index[0]] += 1
57 degs[index[1]] += 1
58 n_edge_flip += 1
59 else:
60 if self.allow_isolate:
61 adj_attack[index[0], index[1]] = 0
62 adj_attack[index[1], index[0]] = 0
63 n_edge_flip += 1
64 else:
65 if degs[index[0]] > 1 and degs[index[1]] > 1:
66 adj_attack[index[0], index[1]] = 0
67 adj_attack[index[1], index[0]] = 0
68 degs[index[0]] -= 1

Callers 1

Calls

no outgoing calls

Tested by 1