MCPcopy
hub / github.com/THUDM/CogDL / adj_preprocess

Function adj_preprocess

cogdl/utils/grb_utils.py:100–152  ·  view source on GitHub ↗

r""" Description ----------- Preprocess the adjacency matrix. Parameters ---------- adj : scipy.sparse.csr.csr_matrix or a tuple Adjacency matrix in form of ``N * N`` sparse matrix. adj_norm_func : func of utils.normalize, optional Function that normaliz

(adj, adj_norm_func=None, mask=None, device="cpu")

Source from the content-addressed store, hash-verified

98
99
100def adj_preprocess(adj, adj_norm_func=None, mask=None, device="cpu"):
101 r"""
102
103 Description
104 -----------
105 Preprocess the adjacency matrix.
106
107 Parameters
108 ----------
109 adj : scipy.sparse.csr.csr_matrix or a tuple
110 Adjacency matrix in form of ``N * N`` sparse matrix.
111 adj_norm_func : func of utils.normalize, optional
112 Function that normalizes adjacency matrix. Default: ``None``.
113 mask : torch.Tensor, optional
114 Mask of nodes in form of ``N * 1`` torch bool tensor. Default: ``None``.
115 model_type : str, optional
116 Type of model's backend, choose from ["torch", "cogdl", "dgl"]. Default: ``"torch"``.
117 device : str, optional
118 Device used to host data. Default: ``cpu``.
119
120 Returns
121 -------
122 adj : torch.Tensor or a tuple
123 Adjacency matrix in form of ``N * N`` sparse tensor or a tuple.
124
125 """
126
127 if adj_norm_func is not None:
128 adj = adj_norm_func(adj)
129
130 if type(adj) is tuple or type(adj) is list:
131 if mask is not None:
132 adj = [
133 adj_to_tensor(adj_[mask][:, mask]).to(device)
134 if type(adj_) != torch.Tensor
135 else adj_[mask][:, mask].to(device)
136 for adj_ in adj
137 ]
138 else:
139 adj = [adj_to_tensor(adj_).to(device) if type(adj_) != torch.Tensor else adj_.to(device) for adj_ in adj]
140 else:
141 if type(adj) != torch.Tensor:
142 if mask is not None:
143 adj = adj_to_tensor(adj[mask][:, mask]).to(device)
144 else:
145 adj = adj_to_tensor(adj).to(device)
146 else:
147 if mask is not None:
148 adj = adj[mask][:, mask].to(device)
149 else:
150 adj = adj.to(device)
151
152 return adj
153
154
155def feat_preprocess(features, feat_norm=None, device="cpu"):

Callers 15

attackMethod · 0.90
update_featuresMethod · 0.90
attackMethod · 0.90
update_featuresMethod · 0.90
attackMethod · 0.90
update_featuresMethod · 0.90
attackMethod · 0.90
update_featuresMethod · 0.90
attackMethod · 0.90
update_featuresMethod · 0.90
attackMethod · 0.90
modificationMethod · 0.90

Calls 2

adj_to_tensorFunction · 0.85
toMethod · 0.45

Tested by

no test coverage detected