| 173 | # Implementation adapted directly from numpy: |
| 174 | # https://github.com/numpy/numpy/blob/v1.17.0/numpy/core/numeric.py#L1107-L1204 |
| 175 | def rollaxis(a, axis, start=0): |
| 176 | n = a.ndim |
| 177 | axis = normalize_axis_index(axis, n) |
| 178 | if start < 0: |
| 179 | start += n |
| 180 | msg = "'%s' arg requires %d <= %s < %d, but %d was passed in" |
| 181 | if not (0 <= start < n + 1): |
| 182 | raise ValueError(msg % ("start", -n, "start", n + 1, start)) |
| 183 | if axis < start: |
| 184 | # it's been removed |
| 185 | start -= 1 |
| 186 | if axis == start: |
| 187 | return a[...] |
| 188 | axes = list(range(0, n)) |
| 189 | axes.remove(axis) |
| 190 | axes.insert(start, axis) |
| 191 | return a.transpose(axes) |