Compute the decoder part of Set Transformer. Parameters ---------- graph : DGLGraph The input graph. feat : torch.Tensor The input feature with shape :math:`(N, D)`, where :math:`N` is the number of nodes in the graph, and
(self, graph, feat)
| 1277 | self.layers = nn.ModuleList(layers) |
| 1278 | |
| 1279 | def forward(self, graph, feat): |
| 1280 | """ |
| 1281 | Compute the decoder part of Set Transformer. |
| 1282 | |
| 1283 | Parameters |
| 1284 | ---------- |
| 1285 | graph : DGLGraph |
| 1286 | The input graph. |
| 1287 | feat : torch.Tensor |
| 1288 | The input feature with shape :math:`(N, D)`, where :math:`N` is the |
| 1289 | number of nodes in the graph, and :math:`D` means the size of features. |
| 1290 | |
| 1291 | Returns |
| 1292 | ------- |
| 1293 | torch.Tensor |
| 1294 | The output feature with shape :math:`(B, D)`, where :math:`B` refers to |
| 1295 | the batch size. |
| 1296 | """ |
| 1297 | len_pma = graph.batch_num_nodes() |
| 1298 | len_sab = [self.k] * graph.batch_size |
| 1299 | feat = self.pma(feat, len_pma) |
| 1300 | for layer in self.layers: |
| 1301 | feat = layer(feat, len_sab) |
| 1302 | return feat.view(graph.batch_size, self.k * self.d_model) |
| 1303 | |
| 1304 | |
| 1305 | class WeightAndSum(nn.Module): |
nothing calls this directly
no test coverage detected