MCPcopy
hub / github.com/pyg-team/pytorch_geometric / forward

Method forward

benchmark/runtime/dgl/gat.py:91–102  ·  view source on GitHub ↗
(self, x)

Source from the content-addressed store, hash-verified

89 zeros(self.bias)
90
91 def forward(self, x):
92 x = torch.matmul(x, self.weight)
93 x = x.reshape((x.size(0), self.heads, -1)) # NxHxD'
94 head_x = x.transpose(0, 1) # HxNxD'
95 a1 = torch.bmm(head_x, self.att_l).transpose(0, 1) # NxHx1
96 a2 = torch.bmm(head_x, self.att_r).transpose(0, 1) # NxHx1
97 self.g.ndata.update({'x': x, 'a1': a1, 'a2': a2})
98 self.g.apply_edges(self.edge_attention)
99 self.edge_softmax()
100 self.g.update_all(fn.src_mul_edge('x', 'a', 'x'), fn.sum('x', 'x'))
101 x = self.g.ndata['x'] / self.g.ndata['z'] # NxHxD'
102 return x.view(-1, self.heads * self.out_channels)
103
104 def edge_attention(self, edge):
105 a = F.leaky_relu(edge.src['a1'] + edge.dst['a2'], self.negative_slope)

Callers

nothing calls this directly

Calls 6

edge_softmaxMethod · 0.95
sumMethod · 0.80
viewMethod · 0.80
matmulMethod · 0.45
sizeMethod · 0.45
updateMethod · 0.45

Tested by

no test coverage detected