Do what numpy.array_split does, but add indices.
(ary, indices_or_sections, axis=0, n_per_split=1)
| 52 | |
| 53 | |
| 54 | def array_split_idx(ary, indices_or_sections, axis=0, n_per_split=1): |
| 55 | """Do what numpy.array_split does, but add indices.""" |
| 56 | # this only works for indices_or_sections as int |
| 57 | indices_or_sections = _ensure_int(indices_or_sections) |
| 58 | ary_split = np.array_split(ary, indices_or_sections, axis=axis) |
| 59 | idx_split = np.array_split(np.arange(ary.shape[axis]), indices_or_sections) |
| 60 | idx_split = ( |
| 61 | np.arange(sp[0] * n_per_split, (sp[-1] + 1) * n_per_split) for sp in idx_split |
| 62 | ) |
| 63 | return zip(idx_split, ary_split) |
| 64 | |
| 65 | |
| 66 | def sum_squared(X): |