MCPcopy
hub / github.com/dmlc/dgl / invoke_udf_reduce

Function invoke_udf_reduce

python/dgl/core.py:99–174  ·  view source on GitHub ↗

Invoke user-defined reduce function on all the nodes in the graph. It analyzes the graph, groups nodes by their degrees and applies the UDF on each group -- a strategy called *degree-bucketing*. Parameters ---------- graph : DGLGraph The input graph. func : callable

(graph, func, msgdata, *, orig_nid=None)

Source from the content-addressed store, hash-verified

97
98
99def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None):
100 """Invoke user-defined reduce function on all the nodes in the graph.
101
102 It analyzes the graph, groups nodes by their degrees and applies the UDF on each
103 group -- a strategy called *degree-bucketing*.
104
105 Parameters
106 ----------
107 graph : DGLGraph
108 The input graph.
109 func : callable
110 The user-defined function.
111 msgdata : dict[str, Tensor]
112 Message data.
113 orig_nid : Tensor, optional
114 Original node IDs. Useful if the input graph is an extracted subgraph.
115
116 Returns
117 -------
118 dict[str, Tensor]
119 Results from running the UDF.
120 """
121 degs = graph.in_degrees()
122 nodes = graph.dstnodes()
123 if orig_nid is None:
124 orig_nid = nodes
125 ntype = graph.dsttypes[0]
126 ntid = graph.get_ntype_id_from_dst(ntype)
127 dstdata = graph._node_frames[ntid]
128 msgdata = Frame(msgdata)
129
130 # degree bucketing
131 unique_degs, bucketor = _bucketing(degs)
132 bkt_rsts = []
133 bkt_nodes = []
134 for deg, node_bkt, orig_nid_bkt in zip(
135 unique_degs, bucketor(nodes), bucketor(orig_nid)
136 ):
137 if deg == 0:
138 # skip reduce function for zero-degree nodes
139 continue
140 bkt_nodes.append(node_bkt)
141 ndata_bkt = dstdata.subframe(node_bkt)
142
143 # order the incoming edges per node by edge ID
144 eid_bkt = F.zerocopy_to_numpy(graph.in_edges(node_bkt, form="eid"))
145 assert len(eid_bkt) == deg * len(node_bkt)
146 eid_bkt = np.sort(eid_bkt.reshape((len(node_bkt), deg)), 1)
147 eid_bkt = F.zerocopy_from_numpy(eid_bkt.flatten())
148
149 msgdata_bkt = msgdata.subframe(eid_bkt)
150 # reshape all msg tensors to (num_nodes_bkt, degree, feat_size)
151 maildata = {}
152 for k, msg in msgdata_bkt.items():
153 newshape = (len(node_bkt), deg) + F.shape(msg)[1:]
154 maildata[k] = F.reshape(msg, newshape)
155 # invoke udf
156 nbatch = NodeBatch(graph, orig_nid_bkt, ntype, ndata_bkt, msgs=maildata)

Callers 1

message_passingFunction · 0.85

Calls 15

subframeMethod · 0.95
update_rowMethod · 0.95
FrameClass · 0.85
_bucketingFunction · 0.85
bucketorFunction · 0.85
NodeBatchClass · 0.85
dstnodesMethod · 0.80
get_ntype_id_from_dstMethod · 0.80
appendMethod · 0.80
flattenMethod · 0.80
funcFunction · 0.50
in_degreesMethod · 0.45

Tested by

no test coverage detected