(dest:UOp, src:UOp)
| 55 | return store.end(i, j).sink(arg=KernelInfo(name=f"flip_contract_{dest.numel()}", opts_to_apply=())) |
| 56 | |
| 57 | def slice_sum_kernel(dest:UOp, src:UOp): |
| 58 | G = UOp.range(src.shape[0], 0) |
| 59 | slice_src = src[G, :] |
| 60 | reg = UOp.placeholder((1,), dest.dtype.base, 0, addrspace=AddrSpace.REG) |
| 61 | reg = reg.after(G)[0].set(0) |
| 62 | R = UOp.range(src.shape[1], 1, AxisType.REDUCE) |
| 63 | reg = reg[0].set(reg.after(R)[0] + slice_src[R], end=R) |
| 64 | ast = dest[G].set(reg[0], end=G) |
| 65 | return ast.sink(arg=KernelInfo(name=f"slice_sum_{src.shape[0]}_{src.shape[1]}", opts_to_apply=())) |
| 66 | |
| 67 | def simple_qkv_kernel(O:UOp, Q:UOp, K:UOp, V:UOp) -> UOp: |
| 68 | # attention without softmax |
nothing calls this directly
no test coverage detected
searching dependent graphs…