MCPcopy
hub / github.com/tinygrad/tinygrad / remove_bufferize

Function remove_bufferize

tinygrad/schedule/rangeify.py:243–307  ·  view source on GitHub ↗
(src:UOp, buf:UOp, idx:UOp)

Source from the content-addressed store, hash-verified

241# if a buffer is being stored just for permutes or something, remove it
242# we want to reexpress the indexes of idx2 in terms of the implied b1
243def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
244 # see if we can't do it, should this ever hit?
245 assert len(buf.src) == len(idx.src), f"index on wrong bufferize, {len(buf.src)} != {len(idx.src)}"
246 assert all(x.op in {Ops.RANGE, Ops.CONST} for x in buf.src[1:])
247
248 # if it's user contiguous, we never remove it
249 if src.op in ALWAYS_RUN_OPS or not buf.arg.removable: return None
250
251 # *** here is where we compute the cost ***
252 # if we return None, the bufferize is kept
253
254 accessed_buffers: list[UOp] = []
255 indexes: list[UOp] = []
256 reduces: list[UOp] = []
257 def red_gate(x:UOp):
258 if (x.op is Ops.STAGE and x.arg.addrspace == AddrSpace.GLOBAL) or x.op is Ops.MSTACK:
259 accessed_buffers.append(x)
260 return False
261 if x.op is Ops.STORE:
262 # don't look inside stores, this doesn't count toward buffer accesses
263 return False
264 if x.op is Ops.PARAM:
265 accessed_buffers.append(x)
266 if x.op is Ops.INDEX:
267 indexes.append(x)
268 if x.op is Ops.REDUCE: reduces.append(x)
269 return True
270 src.toposort(gate=red_gate)
271 del red_gate
272 accessed_buffers = dedup(accessed_buffers)
273
274 # if this is generated from multiple buffers, don't remove this buffer
275 if len(accessed_buffers) > 3 and not (PCONTIG > 2): return None
276
277 # if any reduces access a buffer, don't remove this buffer
278 buffer_in_reduce = False
279 def buf_gate(x:UOp):
280 nonlocal buffer_in_reduce
281 if x.op in {Ops.PARAM, Ops.STAGE}: buffer_in_reduce = True
282 return not buffer_in_reduce
283 UOp.sink(*[x.src[0] for x in reduces]).toposort(gate=buf_gate)
284 del buf_gate
285 if buffer_in_reduce:
286 if PCONTIG > 2:
287 out_in_ratio = (prod(buf.shape)+1) / (sum([x.numel() for x in accessed_buffers])+1)
288 if out_in_ratio < 10: return None
289 # here we have to check the indexes, we might do a partial contig here
290 local_indexes = [x for x in indexes if x.src[0].op is Ops.STAGE and x.src[0].arg.addrspace == AddrSpace.LOCAL]
291 exclude_ranges = UOp.group(*[UOp.group(*x.src[1:]) for x in local_indexes]).ranges
292 subs = [(k,v) for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST]
293 # if it's bufferized or a reduce, it's pcontig
294 is_pcontig, is_subs = partition(subs, lambda x: x[0] in exclude_ranges or any([r.arg[-1] == AxisType.REDUCE for r in x[1].ranges]))
295 if not len(is_subs):
296 return None
297 if len(is_pcontig):
298 ret = src.substitute(dict(is_subs), extra_pm=pm_gate_substitute)
299 return ret.bufferize(*[x[0] for x in is_pcontig], arg=BufferizeOpts(None, AddrSpace.LOCAL)).index(*[x[1] for x in is_pcontig])
300 else:

Callers

nothing calls this directly

Calls 11

dedupFunction · 0.90
prodFunction · 0.90
partitionFunction · 0.90
BufferizeOptsClass · 0.90
toposortMethod · 0.80
numelMethod · 0.80
substituteMethod · 0.80
bufferizeMethod · 0.80
sinkMethod · 0.45
groupMethod · 0.45
indexMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…