MCPcopy Index your code
hub / github.com/dmlc/dgl / forward

Method forward

python/dgl/nn/pytorch/glob.py:1279–1302  ·  view source on GitHub ↗

Compute the decoder part of Set Transformer. Parameters ---------- graph : DGLGraph The input graph. feat : torch.Tensor The input feature with shape :math:`(N, D)`, where :math:`N` is the number of nodes in the graph, and

(self, graph, feat)

Source from the content-addressed store, hash-verified

1277 self.layers = nn.ModuleList(layers)
1278
1279 def forward(self, graph, feat):
1280 """
1281 Compute the decoder part of Set Transformer.
1282
1283 Parameters
1284 ----------
1285 graph : DGLGraph
1286 The input graph.
1287 feat : torch.Tensor
1288 The input feature with shape :math:`(N, D)`, where :math:`N` is the
1289 number of nodes in the graph, and :math:`D` means the size of features.
1290
1291 Returns
1292 -------
1293 torch.Tensor
1294 The output feature with shape :math:`(B, D)`, where :math:`B` refers to
1295 the batch size.
1296 """
1297 len_pma = graph.batch_num_nodes()
1298 len_sab = [self.k] * graph.batch_size
1299 feat = self.pma(feat, len_pma)
1300 for layer in self.layers:
1301 feat = layer(feat, len_sab)
1302 return feat.view(graph.batch_size, self.k * self.d_model)
1303
1304
1305class WeightAndSum(nn.Module):

Callers

nothing calls this directly

Calls 1

batch_num_nodesMethod · 0.80

Tested by

no test coverage detected