| 49 | |
| 50 | |
| 51 | class QuantizedRowParallelLinear(RowParallelLinear): |
| 52 | def __init__(self, weight_bit_width: int, weight=None, *args, **kwargs): |
| 53 | super(QuantizedRowParallelLinear, self).__init__(*args, **kwargs) |
| 54 | self.weight_bit_width = weight_bit_width |
| 55 | |
| 56 | shape = self.weight.shape |
| 57 | del self.weight |
| 58 | |
| 59 | if weight is None: |
| 60 | self.weight = torch.empty( |
| 61 | shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"] |
| 62 | ) |
| 63 | self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"]) |
| 64 | else: |
| 65 | self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half() |
| 66 | self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8) |
| 67 | if weight_bit_width == 4: |
| 68 | self.weight = compress_int4_weight(self.weight) |
| 69 | |
| 70 | self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False) |
| 71 | self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False) |
| 72 | |
| 73 | def forward(self, input_): |
| 74 | # Set up backprop all-reduce. |
| 75 | if self.input_is_parallel: |
| 76 | input_parallel = input_ |
| 77 | else: |
| 78 | input_parallel = scatter_to_model_parallel_region(input_) |
| 79 | # Matrix multiply. |
| 80 | output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale, self.weight_bit_width) |
| 81 | # All-reduce across all the partitions. |
| 82 | output_ = reduce_from_model_parallel_region(output_parallel) |
| 83 | if self.bias is not None: |
| 84 | output = output_ + self.bias |
| 85 | else: |
| 86 | output = output_ |
| 87 | return output |