Invoke user-defined reduce function on all the nodes in the graph. It analyzes the graph, groups nodes by their degrees and applies the UDF on each group -- a strategy called *degree-bucketing*. Parameters ---------- graph : DGLGraph The input graph. func : callable
(graph, func, msgdata, *, orig_nid=None)
| 97 | |
| 98 | |
| 99 | def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None): |
| 100 | """Invoke user-defined reduce function on all the nodes in the graph. |
| 101 | |
| 102 | It analyzes the graph, groups nodes by their degrees and applies the UDF on each |
| 103 | group -- a strategy called *degree-bucketing*. |
| 104 | |
| 105 | Parameters |
| 106 | ---------- |
| 107 | graph : DGLGraph |
| 108 | The input graph. |
| 109 | func : callable |
| 110 | The user-defined function. |
| 111 | msgdata : dict[str, Tensor] |
| 112 | Message data. |
| 113 | orig_nid : Tensor, optional |
| 114 | Original node IDs. Useful if the input graph is an extracted subgraph. |
| 115 | |
| 116 | Returns |
| 117 | ------- |
| 118 | dict[str, Tensor] |
| 119 | Results from running the UDF. |
| 120 | """ |
| 121 | degs = graph.in_degrees() |
| 122 | nodes = graph.dstnodes() |
| 123 | if orig_nid is None: |
| 124 | orig_nid = nodes |
| 125 | ntype = graph.dsttypes[0] |
| 126 | ntid = graph.get_ntype_id_from_dst(ntype) |
| 127 | dstdata = graph._node_frames[ntid] |
| 128 | msgdata = Frame(msgdata) |
| 129 | |
| 130 | # degree bucketing |
| 131 | unique_degs, bucketor = _bucketing(degs) |
| 132 | bkt_rsts = [] |
| 133 | bkt_nodes = [] |
| 134 | for deg, node_bkt, orig_nid_bkt in zip( |
| 135 | unique_degs, bucketor(nodes), bucketor(orig_nid) |
| 136 | ): |
| 137 | if deg == 0: |
| 138 | # skip reduce function for zero-degree nodes |
| 139 | continue |
| 140 | bkt_nodes.append(node_bkt) |
| 141 | ndata_bkt = dstdata.subframe(node_bkt) |
| 142 | |
| 143 | # order the incoming edges per node by edge ID |
| 144 | eid_bkt = F.zerocopy_to_numpy(graph.in_edges(node_bkt, form="eid")) |
| 145 | assert len(eid_bkt) == deg * len(node_bkt) |
| 146 | eid_bkt = np.sort(eid_bkt.reshape((len(node_bkt), deg)), 1) |
| 147 | eid_bkt = F.zerocopy_from_numpy(eid_bkt.flatten()) |
| 148 | |
| 149 | msgdata_bkt = msgdata.subframe(eid_bkt) |
| 150 | # reshape all msg tensors to (num_nodes_bkt, degree, feat_size) |
| 151 | maildata = {} |
| 152 | for k, msg in msgdata_bkt.items(): |
| 153 | newshape = (len(node_bkt), deg) + F.shape(msg)[1:] |
| 154 | maildata[k] = F.reshape(msg, newshape) |
| 155 | # invoke udf |
| 156 | nbatch = NodeBatch(graph, orig_nid_bkt, ntype, ndata_bkt, msgs=maildata) |
no test coverage detected