(ctx:ExecContext, call:UOp, ast:UOp)
| 180 | return et |
| 181 | |
| 182 | def exec_validate(ctx:ExecContext, call:UOp, ast:UOp) -> float|None: |
| 183 | import numpy as np |
| 184 | for bufs, device_vars in unwrap_multi(call, resolve_params(call, ctx.input_uops)): |
| 185 | bufs, dev_bufs = bufs[:len(bufs)//2], bufs[len(bufs)//2:] |
| 186 | var_vals = {**ctx.var_vals, **device_vars} |
| 187 | cpu_rt = get_runtime("CPU", prg:=to_program(ast.src[0], Device["CPU"].renderer)) |
| 188 | global_size, local_size = prg.arg.launch_dims(var_vals) |
| 189 | cpu_rt(*[bufs[i].ensure_allocated()._buf for i in prg.arg.globals], global_size=global_size, local_size=local_size, vals=prg.arg.vals(var_vals)) |
| 190 | for i in prg.arg.outs: np.testing.assert_allclose(dev_bufs[i].ensure_allocated().numpy(), bufs[i].numpy(), rtol=1e-3, atol=1e-3) |
| 191 | return None |
| 192 | |
| 193 | def exec_encdec(ctx:ExecContext, call:UOp, ast:UOp) -> float|None: |
| 194 | bufs = [cast(Buffer, b.buffer).ensure_allocated() for b in resolve_params(call, ctx.input_uops)] |
nothing calls this directly
no test coverage detected
searching dependent graphs…