(model_config, states_tp_pps)
| 15 | |
| 16 | |
| 17 | def convert2hf(model_config, states_tp_pps): |
| 18 | |
| 19 | with tempfile.TemporaryDirectory() as folder: |
| 20 | states = merge_pp(states_tp_pps)[0] |
| 21 | |
| 22 | if "embedding.word_embeddings.weight" in states: |
| 23 | embedding_key = "embedding.word_embeddings.weight" |
| 24 | elif "embedding.weight" in states: |
| 25 | embedding_key = "embedding.weight" |
| 26 | else: |
| 27 | print("Check embedding states'names in below:", flush=True) |
| 28 | print(list(states.keys()), flush=True) |
| 29 | |
| 30 | dims_per_head = model_config["hidden_size"] // model_config["num_attention_heads"] |
| 31 | base = 10000.0 |
| 32 | inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) |
| 33 | |
| 34 | current_states = {} |
| 35 | |
| 36 | current_states["model.embed_tokens.weight"] = states.pop(embedding_key) |
| 37 | current_states["model.norm.weight"] = states.pop("norm.weight") |
| 38 | current_states["lm_head.weight"] = states.pop("head.weight") |
| 39 | |
| 40 | for i in range(model_config["num_layers"]): |
| 41 | states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq") |
| 42 | |
| 43 | wqkv = states.pop(f"blocks.{i}.mixer.Wqkv.weight").reshape( |
| 44 | 3, model_config["num_attention_heads"], -1, model_config["hidden_size"] |
| 45 | ) |
| 46 | bqkv = states.pop(f"blocks.{i}.mixer.Wqkv.bias").reshape(3, model_config["num_attention_heads"], -1) |
| 47 | |
| 48 | current_states[f"model.layers.{i}.self_attn.q_proj.weight"] = wqkv[0].reshape( |
| 49 | -1, model_config["hidden_size"] |
| 50 | ) |
| 51 | current_states[f"model.layers.{i}.self_attn.q_proj.bias"] = bqkv[0].reshape(-1) |
| 52 | current_states[f"model.layers.{i}.self_attn.k_proj.weight"] = wqkv[1].reshape( |
| 53 | -1, model_config["hidden_size"] |
| 54 | ) |
| 55 | current_states[f"model.layers.{i}.self_attn.k_proj.bias"] = bqkv[1].reshape(-1) |
| 56 | current_states[f"model.layers.{i}.self_attn.v_proj.weight"] = wqkv[2].reshape( |
| 57 | -1, model_config["hidden_size"] |
| 58 | ) |
| 59 | current_states[f"model.layers.{i}.self_attn.v_proj.bias"] = bqkv[2].reshape(-1) |
| 60 | |
| 61 | current_states[f"model.layers.{i}.self_attn.o_proj.weight"] = states.pop( |
| 62 | f"blocks.{i}.mixer.out_proj.weight" |
| 63 | ) |
| 64 | current_states[f"model.layers.{i}.self_attn.o_proj.bias"] = states.pop(f"blocks.{i}.mixer.out_proj.bias") |
| 65 | |
| 66 | current_states[f"model.layers.{i}.mlp.gate_proj.weight"] = states.pop(f"blocks.{i}.mlp.w1.weight") |
| 67 | current_states[f"model.layers.{i}.mlp.down_proj.weight"] = states.pop(f"blocks.{i}.mlp.w3.weight") |
| 68 | current_states[f"model.layers.{i}.mlp.up_proj.weight"] = states.pop(f"blocks.{i}.mlp.w2.weight") |
| 69 | |
| 70 | current_states[f"model.layers.{i}.input_layernorm.weight"] = states.pop(f"blocks.{i}.norm1.weight") |
| 71 | current_states[f"model.layers.{i}.post_attention_layernorm.weight"] = states.pop(f"blocks.{i}.norm2.weight") |
| 72 | current_states[f"model.layers.{i}.self_attn.rotary_emb.inv_freq"] = inv_freq |
| 73 | |
| 74 | config = InternLMConfig( |
no test coverage detected