MCPcopy
hub / github.com/InternLM/InternLM / convert2hf

Function convert2hf

tools/transformers/convert2hf.py:17–92  ·  view source on GitHub ↗
(model_config, states_tp_pps)

Source from the content-addressed store, hash-verified

15
16
17def convert2hf(model_config, states_tp_pps):
18
19 with tempfile.TemporaryDirectory() as folder:
20 states = merge_pp(states_tp_pps)[0]
21
22 if "embedding.word_embeddings.weight" in states:
23 embedding_key = "embedding.word_embeddings.weight"
24 elif "embedding.weight" in states:
25 embedding_key = "embedding.weight"
26 else:
27 print("Check embedding states'names in below:", flush=True)
28 print(list(states.keys()), flush=True)
29
30 dims_per_head = model_config["hidden_size"] // model_config["num_attention_heads"]
31 base = 10000.0
32 inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
33
34 current_states = {}
35
36 current_states["model.embed_tokens.weight"] = states.pop(embedding_key)
37 current_states["model.norm.weight"] = states.pop("norm.weight")
38 current_states["lm_head.weight"] = states.pop("head.weight")
39
40 for i in range(model_config["num_layers"]):
41 states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq")
42
43 wqkv = states.pop(f"blocks.{i}.mixer.Wqkv.weight").reshape(
44 3, model_config["num_attention_heads"], -1, model_config["hidden_size"]
45 )
46 bqkv = states.pop(f"blocks.{i}.mixer.Wqkv.bias").reshape(3, model_config["num_attention_heads"], -1)
47
48 current_states[f"model.layers.{i}.self_attn.q_proj.weight"] = wqkv[0].reshape(
49 -1, model_config["hidden_size"]
50 )
51 current_states[f"model.layers.{i}.self_attn.q_proj.bias"] = bqkv[0].reshape(-1)
52 current_states[f"model.layers.{i}.self_attn.k_proj.weight"] = wqkv[1].reshape(
53 -1, model_config["hidden_size"]
54 )
55 current_states[f"model.layers.{i}.self_attn.k_proj.bias"] = bqkv[1].reshape(-1)
56 current_states[f"model.layers.{i}.self_attn.v_proj.weight"] = wqkv[2].reshape(
57 -1, model_config["hidden_size"]
58 )
59 current_states[f"model.layers.{i}.self_attn.v_proj.bias"] = bqkv[2].reshape(-1)
60
61 current_states[f"model.layers.{i}.self_attn.o_proj.weight"] = states.pop(
62 f"blocks.{i}.mixer.out_proj.weight"
63 )
64 current_states[f"model.layers.{i}.self_attn.o_proj.bias"] = states.pop(f"blocks.{i}.mixer.out_proj.bias")
65
66 current_states[f"model.layers.{i}.mlp.gate_proj.weight"] = states.pop(f"blocks.{i}.mlp.w1.weight")
67 current_states[f"model.layers.{i}.mlp.down_proj.weight"] = states.pop(f"blocks.{i}.mlp.w3.weight")
68 current_states[f"model.layers.{i}.mlp.up_proj.weight"] = states.pop(f"blocks.{i}.mlp.w2.weight")
69
70 current_states[f"model.layers.{i}.input_layernorm.weight"] = states.pop(f"blocks.{i}.norm1.weight")
71 current_states[f"model.layers.{i}.post_attention_layernorm.weight"] = states.pop(f"blocks.{i}.norm2.weight")
72 current_states[f"model.layers.{i}.self_attn.rotary_emb.inv_freq"] = inv_freq
73
74 config = InternLMConfig(

Callers 1

convert2hf.pyFile · 0.85

Calls 4

merge_ppFunction · 0.85
InternLMConfigClass · 0.85
saveMethod · 0.80

Tested by

no test coverage detected