MCPcopy
hub / github.com/tinygrad/tinygrad / shift_to

Method shift_to

tinygrad/codegen/opt/postrange.py:94–101  ·  view source on GitHub ↗
(self, rng:UOp, amount:int, new_type:AxisType, top:bool=False, input_new_rng:UOp|None=None)

Source from the content-addressed store, hash-verified

92 def colored_shape(self) -> str: return ' '.join([colored(f'{x.src[0].render():>4s}', color) for x,color in zip(self.rngs, self.colors())])
93
94 def shift_to(self, rng:UOp, amount:int, new_type:AxisType, top:bool=False, input_new_rng:UOp|None=None):
95 if (old_sz:=rng.src[0].divides(amount)) is None:
96 raise KernelOptError(f"{amount} can't divide {rng.src[0]} in {self.colored_shape()}")
97 new_rng = UOp.range(amount, next(self.opt_range), new_type) if input_new_rng is None else input_new_rng
98 replaced_rng = rng.replace(src=(UOp.const(dtypes.int, old_sz),))
99 sub_axis = (new_rng * old_sz + replaced_rng) if top else (replaced_rng * amount + new_rng)
100 self.ast = self.ast.substitute({rng:sub_axis}, name=f"shift {rng.arg[:-1]} {amount} {str(new_type).split('.')[1].lower()}")
101 return replaced_rng, new_rng
102
103 def ranges_of(self, *axis_type:AxisType) -> list[UOp]: return [r for r in self.rngs if r.arg[-1] in axis_type]
104 def axes_of(self, *axis_type:AxisType) -> list[int]: return [i for i,t in enumerate(self.axis_types) if t in axis_type]

Callers 2

apply_optMethod · 0.95
_apply_tc_optMethod · 0.95

Calls 8

colored_shapeMethod · 0.95
KernelOptErrorClass · 0.90
dividesMethod · 0.80
substituteMethod · 0.80
splitMethod · 0.80
rangeMethod · 0.45
replaceMethod · 0.45
constMethod · 0.45

Tested by

no test coverage detected