MCPcopy
hub / github.com/policy-gradient/GRPO-Zero / from_pretrained

Method from_pretrained

qwen2_model.py:307–346  ·  view source on GitHub ↗
(cls, ckpt_path, device: torch.device)

Source from the content-addressed store, hash-verified

305
306 @classmethod
307 def from_pretrained(cls, ckpt_path, device: torch.device):
308 config_file = Path(ckpt_path) / "config.json"
309 with open(config_file, "r") as f:
310 config = json.load(f)
311 args = Qwen2Config(
312 attention_dropout=config["attention_dropout"],
313 bos_token_id=config["bos_token_id"],
314 eos_token_id=config["eos_token_id"],
315 hidden_act=config["hidden_act"],
316 hidden_size=config["hidden_size"],
317 initializer_range=config["initializer_range"],
318 intermediate_size=config["intermediate_size"],
319 max_position_embeddings=config["max_position_embeddings"],
320 max_window_layers=config["max_window_layers"],
321 model_type=config["model_type"],
322 num_hidden_layers=config["num_hidden_layers"],
323 num_attention_heads=config["num_attention_heads"],
324 num_key_value_heads=config["num_key_value_heads"],
325 vocab_size=config["vocab_size"],
326 rms_norm_eps=config["rms_norm_eps"],
327 rope_theta=config["rope_theta"],
328 sliding_window=config["sliding_window"],
329 use_sliding_window=config["use_sliding_window"],
330 use_cache=config["use_cache"],
331 tie_word_embeddings=config["tie_word_embeddings"],
332 torch_dtype=config["torch_dtype"],
333 )
334 with torch.device("meta"):
335 model = cls(params=args, device=device)
336
337 import safetensors.torch
338
339 model_weight_files = sorted(Path(ckpt_path).glob("model*.safetensors"))
340 weights = {}
341 for file in model_weight_files:
342 weights.update(safetensors.torch.load_file(file, device="cpu"))
343 # remove "model." prefix from keys
344 weights = {k.replace("model.", ""): v for k, v in weights.items()}
345 model.load_state_dict(weights, strict=True, assign=True)
346 return model.to(device)

Callers 1

mainFunction · 0.80

Calls 1

Qwen2ConfigClass · 0.85

Tested by

no test coverage detected