| 12 | |
| 13 | |
| 14 | class QuantizedColumnParallelLinear(ColumnParallelLinear): |
| 15 | def __init__(self, weight_bit_width: int, weight=None, *args, **kwargs): |
| 16 | super(QuantizedColumnParallelLinear, self).__init__(*args, **kwargs) |
| 17 | self.weight_bit_width = weight_bit_width |
| 18 | |
| 19 | shape = self.weight.shape |
| 20 | del self.weight |
| 21 | |
| 22 | if weight is None: |
| 23 | self.weight = torch.empty( |
| 24 | shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"] |
| 25 | ) |
| 26 | self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"]) |
| 27 | else: |
| 28 | self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half() |
| 29 | self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8) |
| 30 | if weight_bit_width == 4: |
| 31 | self.weight = compress_int4_weight(self.weight) |
| 32 | |
| 33 | self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False) |
| 34 | self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False) |
| 35 | |
| 36 | def forward(self, input_): |
| 37 | # Set up backprop all-reduce. |
| 38 | input_parallel = copy_to_model_parallel_region(input_) |
| 39 | # Matrix multiply. |
| 40 | output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale, self.weight_bit_width) |
| 41 | if self.bias is not None: |
| 42 | output_parallel = output_parallel + self.bias |
| 43 | if self.gather_output: |
| 44 | # All-gather across the partitions. |
| 45 | output = gather_from_model_parallel_region(output_parallel) |
| 46 | else: |
| 47 | output = output_parallel |
| 48 | return output |
| 49 | |
| 50 | |
| 51 | class QuantizedRowParallelLinear(RowParallelLinear): |