(ctx:Renderer|None, ls:UOp, idx:UOp)
| 151 | # *** correct load/store *** |
| 152 | |
| 153 | def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp): |
| 154 | # this splits loads and stores into multiple chunks |
| 155 | |
| 156 | # if there's only one element to load/store, no splitting needed |
| 157 | if (sz:=ls.src[0].dtype.count) == 1: return None |
| 158 | buf = idx.src[0] |
| 159 | |
| 160 | # determine fold lengths |
| 161 | lengths = [] |
| 162 | must_divide = True |
| 163 | if ctx is not None and ctx.target.device == "DSP": |
| 164 | lengths = [128,64,32,16,8,4] |
| 165 | must_divide = False |
| 166 | elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType): |
| 167 | pass |
| 168 | elif buf.ptrdtype.addrspace == AddrSpace.REG: |
| 169 | pass |
| 170 | elif isinstance(buf.dtype, ImageDType): |
| 171 | lengths = [4] |
| 172 | elif ctx is not None and ctx.supports_float4: |
| 173 | # TODO: a better way to get this than ctx |
| 174 | lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if "AMX" in ctx.target.arch else [4,2]) |
| 175 | lengths.append(1) # worst case, it's not folded |
| 176 | |
| 177 | # filter fold lengths that don't divide |
| 178 | offset, mask = idx.src[1].get_idx(), idx.src[1].get_valid() |
| 179 | if must_divide: lengths = [x for x in lengths if offset.divides(x) is not None] |
| 180 | |
| 181 | # split based on the fold lengths |
| 182 | global_offset = 0 |
| 183 | ret = [] |
| 184 | while global_offset < sz: |
| 185 | # with 1 at the end of the lengths list, this will always hit |
| 186 | for fold_length in lengths: |
| 187 | if global_offset+fold_length > sz: continue |
| 188 | lidx = buf.index((offset + global_offset).valid(mask), ptr=True) |
| 189 | if fold_length > 1: lidx = lidx.cast(buf.ptrdtype.base.vec(fold_length).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace)) |
| 190 | if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length)))))) |
| 191 | else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length))) |
| 192 | global_offset += fold_length |
| 193 | break |
| 194 | |
| 195 | # if it wasn't split, we return None. otherwise we CAT them |
| 196 | if len(ret) <= 1: return None |
| 197 | return UOp(Ops.VCAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret) |
| 198 | |
| 199 | def get_image_idx(idx:UOp, width:int): |
| 200 | x, valid = idx.src[1].get_idx(), idx.src[1].get_valid() |
nothing calls this directly
no test coverage detected
searching dependent graphs…