(self, prg:AMDProgram, args_state:CLikeArgsState, global_size:tuple[sint, ...], local_size:tuple[sint, ...])
| 318 | return self |
| 319 | |
| 320 | def exec(self, prg:AMDProgram, args_state:CLikeArgsState, global_size:tuple[sint, ...], local_size:tuple[sint, ...]): |
| 321 | self.bind_args_state(args_state) |
| 322 | |
| 323 | self.acquire_mem(gli=0, gl2=0) |
| 324 | |
| 325 | user_regs = [] |
| 326 | if prg.enable_private_segment_sgpr: |
| 327 | assert self.dev.xccs == 1, "Only architected flat scratch is supported on multi-xcc" |
| 328 | scratch_hilo = data64_le(prg.dev.scratch.va_addr) |
| 329 | # sgpr word1 bit31 enables swizzle |
| 330 | # sgpr word3 = 0x14 << 12 | 2 << 28 | 2 << 21 | 1 << 23 |
| 331 | user_regs = [scratch_hilo[0], scratch_hilo[1] | 1 << 31, 0xffffffff, 0x20c14000] |
| 332 | |
| 333 | if prg.enable_dispatch_ptr: |
| 334 | dp = (dp_t:=hsa.hsa_kernel_dispatch_packet_t).from_address(int((disp_buf:=args_state.buf.offset(prg.kernargs_segment_size)).va_addr)) |
| 335 | |
| 336 | self.bind_sints(*local_size, mem=disp_buf.cpu_view(), struct_t=dp_t, start_field='workgroup_size_x', fmt='H') |
| 337 | self.bind_sints(*[g*l for g,l in zip(global_size, local_size)], mem=disp_buf.cpu_view(), struct_t=dp_t, start_field='grid_size_x', fmt='I') |
| 338 | dp.group_segment_size, dp.private_segment_size = prg.group_segment_size, prg.private_segment_size |
| 339 | dp.kernarg_address = cast(ctypes.c_void_p, args_state.buf.va_addr) |
| 340 | user_regs += [*data64_le(disp_buf.va_addr)] |
| 341 | |
| 342 | user_regs += [*data64_le(args_state.buf.va_addr)] |
| 343 | |
| 344 | if prg.dev.sqtt_enabled: self.sqtt_setup_exec(prg, global_size) |
| 345 | |
| 346 | self.wreg(self.gc.regCOMPUTE_PGM_LO, *data64_le(prg.prog_addr >> 8)) |
| 347 | self.wreg(self.gc.regCOMPUTE_PGM_RSRC1, prg.rsrc1, prg.rsrc2) |
| 348 | self.wreg(self.gc.regCOMPUTE_PGM_RSRC3, prg.rsrc3) |
| 349 | self.wreg(self.gc.regCOMPUTE_TMPRING_SIZE, prg.dev.tmpring_size) |
| 350 | |
| 351 | # this is what llvm refers to as "architected flat scratch" |
| 352 | for xcc_id in range(self.dev.xccs): |
| 353 | with self.pred_exec(xcc_mask=1<<xcc_id): |
| 354 | scratch_base = prg.dev.scratch.va_addr + (prg.dev.scratch.size // self.dev.xccs * xcc_id) |
| 355 | self.wreg(self.gc.regCOMPUTE_DISPATCH_SCRATCH_BASE_LO, *data64_le(scratch_base >> 8)) |
| 356 | |
| 357 | self.wreg(self.gc.regCOMPUTE_RESTART_X, 0, 0, 0) |
| 358 | self.wreg(self.gc.regCOMPUTE_USER_DATA_0, *user_regs) |
| 359 | self.wreg(self.gc.regCOMPUTE_RESOURCE_LIMITS, waves_per_sh=getenv("WAVES_PER_SH")) |
| 360 | self.wreg(self.gc.regCOMPUTE_START_X, 0, 0, 0, *local_size, 0, 0) |
| 361 | |
| 362 | self.pkt3(self.pm4.PACKET3_DISPATCH_DIRECT, *global_size, |
| 363 | self.gc.regCOMPUTE_DISPATCH_INITIATOR.encode(**({'cs_w32_en': int(prg.wave32)} if prg.dev.target[0] != 9 else {}), |
| 364 | force_start_at_000=1, compute_shader_en=1)) |
| 365 | |
| 366 | if prg.dev.sqtt_enabled: self.pkt3(self.pm4.PACKET3_EVENT_WRITE, self.pm4.EVENT_TYPE(self.soc.THREAD_TRACE_MARKER) | self.pm4.EVENT_INDEX(0)) |
| 367 | self.pkt3(self.pm4.PACKET3_EVENT_WRITE, self.pm4.EVENT_TYPE(self.soc.CS_PARTIAL_FLUSH) | self.pm4.EVENT_INDEX(EVENT_INDEX_PARTIAL_FLUSH)) |
| 368 | return self |
| 369 | |
| 370 | def wait(self, signal:AMDSignal, value:sint=0): return self.wait_reg_mem(mem=signal.value_addr, value=value, mask=0xffffffff) |
| 371 |
nothing calls this directly
no test coverage detected