Segment reduction operator. It aggregates the value tensor along the first dimension by segments. The first argument ``seglen`` stores the length of each segment. Its summation must be equal to the first dimension of the ``value`` tensor. Zero-length segments are allowed. Param
(seglen, value, reducer="sum")
| 7 | |
| 8 | |
| 9 | def segment_reduce(seglen, value, reducer="sum"): |
| 10 | """Segment reduction operator. |
| 11 | |
| 12 | It aggregates the value tensor along the first dimension by segments. |
| 13 | The first argument ``seglen`` stores the length of each segment. Its |
| 14 | summation must be equal to the first dimension of the ``value`` tensor. |
| 15 | Zero-length segments are allowed. |
| 16 | |
| 17 | Parameters |
| 18 | ---------- |
| 19 | seglen : Tensor |
| 20 | Segment lengths. |
| 21 | value : Tensor |
| 22 | Value to aggregate. |
| 23 | reducer : str, optional |
| 24 | Aggregation method. Can be 'sum', 'max', 'min', 'mean'. |
| 25 | |
| 26 | Returns |
| 27 | ------- |
| 28 | Tensor |
| 29 | Aggregated tensor of shape ``(len(seglen), value.shape[1:])``. |
| 30 | |
| 31 | Examples |
| 32 | -------- |
| 33 | |
| 34 | >>> import dgl |
| 35 | >>> import torch as th |
| 36 | >>> val = th.ones(10, 3) |
| 37 | >>> seg = th.tensor([1, 0, 5, 4]) # 4 segments |
| 38 | >>> dgl.segment_reduce(seg, val) |
| 39 | tensor([[1., 1., 1.], |
| 40 | [0., 0., 0.], |
| 41 | [5., 5., 5.], |
| 42 | [4., 4., 4.]]) |
| 43 | """ |
| 44 | offsets = F.cumsum( |
| 45 | F.cat([F.zeros((1,), F.dtype(seglen), F.context(seglen)), seglen], 0), 0 |
| 46 | ) |
| 47 | if reducer == "mean": |
| 48 | rst = F.segment_reduce("sum", value, offsets) |
| 49 | rst_shape = F.shape(rst) |
| 50 | z = F.astype(F.clamp(seglen, 1, len(value)), F.dtype(rst)) |
| 51 | z_shape = (rst_shape[0],) + (1,) * (len(rst_shape) - 1) |
| 52 | return rst / F.reshape(z, z_shape) |
| 53 | elif reducer in ["min", "sum", "max"]: |
| 54 | rst = F.segment_reduce(reducer, value, offsets) |
| 55 | if reducer in ["min", "max"]: |
| 56 | rst = F.replace_inf_with_zero(rst) |
| 57 | return rst |
| 58 | else: |
| 59 | raise DGLError("reducer {} not recognized.".format(reducer)) |
| 60 | |
| 61 | |
| 62 | def segment_softmax(seglen, value): |