(A:UOp)
| 100 | return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts])))) |
| 101 | |
| 102 | def custom_handwritten(A:UOp) -> UOp: |
| 103 | A = A.flatten() |
| 104 | threads = UOp.special(128, "lidx0") |
| 105 | wg = UOp.special(1, "gidx0") |
| 106 | lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=512, addrspace=AddrSpace.LOCAL), (), 'lds') # 128 * 4 bytes |
| 107 | pipes = {getenv("PIPE", "")} if getenv("PIPE", "") else {"SALU", "VALU", "TRANSCENDENTAL", "WMMA"} |
| 108 | k = Kernel() |
| 109 | # wrap in loop to filter out icache misses |
| 110 | LOOP_N, UNROLL_N = 8, 5 |
| 111 | k.emit(r4.s_mov_b32(s[1], LOOP_N)) |
| 112 | k.label("loop") |
| 113 | if "SALU" in pipes: |
| 114 | for i in range(UNROLL_N): |
| 115 | k.emit(r4.s_mov_b32(s[20+i], i)) |
| 116 | k.emit(r4.s_min_i32(s[30+i], i)) |
| 117 | k.emit(r4.s_mov_b32(s[40+i], i)) |
| 118 | k.emit(r4.s_mul_i32(s[14+i], s[12+i], 32)) |
| 119 | if "VALU" in pipes: |
| 120 | for i in range(UNROLL_N): |
| 121 | k.emit(r4.v_mov_b32_e32(v[20+i], i)) |
| 122 | k.emit(r4.v_lshlrev_b64_e32(v[30+2*i:31+2*i], 2, v[12+i:13+i])) |
| 123 | k.emit(r4.v_mad_co_u64_u32(v[40+2*i:41+2*i], NULL, v[12+i], v[13+i], v[14+i:15+i])) |
| 124 | if "TRANSCENDENTAL" in pipes: |
| 125 | # transcendental VALU runs on the TFU, it can run regular VALU at the same time |
| 126 | for i in range(UNROLL_N): |
| 127 | k.emit(r4.v_mov_b32_e32(v[20+i], i)) |
| 128 | k.emit(r4.v_s_rcp_f32(s[20+i], s[12+i])) |
| 129 | k.emit(r4.v_rcp_f32_e32(v[30+i], v[12+i])) |
| 130 | k.emit(r4.v_s_exp_f32(s[30+i], s[12+i])) |
| 131 | if "WMMA" in pipes: |
| 132 | base = 30 |
| 133 | for i in range(UNROLL_N): |
| 134 | a = base + i*40 |
| 135 | b, cd = a + 4, a + 8 |
| 136 | k.emit(r4.v_wmma_f32_16x16x16_f16(v[cd:cd+7], v[a:a+3], v[b:b+3], v[cd:cd+7])) |
| 137 | a = base + i*40 + 16 |
| 138 | b, cd = a + 2, a + 4 |
| 139 | k.emit(r4.v_wmma_i32_16x16x16_iu8(v[cd:cd+7], v[a:a+1], v[b:b+1], v[cd:cd+7])) |
| 140 | k.emit(r4.s_add_co_i32(s[1], s[1], -1)) |
| 141 | k.emit(r4.s_cmp_eq_i32(s[1], 0)) |
| 142 | k.emit(r4.s_cbranch_scc0(), target="loop") |
| 143 | k.emit(r4.s_endpgm()) |
| 144 | insts = k.finalize() |
| 145 | sink = UOp.sink(A.base, threads, wg, lds, arg=KernelInfo("custom_handwritten")) |
| 146 | return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts])))) |
| 147 | |
| 148 | def custom_data_deps(A:UOp) -> UOp: |
| 149 | A = A.flatten() |
nothing calls this directly
no test coverage detected
searching dependent graphs…