MCPcopy
hub / github.com/apple/ml-mgie / find_all_linear_names

Function find_all_linear_names

mgie_train.py:162–173  ·  view source on GitHub ↗
(model)

Source from the content-addressed store, hash-verified

160
161
162def find_all_linear_names(model):
163 cls = torch.nn.Linear
164 lora_module_names = set()
165 for name, module in model.named_modules():
166 if isinstance(module, cls):
167 names = name.split('.')
168 lora_module_names.add(names[0] if len(names) == 1 else names[-1])
169
170
171 if 'lm_head' in lora_module_names: # needed for 16-bit
172 lora_module_names.remove('lm_head')
173 return list(lora_module_names)
174
175
176def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,

Callers 1

trainFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected