(gidxs, weights)
| 426 | |
| 427 | |
| 428 | def csrsum(gidxs, weights): |
| 429 | @tf.custom_gradient |
| 430 | def _lambda(*weights): |
| 431 | return csrsum_real(gidxs, weights) |
| 432 | |
| 433 | nrows, ncols, C_indptr, C_indices, C_eids, C_weights = _lambda(*weights) |
| 434 | num_vtypes = gidxs[0].number_of_ntypes() |
| 435 | gidxC = create_unitgraph_from_csr( |
| 436 | num_vtypes, |
| 437 | nrows.numpy(), |
| 438 | ncols.numpy(), |
| 439 | C_indptr, |
| 440 | C_indices, |
| 441 | C_eids, |
| 442 | ["coo", "csr", "csc"], |
| 443 | ) |
| 444 | return gidxC, C_weights |
| 445 | |
| 446 | |
| 447 | def csrmask_real(gidxA, A_weights, gidxB): |
nothing calls this directly
no test coverage detected