(N)
| 197 | # ============================================================================= |
| 198 | |
| 199 | def build_kernel(N): |
| 200 | assert N % 128 == 0, f"N must be a multiple of 128 (tile size), got {N}" |
| 201 | assert N >= 256, f"N must be >= 256 (prefetch pipeline requires at least 2 K-blocks), got {N}" |
| 202 | k = Kernel() |
| 203 | |
| 204 | # =========================================================================== |
| 205 | # PROLOGUE: Load kernel arguments, compute tile coordinates and addresses |
| 206 | # =========================================================================== |
| 207 | k.emit(s_load_b128(sdata=s[S_KERNARG_A[0]:S_KERNARG_B[1]], sbase=s[0:1], offset=0x0, soffset=NULL)) |
| 208 | k.emit(s_load_b64(sdata=s[S_OUT_PTR[0]:S_OUT_PTR[1]], sbase=s[0:1], offset=0x10, soffset=NULL)) |
| 209 | k.emit(s_mov_b32(s[S_DIM_N], N)) |
| 210 | k.emit(s_mov_b32(s[S_LOOP_CTR], 0)) # used by LDS swizzle, always 0 for valid workgroups |
| 211 | k.emit(s_lshl_b32(s[S_TILE_X], s[S_WORKGROUP_X], 7)) |
| 212 | k.emit(s_lshl_b32(s[S_TILE_Y], s[S_WORKGROUP_Y], 7)) |
| 213 | |
| 214 | # Lane-derived values |
| 215 | k.emit(v_and_b32_e32(v[V_LANE_ID_MOD8], 7, v[V_LANE_ID])) |
| 216 | k.emit(v_lshrrev_b32_e32(v[4], 3, v[V_LANE_ID])) |
| 217 | k.emit(v_or_b32_e32(v[1], s[S_TILE_X], v[V_LANE_ID])) |
| 218 | k.emit(v_or_b32_e32(v[22], s[S_TILE_Y], v[4])) |
| 219 | k.emit(v_lshlrev_b32_e32(v[V_LANE_MOD8_X4], 2, v[V_LANE_ID_MOD8])) |
| 220 | k.waitcnt(lgkm=0) |
| 221 | |
| 222 | # Compute 8 A and B matrix tile base pointers for prefetch |
| 223 | k.emit(s_mov_b64(s[S_PREFETCH_B:S_PREFETCH_B+1], s[S_KERNARG_B[0]:S_KERNARG_B[1]])) # B[0]: no offset |
| 224 | for i in range(1, 8): # B: each pointer 1 row of B apart (N*4 bytes) |
| 225 | k.emit(s_add_u32(s[S_PREFETCH_B+i*2], s[S_KERNARG_B[0]], i * N * 4)) |
| 226 | k.emit(s_addc_u32(s[S_PREFETCH_B+i*2+1], s[S_KERNARG_B[1]], 0)) |
| 227 | k.emit(s_mov_b64(s[S_PREFETCH_A:S_PREFETCH_A+1], s[S_KERNARG_A[0]:S_KERNARG_A[1]])) # A[0]: no offset |
| 228 | for i in range(1, 8): # A: each pointer 16 rows of A apart (16*N*4 bytes) |
| 229 | k.emit(s_add_u32(s[S_PREFETCH_A+i*2], s[S_KERNARG_A[0]], i * N * 64)) |
| 230 | k.emit(s_addc_u32(s[S_PREFETCH_A+i*2+1], s[S_KERNARG_A[1]], 0)) |
| 231 | |
| 232 | # Global prefetch addresses: B = (tile_x + lane_id) * 4, A = (tile_y*N + (lane_id/8)*N + lane_id%8) * 4 |
| 233 | k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], s[S_TILE_X], v[V_LANE_ID])) |
| 234 | k.emit(v_lshlrev_b32_e32(v[V_GLOBAL_B_ADDR], 2, v[V_GLOBAL_B_ADDR])) |
| 235 | k.emit(s_mul_i32(s[19], s[S_TILE_Y], N)) |
| 236 | k.emit(v_mul_lo_u32(v[V_GLOBAL_A_ADDR], v[4], N)) # (lane_id/8)*N |
| 237 | k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], v[V_LANE_ID_MOD8], v[V_GLOBAL_A_ADDR])) # + lane_id%8 |
| 238 | k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], s[19], v[V_GLOBAL_A_ADDR])) |
| 239 | k.emit(v_lshlrev_b32_e32(v[V_GLOBAL_A_ADDR], 2, v[V_GLOBAL_A_ADDR])) |
| 240 | |
| 241 | # Do initial loads |
| 242 | for vdst, saddr_lo in INIT_PREFETCH: |
| 243 | k.emit(global_load_b32(vdst=v[vdst], addr=v[V_GLOBAL_B_ADDR], saddr=s[saddr_lo:saddr_lo+1])) |
| 244 | for iter in range(6): |
| 245 | vdst1, vdst2, addr, slo1, slo2 = PREFETCH_LOADS[iter] |
| 246 | k.emit(global_load_b32(vdst=v[vdst1], addr=v[addr], saddr=s[slo1:slo1+1])) |
| 247 | k.emit(global_load_b32(vdst=v[vdst2], addr=v[addr], saddr=s[slo2:slo2+1])) |
| 248 | |
| 249 | # =========================================================================== |
| 250 | # LDS store address computation (bank-conflict-avoiding swizzle) |
| 251 | # =========================================================================== |
| 252 | # This section computes LDS store addresses with a swizzle pattern to avoid bank conflicts. |
| 253 | # The swizzle ensures that threads in the same wavefront write to different LDS banks. |
| 254 | # Formula: swizzled_addr = base + (lane_id & 7) * LDS_A_STRIDE + swizzle_offset |
| 255 | # where swizzle_offset depends on (lane_id >> 3) to distribute across banks. |
| 256 | k.emit(v_add_nc_u32_e32(v[9], s[S_LOOP_CTR], v[22])) # row 0 base |
searching dependent graphs…