(model, quant_config: QuantConfig)
| 537 | |
| 538 | |
| 539 | def fp4_quantize(model, quant_config: QuantConfig): |
| 540 | assert quant_config.quant_mode.has_nvfp4() |
| 541 | quant_map = { |
| 542 | ColumnLinear: FP4Linear, |
| 543 | RowLinear: FP4RowLinear, |
| 544 | MixtureOfExperts: MixtureOfExperts, |
| 545 | } |
| 546 | |
| 547 | model = quantize_layers( |
| 548 | model, |
| 549 | quant_config, |
| 550 | quant_map, |
| 551 | ) |
| 552 | return model |
| 553 | |
| 554 | |
| 555 | # Now consider the kv cache is enabled for all layers |