MCPcopy
hub / github.com/deepseek-ai/DeepSeek-V3 / fp8_gemm

Function fp8_gemm

inference/kernel.py:170–191  ·  view source on GitHub ↗

Perform a matrix multiplication using FP8 precision. Args: a (torch.Tensor): The first input matrix, must be contiguous. a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous. b (torch.Tensor): The second input matrix, must be contigu

(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor)

Source from the content-addressed store, hash-verified

168
169
170def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
171 """
172 Perform a matrix multiplication using FP8 precision.
173
174 Args:
175 a (torch.Tensor): The first input matrix, must be contiguous.
176 a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
177 b (torch.Tensor): The second input matrix, must be contiguous.
178 b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
179
180 Returns:
181 torch.Tensor: The result of the matrix multiplication.
182 """
183 assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous'
184 assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
185 K = a.size(-1)
186 M = a.numel() // K
187 N = b.size(0)
188 c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
189 grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
190 fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
191 return c

Callers 1

linearFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected