MCPcopy
hub / github.com/InternLM/lmdeploy / load_weights

Method load_weights

lmdeploy/pytorch/models/sdar.py:381–411  ·  view source on GitHub ↗

Load weights.

(self, weights: Iterable[tuple[str, torch.Tensor]])

Source from the content-addressed store, hash-verified

379 )
380
381 def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
382 """Load weights."""
383 # modify from vllm
384 stacked_params_mapping = [
385 # (param_name, shard_name, shard_id)
386 ('.qkv_proj', '.q_proj', 'q'),
387 ('.qkv_proj', '.k_proj', 'k'),
388 ('.qkv_proj', '.v_proj', 'v'),
389 ('.gate_up_proj', '.gate_proj', 0),
390 ('.gate_up_proj', '.up_proj', 1),
391 ]
392
393 params_dict = dict(self.named_parameters())
394 for name, loaded_weight in weights:
395 if 'rotary_emb.inv_freq' in name:
396 continue
397 if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
398 continue
399 if self.config.tie_word_embeddings and 'lm_head.weight' in name:
400 continue
401
402 for (param_name, weight_name, shard_id) in stacked_params_mapping:
403 if weight_name not in name:
404 continue
405 name = name.replace(weight_name, param_name)
406 param = params_dict[name]
407 load_weight(param, loaded_weight, shard_id=shard_id)
408 break
409 else:
410 param = params_dict[name]
411 load_weight(param, loaded_weight)

Callers 2

load_model_weightsMethod · 0.45
update_paramsMethod · 0.45

Calls 1

load_weightFunction · 0.90

Tested by

no test coverage detected