(self, norm="sym")
| 258 | self.__normed__ = norm |
| 259 | |
| 260 | def normalize_adj(self, norm="sym"): |
| 261 | if self.__normed__: |
| 262 | return |
| 263 | if self.weight is None or self.weight.shape[0] != self.col.shape[0]: |
| 264 | self.weight = torch.ones(self.num_edges, device=self.device) |
| 265 | |
| 266 | if norm == "sym": |
| 267 | self.weight = symmetric_normalization(self.num_nodes, self.row, self.col, self.weight) |
| 268 | elif norm == "row": |
| 269 | self.weight = row_normalization(self.num_nodes, self.row, self.col, self.weight) |
| 270 | elif norm == "col": |
| 271 | self.weight = row_normalization(self.num_nodes, self.col, self.row, self.weight) |
| 272 | else: |
| 273 | raise NotImplementedError |
| 274 | self.__normed__ = norm |
| 275 | |
| 276 | def convert_csr(self): |
| 277 | self._to_csr() |
no test coverage detected