(self, rng:UOp, amount:int, new_type:AxisType, top:bool=False, input_new_rng:UOp|None=None)
| 92 | def colored_shape(self) -> str: return ' '.join([colored(f'{x.src[0].render():>4s}', color) for x,color in zip(self.rngs, self.colors())]) |
| 93 | |
| 94 | def shift_to(self, rng:UOp, amount:int, new_type:AxisType, top:bool=False, input_new_rng:UOp|None=None): |
| 95 | if (old_sz:=rng.src[0].divides(amount)) is None: |
| 96 | raise KernelOptError(f"{amount} can't divide {rng.src[0]} in {self.colored_shape()}") |
| 97 | new_rng = UOp.range(amount, next(self.opt_range), new_type) if input_new_rng is None else input_new_rng |
| 98 | replaced_rng = rng.replace(src=(UOp.const(dtypes.int, old_sz),)) |
| 99 | sub_axis = (new_rng * old_sz + replaced_rng) if top else (replaced_rng * amount + new_rng) |
| 100 | self.ast = self.ast.substitute({rng:sub_axis}, name=f"shift {rng.arg[:-1]} {amount} {str(new_type).split('.')[1].lower()}") |
| 101 | return replaced_rng, new_rng |
| 102 | |
| 103 | def ranges_of(self, *axis_type:AxisType) -> list[UOp]: return [r for r in self.rngs if r.arg[-1] in axis_type] |
| 104 | def axes_of(self, *axis_type:AxisType) -> list[int]: return [i for i,t in enumerate(self.axis_types) if t in axis_type] |
no test coverage detected