(self, g, X, pos_enc)
| 98 | ) |
| 99 | |
| 100 | def forward(self, g, X, pos_enc): |
| 101 | indices = torch.stack(g.edges()) |
| 102 | N = g.num_nodes() |
| 103 | A = dglsp.spmatrix(indices, shape=(N, N)) |
| 104 | h = self.atom_encoder(X) + self.pos_linear(pos_enc) |
| 105 | for layer in self.layers: |
| 106 | h = layer(A, h) |
| 107 | h = self.pooler(g, h) |
| 108 | |
| 109 | return self.predictor(h) |
| 110 | |
| 111 | |
| 112 | @torch.no_grad() |