(weights, config, args)
| 11 | |
| 12 | |
| 13 | def quantize(weights, config, args): |
| 14 | quantized_config = copy.deepcopy(config) |
| 15 | |
| 16 | # Load the model: |
| 17 | model = models.Model(models.ModelArgs.from_dict(config)) |
| 18 | model.load_weights(list(weights.items())) |
| 19 | |
| 20 | # Quantize the model: |
| 21 | nn.quantize( |
| 22 | model, |
| 23 | args.q_group_size, |
| 24 | args.q_bits, |
| 25 | ) |
| 26 | |
| 27 | # Update the config: |
| 28 | quantized_config["quantization"] = { |
| 29 | "group_size": args.q_group_size, |
| 30 | "bits": args.q_bits, |
| 31 | } |
| 32 | quantized_weights = dict(tree_flatten(model.parameters())) |
| 33 | |
| 34 | return quantized_weights, quantized_config |
| 35 | |
| 36 | |
| 37 | if __name__ == "__main__": |
no test coverage detected