| 145 | |
| 146 | def test_custom_kernel_three_output_backward(self): |
| 147 | def addmulsub_kernel(C:UOp, D:UOp, E:UOp, A:UOp, B:UOp) -> UOp: |
| 148 | C, D, E, A, B = C.flatten(), D.flatten(), E.flatten(), A.flatten(), B.flatten() |
| 149 | i = UOp.range(C.numel(), 0) |
| 150 | store_c = C[i].store(A[i] + B[i]) |
| 151 | store_d = D[i].store(A[i] * B[i]) |
| 152 | store_e = E[i].store(A[i] - B[i]) |
| 153 | return UOp.group(store_c, store_d, store_e).end(i).sink(arg=KernelInfo(name="addmulsub")).simplify() |
| 154 | def backward_addmulsub(grad_c, grad_d, grad_e, call): |
| 155 | _c, _d, _e, a, b = call.src[1:] |
| 156 | grad_a = (Tensor(grad_c) + Tensor(grad_d) * Tensor(b) + Tensor(grad_e)).uop |