(linear:UOp, max_batch_size:int=0)
| 29 | return cf.call(*input_list, metadata=tuple(m for si in batch for m in si.arg.metadata)) |
| 30 | |
| 31 | def graph_split_rewrite(linear:UOp, max_batch_size:int=0) -> UOp: |
| 32 | new_src: list[UOp] = [] |
| 33 | current_batch: list[UOp] = [] |
| 34 | current_batch_devs: list[Compiled] = [] |
| 35 | |
| 36 | def flush_batch(): |
| 37 | nonlocal current_batch, current_batch_devs, max_batch_size, new_src |
| 38 | if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): new_src.extend(current_batch) |
| 39 | else: |
| 40 | new_src.append(create_graph_call(current_batch)) |
| 41 | max_batch_size *= 2 |
| 42 | if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels") |
| 43 | current_batch, current_batch_devs = [], [] |
| 44 | |
| 45 | for si in linear.src: |
| 46 | if si.src[0].op is Ops.BUFFER_VIEW: continue |
| 47 | |
| 48 | devs = dedup([Device[x] for b in si.src[1:] if b.op is not Ops.BIND for x in (b.device if isinstance(b.device, tuple) else (b.device,))]) |
| 49 | graph_t = graph_class(devs[0]) if devs[0].graph is not None else None |
| 50 | |
| 51 | can_graph = graph_t is not None and graph_t.supports_uop(devs, si) |
| 52 | can_extend = can_graph and graph_t is not None and (not current_batch_devs or graph_t.supports_uop(current_batch_devs, si)) \ |
| 53 | and (max_batch_size == 0 or len(current_batch) < max_batch_size) |
| 54 | if not can_extend and current_batch: flush_batch() |
| 55 | |
| 56 | # append this si and update devs |
| 57 | (current_batch if can_graph else new_src).append(si) |
| 58 | current_batch_devs = dedup(current_batch_devs + devs) if can_graph else [] |
| 59 | if current_batch: flush_batch() |
| 60 | return linear.replace(src=tuple(new_src)) |
| 61 | |
| 62 | def _copy_input(u:UOp) -> UOp: |
| 63 | run_linear(UOp(Ops.LINEAR, src=(u.copy_to_device(u.device).call(new:=UOp.new_buffer(u.device, u.arg, u.dtype), u, metadata=()),))) |
no test coverage detected
searching dependent graphs…