(self, x)
| 42 | return {'x': x} |
| 43 | |
| 44 | def forward(self, x): |
| 45 | x = torch.mm(x, self.weight).view(-1, self.heads, self.out_channels) |
| 46 | self.g.ndata['x'] = x |
| 47 | self.g.update_all(self.gat_msg, self.gat_reduce) |
| 48 | x = self.g.ndata.pop('x') |
| 49 | x = x.view(-1, self.heads * self.out_channels) |
| 50 | x = x + self.bias |
| 51 | return x |
| 52 | |
| 53 | |
| 54 | class GAT(torch.nn.Module): |