(self, graph, x)
| 119 | return torch.cat(op_embedding, dim=1).to(device) |
| 120 | |
| 121 | def preprocessing(self, graph, x): |
| 122 | print("Preprocessing...") |
| 123 | dataset_name = None |
| 124 | if dataset_name is not None: |
| 125 | adj_norm = ",".join(self.adj_norm) |
| 126 | dataset_name = f"{self.dataset_name}_{self.num_propagations}_{self.diffusion}_{adj_norm}.pt" |
| 127 | if os.path.exists(dataset_name): |
| 128 | return torch.load(dataset_name).to(x.device) |
| 129 | if graph.is_inductive(): |
| 130 | graph.train() |
| 131 | x_train = self._preprocessing(graph, x, drop_edge=True) |
| 132 | graph.eval() |
| 133 | x_all = self._preprocessing(graph, x, drop_edge=False) |
| 134 | train_nid = graph.train_nid |
| 135 | x_all[train_nid] = x_train[train_nid] |
| 136 | else: |
| 137 | x_all = self._preprocessing(graph, x, drop_edge=False) |
| 138 | |
| 139 | if dataset_name is not None: |
| 140 | torch.save(x_all.cpu(), dataset_name) |
| 141 | print("Preprocessing Done...") |
| 142 | return x_all |
| 143 | |
| 144 | def reset_parameters(self): |
| 145 | self.mlp.nn.reset_parameters() |
no test coverage detected