(self, graph, x)
| 134 | return torch.cat(op_embedding, dim=1).to(device) |
| 135 | |
| 136 | def preprocessing(self, graph, x): |
| 137 | print("Preprocessing...") |
| 138 | dataset_name = None |
| 139 | if self.dataset_name is not None: |
| 140 | adj_norm = ",".join(self.adj_norm) |
| 141 | dataset_name = f"{self.dataset_name}_{self.num_propagations}_{self.diffusion}_{adj_norm}.pt" |
| 142 | if os.path.exists(dataset_name): |
| 143 | return torch.load(dataset_name).to(x.device) |
| 144 | if graph.is_inductive(): |
| 145 | graph.train() |
| 146 | x_train = self._preprocessing(graph, x, drop_edge=True) |
| 147 | graph.eval() |
| 148 | x_all = self._preprocessing(graph, x, drop_edge=False) |
| 149 | train_nid = graph.train_nid |
| 150 | x_all[train_nid] = x_train[train_nid] |
| 151 | else: |
| 152 | x_all = self._preprocessing(graph, x, drop_edge=False) |
| 153 | |
| 154 | if dataset_name is not None: |
| 155 | torch.save(x_all.cpu(), dataset_name) |
| 156 | print("Preprocessing Done...") |
| 157 | return x_all |
| 158 | |
| 159 | def forward(self, graph): |
| 160 | if self.cache_x is None: |
no test coverage detected