Extract the input node features and output node features of the given nodes from :attr:`graph` and return them in frames ready for a block. Note that this function does not perform actual tensor memory copy but using `Frame.subframe` to get the features. If :attr:`srcnodes` or :attr:`ds
(graph, srcnodes, dstnodes)
| 881 | |
| 882 | |
| 883 | def extract_node_subframes_for_block(graph, srcnodes, dstnodes): |
| 884 | """Extract the input node features and output node features of the given nodes from |
| 885 | :attr:`graph` and return them in frames ready for a block. |
| 886 | |
| 887 | Note that this function does not perform actual tensor memory copy but using `Frame.subframe` |
| 888 | to get the features. If :attr:`srcnodes` or :attr:`dstnodes` is None, it performs a |
| 889 | shallow copy of the original node frames that only copies the dictionary structure |
| 890 | but not the tensor contents. |
| 891 | |
| 892 | Parameters |
| 893 | ---------- |
| 894 | graph : DGLGraph |
| 895 | The graph to extract features from. |
| 896 | srcnodes : list[Tensor] |
| 897 | Input node IDs. The list length must be equal to the number of node types |
| 898 | in the graph. The returned frames store the node IDs in the ``dgl.NID`` field. |
| 899 | dstnodes : list[Tensor] |
| 900 | Output node IDs. The list length must be equal to the number of node types |
| 901 | in the graph. The returned frames store the node IDs in the ``dgl.NID`` field. |
| 902 | |
| 903 | Returns |
| 904 | ------- |
| 905 | list[Frame] |
| 906 | Extracted node frames. |
| 907 | """ |
| 908 | node_frames = [] |
| 909 | for i, ind_nodes in enumerate(srcnodes): |
| 910 | subf = graph._node_frames[i].subframe(ind_nodes) |
| 911 | subf[NID] = ind_nodes |
| 912 | node_frames.append(subf) |
| 913 | for i, ind_nodes in enumerate(dstnodes): |
| 914 | subf = graph._node_frames[i].subframe(ind_nodes) |
| 915 | subf[NID] = ind_nodes |
| 916 | node_frames.append(subf) |
| 917 | return node_frames |
| 918 | |
| 919 | |
| 920 | def extract_edge_subframes(graph, edges_or_device, store_ids=True): |