MCPcopy
hub / github.com/tinygrad/tinygrad / build_kernel

Function build_kernel

extra/gemm/amd_asm_matmul.py:199–432  ·  view source on GitHub ↗
(N)

Source from the content-addressed store, hash-verified

197# =============================================================================
198
199def build_kernel(N):
200 assert N % 128 == 0, f"N must be a multiple of 128 (tile size), got {N}"
201 assert N >= 256, f"N must be >= 256 (prefetch pipeline requires at least 2 K-blocks), got {N}"
202 k = Kernel()
203
204 # ===========================================================================
205 # PROLOGUE: Load kernel arguments, compute tile coordinates and addresses
206 # ===========================================================================
207 k.emit(s_load_b128(sdata=s[S_KERNARG_A[0]:S_KERNARG_B[1]], sbase=s[0:1], offset=0x0, soffset=NULL))
208 k.emit(s_load_b64(sdata=s[S_OUT_PTR[0]:S_OUT_PTR[1]], sbase=s[0:1], offset=0x10, soffset=NULL))
209 k.emit(s_mov_b32(s[S_DIM_N], N))
210 k.emit(s_mov_b32(s[S_LOOP_CTR], 0)) # used by LDS swizzle, always 0 for valid workgroups
211 k.emit(s_lshl_b32(s[S_TILE_X], s[S_WORKGROUP_X], 7))
212 k.emit(s_lshl_b32(s[S_TILE_Y], s[S_WORKGROUP_Y], 7))
213
214 # Lane-derived values
215 k.emit(v_and_b32_e32(v[V_LANE_ID_MOD8], 7, v[V_LANE_ID]))
216 k.emit(v_lshrrev_b32_e32(v[4], 3, v[V_LANE_ID]))
217 k.emit(v_or_b32_e32(v[1], s[S_TILE_X], v[V_LANE_ID]))
218 k.emit(v_or_b32_e32(v[22], s[S_TILE_Y], v[4]))
219 k.emit(v_lshlrev_b32_e32(v[V_LANE_MOD8_X4], 2, v[V_LANE_ID_MOD8]))
220 k.waitcnt(lgkm=0)
221
222 # Compute 8 A and B matrix tile base pointers for prefetch
223 k.emit(s_mov_b64(s[S_PREFETCH_B:S_PREFETCH_B+1], s[S_KERNARG_B[0]:S_KERNARG_B[1]])) # B[0]: no offset
224 for i in range(1, 8): # B: each pointer 1 row of B apart (N*4 bytes)
225 k.emit(s_add_u32(s[S_PREFETCH_B+i*2], s[S_KERNARG_B[0]], i * N * 4))
226 k.emit(s_addc_u32(s[S_PREFETCH_B+i*2+1], s[S_KERNARG_B[1]], 0))
227 k.emit(s_mov_b64(s[S_PREFETCH_A:S_PREFETCH_A+1], s[S_KERNARG_A[0]:S_KERNARG_A[1]])) # A[0]: no offset
228 for i in range(1, 8): # A: each pointer 16 rows of A apart (16*N*4 bytes)
229 k.emit(s_add_u32(s[S_PREFETCH_A+i*2], s[S_KERNARG_A[0]], i * N * 64))
230 k.emit(s_addc_u32(s[S_PREFETCH_A+i*2+1], s[S_KERNARG_A[1]], 0))
231
232 # Global prefetch addresses: B = (tile_x + lane_id) * 4, A = (tile_y*N + (lane_id/8)*N + lane_id%8) * 4
233 k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], s[S_TILE_X], v[V_LANE_ID]))
234 k.emit(v_lshlrev_b32_e32(v[V_GLOBAL_B_ADDR], 2, v[V_GLOBAL_B_ADDR]))
235 k.emit(s_mul_i32(s[19], s[S_TILE_Y], N))
236 k.emit(v_mul_lo_u32(v[V_GLOBAL_A_ADDR], v[4], N)) # (lane_id/8)*N
237 k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], v[V_LANE_ID_MOD8], v[V_GLOBAL_A_ADDR])) # + lane_id%8
238 k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], s[19], v[V_GLOBAL_A_ADDR]))
239 k.emit(v_lshlrev_b32_e32(v[V_GLOBAL_A_ADDR], 2, v[V_GLOBAL_A_ADDR]))
240
241 # Do initial loads
242 for vdst, saddr_lo in INIT_PREFETCH:
243 k.emit(global_load_b32(vdst=v[vdst], addr=v[V_GLOBAL_B_ADDR], saddr=s[saddr_lo:saddr_lo+1]))
244 for iter in range(6):
245 vdst1, vdst2, addr, slo1, slo2 = PREFETCH_LOADS[iter]
246 k.emit(global_load_b32(vdst=v[vdst1], addr=v[addr], saddr=s[slo1:slo1+1]))
247 k.emit(global_load_b32(vdst=v[vdst2], addr=v[addr], saddr=s[slo2:slo2+1]))
248
249 # ===========================================================================
250 # LDS store address computation (bank-conflict-avoiding swizzle)
251 # ===========================================================================
252 # This section computes LDS store addresses with a swizzle pattern to avoid bank conflicts.
253 # The swizzle ensures that threads in the same wavefront write to different LDS banks.
254 # Formula: swizzled_addr = base + (lane_id & 7) * LDS_A_STRIDE + swizzle_offset
255 # where swizzle_offset depends on (lane_id >> 3) to distribute across banks.
256 k.emit(v_add_nc_u32_e32(v[9], s[S_LOOP_CTR], v[22])) # row 0 base

Callers 1

test_matmulFunction · 0.70

Calls 7

emitMethod · 0.95
waitcntMethod · 0.95
labelMethod · 0.95
finalizeMethod · 0.95
getenvFunction · 0.90
KernelClass · 0.70
VOPDClass · 0.50

Tested by 1

test_matmulFunction · 0.56

Used in the wild real call sites across dependent graphs

searching dependent graphs…