(a, repeats, axis=None)
| 859 | |
| 860 | @derived_from(np) |
| 861 | def repeat(a, repeats, axis=None): |
| 862 | if axis is None: |
| 863 | if a.ndim == 1: |
| 864 | axis = 0 |
| 865 | else: |
| 866 | raise NotImplementedError("Must supply an integer axis value") |
| 867 | |
| 868 | if not isinstance(repeats, Integral): |
| 869 | raise NotImplementedError("Only integer valued repeats supported") |
| 870 | |
| 871 | if -a.ndim <= axis < 0: |
| 872 | axis += a.ndim |
| 873 | elif not 0 <= axis <= a.ndim - 1: |
| 874 | raise ValueError(f"axis(={axis}) out of bounds") |
| 875 | |
| 876 | if repeats == 0: |
| 877 | return a[tuple(slice(None) if d != axis else slice(0) for d in range(a.ndim))] |
| 878 | elif repeats == 1: |
| 879 | return a |
| 880 | |
| 881 | cchunks = cached_cumsum(a.chunks[axis], initial_zero=True) |
| 882 | slices = [] |
| 883 | for c_start, c_stop in sliding_window(2, cchunks): |
| 884 | ls = np.linspace(c_start, c_stop, repeats).round(0) |
| 885 | for ls_start, ls_stop in sliding_window(2, ls): |
| 886 | if ls_start != ls_stop: |
| 887 | slices.append(slice(ls_start, ls_stop)) |
| 888 | |
| 889 | all_slice = slice(None, None, None) |
| 890 | slices = [ |
| 891 | (all_slice,) * axis + (s,) + (all_slice,) * (a.ndim - axis - 1) for s in slices |
| 892 | ] |
| 893 | |
| 894 | slabs = [a[slc] for slc in slices] |
| 895 | |
| 896 | out = [] |
| 897 | for slab in slabs: |
| 898 | chunks = list(slab.chunks) |
| 899 | assert len(chunks[axis]) == 1 |
| 900 | chunks[axis] = (chunks[axis][0] * repeats,) |
| 901 | chunks = tuple(chunks) |
| 902 | result = slab.map_blocks( |
| 903 | np.repeat, repeats, axis=axis, chunks=chunks, dtype=slab.dtype |
| 904 | ) |
| 905 | out.append(result) |
| 906 | |
| 907 | return concatenate(out, axis=axis) |
| 908 | |
| 909 | |
| 910 | @derived_from(np) |
searching dependent graphs…