MCPcopy
hub / github.com/THUDM/CogDL / attack

Method attack

examples/GRB/attack/injection/fgsm.py:88–141  ·  view source on GitHub ↗

r""" Description ----------- Attack process consists of injection and feature update. Parameters ---------- model : torch.nn.module Model implemented based on ``torch.nn.module``. graph : cogdl.data.Graph Graph to attc

(self, model, graph, feat_norm=None, adj_norm_func=None)

Source from the content-addressed store, hash-verified

86 self.early_stop = None
87
88 def attack(self, model, graph, feat_norm=None, adj_norm_func=None):
89 r"""
90
91 Description
92 -----------
93 Attack process consists of injection and feature update.
94
95 Parameters
96 ----------
97 model : torch.nn.module
98 Model implemented based on ``torch.nn.module``.
99 graph : cogdl.data.Graph
100 Graph to attcak.
101 feat_norm : str, optional
102 Type of feature normalization, ['arctan', 'tanh']. Default: ``None``.
103 adj_norm_func : func of grb_utils.normalize, optional
104 Function that normalizes adjacency matrix. Default: ``None``.
105
106 Returns
107 -------
108 out_graph : cogdl.data.Graph
109 Graph attacked.
110
111 """
112 time_start = time.time()
113 adj = graph.to_scipy_csr()
114 target_mask = graph.test_mask
115 features = graph.x
116 # adj, features = getGRBGraph(graph)
117 model.to(self.device)
118 n_total, n_feat = features.shape
119 features = feat_preprocess(features=features, feat_norm=feat_norm, device=self.device)
120 adj_tensor = adj_preprocess(adj=adj, adj_norm_func=adj_norm_func, device=self.device)
121 pred_origin = model(getGraph(adj_tensor, features, device=self.device))
122 labels_origin = torch.argmax(pred_origin, dim=1)
123 adj_attack = self.injection(adj=adj, n_inject=self.n_inject_max, n_node=n_total, target_mask=target_mask)
124 features_attack = np.zeros((self.n_inject_max, n_feat))
125 features_attack = self.update_features(
126 model=model,
127 adj_attack=adj_attack,
128 features_origin=features,
129 features_attack=features_attack,
130 labels_origin=labels_origin,
131 target_mask=target_mask,
132 feat_norm=feat_norm,
133 adj_norm_func=adj_norm_func,
134 )
135 out_features = torch.cat((features, features_attack), 0)
136 time_end = time.time()
137 if self.verbose:
138 print("Attack runtime: {:.4f}.".format(time_end - time_start))
139
140 out_graph = getGraph(adj_attack, out_features, graph.y, device=self.device)
141 return out_graph
142
143 def injection(self, adj, n_inject, n_node, target_mask):
144 r"""

Callers 2

test_adversarial_trainFunction · 0.95

Calls 7

injectionMethod · 0.95
update_featuresMethod · 0.95
feat_preprocessFunction · 0.90
adj_preprocessFunction · 0.90
getGraphFunction · 0.90
to_scipy_csrMethod · 0.45
toMethod · 0.45

Tested by 2

test_adversarial_trainFunction · 0.76