MCPcopy
hub / github.com/dmlc/dgl / message_passing

Function message_passing

python/dgl/core.py:372–425  ·  view source on GitHub ↗

Invoke message passing computation on the whole graph. Parameters ---------- g : DGLGraph The input graph. mfunc : callable or dgl.function.BuiltinFunction Message function. rfunc : callable or dgl.function.BuiltinFunction Reduce function. afunc : cal

(g, mfunc, rfunc, afunc)

Source from the content-addressed store, hash-verified

370
371
372def message_passing(g, mfunc, rfunc, afunc):
373 """Invoke message passing computation on the whole graph.
374
375 Parameters
376 ----------
377 g : DGLGraph
378 The input graph.
379 mfunc : callable or dgl.function.BuiltinFunction
380 Message function.
381 rfunc : callable or dgl.function.BuiltinFunction
382 Reduce function.
383 afunc : callable or dgl.function.BuiltinFunction
384 Apply function.
385
386 Returns
387 -------
388 dict[str, Tensor]
389 Results from the message passing computation.
390 """
391 if (
392 is_builtin(mfunc)
393 and is_builtin(rfunc)
394 and getattr(ops, "{}_{}".format(mfunc.name, rfunc.name), None)
395 is not None
396 ):
397 # invoke fused message passing
398 ndata = invoke_gspmm(g, mfunc, rfunc)
399 else:
400 # invoke message passing in two separate steps
401 # message phase
402 if is_builtin(mfunc):
403 msgdata = invoke_gsddmm(g, mfunc)
404 else:
405 orig_eid = g.edata.get(EID, None)
406 msgdata = invoke_edge_udf(
407 g, ALL, g.canonical_etypes[0], mfunc, orig_eid=orig_eid
408 )
409 # reduce phase
410 if is_builtin(rfunc):
411 msg = rfunc.msg_field
412 ndata = invoke_gspmm(g, fn.copy_e(msg, msg), rfunc, edata=msgdata)
413 else:
414 orig_nid = g.dstdata.get(NID, None)
415 ndata = invoke_udf_reduce(g, rfunc, msgdata, orig_nid=orig_nid)
416 # apply phase
417 if afunc is not None:
418 for k, v in g.dstdata.items(): # include original node features
419 if k not in ndata:
420 ndata[k] = v
421 orig_nid = g.dstdata.get(NID, None)
422 ndata = invoke_node_udf(
423 g, ALL, g.dsttypes[0], afunc, ndata=ndata, orig_nid=orig_nid
424 )
425 return ndata

Callers

nothing calls this directly

Calls 9

is_builtinFunction · 0.85
invoke_gspmmFunction · 0.85
invoke_gsddmmFunction · 0.85
invoke_edge_udfFunction · 0.85
invoke_udf_reduceFunction · 0.85
invoke_node_udfFunction · 0.85
formatMethod · 0.80
getMethod · 0.45
itemsMethod · 0.45

Tested by

no test coverage detected