Invoke g-SDDMM computation on the graph. Parameters ---------- graph : DGLGraph The input graph. func : dgl.function.BaseMessageFunction Built-in message function. Returns ------- dict[str, Tensor] Results from the g-SDDMM computation.
(graph, func)
| 271 | |
| 272 | |
| 273 | def invoke_gsddmm(graph, func): |
| 274 | """Invoke g-SDDMM computation on the graph. |
| 275 | |
| 276 | Parameters |
| 277 | ---------- |
| 278 | graph : DGLGraph |
| 279 | The input graph. |
| 280 | func : dgl.function.BaseMessageFunction |
| 281 | Built-in message function. |
| 282 | |
| 283 | Returns |
| 284 | ------- |
| 285 | dict[str, Tensor] |
| 286 | Results from the g-SDDMM computation. |
| 287 | """ |
| 288 | alldata = [graph.srcdata, graph.dstdata, graph.edata] |
| 289 | if isinstance(func, fn.BinaryMessageFunction): |
| 290 | x = alldata[func.lhs][func.lhs_field] |
| 291 | y = alldata[func.rhs][func.rhs_field] |
| 292 | op = getattr(ops, func.name) |
| 293 | if graph._graph.number_of_etypes() > 1: |
| 294 | lhs_target, _, rhs_target = func.name.split("_", 2) |
| 295 | x = data_dict_to_list(graph, x, func, lhs_target) |
| 296 | y = data_dict_to_list(graph, y, func, rhs_target) |
| 297 | z = op(graph, x, y) |
| 298 | else: |
| 299 | x = alldata[func.target][func.in_field] |
| 300 | op = getattr(ops, func.name) |
| 301 | if graph._graph.number_of_etypes() > 1: |
| 302 | # Convert to list as dict is unordered. |
| 303 | if func.name == "copy_u": |
| 304 | x = data_dict_to_list(graph, x, func, "u") |
| 305 | else: # "copy_e" |
| 306 | x = data_dict_to_list(graph, x, func, "e") |
| 307 | z = op(graph, x) |
| 308 | return {func.out_field: z} |
| 309 | |
| 310 | |
| 311 | def invoke_gspmm( |
no test coverage detected