(src:UOp, buf:UOp, idx:UOp)
| 241 | # if a buffer is being stored just for permutes or something, remove it |
| 242 | # we want to reexpress the indexes of idx2 in terms of the implied b1 |
| 243 | def remove_bufferize(src:UOp, buf:UOp, idx:UOp): |
| 244 | # see if we can't do it, should this ever hit? |
| 245 | assert len(buf.src) == len(idx.src), f"index on wrong bufferize, {len(buf.src)} != {len(idx.src)}" |
| 246 | assert all(x.op in {Ops.RANGE, Ops.CONST} for x in buf.src[1:]) |
| 247 | |
| 248 | # if it's user contiguous, we never remove it |
| 249 | if src.op in ALWAYS_RUN_OPS or not buf.arg.removable: return None |
| 250 | |
| 251 | # *** here is where we compute the cost *** |
| 252 | # if we return None, the bufferize is kept |
| 253 | |
| 254 | accessed_buffers: list[UOp] = [] |
| 255 | indexes: list[UOp] = [] |
| 256 | reduces: list[UOp] = [] |
| 257 | def red_gate(x:UOp): |
| 258 | if (x.op is Ops.STAGE and x.arg.addrspace == AddrSpace.GLOBAL) or x.op is Ops.MSTACK: |
| 259 | accessed_buffers.append(x) |
| 260 | return False |
| 261 | if x.op is Ops.STORE: |
| 262 | # don't look inside stores, this doesn't count toward buffer accesses |
| 263 | return False |
| 264 | if x.op is Ops.PARAM: |
| 265 | accessed_buffers.append(x) |
| 266 | if x.op is Ops.INDEX: |
| 267 | indexes.append(x) |
| 268 | if x.op is Ops.REDUCE: reduces.append(x) |
| 269 | return True |
| 270 | src.toposort(gate=red_gate) |
| 271 | del red_gate |
| 272 | accessed_buffers = dedup(accessed_buffers) |
| 273 | |
| 274 | # if this is generated from multiple buffers, don't remove this buffer |
| 275 | if len(accessed_buffers) > 3 and not (PCONTIG > 2): return None |
| 276 | |
| 277 | # if any reduces access a buffer, don't remove this buffer |
| 278 | buffer_in_reduce = False |
| 279 | def buf_gate(x:UOp): |
| 280 | nonlocal buffer_in_reduce |
| 281 | if x.op in {Ops.PARAM, Ops.STAGE}: buffer_in_reduce = True |
| 282 | return not buffer_in_reduce |
| 283 | UOp.sink(*[x.src[0] for x in reduces]).toposort(gate=buf_gate) |
| 284 | del buf_gate |
| 285 | if buffer_in_reduce: |
| 286 | if PCONTIG > 2: |
| 287 | out_in_ratio = (prod(buf.shape)+1) / (sum([x.numel() for x in accessed_buffers])+1) |
| 288 | if out_in_ratio < 10: return None |
| 289 | # here we have to check the indexes, we might do a partial contig here |
| 290 | local_indexes = [x for x in indexes if x.src[0].op is Ops.STAGE and x.src[0].arg.addrspace == AddrSpace.LOCAL] |
| 291 | exclude_ranges = UOp.group(*[UOp.group(*x.src[1:]) for x in local_indexes]).ranges |
| 292 | subs = [(k,v) for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST] |
| 293 | # if it's bufferized or a reduce, it's pcontig |
| 294 | is_pcontig, is_subs = partition(subs, lambda x: x[0] in exclude_ranges or any([r.arg[-1] == AxisType.REDUCE for r in x[1].ranges])) |
| 295 | if not len(is_subs): |
| 296 | return None |
| 297 | if len(is_pcontig): |
| 298 | ret = src.substitute(dict(is_subs), extra_pm=pm_gate_substitute) |
| 299 | return ret.bufferize(*[x[0] for x in is_pcontig], arg=BufferizeOpts(None, AddrSpace.LOCAL)).index(*[x[1] for x in is_pcontig]) |
| 300 | else: |
nothing calls this directly
no test coverage detected
searching dependent graphs…