Output plan results as JSON.
(
model: ModelInfo,
context_length: int,
target_quant: str,
)
| 75 | |
| 76 | |
| 77 | def display_plan_json( |
| 78 | model: ModelInfo, |
| 79 | context_length: int, |
| 80 | target_quant: str, |
| 81 | ) -> None: |
| 82 | """Output plan results as JSON.""" |
| 83 | from whichllm.constants import ( |
| 84 | GPU_BANDWIDTH, |
| 85 | QUANT_BYTES_PER_WEIGHT, |
| 86 | QUANT_QUALITY_PENALTY, |
| 87 | ) |
| 88 | from whichllm.engine.performance import estimate_tok_per_sec |
| 89 | from whichllm.engine.vram import estimate_vram |
| 90 | from whichllm.hardware.types import GPUInfo |
| 91 | |
| 92 | _GiB = 1024**3 |
| 93 | |
| 94 | quant_levels = ["Q2_K", "Q3_K_M", "Q4_K_M", "Q5_K_M", "Q6_K", "Q8_0", "F16"] |
| 95 | vram_by_quant = {} |
| 96 | for qt in quant_levels: |
| 97 | bpw = QUANT_BYTES_PER_WEIGHT.get(qt) |
| 98 | if bpw is None: |
| 99 | continue |
| 100 | fake_size = int(model.parameter_count * bpw) |
| 101 | fake_variant = GGUFVariant( |
| 102 | filename="", quant_type=qt, file_size_bytes=fake_size |
| 103 | ) |
| 104 | vram_bytes = estimate_vram(model, fake_variant, context_length) |
| 105 | vram_by_quant[qt] = { |
| 106 | "vram_bytes": vram_bytes, |
| 107 | "quality_loss": QUANT_QUALITY_PENALTY.get(qt, 0.0), |
| 108 | } |
| 109 | |
| 110 | target_vram = vram_by_quant.get(target_quant.upper(), {}).get("vram_bytes", 0) |
| 111 | if target_vram == 0: |
| 112 | bpw = QUANT_BYTES_PER_WEIGHT.get(target_quant.upper(), 0.5625) |
| 113 | fake_size = int(model.parameter_count * bpw) |
| 114 | fake_variant = GGUFVariant( |
| 115 | filename="", quant_type=target_quant, file_size_bytes=fake_size |
| 116 | ) |
| 117 | target_vram = estimate_vram(model, fake_variant, context_length) |
| 118 | |
| 119 | _PLAN_GPUS: list[tuple[str, int]] = [ |
| 120 | ("RTX 4060", 8), |
| 121 | ("RTX 3060", 12), |
| 122 | ("RTX 4070", 12), |
| 123 | ("RTX 4080", 16), |
| 124 | ("RTX 4090", 24), |
| 125 | ("RX 7900 XTX", 24), |
| 126 | ("RTX 5090", 32), |
| 127 | ("A100 40GB", 40), |
| 128 | ("L40S", 48), |
| 129 | ("A100 80GB", 80), |
| 130 | ("H100", 80), |
| 131 | ("H200", 141), |
| 132 | ] |
| 133 | |
| 134 | bpw = QUANT_BYTES_PER_WEIGHT.get(target_quant.upper(), 0.5625) |