MCPcopy
hub / github.com/tinygrad/tinygrad / split_load_store

Function split_load_store

tinygrad/codegen/late/devectorizer.py:153–197  ·  view source on GitHub ↗
(ctx:Renderer|None, ls:UOp, idx:UOp)

Source from the content-addressed store, hash-verified

151# *** correct load/store ***
152
153def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
154 # this splits loads and stores into multiple chunks
155
156 # if there's only one element to load/store, no splitting needed
157 if (sz:=ls.src[0].dtype.count) == 1: return None
158 buf = idx.src[0]
159
160 # determine fold lengths
161 lengths = []
162 must_divide = True
163 if ctx is not None and ctx.target.device == "DSP":
164 lengths = [128,64,32,16,8,4]
165 must_divide = False
166 elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType):
167 pass
168 elif buf.ptrdtype.addrspace == AddrSpace.REG:
169 pass
170 elif isinstance(buf.dtype, ImageDType):
171 lengths = [4]
172 elif ctx is not None and ctx.supports_float4:
173 # TODO: a better way to get this than ctx
174 lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if "AMX" in ctx.target.arch else [4,2])
175 lengths.append(1) # worst case, it's not folded
176
177 # filter fold lengths that don't divide
178 offset, mask = idx.src[1].get_idx(), idx.src[1].get_valid()
179 if must_divide: lengths = [x for x in lengths if offset.divides(x) is not None]
180
181 # split based on the fold lengths
182 global_offset = 0
183 ret = []
184 while global_offset < sz:
185 # with 1 at the end of the lengths list, this will always hit
186 for fold_length in lengths:
187 if global_offset+fold_length > sz: continue
188 lidx = buf.index((offset + global_offset).valid(mask), ptr=True)
189 if fold_length > 1: lidx = lidx.cast(buf.ptrdtype.base.vec(fold_length).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace))
190 if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))))
191 else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length)))
192 global_offset += fold_length
193 break
194
195 # if it wasn't split, we return None. otherwise we CAT them
196 if len(ret) <= 1: return None
197 return UOp(Ops.VCAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret)
198
199def get_image_idx(idx:UOp, width:int):
200 x, valid = idx.src[1].get_idx(), idx.src[1].get_valid()

Callers

nothing calls this directly

Calls 15

getenvFunction · 0.90
UOpClass · 0.90
appendMethod · 0.80
get_idxMethod · 0.80
get_validMethod · 0.80
dividesMethod · 0.80
scalarMethod · 0.80
indexMethod · 0.45
validMethod · 0.45
castMethod · 0.45
ptrMethod · 0.45
vecMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…