| 238 | self.__symmetric__ = False |
| 239 | |
| 240 | def generate_normalization(self, norm="sym"): |
| 241 | if self.__normed__: |
| 242 | return |
| 243 | degrees = (self.row_ptr[1:] - self.row_ptr[:-1]).float() |
| 244 | if norm == "sym": |
| 245 | edge_norm = torch.pow(degrees, -0.5).to(self.device) |
| 246 | edge_norm[torch.isinf(edge_norm)] = 0 |
| 247 | self.__out_norm__ = self.__in_norm__ = edge_norm.view(-1, 1) |
| 248 | elif norm == "row": |
| 249 | edge_norm = torch.pow(degrees, -1).to(self.device) |
| 250 | edge_norm[torch.isinf(edge_norm)] = 0 |
| 251 | self.__out_norm__ = None |
| 252 | self.__in_norm__ = edge_norm.view(-1, 1) |
| 253 | elif norm == "col": |
| 254 | self.row, _, _ = csr2coo(self.row_ptr, self.col, self.weight) |
| 255 | self.weight = row_normalization(self.num_nodes, self.col, self.row, self.weight) |
| 256 | else: |
| 257 | raise NotImplementedError |
| 258 | self.__normed__ = norm |
| 259 | |
| 260 | def normalize_adj(self, norm="sym"): |
| 261 | if self.__normed__: |