(self, x, adj)
| 126 | |
| 127 | @torch.no_grad() |
| 128 | def inference(self, x, adj): |
| 129 | x = x.to(self._device) |
| 130 | origin_device = adj.device |
| 131 | adj = adj.to(self._device) |
| 132 | xs = [x] |
| 133 | for i in range(self.num_layers): |
| 134 | x = spmm(adj, x) |
| 135 | x = self.lins[i](x) |
| 136 | if i != self.num_layers - 1: |
| 137 | x = self.norms[i](x) |
| 138 | x = x.relu_() |
| 139 | xs.append(x) |
| 140 | adj = adj.to(origin_device) |
| 141 | return x, xs |
| 142 | |
| 143 | @torch.no_grad() |
| 144 | def inference_batch(self, x, test_loader): |
no test coverage detected