Applies a linear transformation to the incoming data: y = xA^T + b. This function supports specialized implementations based on quantization and tensor formats. Args: x (torch.Tensor): The input tensor. weight (torch.Tensor): The weight tensor. It may be quantized a
(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None)
| 127 | |
| 128 | |
| 129 | def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: |
| 130 | """ |
| 131 | Applies a linear transformation to the incoming data: y = xA^T + b. |
| 132 | This function supports specialized implementations based on quantization |
| 133 | and tensor formats. |
| 134 | |
| 135 | Args: |
| 136 | x (torch.Tensor): The input tensor. |
| 137 | weight (torch.Tensor): The weight tensor. It may be quantized and |
| 138 | requires dequantization for certain cases. |
| 139 | bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None. |
| 140 | |
| 141 | Returns: |
| 142 | torch.Tensor: The result of the linear transformation, which may involve |
| 143 | quantization-aware computations depending on the input parameters. |
| 144 | |
| 145 | Notes: |
| 146 | - If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version |
| 147 | is used for computation. |
| 148 | - If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied. |
| 149 | - For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation. |
| 150 | """ |
| 151 | if weight.element_size() > 1: |
| 152 | return F.linear(x, weight, bias) |
| 153 | elif gemm_impl == "bf16": |
| 154 | weight = weight_dequant(weight, weight.scale) |
| 155 | return F.linear(x, weight, bias) |
| 156 | else: |
| 157 | x, scale = act_quant(x, block_size) |
| 158 | y = fp8_gemm(x, scale, weight, weight.scale) |
| 159 | if bias is not None: |
| 160 | y += bias |
| 161 | return y |
| 162 | |
| 163 | |
| 164 | class Linear(nn.Module): |