MCPcopy
hub / github.com/dmlc/dgl / invoke_gsddmm

Function invoke_gsddmm

python/dgl/core.py:273–308  ·  view source on GitHub ↗

Invoke g-SDDMM computation on the graph. Parameters ---------- graph : DGLGraph The input graph. func : dgl.function.BaseMessageFunction Built-in message function. Returns ------- dict[str, Tensor] Results from the g-SDDMM computation.

(graph, func)

Source from the content-addressed store, hash-verified

271
272
273def invoke_gsddmm(graph, func):
274 """Invoke g-SDDMM computation on the graph.
275
276 Parameters
277 ----------
278 graph : DGLGraph
279 The input graph.
280 func : dgl.function.BaseMessageFunction
281 Built-in message function.
282
283 Returns
284 -------
285 dict[str, Tensor]
286 Results from the g-SDDMM computation.
287 """
288 alldata = [graph.srcdata, graph.dstdata, graph.edata]
289 if isinstance(func, fn.BinaryMessageFunction):
290 x = alldata[func.lhs][func.lhs_field]
291 y = alldata[func.rhs][func.rhs_field]
292 op = getattr(ops, func.name)
293 if graph._graph.number_of_etypes() > 1:
294 lhs_target, _, rhs_target = func.name.split("_", 2)
295 x = data_dict_to_list(graph, x, func, lhs_target)
296 y = data_dict_to_list(graph, y, func, rhs_target)
297 z = op(graph, x, y)
298 else:
299 x = alldata[func.target][func.in_field]
300 op = getattr(ops, func.name)
301 if graph._graph.number_of_etypes() > 1:
302 # Convert to list as dict is unordered.
303 if func.name == "copy_u":
304 x = data_dict_to_list(graph, x, func, "u")
305 else: # "copy_e"
306 x = data_dict_to_list(graph, x, func, "e")
307 z = op(graph, x)
308 return {func.out_field: z}
309
310
311def invoke_gspmm(

Callers 1

message_passingFunction · 0.85

Calls 2

data_dict_to_listFunction · 0.85
number_of_etypesMethod · 0.80

Tested by

no test coverage detected