r""" Description ----------- Preprocess the adjacency matrix. Parameters ---------- adj : scipy.sparse.csr.csr_matrix or a tuple Adjacency matrix in form of ``N * N`` sparse matrix. adj_norm_func : func of utils.normalize, optional Function that normaliz
(adj, adj_norm_func=None, mask=None, device="cpu")
| 98 | |
| 99 | |
| 100 | def adj_preprocess(adj, adj_norm_func=None, mask=None, device="cpu"): |
| 101 | r""" |
| 102 | |
| 103 | Description |
| 104 | ----------- |
| 105 | Preprocess the adjacency matrix. |
| 106 | |
| 107 | Parameters |
| 108 | ---------- |
| 109 | adj : scipy.sparse.csr.csr_matrix or a tuple |
| 110 | Adjacency matrix in form of ``N * N`` sparse matrix. |
| 111 | adj_norm_func : func of utils.normalize, optional |
| 112 | Function that normalizes adjacency matrix. Default: ``None``. |
| 113 | mask : torch.Tensor, optional |
| 114 | Mask of nodes in form of ``N * 1`` torch bool tensor. Default: ``None``. |
| 115 | model_type : str, optional |
| 116 | Type of model's backend, choose from ["torch", "cogdl", "dgl"]. Default: ``"torch"``. |
| 117 | device : str, optional |
| 118 | Device used to host data. Default: ``cpu``. |
| 119 | |
| 120 | Returns |
| 121 | ------- |
| 122 | adj : torch.Tensor or a tuple |
| 123 | Adjacency matrix in form of ``N * N`` sparse tensor or a tuple. |
| 124 | |
| 125 | """ |
| 126 | |
| 127 | if adj_norm_func is not None: |
| 128 | adj = adj_norm_func(adj) |
| 129 | |
| 130 | if type(adj) is tuple or type(adj) is list: |
| 131 | if mask is not None: |
| 132 | adj = [ |
| 133 | adj_to_tensor(adj_[mask][:, mask]).to(device) |
| 134 | if type(adj_) != torch.Tensor |
| 135 | else adj_[mask][:, mask].to(device) |
| 136 | for adj_ in adj |
| 137 | ] |
| 138 | else: |
| 139 | adj = [adj_to_tensor(adj_).to(device) if type(adj_) != torch.Tensor else adj_.to(device) for adj_ in adj] |
| 140 | else: |
| 141 | if type(adj) != torch.Tensor: |
| 142 | if mask is not None: |
| 143 | adj = adj_to_tensor(adj[mask][:, mask]).to(device) |
| 144 | else: |
| 145 | adj = adj_to_tensor(adj).to(device) |
| 146 | else: |
| 147 | if mask is not None: |
| 148 | adj = adj[mask][:, mask].to(device) |
| 149 | else: |
| 150 | adj = adj.to(device) |
| 151 | |
| 152 | return adj |
| 153 | |
| 154 | |
| 155 | def feat_preprocess(features, feat_norm=None, device="cpu"): |
no test coverage detected