Perform a matrix multiplication using FP8 precision. Args: a (torch.Tensor): The first input matrix, must be contiguous. a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous. b (torch.Tensor): The second input matrix, must be contigu
(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor)
| 168 | |
| 169 | |
| 170 | def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor): |
| 171 | """ |
| 172 | Perform a matrix multiplication using FP8 precision. |
| 173 | |
| 174 | Args: |
| 175 | a (torch.Tensor): The first input matrix, must be contiguous. |
| 176 | a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous. |
| 177 | b (torch.Tensor): The second input matrix, must be contiguous. |
| 178 | b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous. |
| 179 | |
| 180 | Returns: |
| 181 | torch.Tensor: The result of the matrix multiplication. |
| 182 | """ |
| 183 | assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous' |
| 184 | assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous' |
| 185 | K = a.size(-1) |
| 186 | M = a.numel() // K |
| 187 | N = b.size(0) |
| 188 | c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) |
| 189 | grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N'])) |
| 190 | fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K) |
| 191 | return c |