MCPcopy
hub / github.com/THUDM/CogDL / attack

Method attack

examples/GRB/attack/injection/speit.py:93–147  ·  view source on GitHub ↗

r""" Description ----------- Attack process consists of injection and feature update. Parameters ---------- model : torch.nn.module Model implemented based on ``torch.nn.module``. graph : cogdl.data.Graph Graph to attca

(self, model, graph, feat_norm=None, adj_norm_func=None)

Source from the content-addressed store, hash-verified

91 self.early_stop = None
92
93 def attack(self, model, graph, feat_norm=None, adj_norm_func=None):
94 r"""
95 Description
96 -----------
97 Attack process consists of injection and feature update.
98
99 Parameters
100 ----------
101 model : torch.nn.module
102 Model implemented based on ``torch.nn.module``.
103 graph : cogdl.data.Graph
104 Graph to attcak.
105 feat_norm : str, optional
106 Type of feature normalization, ['arctan', 'tanh']. Default: ``None``.
107 adj_norm_func : func of utils.normalize, optional
108 Function that normalizes adjacency matrix. Default: ``None``.
109
110 Returns
111 -------
112 out_graph : cogdl.data.Graph
113 Graph attacked.
114
115 """
116 time_start = time.time()
117 adj = graph.to_scipy_csr()
118 target_mask = graph.test_mask
119 features = graph.x
120 model.to(self.device)
121 n_total, n_feat = features.shape
122 features = feat_preprocess(features=features, feat_norm=feat_norm, device=self.device)
123 adj_tensor = adj_preprocess(adj=adj, adj_norm_func=adj_norm_func, device=self.device)
124 pred_origin = model(getGraph(adj_tensor, features, device=self.device))
125 labels_origin = torch.argmax(pred_origin, dim=1)
126 adj_attack = self.injection(
127 adj=adj, n_inject=self.n_inject_max, n_node=n_total, target_mask=target_mask, mode=self.inject_mode
128 )
129
130 features_attack = np.zeros([self.n_inject_max, n_feat])
131 features_attack = self.update_features(
132 model=model,
133 adj_attack=adj_attack,
134 features_origin=features,
135 features_attack=features_attack,
136 labels_origin=labels_origin,
137 target_mask=target_mask,
138 feat_norm=feat_norm,
139 adj_norm_func=adj_norm_func,
140 )
141 out_features = torch.cat((features, features_attack), 0)
142 time_end = time.time()
143 if self.verbose:
144 print("Attack runtime: {:.4f}.".format(time_end - time_start))
145
146 out_graph = getGraph(adj_attack, out_features, graph.y, device=self.device)
147 return out_graph
148
149 # flake8: noqa: C901
150 def injection(self, adj, n_inject, n_node, target_mask, mode="random-inter"):

Callers

nothing calls this directly

Calls 7

injectionMethod · 0.95
update_featuresMethod · 0.95
feat_preprocessFunction · 0.90
adj_preprocessFunction · 0.90
getGraphFunction · 0.90
to_scipy_csrMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected