Load weights.
(self, weights: Iterable[tuple[str, torch.Tensor]])
| 379 | ) |
| 380 | |
| 381 | def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): |
| 382 | """Load weights.""" |
| 383 | # modify from vllm |
| 384 | stacked_params_mapping = [ |
| 385 | # (param_name, shard_name, shard_id) |
| 386 | ('.qkv_proj', '.q_proj', 'q'), |
| 387 | ('.qkv_proj', '.k_proj', 'k'), |
| 388 | ('.qkv_proj', '.v_proj', 'v'), |
| 389 | ('.gate_up_proj', '.gate_proj', 0), |
| 390 | ('.gate_up_proj', '.up_proj', 1), |
| 391 | ] |
| 392 | |
| 393 | params_dict = dict(self.named_parameters()) |
| 394 | for name, loaded_weight in weights: |
| 395 | if 'rotary_emb.inv_freq' in name: |
| 396 | continue |
| 397 | if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): |
| 398 | continue |
| 399 | if self.config.tie_word_embeddings and 'lm_head.weight' in name: |
| 400 | continue |
| 401 | |
| 402 | for (param_name, weight_name, shard_id) in stacked_params_mapping: |
| 403 | if weight_name not in name: |
| 404 | continue |
| 405 | name = name.replace(weight_name, param_name) |
| 406 | param = params_dict[name] |
| 407 | load_weight(param, loaded_weight, shard_id=shard_id) |
| 408 | break |
| 409 | else: |
| 410 | param = params_dict[name] |
| 411 | load_weight(param, loaded_weight) |
no test coverage detected