Invoke user-defined edge function on the given edges. Parameters ---------- graph : DGLGraph The input graph. eid : Tensor The IDs of the edges to invoke UDF on. etype : (str, str, str) Edge type. func : callable The user-defined function.
(graph, eid, etype, func, *, orig_eid=None)
| 50 | |
| 51 | |
| 52 | def invoke_edge_udf(graph, eid, etype, func, *, orig_eid=None): |
| 53 | """Invoke user-defined edge function on the given edges. |
| 54 | |
| 55 | Parameters |
| 56 | ---------- |
| 57 | graph : DGLGraph |
| 58 | The input graph. |
| 59 | eid : Tensor |
| 60 | The IDs of the edges to invoke UDF on. |
| 61 | etype : (str, str, str) |
| 62 | Edge type. |
| 63 | func : callable |
| 64 | The user-defined function. |
| 65 | orig_eid : Tensor, optional |
| 66 | Original edge IDs. Useful if the input graph is an extracted subgraph. |
| 67 | |
| 68 | Returns |
| 69 | ------- |
| 70 | dict[str, Tensor] |
| 71 | Results from running the UDF. |
| 72 | """ |
| 73 | etid = graph.get_etype_id(etype) |
| 74 | stid, dtid = graph._graph.metagraph.find_edge(etid) |
| 75 | if is_all(eid): |
| 76 | u, v, eid = graph.edges(form="all") |
| 77 | edata = graph._edge_frames[etid] |
| 78 | else: |
| 79 | u, v = graph.find_edges(eid) |
| 80 | edata = graph._edge_frames[etid].subframe(eid) |
| 81 | if len(u) == 0: |
| 82 | dgl_warning( |
| 83 | "The input graph for the user-defined edge function " |
| 84 | "does not contain valid edges" |
| 85 | ) |
| 86 | srcdata = graph._node_frames[stid].subframe(u) |
| 87 | dstdata = graph._node_frames[dtid].subframe(v) |
| 88 | ebatch = EdgeBatch( |
| 89 | graph, |
| 90 | eid if orig_eid is None else orig_eid, |
| 91 | etype, |
| 92 | srcdata, |
| 93 | edata, |
| 94 | dstdata, |
| 95 | ) |
| 96 | return func(ebatch) |
| 97 | |
| 98 | |
| 99 | def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None): |
no test coverage detected