| 175 | aggrs = args.aggr.split(',') |
| 176 | |
| 177 | def pytorch_scatter(x, index, dim_size, reduce): |
| 178 | if reduce == 'min' or reduce == 'max': |
| 179 | reduce = f'a{aggr}' # `amin` or `amax` |
| 180 | elif reduce == 'mul': |
| 181 | reduce = 'prod' |
| 182 | out = x.new_zeros(dim_size, x.size(-1)) |
| 183 | include_self = reduce in ['sum', 'mean'] |
| 184 | index = index.view(-1, 1).expand(-1, x.size(-1)) |
| 185 | out.scatter_reduce_(0, index, x, reduce, include_self=include_self) |
| 186 | return out |
| 187 | |
| 188 | def pytorch_index_add(x, index, dim_size, reduce): |
| 189 | if reduce != 'sum': |