MCPcopy
hub / github.com/tinygrad/tinygrad / exec

Method exec

tinygrad/runtime/ops_amd.py:320–368  ·  view source on GitHub ↗
(self, prg:AMDProgram, args_state:CLikeArgsState, global_size:tuple[sint, ...], local_size:tuple[sint, ...])

Source from the content-addressed store, hash-verified

318 return self
319
320 def exec(self, prg:AMDProgram, args_state:CLikeArgsState, global_size:tuple[sint, ...], local_size:tuple[sint, ...]):
321 self.bind_args_state(args_state)
322
323 self.acquire_mem(gli=0, gl2=0)
324
325 user_regs = []
326 if prg.enable_private_segment_sgpr:
327 assert self.dev.xccs == 1, "Only architected flat scratch is supported on multi-xcc"
328 scratch_hilo = data64_le(prg.dev.scratch.va_addr)
329 # sgpr word1 bit31 enables swizzle
330 # sgpr word3 = 0x14 << 12 | 2 << 28 | 2 << 21 | 1 << 23
331 user_regs = [scratch_hilo[0], scratch_hilo[1] | 1 << 31, 0xffffffff, 0x20c14000]
332
333 if prg.enable_dispatch_ptr:
334 dp = (dp_t:=hsa.hsa_kernel_dispatch_packet_t).from_address(int((disp_buf:=args_state.buf.offset(prg.kernargs_segment_size)).va_addr))
335
336 self.bind_sints(*local_size, mem=disp_buf.cpu_view(), struct_t=dp_t, start_field='workgroup_size_x', fmt='H')
337 self.bind_sints(*[g*l for g,l in zip(global_size, local_size)], mem=disp_buf.cpu_view(), struct_t=dp_t, start_field='grid_size_x', fmt='I')
338 dp.group_segment_size, dp.private_segment_size = prg.group_segment_size, prg.private_segment_size
339 dp.kernarg_address = cast(ctypes.c_void_p, args_state.buf.va_addr)
340 user_regs += [*data64_le(disp_buf.va_addr)]
341
342 user_regs += [*data64_le(args_state.buf.va_addr)]
343
344 if prg.dev.sqtt_enabled: self.sqtt_setup_exec(prg, global_size)
345
346 self.wreg(self.gc.regCOMPUTE_PGM_LO, *data64_le(prg.prog_addr >> 8))
347 self.wreg(self.gc.regCOMPUTE_PGM_RSRC1, prg.rsrc1, prg.rsrc2)
348 self.wreg(self.gc.regCOMPUTE_PGM_RSRC3, prg.rsrc3)
349 self.wreg(self.gc.regCOMPUTE_TMPRING_SIZE, prg.dev.tmpring_size)
350
351 # this is what llvm refers to as "architected flat scratch"
352 for xcc_id in range(self.dev.xccs):
353 with self.pred_exec(xcc_mask=1<<xcc_id):
354 scratch_base = prg.dev.scratch.va_addr + (prg.dev.scratch.size // self.dev.xccs * xcc_id)
355 self.wreg(self.gc.regCOMPUTE_DISPATCH_SCRATCH_BASE_LO, *data64_le(scratch_base >> 8))
356
357 self.wreg(self.gc.regCOMPUTE_RESTART_X, 0, 0, 0)
358 self.wreg(self.gc.regCOMPUTE_USER_DATA_0, *user_regs)
359 self.wreg(self.gc.regCOMPUTE_RESOURCE_LIMITS, waves_per_sh=getenv("WAVES_PER_SH"))
360 self.wreg(self.gc.regCOMPUTE_START_X, 0, 0, 0, *local_size, 0, 0)
361
362 self.pkt3(self.pm4.PACKET3_DISPATCH_DIRECT, *global_size,
363 self.gc.regCOMPUTE_DISPATCH_INITIATOR.encode(**({'cs_w32_en': int(prg.wave32)} if prg.dev.target[0] != 9 else {}),
364 force_start_at_000=1, compute_shader_en=1))
365
366 if prg.dev.sqtt_enabled: self.pkt3(self.pm4.PACKET3_EVENT_WRITE, self.pm4.EVENT_TYPE(self.soc.THREAD_TRACE_MARKER) | self.pm4.EVENT_INDEX(0))
367 self.pkt3(self.pm4.PACKET3_EVENT_WRITE, self.pm4.EVENT_TYPE(self.soc.CS_PARTIAL_FLUSH) | self.pm4.EVENT_INDEX(EVENT_INDEX_PARTIAL_FLUSH))
368 return self
369
370 def wait(self, signal:AMDSignal, value:sint=0): return self.wait_reg_mem(mem=signal.value_addr, value=value, mask=0xffffffff)
371

Callers

nothing calls this directly

Calls 13

acquire_memMethod · 0.95
sqtt_setup_execMethod · 0.95
wregMethod · 0.95
pred_execMethod · 0.95
pkt3Method · 0.95
data64_leFunction · 0.90
getenvFunction · 0.90
castFunction · 0.85
bind_args_stateMethod · 0.80
bind_sintsMethod · 0.80
offsetMethod · 0.45
cpu_viewMethod · 0.45

Tested by

no test coverage detected