| 89 | zeros(self.bias) |
| 90 | |
| 91 | def forward(self, x): |
| 92 | x = torch.matmul(x, self.weight) |
| 93 | x = x.reshape((x.size(0), self.heads, -1)) # NxHxD' |
| 94 | head_x = x.transpose(0, 1) # HxNxD' |
| 95 | a1 = torch.bmm(head_x, self.att_l).transpose(0, 1) # NxHx1 |
| 96 | a2 = torch.bmm(head_x, self.att_r).transpose(0, 1) # NxHx1 |
| 97 | self.g.ndata.update({'x': x, 'a1': a1, 'a2': a2}) |
| 98 | self.g.apply_edges(self.edge_attention) |
| 99 | self.edge_softmax() |
| 100 | self.g.update_all(fn.src_mul_edge('x', 'a', 'x'), fn.sum('x', 'x')) |
| 101 | x = self.g.ndata['x'] / self.g.ndata['z'] # NxHxD' |
| 102 | return x.view(-1, self.heads * self.out_channels) |
| 103 | |
| 104 | def edge_attention(self, edge): |
| 105 | a = F.leaky_relu(edge.src['a1'] + edge.dst['a2'], self.negative_slope) |