Invoke the C API of segment_mm.
(A, B, out, seglen_A, b_trans=False)
| 434 | |
| 435 | |
| 436 | def _segment_mm(A, B, out, seglen_A, b_trans=False): |
| 437 | """Invoke the C API of segment_mm.""" |
| 438 | _CAPI_DGLKernelSEGMENTMM( |
| 439 | to_dgl_nd(A), |
| 440 | to_dgl_nd(B), |
| 441 | to_dgl_nd_for_write(out), |
| 442 | to_dgl_nd(seglen_A), |
| 443 | False, |
| 444 | b_trans, |
| 445 | ) |
| 446 | return out |
| 447 | |
| 448 | |
| 449 | def _segment_mm_backward_B(A, dC, dB, seglen): |
no test coverage detected