MCPcopy
hub / github.com/tinygrad/tinygrad / build_kernel

Function build_kernel

extra/gemm/rdna4_asm_matmul.py:22–199  ·  view source on GitHub ↗
(N, arch='gfx1200')

Source from the content-addressed store, hash-verified

20ACC, DA, DB, FA, FB, ET = 60, 188, 196, 204, 44, 10
21
22def build_kernel(N, arch='gfx1200'):
23 assert N % BLOCK_M == 0 and N >= 256
24 NO_ALU, NO_DS, NO_GLOBAL = getenv("NO_ALU", 0), getenv("NO_DS", 0), getenv("NO_GLOBAL", 0)
25 I, L, B = [], {}, []
26 def e(i): I.append(i); return i
27 def label(n): L[n] = sum(i.size() for i in I)
28 def br(i, t): B.append((len(I)-1, t))
29
30 e(s_load_b128(sdata=s[4:7], sbase=s[0:1], ioffset=0, soffset=NULL))
31 e(s_load_b64(sdata=s[8:9], sbase=s[0:1], ioffset=0x10, soffset=NULL))
32 e(s_wait_kmcnt(simm16=0))
33 e(s_mov_b32(s[10], ttmp[9])); e(s_and_b32(s[11], ttmp[7], 0xFFFF))
34 e(s_lshl_b32(s[10], s[10], 7)); e(s_lshl_b32(s[11], s[11], 7))
35 e(s_mov_b32(s[12], N)); e(s_lshl_b32(s[13], s[12], 1))
36 e(s_mul_i32(s[14], s[12], BLOCK_K*ELEM))
37 e(s_add_co_i32(s[17], s[12], -2*BLOCK_K)) # loop bound
38
39 e(v_and_b32_e32(v[1], 31, v[0])); e(v_lshrrev_b32_e32(v[2], 5, v[0]))
40 e(v_and_b32_e32(v[3], 1, v[2])); e(v_lshrrev_b32_e32(v[2], 1, v[2]))
41
42 e(v_lshlrev_b32_e32(v[4], 5, v[0]))
43 # B store: transposed layout for stride-32 reads. addr = LDS_B_OFF + (tid%8)*512 + (tid/8)*32
44 e(v_and_b32_e32(v[48], 7, v[0])); e(v_lshlrev_b32_e32(v[5], 9, v[48])) # (tid%8)*512
45 e(v_lshrrev_b32_e32(v[48], 3, v[0])); e(v_lshlrev_b32_e32(v[48], 5, v[48])) # (tid/8)*32
46 e(v_add_nc_u32_e32(v[5], v[5], v[48])); e(v_add_nc_u32_e32(v[5], LDS_B_OFF, v[5]))
47
48 e(v_add_nc_u32_e32(v[48], s[11], v[0]))
49 e(v_mul_lo_u32(v[6], v[48], N*ELEM)); e(v_mov_b32_e32(v[7], 0))
50 e(v_lshrrev_b32_e32(v[48], 3, v[0])); e(v_mul_lo_u32(v[8], v[48], N*ELEM))
51 e(v_and_b32_e32(v[48], 7, v[0])); e(v_lshlrev_b32_e32(v[48], 5, v[48]))
52 e(v_add_nc_u32_e32(v[8], v[8], v[48]))
53 e(s_mul_i32(s[15], s[10], ELEM)); e(v_add_nc_u32_e32(v[8], s[15], v[8]))
54 e(v_mov_b32_e32(v[9], 0))
55
56 # LDS read addrs with padded strides (eliminates bank conflicts)
57 # A: (lane%16)*LDS_A_ROW + (lane/16)*16 + wave_m*64*LDS_A_ROW
58 # B: (lane%16)*LDS_B_ROW + (lane/16)*16 + wave_n*64*ELEM + LDS_B_OFF
59 LLA, LLB = 40, 43
60 e(v_and_b32_e32(v[50], 15, v[1])); e(v_lshrrev_b32_e32(v[51], 4, v[1]))
61 e(v_lshlrev_b32_e32(v[LLA], 5, v[50])) # (lane%16) * 32
62 e(v_lshlrev_b32_e32(v[51], 4, v[51])) # (lane/16) * 16
63 e(v_add_nc_u32_e32(v[LLA], v[LLA], v[51]))
64 e(v_lshlrev_b32_e32(v[52], 11, v[2])) # wave_m * 2048
65 e(v_add_nc_u32_e32(v[LLA], v[LLA], v[52]))
66 # B read: transposed layout. addr = LDS_B_OFF + (lane%16)*32 + (lane/16)*16 + wave_n*2*512
67 # wave_n selects column panels: wave_n*2 panels (each panel=16 cols, wave_n covers 64 cols = 4 panels)
68 # But wave_n*2*512 = wave_n*1024. Hmm, wave_n covers cols [wave_n*64 : (wave_n+1)*64].
69 # Each panel = 16 cols = 512 bytes. wave_n*64/16 = wave_n*4 panels. Offset = wave_n*4*512 = wave_n*2048.
70 e(v_lshlrev_b32_e32(v[LLB], 5, v[50])) # (lane%16) * 32 (stride 32!)
71 e(v_add_nc_u32_e32(v[LLB], v[LLB], v[51])) # + (lane/16)*16
72 e(v_lshlrev_b32_e32(v[52], 11, v[3])) # wave_n * 2048
73 e(v_add_nc_u32_e32(v[LLB], v[LLB], v[52]))
74 e(v_add_nc_u32_e32(v[LLB], LDS_B_OFF, v[LLB]))
75
76 for i in range(0, 128, 2):
77 e(VOPD(VOPDOp.V_DUAL_MOV_B32, VOPDOp.V_DUAL_MOV_B32, vdstx=v[ACC+i], vdsty=v[ACC+i+1], srcx0=0, srcy0=0))
78 e(s_mov_b32(s[16], 0))
79

Callers 1

test_matmulFunction · 0.70

Calls 7

getenvFunction · 0.90
labelFunction · 0.85
emit_iter_bodyFunction · 0.85
brFunction · 0.85
eFunction · 0.70
VOPDClass · 0.50
sizeMethod · 0.45

Tested by 1

test_matmulFunction · 0.56

Used in the wild real call sites across dependent graphs

searching dependent graphs…