Invoke the C API of the backward of segment_mm on B.
(A, dC, dB, seglen)
| 447 | |
| 448 | |
| 449 | def _segment_mm_backward_B(A, dC, dB, seglen): |
| 450 | """Invoke the C API of the backward of segment_mm on B.""" |
| 451 | _CAPI_DGLKernelSEGMENTMMBackwardB( |
| 452 | to_dgl_nd(A), to_dgl_nd(dC), to_dgl_nd_for_write(dB), to_dgl_nd(seglen) |
| 453 | ) |
| 454 | return dB |
| 455 | |
| 456 | |
| 457 | def _gather_mm(A, B, out, idx_a=None, idx_b=None): |
no test coverage detected