r""" Description ----------- Attack process consists of injection and feature update. Parameters ---------- model : torch.nn.module Model implemented based on ``torch.nn.module``. graph : cogdl.data.Graph Graph to attca
(self, model, graph, feat_norm=None, adj_norm_func=None)
| 91 | self.early_stop = None |
| 92 | |
| 93 | def attack(self, model, graph, feat_norm=None, adj_norm_func=None): |
| 94 | r""" |
| 95 | Description |
| 96 | ----------- |
| 97 | Attack process consists of injection and feature update. |
| 98 | |
| 99 | Parameters |
| 100 | ---------- |
| 101 | model : torch.nn.module |
| 102 | Model implemented based on ``torch.nn.module``. |
| 103 | graph : cogdl.data.Graph |
| 104 | Graph to attcak. |
| 105 | feat_norm : str, optional |
| 106 | Type of feature normalization, ['arctan', 'tanh']. Default: ``None``. |
| 107 | adj_norm_func : func of utils.normalize, optional |
| 108 | Function that normalizes adjacency matrix. Default: ``None``. |
| 109 | |
| 110 | Returns |
| 111 | ------- |
| 112 | out_graph : cogdl.data.Graph |
| 113 | Graph attacked. |
| 114 | |
| 115 | """ |
| 116 | time_start = time.time() |
| 117 | adj = graph.to_scipy_csr() |
| 118 | target_mask = graph.test_mask |
| 119 | features = graph.x |
| 120 | model.to(self.device) |
| 121 | n_total, n_feat = features.shape |
| 122 | features = feat_preprocess(features=features, feat_norm=feat_norm, device=self.device) |
| 123 | adj_tensor = adj_preprocess(adj=adj, adj_norm_func=adj_norm_func, device=self.device) |
| 124 | pred_origin = model(getGraph(adj_tensor, features, device=self.device)) |
| 125 | labels_origin = torch.argmax(pred_origin, dim=1) |
| 126 | adj_attack = self.injection( |
| 127 | adj=adj, n_inject=self.n_inject_max, n_node=n_total, target_mask=target_mask, mode=self.inject_mode |
| 128 | ) |
| 129 | |
| 130 | features_attack = np.zeros([self.n_inject_max, n_feat]) |
| 131 | features_attack = self.update_features( |
| 132 | model=model, |
| 133 | adj_attack=adj_attack, |
| 134 | features_origin=features, |
| 135 | features_attack=features_attack, |
| 136 | labels_origin=labels_origin, |
| 137 | target_mask=target_mask, |
| 138 | feat_norm=feat_norm, |
| 139 | adj_norm_func=adj_norm_func, |
| 140 | ) |
| 141 | out_features = torch.cat((features, features_attack), 0) |
| 142 | time_end = time.time() |
| 143 | if self.verbose: |
| 144 | print("Attack runtime: {:.4f}.".format(time_end - time_start)) |
| 145 | |
| 146 | out_graph = getGraph(adj_attack, out_features, graph.y, device=self.device) |
| 147 | return out_graph |
| 148 | |
| 149 | # flake8: noqa: C901 |
| 150 | def injection(self, adj, n_inject, n_node, target_mask, mode="random-inter"): |
nothing calls this directly
no test coverage detected