r"""Segment reduction operator. It aggregates the value tensor along the first dimension by segments. The argument ``offsets`` specifies the start offset of each segment (and the upper bound of the last segment). Zero-length segments are allowed. .. math:: y_i = \Phi_{j=\math
(op, feat, offsets)
| 639 | |
| 640 | |
| 641 | def _segment_reduce(op, feat, offsets): |
| 642 | r"""Segment reduction operator. |
| 643 | |
| 644 | It aggregates the value tensor along the first dimension by segments. |
| 645 | The argument ``offsets`` specifies the start offset of each segment (and |
| 646 | the upper bound of the last segment). Zero-length segments are allowed. |
| 647 | |
| 648 | .. math:: |
| 649 | y_i = \Phi_{j=\mathrm{offsets}_i}^{\mathrm{offsets}_{i+1}-1} x_j |
| 650 | |
| 651 | where :math:`\Phi` is the reduce operator. |
| 652 | |
| 653 | Parameters |
| 654 | ---------- |
| 655 | op : str |
| 656 | Aggregation method. Can be ``sum``, ``max``, ``min``. |
| 657 | x : Tensor |
| 658 | Value to aggregate. |
| 659 | offsets : Tensor |
| 660 | The start offsets of segments. |
| 661 | |
| 662 | Returns |
| 663 | ------- |
| 664 | tuple(Tensor) |
| 665 | The first tensor correspond to aggregated tensor of shape |
| 666 | ``(len(seglen), value.shape[1:])``, and the second tensor records |
| 667 | the argmin/max at each position for computing gradients. |
| 668 | |
| 669 | Notes |
| 670 | ----- |
| 671 | This function does not handle gradients. |
| 672 | """ |
| 673 | n = F.shape(offsets)[0] - 1 |
| 674 | out_shp = (n,) + F.shape(feat)[1:] |
| 675 | ctx = F.context(feat) |
| 676 | dtype = F.dtype(feat) |
| 677 | idtype = F.dtype(offsets) |
| 678 | out = F.zeros(out_shp, dtype, ctx) |
| 679 | arg = None |
| 680 | if op in ["min", "max"]: |
| 681 | arg = F.zeros(out_shp, idtype, ctx) |
| 682 | arg_nd = to_dgl_nd_for_write(arg) |
| 683 | _CAPI_DGLKernelSegmentReduce( |
| 684 | op, |
| 685 | to_dgl_nd(feat), |
| 686 | to_dgl_nd(offsets), |
| 687 | to_dgl_nd_for_write(out), |
| 688 | arg_nd, |
| 689 | ) |
| 690 | arg = None if arg is None else F.zerocopy_from_dgl_ndarray(arg_nd) |
| 691 | return out, arg |
| 692 | |
| 693 | |
| 694 | def _scatter_add(x, idx, m): |
no test coverage detected