(prg:UOp, var_vals:dict[str, int], rawbufs:list[Buffer], early_stop:float|None=None,
allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test", dev_timeout=False)
| 35 | return test_global_size, input_size / prod(test_global_size) |
| 36 | |
| 37 | def _time_program(prg:UOp, var_vals:dict[str, int], rawbufs:list[Buffer], early_stop:float|None=None, |
| 38 | allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test", dev_timeout=False) -> list[float]: |
| 39 | timeout = int(early_stop * 1e3) if dev_timeout and early_stop is not None and early_stop < math.inf else None |
| 40 | factor = 1 |
| 41 | if allow_test_size and max_global_size is not None: |
| 42 | global_size, factor = get_test_global_size(prg.arg.global_size, max_global_size, var_vals) |
| 43 | prg = prg.replace(arg=replace(prg.arg, global_size=tuple(global_size))) |
| 44 | call = prg.call(*[UOp.from_buffer(b) for b in rawbufs]) |
| 45 | tms = [] |
| 46 | for _ in range(cnt): |
| 47 | try: tms.append(time_call(call, var_vals, timeout=timeout, clear_l2=clear_l2) * factor) |
| 48 | except AssertionError: return [math.inf] * cnt |
| 49 | if early_stop is not None and early_stop < min(tms): break |
| 50 | return tms |
| 51 | |
| 52 | class TimeoutException(Exception): pass |
| 53 | def timeout_handler(signum, frame): |
no test coverage detected
searching dependent graphs…