| 305 | |
| 306 | @classmethod |
| 307 | def from_pretrained(cls, ckpt_path, device: torch.device): |
| 308 | config_file = Path(ckpt_path) / "config.json" |
| 309 | with open(config_file, "r") as f: |
| 310 | config = json.load(f) |
| 311 | args = Qwen2Config( |
| 312 | attention_dropout=config["attention_dropout"], |
| 313 | bos_token_id=config["bos_token_id"], |
| 314 | eos_token_id=config["eos_token_id"], |
| 315 | hidden_act=config["hidden_act"], |
| 316 | hidden_size=config["hidden_size"], |
| 317 | initializer_range=config["initializer_range"], |
| 318 | intermediate_size=config["intermediate_size"], |
| 319 | max_position_embeddings=config["max_position_embeddings"], |
| 320 | max_window_layers=config["max_window_layers"], |
| 321 | model_type=config["model_type"], |
| 322 | num_hidden_layers=config["num_hidden_layers"], |
| 323 | num_attention_heads=config["num_attention_heads"], |
| 324 | num_key_value_heads=config["num_key_value_heads"], |
| 325 | vocab_size=config["vocab_size"], |
| 326 | rms_norm_eps=config["rms_norm_eps"], |
| 327 | rope_theta=config["rope_theta"], |
| 328 | sliding_window=config["sliding_window"], |
| 329 | use_sliding_window=config["use_sliding_window"], |
| 330 | use_cache=config["use_cache"], |
| 331 | tie_word_embeddings=config["tie_word_embeddings"], |
| 332 | torch_dtype=config["torch_dtype"], |
| 333 | ) |
| 334 | with torch.device("meta"): |
| 335 | model = cls(params=args, device=device) |
| 336 | |
| 337 | import safetensors.torch |
| 338 | |
| 339 | model_weight_files = sorted(Path(ckpt_path).glob("model*.safetensors")) |
| 340 | weights = {} |
| 341 | for file in model_weight_files: |
| 342 | weights.update(safetensors.torch.load_file(file, device="cpu")) |
| 343 | # remove "model." prefix from keys |
| 344 | weights = {k.replace("model.", ""): v for k, v in weights.items()} |
| 345 | model.load_state_dict(weights, strict=True, assign=True) |
| 346 | return model.to(device) |