(linear:UOp, output_bufs:List[Buffer])
| 18 | return (u for u in linear.toposort(gate=lambda x: x.op not in _KERNEL_ASTS) if u.op is Ops.CALL and u.src[0].op in _KERNEL_ASTS) |
| 19 | |
| 20 | def compile_net(linear:UOp, output_bufs:List[Buffer]) -> Tuple[Dict[str,str], List, Dict[str,Tuple[int,DType,int]], Dict[str,Buffer]]: |
| 21 | output_name = {id(b): f"output{i}" for i, b in enumerate(output_bufs)} |
| 22 | functions, bufs, bufs_to_save, statements, n = {}, {}, {}, [], 0 |
| 23 | |
| 24 | def name_of(bu:UOp, is_out:bool) -> str: |
| 25 | nonlocal n |
| 26 | if bu.op is Ops.PARAM: key, name, size = ("in", bu.arg), f"input{bu.arg}", prod(bu.shape)*bu.dtype.itemsize |
| 27 | else: |
| 28 | b = bu.buffer |
| 29 | key, size = (id(b.base), b.offset, b.size, b.dtype), b.size*b.dtype.itemsize |
| 30 | if key in bufs: return bufs[key][0] |
| 31 | if (name:=output_name.get(id(b))) is None: |
| 32 | name, n = f"buf_{n}", n+1 |
| 33 | if not is_out: bufs_to_save[name] = b |
| 34 | bufs[key] = (name, size, bu.dtype, key) |
| 35 | return name |
| 36 | |
| 37 | for call in iter_kernel_calls(linear): |
| 38 | arg_uops = [b for b in call.src[1:] if b.op is not Ops.BIND] |
| 39 | prg = to_program(call.src[0], Device[arg_uops[0].device].renderer) |
| 40 | info = prg.arg |
| 41 | functions[info.function_name] = prg.src[3].arg |
| 42 | cargs = [name_of(bu, i == 0) for i, bu in enumerate(arg_uops)] + [v for v in info.vars if v.op is Ops.DEFINE_VAR] |
| 43 | statements.append((info.function_name, cargs, info.global_size, info.local_size)) |
| 44 | |
| 45 | return functions, statements, {name:(size, dtype, key) for name, size, dtype, key in bufs.values()}, bufs_to_save |
| 46 | |
| 47 | def jit_model(model, *args) -> Tuple[UOp, List[Buffer]]: |
| 48 | assert hasattr(model, "forward") or callable(model), "model needs a forward function" |
no test coverage detected
searching dependent graphs…