The class to represent a batch of nodes. Parameters ---------- graph : DGLGraph Graph object. nodes : Tensor Node ids. ntype : str, optional The node type of this node batch, data : dict[str, Tensor] Node feature data. msgs : dict[str, Ten
| 239 | |
| 240 | |
| 241 | class NodeBatch(object): |
| 242 | """The class to represent a batch of nodes. |
| 243 | |
| 244 | Parameters |
| 245 | ---------- |
| 246 | graph : DGLGraph |
| 247 | Graph object. |
| 248 | nodes : Tensor |
| 249 | Node ids. |
| 250 | ntype : str, optional |
| 251 | The node type of this node batch, |
| 252 | data : dict[str, Tensor] |
| 253 | Node feature data. |
| 254 | msgs : dict[str, Tensor], optional |
| 255 | Messages data. |
| 256 | """ |
| 257 | |
| 258 | def __init__(self, graph, nodes, ntype, data, msgs=None): |
| 259 | self._graph = graph |
| 260 | self._nodes = nodes |
| 261 | self._ntype = ntype |
| 262 | self._data = data |
| 263 | self._msgs = msgs |
| 264 | |
| 265 | @property |
| 266 | def data(self): |
| 267 | """Return a view of the node features for the nodes in the batch. |
| 268 | |
| 269 | Examples |
| 270 | -------- |
| 271 | The following example uses PyTorch backend. |
| 272 | |
| 273 | >>> import dgl |
| 274 | >>> import torch |
| 275 | |
| 276 | >>> # Instantiate a graph and set a feature 'h'. |
| 277 | >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0]))) |
| 278 | >>> g.ndata['h'] = torch.ones(2, 1) |
| 279 | |
| 280 | >>> # Define a UDF that computes the sum of the messages received and |
| 281 | >>> # the original feature for each node. |
| 282 | >>> def node_udf(nodes): |
| 283 | >>> # nodes.data['h'] is a tensor of shape (N, 1), |
| 284 | >>> # nodes.mailbox['m'] is a tensor of shape (N, D, 1), |
| 285 | >>> # where N is the number of nodes in the batch, D is the number |
| 286 | >>> # of messages received per node for this node batch. |
| 287 | >>> return {'h': nodes.data['h'] + nodes.mailbox['m'].sum(1)} |
| 288 | |
| 289 | >>> # Use node UDF in message passing. |
| 290 | >>> import dgl.function as fn |
| 291 | >>> g.update_all(fn.copy_u('h', 'm'), node_udf) |
| 292 | >>> g.ndata['h'] |
| 293 | tensor([[2.], |
| 294 | [3.]]) |
| 295 | """ |
| 296 | return self._data |
| 297 | |
| 298 | @property |
no outgoing calls
no test coverage detected