(draft_id: str)
| 204 | |
| 205 | |
| 206 | def load_draft(draft_id: str) -> DFlashDraftModel: |
| 207 | path = Path(snapshot_download(draft_id, allow_patterns=["*.safetensors", "*.json"])) |
| 208 | cfg = json.loads((path / "config.json").read_text()) |
| 209 | layer_types = tuple(cfg.get("layer_types") or ["full_attention"] * cfg["num_hidden_layers"]) |
| 210 | if len(layer_types) != cfg["num_hidden_layers"]: |
| 211 | raise ValueError("Draft config layer_types length must match num_hidden_layers.") |
| 212 | unknown_layer_types = set(layer_types) - {"full_attention", "sliding_attention"} |
| 213 | if unknown_layer_types: |
| 214 | raise ValueError(f"Unsupported draft layer_types: {sorted(unknown_layer_types)}.") |
| 215 | if "sliding_attention" in layer_types and cfg.get("sliding_window") is None: |
| 216 | raise ValueError("Draft config must define sliding_window for sliding_attention layers.") |
| 217 | config = DFlashConfig( |
| 218 | hidden_size=cfg["hidden_size"], |
| 219 | num_hidden_layers=cfg["num_hidden_layers"], |
| 220 | num_attention_heads=cfg["num_attention_heads"], |
| 221 | num_key_value_heads=cfg["num_key_value_heads"], |
| 222 | head_dim=cfg["head_dim"], |
| 223 | intermediate_size=cfg["intermediate_size"], |
| 224 | vocab_size=cfg["vocab_size"], |
| 225 | rms_norm_eps=cfg["rms_norm_eps"], |
| 226 | rope_theta=cfg["rope_theta"], |
| 227 | max_position_embeddings=cfg["max_position_embeddings"], |
| 228 | block_size=cfg["block_size"], |
| 229 | target_layer_ids=tuple(cfg["dflash_config"]["target_layer_ids"]), |
| 230 | num_target_layers=cfg["num_target_layers"], |
| 231 | mask_token_id=cfg["dflash_config"]["mask_token_id"], |
| 232 | rope_scaling=cfg.get("rope_scaling"), |
| 233 | layer_types=layer_types, |
| 234 | sliding_window=cfg.get("sliding_window"), |
| 235 | final_logit_softcapping=cfg.get("final_logit_softcapping"), |
| 236 | ) |
| 237 | weights = {k: v for f in path.glob("*.safetensors") for k, v in mx.load(str(f)).items()} |
| 238 | model = DFlashDraftModel(config) |
| 239 | model.load_weights(list(weights.items())) |
| 240 | return model |
| 241 | |
| 242 | |
| 243 | def _trim_recent_cache(cache: List[Any], num_tokens: int) -> None: |
no test coverage detected