(device:str, ast:UOp, cache=True)
| 107 | |
| 108 | runtime_cache: dict[tuple[bytes, str], Any] = {} |
| 109 | def get_runtime(device:str, ast:UOp, cache=True): |
| 110 | assert ast.op is Ops.PROGRAM and isinstance(ast.arg, ProgramInfo), "get_runtime should only be called with a PROGRAM ast" |
| 111 | if (runtime:=runtime_cache.get(key:=(ast.key, device))) is None: |
| 112 | runtime = Device[device].runtime(ast.arg.function_name, ast.src[4].arg, *ast.arg.aux, runtimevars=ast.arg.runtimevars, prg=ast) |
| 113 | if cache: runtime_cache[key] = runtime |
| 114 | return runtime |
| 115 | |
| 116 | graph_cache:weakref.WeakKeyDictionary[UOp, Any] = weakref.WeakKeyDictionary() |
| 117 | def get_graph_runtime(ast:UOp, input_uops:tuple[UOp, ...]|None=None): |
searching dependent graphs…