MCPcopy Index your code
hub / github.com/tinygrad/tinygrad / program

Method program

extra/hcq2/ops_amd2.py:101–133  ·  view source on GitHub ↗
(self, x)

Source from the content-addressed store, hash-verified

99 self.pm4.int_sel__mec_release_mem__none)
100
101 def program(self, x):
102 data, info = x.arg
103 lib_gpu, args = x.src
104 prog_addr = self.get_dev_addr(lib_gpu) + data.entry_point_offset
105
106 self.acquire_mem(gli=0, gl2=0)
107
108 args_addr = self.get_dev_addr(args)
109 user_regs = []
110 if data.enable_private_segment_sgpr:
111 scratch_hilo = data64_le(self.dev.scratch.va_addr)
112 user_regs = [scratch_hilo[0], scratch_hilo[1] | 1 << 31, 0xffffffff, 0x20c14000]
113 if data.enable_dispatch_ptr: user_regs += [*data64_le(args_addr + data.kernargs_segment_size)]
114 user_regs += [*data64_le(args_addr)]
115
116 self.wreg(self.gc.regCOMPUTE_PGM_LO, *data64_le(prog_addr >> 8))
117 self.wreg(self.gc.regCOMPUTE_PGM_RSRC1, data.rsrc1, data.rsrc2)
118 self.wreg(self.gc.regCOMPUTE_PGM_RSRC3, data.rsrc3)
119 self.wreg(self.gc.regCOMPUTE_TMPRING_SIZE, self.dev.tmpring_size)
120
121 for xcc_id in range(self.dev.xccs):
122 scratch_base = self.dev.scratch.va_addr + (self.dev.scratch.size // self.dev.xccs * xcc_id)
123 self.wreg(self.gc.regCOMPUTE_DISPATCH_SCRATCH_BASE_LO, *data64_le(scratch_base >> 8))
124
125 self.wreg(self.gc.regCOMPUTE_RESTART_X, 0, 0, 0)
126 self.wreg(self.gc.regCOMPUTE_USER_DATA_0, *user_regs)
127 self.wreg(self.gc.regCOMPUTE_RESOURCE_LIMITS, self.gc.regCOMPUTE_RESOURCE_LIMITS.encode(waves_per_sh=getenv("WAVES_PER_SH")))
128 self.wreg(self.gc.regCOMPUTE_START_X, 0, 0, 0, *(info.local_size or (1, 1, 1)), 0, 0)
129
130 dispatch_init = self.gc.regCOMPUTE_DISPATCH_INITIATOR.encode(
131 **({'cs_w32_en': int(data.wave32)} if self.dev.target[0] != 9 else {}), force_start_at_000=1, compute_shader_en=1)
132 self.pkt3(self.pm4.PACKET3_DISPATCH_DIRECT, *info.global_size, dispatch_init)
133 self.pkt3(self.pm4.PACKET3_EVENT_WRITE, self.pm4.EVENT_TYPE(self.soc.CS_PARTIAL_FLUSH) | self.pm4.EVENT_INDEX(EVENT_INDEX_PARTIAL_FLUSH))
134
135amd_inner_pm = PatternMatcher([
136 (UPat(Ops.LINEAR, src=(UPat(Ops.WAIT, name="x"),)), lambda ctx, x: ctx.wait(x)),

Callers 1

ops_amd2.pyFile · 0.80

Calls 7

acquire_memMethod · 0.95
wregMethod · 0.95
pkt3Method · 0.95
data64_leFunction · 0.90
getenvFunction · 0.90
get_dev_addrMethod · 0.80
encodeMethod · 0.45

Tested by

no test coverage detected