MCPcopy
hub / github.com/tinygrad/tinygrad / custom_handwritten

Function custom_handwritten

test/amd/test_custom_kernel.py:102–146  ·  view source on GitHub ↗
(A:UOp)

Source from the content-addressed store, hash-verified

100 return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
101
102def custom_handwritten(A:UOp) -> UOp:
103 A = A.flatten()
104 threads = UOp.special(128, "lidx0")
105 wg = UOp.special(1, "gidx0")
106 lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=512, addrspace=AddrSpace.LOCAL), (), 'lds') # 128 * 4 bytes
107 pipes = {getenv("PIPE", "")} if getenv("PIPE", "") else {"SALU", "VALU", "TRANSCENDENTAL", "WMMA"}
108 k = Kernel()
109 # wrap in loop to filter out icache misses
110 LOOP_N, UNROLL_N = 8, 5
111 k.emit(r4.s_mov_b32(s[1], LOOP_N))
112 k.label("loop")
113 if "SALU" in pipes:
114 for i in range(UNROLL_N):
115 k.emit(r4.s_mov_b32(s[20+i], i))
116 k.emit(r4.s_min_i32(s[30+i], i))
117 k.emit(r4.s_mov_b32(s[40+i], i))
118 k.emit(r4.s_mul_i32(s[14+i], s[12+i], 32))
119 if "VALU" in pipes:
120 for i in range(UNROLL_N):
121 k.emit(r4.v_mov_b32_e32(v[20+i], i))
122 k.emit(r4.v_lshlrev_b64_e32(v[30+2*i:31+2*i], 2, v[12+i:13+i]))
123 k.emit(r4.v_mad_co_u64_u32(v[40+2*i:41+2*i], NULL, v[12+i], v[13+i], v[14+i:15+i]))
124 if "TRANSCENDENTAL" in pipes:
125 # transcendental VALU runs on the TFU, it can run regular VALU at the same time
126 for i in range(UNROLL_N):
127 k.emit(r4.v_mov_b32_e32(v[20+i], i))
128 k.emit(r4.v_s_rcp_f32(s[20+i], s[12+i]))
129 k.emit(r4.v_rcp_f32_e32(v[30+i], v[12+i]))
130 k.emit(r4.v_s_exp_f32(s[30+i], s[12+i]))
131 if "WMMA" in pipes:
132 base = 30
133 for i in range(UNROLL_N):
134 a = base + i*40
135 b, cd = a + 4, a + 8
136 k.emit(r4.v_wmma_f32_16x16x16_f16(v[cd:cd+7], v[a:a+3], v[b:b+3], v[cd:cd+7]))
137 a = base + i*40 + 16
138 b, cd = a + 2, a + 4
139 k.emit(r4.v_wmma_i32_16x16x16_iu8(v[cd:cd+7], v[a:a+1], v[b:b+1], v[cd:cd+7]))
140 k.emit(r4.s_add_co_i32(s[1], s[1], -1))
141 k.emit(r4.s_cmp_eq_i32(s[1], 0))
142 k.emit(r4.s_cbranch_scc0(), target="loop")
143 k.emit(r4.s_endpgm())
144 insts = k.finalize()
145 sink = UOp.sink(A.base, threads, wg, lds, arg=KernelInfo("custom_handwritten"))
146 return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
147
148def custom_data_deps(A:UOp) -> UOp:
149 A = A.flatten()

Callers

nothing calls this directly

Calls 11

emitMethod · 0.95
labelMethod · 0.95
finalizeMethod · 0.95
UOpClass · 0.90
getenvFunction · 0.90
KernelClass · 0.90
KernelInfoClass · 0.90
flattenMethod · 0.80
specialMethod · 0.80
ptrMethod · 0.45
sinkMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…