Compute the Encoder part of Set Transformer. Parameters ---------- graph : DGLGraph The input graph. feat : torch.Tensor The input feature with shape :math:`(N, D)`, where :math:`N` is the number of nodes in the graph.
(self, graph, feat)
| 1143 | self.layers = nn.ModuleList(layers) |
| 1144 | |
| 1145 | def forward(self, graph, feat): |
| 1146 | """ |
| 1147 | Compute the Encoder part of Set Transformer. |
| 1148 | |
| 1149 | Parameters |
| 1150 | ---------- |
| 1151 | graph : DGLGraph |
| 1152 | The input graph. |
| 1153 | feat : torch.Tensor |
| 1154 | The input feature with shape :math:`(N, D)`, where :math:`N` is the |
| 1155 | number of nodes in the graph. |
| 1156 | |
| 1157 | Returns |
| 1158 | ------- |
| 1159 | torch.Tensor |
| 1160 | The output feature with shape :math:`(N, D)`. |
| 1161 | """ |
| 1162 | lengths = graph.batch_num_nodes() |
| 1163 | for layer in self.layers: |
| 1164 | feat = layer(feat, lengths) |
| 1165 | return feat |
| 1166 | |
| 1167 | |
| 1168 | class SetTransformerDecoder(nn.Module): |
nothing calls this directly
no test coverage detected