(C:UOp, D:UOp, A:UOp, B:UOp)
| 107 | class TestMultiOutputGradient(unittest.TestCase): |
| 108 | @staticmethod |
| 109 | def addmul_kernel(C:UOp, D:UOp, A:UOp, B:UOp) -> UOp: |
| 110 | C, D, A, B = C.flatten(), D.flatten(), A.flatten(), B.flatten() |
| 111 | i = UOp.range(C.numel(), 0) |
| 112 | store_c = C[i].store(A[i] + B[i]) |
| 113 | store_d = D[i].store(A[i] * B[i]) |
| 114 | return UOp.group(store_c, store_d).end(i).sink(arg=KernelInfo(name="addmul")).simplify() |
| 115 | @staticmethod |
| 116 | def backward_addmul(grad_c, grad_d, call): |
| 117 | _c, _d, a, b = call.src[1:] |