(a, repeats, axis=None)
| 535 | |
| 536 | @derived_from(np) |
| 537 | def repeat(a, repeats, axis=None): |
| 538 | if axis is None: |
| 539 | if a.ndim == 1: |
| 540 | axis = 0 |
| 541 | else: |
| 542 | raise NotImplementedError("Must supply an integer axis value") |
| 543 | |
| 544 | if not isinstance(repeats, Integral): |
| 545 | raise NotImplementedError("Only integer valued repeats supported") |
| 546 | |
| 547 | if -a.ndim <= axis < 0: |
| 548 | axis += a.ndim |
| 549 | elif not 0 <= axis <= a.ndim - 1: |
| 550 | raise ValueError(f"axis(={axis}) out of bounds") |
| 551 | |
| 552 | if repeats == 0: |
| 553 | return a[tuple(slice(None) if d != axis else slice(0) for d in range(a.ndim))] |
| 554 | elif repeats == 1: |
| 555 | return a |
| 556 | |
| 557 | cchunks = cached_cumsum(a.chunks[axis], initial_zero=True) |
| 558 | slices = [] |
| 559 | for c_start, c_stop in sliding_window(2, cchunks): |
| 560 | ls = np.linspace(c_start, c_stop, repeats).round(0) |
| 561 | for ls_start, ls_stop in sliding_window(2, ls): |
| 562 | if ls_start != ls_stop: |
| 563 | slices.append(slice(ls_start, ls_stop)) |
| 564 | |
| 565 | all_slice = slice(None, None, None) |
| 566 | slices = [ |
| 567 | (all_slice,) * axis + (s,) + (all_slice,) * (a.ndim - axis - 1) for s in slices |
| 568 | ] |
| 569 | |
| 570 | slabs = [a[slc] for slc in slices] |
| 571 | |
| 572 | out = [] |
| 573 | for slab in slabs: |
| 574 | chunks = list(slab.chunks) |
| 575 | assert len(chunks[axis]) == 1 |
| 576 | chunks[axis] = (chunks[axis][0] * repeats,) |
| 577 | chunks = tuple(chunks) |
| 578 | result = slab.map_blocks( |
| 579 | np.repeat, repeats, axis=axis, chunks=chunks, dtype=slab.dtype |
| 580 | ) |
| 581 | out.append(result) |
| 582 | |
| 583 | return concatenate(out, axis=axis) |
no test coverage detected
searching dependent graphs…