MCPcopy
hub / github.com/z-lab/dflash / load_draft

Function load_draft

dflash/model_mlx.py:206–240  ·  view source on GitHub ↗
(draft_id: str)

Source from the content-addressed store, hash-verified

204
205
206def load_draft(draft_id: str) -> DFlashDraftModel:
207 path = Path(snapshot_download(draft_id, allow_patterns=["*.safetensors", "*.json"]))
208 cfg = json.loads((path / "config.json").read_text())
209 layer_types = tuple(cfg.get("layer_types") or ["full_attention"] * cfg["num_hidden_layers"])
210 if len(layer_types) != cfg["num_hidden_layers"]:
211 raise ValueError("Draft config layer_types length must match num_hidden_layers.")
212 unknown_layer_types = set(layer_types) - {"full_attention", "sliding_attention"}
213 if unknown_layer_types:
214 raise ValueError(f"Unsupported draft layer_types: {sorted(unknown_layer_types)}.")
215 if "sliding_attention" in layer_types and cfg.get("sliding_window") is None:
216 raise ValueError("Draft config must define sliding_window for sliding_attention layers.")
217 config = DFlashConfig(
218 hidden_size=cfg["hidden_size"],
219 num_hidden_layers=cfg["num_hidden_layers"],
220 num_attention_heads=cfg["num_attention_heads"],
221 num_key_value_heads=cfg["num_key_value_heads"],
222 head_dim=cfg["head_dim"],
223 intermediate_size=cfg["intermediate_size"],
224 vocab_size=cfg["vocab_size"],
225 rms_norm_eps=cfg["rms_norm_eps"],
226 rope_theta=cfg["rope_theta"],
227 max_position_embeddings=cfg["max_position_embeddings"],
228 block_size=cfg["block_size"],
229 target_layer_ids=tuple(cfg["dflash_config"]["target_layer_ids"]),
230 num_target_layers=cfg["num_target_layers"],
231 mask_token_id=cfg["dflash_config"]["mask_token_id"],
232 rope_scaling=cfg.get("rope_scaling"),
233 layer_types=layer_types,
234 sliding_window=cfg.get("sliding_window"),
235 final_logit_softcapping=cfg.get("final_logit_softcapping"),
236 )
237 weights = {k: v for f in path.glob("*.safetensors") for k, v in mx.load(str(f)).items()}
238 model = DFlashDraftModel(config)
239 model.load_weights(list(weights.items()))
240 return model
241
242
243def _trim_recent_cache(cache: List[Any], num_tokens: int) -> None:

Callers 1

_run_mlxFunction · 0.85

Calls 2

DFlashConfigClass · 0.85
DFlashDraftModelClass · 0.70

Tested by

no test coverage detected