r"""Backward phase of segment reduction (for 'min'/'max' reduction). It computes the gradient of input feature given output gradient of the segment reduction result. Parameters ---------- feat : Tensor The output gradient arg : Tensor The ArgMin/Max tensor p
(feat, arg, m)
| 770 | |
| 771 | |
| 772 | def _bwd_segment_cmp(feat, arg, m): |
| 773 | r"""Backward phase of segment reduction (for 'min'/'max' reduction). |
| 774 | |
| 775 | It computes the gradient of input feature given output gradient of |
| 776 | the segment reduction result. |
| 777 | |
| 778 | Parameters |
| 779 | ---------- |
| 780 | feat : Tensor |
| 781 | The output gradient |
| 782 | arg : Tensor |
| 783 | The ArgMin/Max tensor produced by segment_reduce op. |
| 784 | m : int |
| 785 | The length of input gradients' first dimension. |
| 786 | |
| 787 | Returns |
| 788 | ------- |
| 789 | Tensor |
| 790 | The input gradient. |
| 791 | """ |
| 792 | out_shp = (m,) + F.shape(feat)[1:] |
| 793 | ctx = F.context(feat) |
| 794 | dtype = F.dtype(feat) |
| 795 | out = F.zeros(out_shp, dtype, ctx) |
| 796 | _CAPI_DGLKernelBwdSegmentCmp( |
| 797 | to_dgl_nd(feat), to_dgl_nd(arg), to_dgl_nd_for_write(out) |
| 798 | ) |
| 799 | return out |
| 800 | |
| 801 | |
| 802 | def _csrmm(A, A_weights, B, B_weights, num_vtypes): |
no test coverage detected