MCPcopy
hub / github.com/dmlc/dgl / csrsum

Function csrsum

python/dgl/backend/tensorflow/sparse.py:428–444  ·  view source on GitHub ↗
(gidxs, weights)

Source from the content-addressed store, hash-verified

426
427
428def csrsum(gidxs, weights):
429 @tf.custom_gradient
430 def _lambda(*weights):
431 return csrsum_real(gidxs, weights)
432
433 nrows, ncols, C_indptr, C_indices, C_eids, C_weights = _lambda(*weights)
434 num_vtypes = gidxs[0].number_of_ntypes()
435 gidxC = create_unitgraph_from_csr(
436 num_vtypes,
437 nrows.numpy(),
438 ncols.numpy(),
439 C_indptr,
440 C_indices,
441 C_eids,
442 ["coo", "csr", "csc"],
443 )
444 return gidxC, C_weights
445
446
447def csrmask_real(gidxA, A_weights, gidxB):

Callers

nothing calls this directly

Calls 3

_lambdaFunction · 0.85
number_of_ntypesMethod · 0.80

Tested by

no test coverage detected