(ctx:ExecContext, call:UOp, device:str, bufs:list[Buffer], var_vals:dict[str, int])
| 47 | first_run_cache:set[bytes] = set() |
| 48 | @contextlib.contextmanager |
| 49 | def track_stats(ctx:ExecContext, call:UOp, device:str, bufs:list[Buffer], var_vals:dict[str, int]): |
| 50 | if PROFILE: |
| 51 | outputs, inputs = get_call_outs_ins(call) |
| 52 | cpu_events.append(ProfilePointEvent(device, "exec", len(cpu_events), {"metadata": call.arg.metadata, "var_vals": var_vals, |
| 53 | "bufs": [b.trace_num for b in bufs], "name": get_call_name(call, bufs, var_vals), "outputs": outputs, "inputs": inputs})) |
| 54 | et: list[float|None] = [None] |
| 55 | if DEBUG >= 2: st = time.perf_counter() |
| 56 | yield et |
| 57 | if not ctx.update_stats: return |
| 58 | |
| 59 | if DEBUG >= 2 and et[0] is None: |
| 60 | Device[device].synchronize() |
| 61 | et[0] = time.perf_counter() - st |
| 62 | |
| 63 | estimates = estimate_uop(call) |
| 64 | GlobalCounters.kernel_count += 1 |
| 65 | GlobalCounters.global_ops += (op_est:=sym_infer(estimates.ops, var_vals)) |
| 66 | GlobalCounters.global_mem += (mem_est:=sym_infer(estimates.mem, var_vals)) |
| 67 | if et[0] is not None: GlobalCounters.time_sum_s += et[0] |
| 68 | if DEBUG >= 2: |
| 69 | display_name = get_call_name(call, bufs, var_vals) |
| 70 | lds_est = sym_infer(estimates.lds, var_vals) |
| 71 | header_color = 'magenta' if ctx.jit else ('green' if call.src[0].key not in first_run_cache else None) |
| 72 | ptm = colored(time_to_str(et[0], w=9), "yellow" if et[0] > 0.01 else None) if et[0] is not None else "" |
| 73 | flops, membw, ldsbw = op_est/(et[0] or 1e-20), mem_est/(et[0] or 1e-20), lds_est/(et[0] or 1e-20) |
| 74 | flops_str = f"{flops*1e-9:7.0f} GFLOPS" if flops < 1e14 else colored(f"{flops*1e-12:7.0f} TFLOPS", 'green') |
| 75 | mem_str = f"{membw*1e-9:4.0f}|{ldsbw*1e-9:<6.0f} GB/s" if membw < 1e13 and ldsbw < 1e15 else \ |
| 76 | colored(f"{membw*1e-12:4.0f}|{ldsbw*1e-12:<6.0f} TB/s", 'green') |
| 77 | print(f"{colored(f'*** {device[:7]:7s} {GlobalCounters.kernel_count:4d}', header_color)}"+ |
| 78 | f" {display_name+' '*(46-ansilen(display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:6.2f} GB"+ |
| 79 | ("" if et[0] is None else f" tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({flops_str} {mem_str})")+ |
| 80 | f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in call.arg.metadata] if call.arg.metadata else ''}") |
| 81 | first_run_cache.add(call.src[0].key) |
| 82 | |
| 83 | local_size_cache: dict[bytes, tuple[int, ...]] = {} |
| 84 | def optimize_local_size(call:UOp, prg:UOp) -> UOp|None: |
no test coverage detected
searching dependent graphs…