MCPcopy
hub / github.com/dmlc/dgl / segment_reduce

Function segment_reduce

python/dgl/ops/segment.py:9–59  ·  view source on GitHub ↗

Segment reduction operator. It aggregates the value tensor along the first dimension by segments. The first argument ``seglen`` stores the length of each segment. Its summation must be equal to the first dimension of the ``value`` tensor. Zero-length segments are allowed. Param

(seglen, value, reducer="sum")

Source from the content-addressed store, hash-verified

7
8
9def segment_reduce(seglen, value, reducer="sum"):
10 """Segment reduction operator.
11
12 It aggregates the value tensor along the first dimension by segments.
13 The first argument ``seglen`` stores the length of each segment. Its
14 summation must be equal to the first dimension of the ``value`` tensor.
15 Zero-length segments are allowed.
16
17 Parameters
18 ----------
19 seglen : Tensor
20 Segment lengths.
21 value : Tensor
22 Value to aggregate.
23 reducer : str, optional
24 Aggregation method. Can be 'sum', 'max', 'min', 'mean'.
25
26 Returns
27 -------
28 Tensor
29 Aggregated tensor of shape ``(len(seglen), value.shape[1:])``.
30
31 Examples
32 --------
33
34 >>> import dgl
35 >>> import torch as th
36 >>> val = th.ones(10, 3)
37 >>> seg = th.tensor([1, 0, 5, 4]) # 4 segments
38 >>> dgl.segment_reduce(seg, val)
39 tensor([[1., 1., 1.],
40 [0., 0., 0.],
41 [5., 5., 5.],
42 [4., 4., 4.]])
43 """
44 offsets = F.cumsum(
45 F.cat([F.zeros((1,), F.dtype(seglen), F.context(seglen)), seglen], 0), 0
46 )
47 if reducer == "mean":
48 rst = F.segment_reduce("sum", value, offsets)
49 rst_shape = F.shape(rst)
50 z = F.astype(F.clamp(seglen, 1, len(value)), F.dtype(rst))
51 z_shape = (rst_shape[0],) + (1,) * (len(rst_shape) - 1)
52 return rst / F.reshape(z, z_shape)
53 elif reducer in ["min", "sum", "max"]:
54 rst = F.segment_reduce(reducer, value, offsets)
55 if reducer in ["min", "max"]:
56 rst = F.replace_inf_with_zero(rst)
57 return rst
58 else:
59 raise DGLError("reducer {} not recognized.".format(reducer))
60
61
62def segment_softmax(seglen, value):

Callers 2

test_segment_reduceFunction · 0.90
segment_softmaxFunction · 0.70

Calls 6

DGLErrorClass · 0.85
contextMethod · 0.80
formatMethod · 0.80
dtypeMethod · 0.45
shapeMethod · 0.45
astypeMethod · 0.45

Tested by 1

test_segment_reduceFunction · 0.72