MCPcopy
hub / github.com/dmlc/dgl / invoke_gspmm

Function invoke_gspmm

python/dgl/core.py:311–369  ·  view source on GitHub ↗

Invoke g-SPMM computation on the graph. Parameters ---------- graph : DGLGraph The input graph. mfunc : dgl.function.BaseMessageFunction Built-in message function. rfunc : dgl.function.BaseReduceFunction Built-in reduce function. srcdata : dict[str,

(
    graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None
)

Source from the content-addressed store, hash-verified

309
310
311def invoke_gspmm(
312 graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None
313):
314 """Invoke g-SPMM computation on the graph.
315
316 Parameters
317 ----------
318 graph : DGLGraph
319 The input graph.
320 mfunc : dgl.function.BaseMessageFunction
321 Built-in message function.
322 rfunc : dgl.function.BaseReduceFunction
323 Built-in reduce function.
324 srcdata : dict[str, Tensor], optional
325 Source node feature data. If not provided, it use ``graph.srcdata``.
326 dstdata : dict[str, Tensor], optional
327 Destination node feature data. If not provided, it use ``graph.dstdata``.
328 edata : dict[str, Tensor], optional
329 Edge feature data. If not provided, it use ``graph.edata``.
330
331 Returns
332 -------
333 dict[str, Tensor]
334 Results from the g-SPMM computation.
335 """
336 # sanity check
337 if mfunc.out_field != rfunc.msg_field:
338 raise DGLError(
339 "Invalid message ({}) and reduce ({}) function pairs."
340 " The output field of the message function must be equal to the"
341 " message field of the reduce function.".format(mfunc, rfunc)
342 )
343 if edata is None:
344 edata = graph.edata
345 if srcdata is None:
346 srcdata = graph.srcdata
347 if dstdata is None:
348 dstdata = graph.dstdata
349 alldata = [srcdata, dstdata, edata]
350
351 if isinstance(mfunc, fn.BinaryMessageFunction):
352 x = alldata[mfunc.lhs][mfunc.lhs_field]
353 y = alldata[mfunc.rhs][mfunc.rhs_field]
354 op = getattr(ops, "{}_{}".format(mfunc.name, rfunc.name))
355 if graph._graph.number_of_etypes() > 1:
356 lhs_target, _, rhs_target = mfunc.name.split("_", 2)
357 x = data_dict_to_list(graph, x, mfunc, lhs_target)
358 y = data_dict_to_list(graph, y, mfunc, rhs_target)
359 z = op(graph, x, y)
360 else:
361 x = alldata[mfunc.target][mfunc.in_field]
362 op = getattr(ops, "{}_{}".format(mfunc.name, rfunc.name))
363 if graph._graph.number_of_etypes() > 1 and not isinstance(x, tuple):
364 if mfunc.name == "copy_u":
365 x = data_dict_to_list(graph, x, mfunc, "u")
366 else: # "copy_e"
367 x = data_dict_to_list(graph, x, mfunc, "e")
368 z = op(graph, x)

Callers 1

message_passingFunction · 0.85

Calls 4

DGLErrorClass · 0.85
data_dict_to_listFunction · 0.85
formatMethod · 0.80
number_of_etypesMethod · 0.80

Tested by

no test coverage detected