(q, k, v, kvScale, headElems)
| 154 | |
| 155 | |
| 156 | def compute(q, k, v, kvScale, headElems): |
| 157 | qkScale = (headElems**-0.5) * kvScale |
| 158 | qk = q @ k.T * qkScale |
| 159 | row_max = np.max(qk, axis=1).reshape(-1, 1) |
| 160 | x = np.exp(qk - row_max) |
| 161 | row_sum = np.sum(x, axis=1).reshape(-1, 1) |
| 162 | x @ v * (kvScale / row_sum) |
| 163 | return x, row_max, row_sum |