(lora_weights)
| 43 | |
| 44 | |
| 45 | def get_all_lora_weights(lora_weights): |
| 46 | all_weights = defaultdict(lambda: defaultdict(dict)) |
| 47 | pattern = re.compile( |
| 48 | r'(.*\.layers\.([0-9]+)\.(self_attn|mlp)\.([a-z_]+))\.(?:lora_(?:(A|B)\.weight|(magnitude)_vector)|weight_(m_wdecomp).weight).*' |
| 49 | ) |
| 50 | moe_pattern = re.compile( |
| 51 | r'(.*\.layers\.([0-9]+)\.(block_sparse_moe)\.((experts)\.([0-9]+)\.|)([a-zA-Z0-9_]+))\.(?:lora_(?:(A|B)\.weight|(magnitude)_vector)|weight_(m_wdecomp).weight).*' |
| 52 | ) |
| 53 | for key, weights in lora_weights.items(): |
| 54 | m = pattern.match(key) |
| 55 | m_moe = moe_pattern.match(key) |
| 56 | if m: |
| 57 | layer_idx = int(m.group(2)) |
| 58 | hf_module = m.group(4) |
| 59 | inout = m.group(5) |
| 60 | dora_magnitude = m.group(6) or m.group(7) |
| 61 | |
| 62 | if inout: |
| 63 | inout = "in" if inout == "A" else "out" |
| 64 | all_weights[layer_idx][hf_module][inout] = weights |
| 65 | elif dora_magnitude: |
| 66 | LOGGER.warning( |
| 67 | "Detected DoRA magnitude vector, make sure it was preprocessed and normalized using the proper base model weights" |
| 68 | ) |
| 69 | all_weights[layer_idx][hf_module]["magnitude"] = weights.view( |
| 70 | -1) |
| 71 | |
| 72 | elif m_moe: |
| 73 | layer_idx = int(m_moe.group(2)) |
| 74 | hf_module = m_moe.group(7) |
| 75 | inout = m_moe.group(8) |
| 76 | dora_magnitude = m_moe.group(9) or m.group(10) |
| 77 | |
| 78 | if inout: |
| 79 | inout = "in" if inout == "A" else "out" |
| 80 | all_weights[layer_idx][hf_module][inout] = weights |
| 81 | elif dora_magnitude: |
| 82 | LOGGER.warning( |
| 83 | "Detected DoRA magnitude vector, make sure it was preprocessed and normalized using the proper base model weights" |
| 84 | ) |
| 85 | all_weights[layer_idx][hf_module]["magnitude"] = weights.view( |
| 86 | -1) |
| 87 | else: |
| 88 | print(f"no match {key}") |
| 89 | continue |
| 90 | return all_weights |
| 91 | |
| 92 | |
| 93 | def preprocess_lora_weights(lora_model): |
no test coverage detected