Replace fp16 linear with quantized linear
(model, weight_bit_width)
| 5 | |
| 6 | |
| 7 | def quantize(model, weight_bit_width): |
| 8 | """Replace fp16 linear with quantized linear""" |
| 9 | |
| 10 | if torch.distributed.get_rank() == 0: |
| 11 | print(f"> Quantizing model weight to {weight_bit_width} bits") |
| 12 | |
| 13 | for layer in model.transformer.layers: |
| 14 | layer.attention.query_key_value = QuantizedColumnParallelLinear( |
| 15 | weight_bit_width=weight_bit_width, |
| 16 | weight=layer.attention.query_key_value.weight.to(torch.cuda.current_device()), |
| 17 | input_size=layer.attention.query_key_value.input_size, |
| 18 | output_size=layer.attention.query_key_value.output_size, |
| 19 | bias=True, |
| 20 | gather_output=False, |
| 21 | params_dtype=torch.half, |
| 22 | name="query_key_value", |
| 23 | skip_init=True, |
| 24 | device=layer.attention.query_key_value.weight.device, |
| 25 | ) |
| 26 | layer.attention.dense = QuantizedRowParallelLinear( |
| 27 | weight_bit_width=weight_bit_width, |
| 28 | weight=layer.attention.dense.weight.to(torch.cuda.current_device()), |
| 29 | input_size=layer.attention.dense.input_size, |
| 30 | output_size=layer.attention.dense.output_size, |
| 31 | bias=True, |
| 32 | input_is_parallel=True, |
| 33 | params_dtype=torch.half, |
| 34 | name="dense", |
| 35 | skip_init=True, |
| 36 | device=layer.attention.dense.weight.device, |
| 37 | ) |
| 38 | layer.mlp.dense_h_to_4h = QuantizedColumnParallelLinear( |
| 39 | weight_bit_width=weight_bit_width, |
| 40 | weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()), |
| 41 | input_size=layer.mlp.dense_h_to_4h.input_size, |
| 42 | output_size=layer.mlp.dense_h_to_4h.output_size, |
| 43 | bias=True, |
| 44 | gather_output=False, |
| 45 | params_dtype=torch.half, |
| 46 | name="dense_h_to_4h", |
| 47 | skip_init=True, |
| 48 | device=layer.mlp.dense_h_to_4h.weight.device, |
| 49 | ) |
| 50 | layer.mlp.dense_4h_to_h = QuantizedRowParallelLinear( |
| 51 | weight_bit_width=weight_bit_width, |
| 52 | weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()), |
| 53 | input_size=layer.mlp.dense_4h_to_h.input_size, |
| 54 | output_size=layer.mlp.dense_4h_to_h.output_size, |
| 55 | bias=True, |
| 56 | input_is_parallel=True, |
| 57 | params_dtype=torch.half, |
| 58 | name="dense_h_to_4h", |
| 59 | skip_init=True, |
| 60 | device=layer.mlp.dense_4h_to_h.weight.device, |
| 61 | ) |
| 62 | |
| 63 | return model |
no test coverage detected