r""" Description ----------- Attack process consists of injection and feature update. Parameters ---------- model : torch.nn.module Model implemented based on ``torch.nn.module``. graph : cogdl.data.Graph Graph to attc
(self, model, graph, feat_norm=None, adj_norm_func=None)
| 86 | self.early_stop = None |
| 87 | |
| 88 | def attack(self, model, graph, feat_norm=None, adj_norm_func=None): |
| 89 | r""" |
| 90 | |
| 91 | Description |
| 92 | ----------- |
| 93 | Attack process consists of injection and feature update. |
| 94 | |
| 95 | Parameters |
| 96 | ---------- |
| 97 | model : torch.nn.module |
| 98 | Model implemented based on ``torch.nn.module``. |
| 99 | graph : cogdl.data.Graph |
| 100 | Graph to attcak. |
| 101 | feat_norm : str, optional |
| 102 | Type of feature normalization, ['arctan', 'tanh']. Default: ``None``. |
| 103 | adj_norm_func : func of grb_utils.normalize, optional |
| 104 | Function that normalizes adjacency matrix. Default: ``None``. |
| 105 | |
| 106 | Returns |
| 107 | ------- |
| 108 | out_graph : cogdl.data.Graph |
| 109 | Graph attacked. |
| 110 | |
| 111 | """ |
| 112 | time_start = time.time() |
| 113 | adj = graph.to_scipy_csr() |
| 114 | target_mask = graph.test_mask |
| 115 | features = graph.x |
| 116 | # adj, features = getGRBGraph(graph) |
| 117 | model.to(self.device) |
| 118 | n_total, n_feat = features.shape |
| 119 | features = feat_preprocess(features=features, feat_norm=feat_norm, device=self.device) |
| 120 | adj_tensor = adj_preprocess(adj=adj, adj_norm_func=adj_norm_func, device=self.device) |
| 121 | pred_origin = model(getGraph(adj_tensor, features, device=self.device)) |
| 122 | labels_origin = torch.argmax(pred_origin, dim=1) |
| 123 | adj_attack = self.injection(adj=adj, n_inject=self.n_inject_max, n_node=n_total, target_mask=target_mask) |
| 124 | features_attack = np.zeros((self.n_inject_max, n_feat)) |
| 125 | features_attack = self.update_features( |
| 126 | model=model, |
| 127 | adj_attack=adj_attack, |
| 128 | features_origin=features, |
| 129 | features_attack=features_attack, |
| 130 | labels_origin=labels_origin, |
| 131 | target_mask=target_mask, |
| 132 | feat_norm=feat_norm, |
| 133 | adj_norm_func=adj_norm_func, |
| 134 | ) |
| 135 | out_features = torch.cat((features, features_attack), 0) |
| 136 | time_end = time.time() |
| 137 | if self.verbose: |
| 138 | print("Attack runtime: {:.4f}.".format(time_end - time_start)) |
| 139 | |
| 140 | out_graph = getGraph(adj_attack, out_features, graph.y, device=self.device) |
| 141 | return out_graph |
| 142 | |
| 143 | def injection(self, adj, n_inject, n_node, target_mask): |
| 144 | r""" |