Invoke g-SPMM computation on the graph. Parameters ---------- graph : DGLGraph The input graph. mfunc : dgl.function.BaseMessageFunction Built-in message function. rfunc : dgl.function.BaseReduceFunction Built-in reduce function. srcdata : dict[str,
(
graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None
)
| 309 | |
| 310 | |
| 311 | def invoke_gspmm( |
| 312 | graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None |
| 313 | ): |
| 314 | """Invoke g-SPMM computation on the graph. |
| 315 | |
| 316 | Parameters |
| 317 | ---------- |
| 318 | graph : DGLGraph |
| 319 | The input graph. |
| 320 | mfunc : dgl.function.BaseMessageFunction |
| 321 | Built-in message function. |
| 322 | rfunc : dgl.function.BaseReduceFunction |
| 323 | Built-in reduce function. |
| 324 | srcdata : dict[str, Tensor], optional |
| 325 | Source node feature data. If not provided, it use ``graph.srcdata``. |
| 326 | dstdata : dict[str, Tensor], optional |
| 327 | Destination node feature data. If not provided, it use ``graph.dstdata``. |
| 328 | edata : dict[str, Tensor], optional |
| 329 | Edge feature data. If not provided, it use ``graph.edata``. |
| 330 | |
| 331 | Returns |
| 332 | ------- |
| 333 | dict[str, Tensor] |
| 334 | Results from the g-SPMM computation. |
| 335 | """ |
| 336 | # sanity check |
| 337 | if mfunc.out_field != rfunc.msg_field: |
| 338 | raise DGLError( |
| 339 | "Invalid message ({}) and reduce ({}) function pairs." |
| 340 | " The output field of the message function must be equal to the" |
| 341 | " message field of the reduce function.".format(mfunc, rfunc) |
| 342 | ) |
| 343 | if edata is None: |
| 344 | edata = graph.edata |
| 345 | if srcdata is None: |
| 346 | srcdata = graph.srcdata |
| 347 | if dstdata is None: |
| 348 | dstdata = graph.dstdata |
| 349 | alldata = [srcdata, dstdata, edata] |
| 350 | |
| 351 | if isinstance(mfunc, fn.BinaryMessageFunction): |
| 352 | x = alldata[mfunc.lhs][mfunc.lhs_field] |
| 353 | y = alldata[mfunc.rhs][mfunc.rhs_field] |
| 354 | op = getattr(ops, "{}_{}".format(mfunc.name, rfunc.name)) |
| 355 | if graph._graph.number_of_etypes() > 1: |
| 356 | lhs_target, _, rhs_target = mfunc.name.split("_", 2) |
| 357 | x = data_dict_to_list(graph, x, mfunc, lhs_target) |
| 358 | y = data_dict_to_list(graph, y, mfunc, rhs_target) |
| 359 | z = op(graph, x, y) |
| 360 | else: |
| 361 | x = alldata[mfunc.target][mfunc.in_field] |
| 362 | op = getattr(ops, "{}_{}".format(mfunc.name, rfunc.name)) |
| 363 | if graph._graph.number_of_etypes() > 1 and not isinstance(x, tuple): |
| 364 | if mfunc.name == "copy_u": |
| 365 | x = data_dict_to_list(graph, x, mfunc, "u") |
| 366 | else: # "copy_e" |
| 367 | x = data_dict_to_list(graph, x, mfunc, "e") |
| 368 | z = op(graph, x) |
no test coverage detected