Invoke message passing computation on the whole graph. Parameters ---------- g : DGLGraph The input graph. mfunc : callable or dgl.function.BuiltinFunction Message function. rfunc : callable or dgl.function.BuiltinFunction Reduce function. afunc : cal
(g, mfunc, rfunc, afunc)
| 370 | |
| 371 | |
| 372 | def message_passing(g, mfunc, rfunc, afunc): |
| 373 | """Invoke message passing computation on the whole graph. |
| 374 | |
| 375 | Parameters |
| 376 | ---------- |
| 377 | g : DGLGraph |
| 378 | The input graph. |
| 379 | mfunc : callable or dgl.function.BuiltinFunction |
| 380 | Message function. |
| 381 | rfunc : callable or dgl.function.BuiltinFunction |
| 382 | Reduce function. |
| 383 | afunc : callable or dgl.function.BuiltinFunction |
| 384 | Apply function. |
| 385 | |
| 386 | Returns |
| 387 | ------- |
| 388 | dict[str, Tensor] |
| 389 | Results from the message passing computation. |
| 390 | """ |
| 391 | if ( |
| 392 | is_builtin(mfunc) |
| 393 | and is_builtin(rfunc) |
| 394 | and getattr(ops, "{}_{}".format(mfunc.name, rfunc.name), None) |
| 395 | is not None |
| 396 | ): |
| 397 | # invoke fused message passing |
| 398 | ndata = invoke_gspmm(g, mfunc, rfunc) |
| 399 | else: |
| 400 | # invoke message passing in two separate steps |
| 401 | # message phase |
| 402 | if is_builtin(mfunc): |
| 403 | msgdata = invoke_gsddmm(g, mfunc) |
| 404 | else: |
| 405 | orig_eid = g.edata.get(EID, None) |
| 406 | msgdata = invoke_edge_udf( |
| 407 | g, ALL, g.canonical_etypes[0], mfunc, orig_eid=orig_eid |
| 408 | ) |
| 409 | # reduce phase |
| 410 | if is_builtin(rfunc): |
| 411 | msg = rfunc.msg_field |
| 412 | ndata = invoke_gspmm(g, fn.copy_e(msg, msg), rfunc, edata=msgdata) |
| 413 | else: |
| 414 | orig_nid = g.dstdata.get(NID, None) |
| 415 | ndata = invoke_udf_reduce(g, rfunc, msgdata, orig_nid=orig_nid) |
| 416 | # apply phase |
| 417 | if afunc is not None: |
| 418 | for k, v in g.dstdata.items(): # include original node features |
| 419 | if k not in ndata: |
| 420 | ndata[k] = v |
| 421 | orig_nid = g.dstdata.get(NID, None) |
| 422 | ndata = invoke_node_udf( |
| 423 | g, ALL, g.dsttypes[0], afunc, ndata=ndata, orig_nid=orig_nid |
| 424 | ) |
| 425 | return ndata |
nothing calls this directly
no test coverage detected