| 296 | |
| 297 | |
| 298 | def quantize(weights, config, args): |
| 299 | quantized_config = copy.deepcopy(config) |
| 300 | |
| 301 | # Load the model: |
| 302 | model = Whisper(ModelDimensions(**config)) |
| 303 | weights = tree_map(mx.array, weights) |
| 304 | model.update(tree_unflatten(list(weights.items()))) |
| 305 | |
| 306 | # Quantize the model: |
| 307 | nn.quantize(model, args.q_group_size, args.q_bits) |
| 308 | |
| 309 | # Update the config: |
| 310 | quantized_config["quantization"] = { |
| 311 | "group_size": args.q_group_size, |
| 312 | "bits": args.q_bits, |
| 313 | } |
| 314 | quantized_weights = dict(tree_flatten(model.parameters())) |
| 315 | |
| 316 | return quantized_weights, quantized_config |
| 317 | |
| 318 | |
| 319 | if __name__ == "__main__": |