(N, arch='gfx1200')
| 20 | ACC, DA, DB, FA, FB, ET = 60, 188, 196, 204, 44, 10 |
| 21 | |
| 22 | def build_kernel(N, arch='gfx1200'): |
| 23 | assert N % BLOCK_M == 0 and N >= 256 |
| 24 | NO_ALU, NO_DS, NO_GLOBAL = getenv("NO_ALU", 0), getenv("NO_DS", 0), getenv("NO_GLOBAL", 0) |
| 25 | I, L, B = [], {}, [] |
| 26 | def e(i): I.append(i); return i |
| 27 | def label(n): L[n] = sum(i.size() for i in I) |
| 28 | def br(i, t): B.append((len(I)-1, t)) |
| 29 | |
| 30 | e(s_load_b128(sdata=s[4:7], sbase=s[0:1], ioffset=0, soffset=NULL)) |
| 31 | e(s_load_b64(sdata=s[8:9], sbase=s[0:1], ioffset=0x10, soffset=NULL)) |
| 32 | e(s_wait_kmcnt(simm16=0)) |
| 33 | e(s_mov_b32(s[10], ttmp[9])); e(s_and_b32(s[11], ttmp[7], 0xFFFF)) |
| 34 | e(s_lshl_b32(s[10], s[10], 7)); e(s_lshl_b32(s[11], s[11], 7)) |
| 35 | e(s_mov_b32(s[12], N)); e(s_lshl_b32(s[13], s[12], 1)) |
| 36 | e(s_mul_i32(s[14], s[12], BLOCK_K*ELEM)) |
| 37 | e(s_add_co_i32(s[17], s[12], -2*BLOCK_K)) # loop bound |
| 38 | |
| 39 | e(v_and_b32_e32(v[1], 31, v[0])); e(v_lshrrev_b32_e32(v[2], 5, v[0])) |
| 40 | e(v_and_b32_e32(v[3], 1, v[2])); e(v_lshrrev_b32_e32(v[2], 1, v[2])) |
| 41 | |
| 42 | e(v_lshlrev_b32_e32(v[4], 5, v[0])) |
| 43 | # B store: transposed layout for stride-32 reads. addr = LDS_B_OFF + (tid%8)*512 + (tid/8)*32 |
| 44 | e(v_and_b32_e32(v[48], 7, v[0])); e(v_lshlrev_b32_e32(v[5], 9, v[48])) # (tid%8)*512 |
| 45 | e(v_lshrrev_b32_e32(v[48], 3, v[0])); e(v_lshlrev_b32_e32(v[48], 5, v[48])) # (tid/8)*32 |
| 46 | e(v_add_nc_u32_e32(v[5], v[5], v[48])); e(v_add_nc_u32_e32(v[5], LDS_B_OFF, v[5])) |
| 47 | |
| 48 | e(v_add_nc_u32_e32(v[48], s[11], v[0])) |
| 49 | e(v_mul_lo_u32(v[6], v[48], N*ELEM)); e(v_mov_b32_e32(v[7], 0)) |
| 50 | e(v_lshrrev_b32_e32(v[48], 3, v[0])); e(v_mul_lo_u32(v[8], v[48], N*ELEM)) |
| 51 | e(v_and_b32_e32(v[48], 7, v[0])); e(v_lshlrev_b32_e32(v[48], 5, v[48])) |
| 52 | e(v_add_nc_u32_e32(v[8], v[8], v[48])) |
| 53 | e(s_mul_i32(s[15], s[10], ELEM)); e(v_add_nc_u32_e32(v[8], s[15], v[8])) |
| 54 | e(v_mov_b32_e32(v[9], 0)) |
| 55 | |
| 56 | # LDS read addrs with padded strides (eliminates bank conflicts) |
| 57 | # A: (lane%16)*LDS_A_ROW + (lane/16)*16 + wave_m*64*LDS_A_ROW |
| 58 | # B: (lane%16)*LDS_B_ROW + (lane/16)*16 + wave_n*64*ELEM + LDS_B_OFF |
| 59 | LLA, LLB = 40, 43 |
| 60 | e(v_and_b32_e32(v[50], 15, v[1])); e(v_lshrrev_b32_e32(v[51], 4, v[1])) |
| 61 | e(v_lshlrev_b32_e32(v[LLA], 5, v[50])) # (lane%16) * 32 |
| 62 | e(v_lshlrev_b32_e32(v[51], 4, v[51])) # (lane/16) * 16 |
| 63 | e(v_add_nc_u32_e32(v[LLA], v[LLA], v[51])) |
| 64 | e(v_lshlrev_b32_e32(v[52], 11, v[2])) # wave_m * 2048 |
| 65 | e(v_add_nc_u32_e32(v[LLA], v[LLA], v[52])) |
| 66 | # B read: transposed layout. addr = LDS_B_OFF + (lane%16)*32 + (lane/16)*16 + wave_n*2*512 |
| 67 | # wave_n selects column panels: wave_n*2 panels (each panel=16 cols, wave_n covers 64 cols = 4 panels) |
| 68 | # But wave_n*2*512 = wave_n*1024. Hmm, wave_n covers cols [wave_n*64 : (wave_n+1)*64]. |
| 69 | # Each panel = 16 cols = 512 bytes. wave_n*64/16 = wave_n*4 panels. Offset = wave_n*4*512 = wave_n*2048. |
| 70 | e(v_lshlrev_b32_e32(v[LLB], 5, v[50])) # (lane%16) * 32 (stride 32!) |
| 71 | e(v_add_nc_u32_e32(v[LLB], v[LLB], v[51])) # + (lane/16)*16 |
| 72 | e(v_lshlrev_b32_e32(v[52], 11, v[3])) # wave_n * 2048 |
| 73 | e(v_add_nc_u32_e32(v[LLB], v[LLB], v[52])) |
| 74 | e(v_add_nc_u32_e32(v[LLB], LDS_B_OFF, v[LLB])) |
| 75 | |
| 76 | for i in range(0, 128, 2): |
| 77 | e(VOPD(VOPDOp.V_DUAL_MOV_B32, VOPDOp.V_DUAL_MOV_B32, vdstx=v[ACC+i], vdsty=v[ACC+i+1], srcx0=0, srcy0=0)) |
| 78 | e(s_mov_b32(s[16], 0)) |
| 79 |
searching dependent graphs…