MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / get_all_lora_weights

Function get_all_lora_weights

examples/hf_lora_convert.py:45–90  ·  view source on GitHub ↗
(lora_weights)

Source from the content-addressed store, hash-verified

43
44
45def get_all_lora_weights(lora_weights):
46 all_weights = defaultdict(lambda: defaultdict(dict))
47 pattern = re.compile(
48 r'(.*\.layers\.([0-9]+)\.(self_attn|mlp)\.([a-z_]+))\.(?:lora_(?:(A|B)\.weight|(magnitude)_vector)|weight_(m_wdecomp).weight).*'
49 )
50 moe_pattern = re.compile(
51 r'(.*\.layers\.([0-9]+)\.(block_sparse_moe)\.((experts)\.([0-9]+)\.|)([a-zA-Z0-9_]+))\.(?:lora_(?:(A|B)\.weight|(magnitude)_vector)|weight_(m_wdecomp).weight).*'
52 )
53 for key, weights in lora_weights.items():
54 m = pattern.match(key)
55 m_moe = moe_pattern.match(key)
56 if m:
57 layer_idx = int(m.group(2))
58 hf_module = m.group(4)
59 inout = m.group(5)
60 dora_magnitude = m.group(6) or m.group(7)
61
62 if inout:
63 inout = "in" if inout == "A" else "out"
64 all_weights[layer_idx][hf_module][inout] = weights
65 elif dora_magnitude:
66 LOGGER.warning(
67 "Detected DoRA magnitude vector, make sure it was preprocessed and normalized using the proper base model weights"
68 )
69 all_weights[layer_idx][hf_module]["magnitude"] = weights.view(
70 -1)
71
72 elif m_moe:
73 layer_idx = int(m_moe.group(2))
74 hf_module = m_moe.group(7)
75 inout = m_moe.group(8)
76 dora_magnitude = m_moe.group(9) or m.group(10)
77
78 if inout:
79 inout = "in" if inout == "A" else "out"
80 all_weights[layer_idx][hf_module][inout] = weights
81 elif dora_magnitude:
82 LOGGER.warning(
83 "Detected DoRA magnitude vector, make sure it was preprocessed and normalized using the proper base model weights"
84 )
85 all_weights[layer_idx][hf_module]["magnitude"] = weights.view(
86 -1)
87 else:
88 print(f"no match {key}")
89 continue
90 return all_weights
91
92
93def preprocess_lora_weights(lora_model):

Callers 1

convert_hf_modelFunction · 0.85

Calls 4

compileMethod · 0.45
matchMethod · 0.45
warningMethod · 0.45
viewMethod · 0.45

Tested by

no test coverage detected