MCPcopy
hub / github.com/dmlc/dgl / _segment_reduce

Function _segment_reduce

python/dgl/_sparse_ops.py:641–691  ·  view source on GitHub ↗

r"""Segment reduction operator. It aggregates the value tensor along the first dimension by segments. The argument ``offsets`` specifies the start offset of each segment (and the upper bound of the last segment). Zero-length segments are allowed. .. math:: y_i = \Phi_{j=\math

(op, feat, offsets)

Source from the content-addressed store, hash-verified

639
640
641def _segment_reduce(op, feat, offsets):
642 r"""Segment reduction operator.
643
644 It aggregates the value tensor along the first dimension by segments.
645 The argument ``offsets`` specifies the start offset of each segment (and
646 the upper bound of the last segment). Zero-length segments are allowed.
647
648 .. math::
649 y_i = \Phi_{j=\mathrm{offsets}_i}^{\mathrm{offsets}_{i+1}-1} x_j
650
651 where :math:`\Phi` is the reduce operator.
652
653 Parameters
654 ----------
655 op : str
656 Aggregation method. Can be ``sum``, ``max``, ``min``.
657 x : Tensor
658 Value to aggregate.
659 offsets : Tensor
660 The start offsets of segments.
661
662 Returns
663 -------
664 tuple(Tensor)
665 The first tensor correspond to aggregated tensor of shape
666 ``(len(seglen), value.shape[1:])``, and the second tensor records
667 the argmin/max at each position for computing gradients.
668
669 Notes
670 -----
671 This function does not handle gradients.
672 """
673 n = F.shape(offsets)[0] - 1
674 out_shp = (n,) + F.shape(feat)[1:]
675 ctx = F.context(feat)
676 dtype = F.dtype(feat)
677 idtype = F.dtype(offsets)
678 out = F.zeros(out_shp, dtype, ctx)
679 arg = None
680 if op in ["min", "max"]:
681 arg = F.zeros(out_shp, idtype, ctx)
682 arg_nd = to_dgl_nd_for_write(arg)
683 _CAPI_DGLKernelSegmentReduce(
684 op,
685 to_dgl_nd(feat),
686 to_dgl_nd(offsets),
687 to_dgl_nd_for_write(out),
688 arg_nd,
689 )
690 arg = None if arg is None else F.zerocopy_from_dgl_ndarray(arg_nd)
691 return out, arg
692
693
694def _scatter_add(x, idx, m):

Callers 3

forwardMethod · 0.85
segment_reduce_realFunction · 0.85
forwardMethod · 0.85

Calls 5

to_dgl_nd_for_writeFunction · 0.85
contextMethod · 0.80
to_dgl_ndFunction · 0.70
shapeMethod · 0.45
dtypeMethod · 0.45

Tested by

no test coverage detected