(A:UOp)
| 15 | from extra.gemm.amd_asm_matmul import Kernel |
| 16 | |
| 17 | def custom_add_one(A:UOp) -> UOp: |
| 18 | A = A.flatten() |
| 19 | assert dtypes.is_float(A.dtype.base), f"buffer dtype must be float32, got {A.dtype}" |
| 20 | threads = UOp.special(A.numel(), "lidx0") |
| 21 | insts = [ |
| 22 | s_load_b64(s[0:1], s[0:1], soffset=NULL), |
| 23 | s_waitcnt_lgkmcnt(sdst=NULL, simm16=0), |
| 24 | v_lshlrev_b32_e32(v[0], 2, v[0]), # element offset |
| 25 | global_load_b32(v[1], v[0], saddr=s[0:1]), |
| 26 | s_waitcnt_vmcnt(sdst=NULL, simm16=0), |
| 27 | v_mov_b32_e32(v[2], 1.0), |
| 28 | v_add_f32_e32(v[1], v[1], v[2]), |
| 29 | global_store_b32(addr=v[0], data=v[1], saddr=s[0:1]), |
| 30 | s_endpgm(), |
| 31 | ] |
| 32 | sink = UOp.sink(A.base, threads, arg=KernelInfo(f"custom_add_one_{A.numel()}", estimates=Estimates(ops=A.numel(), mem=A.numel()*4*2))) |
| 33 | return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts])))) |
| 34 | |
| 35 | def custom_add_var(A:UOp, B:UOp) -> UOp: |
| 36 | A,B = A.flatten(), B.flatten() |
nothing calls this directly
no test coverage detected
searching dependent graphs…