(model)
| 160 | |
| 161 | |
| 162 | def find_all_linear_names(model): |
| 163 | cls = torch.nn.Linear |
| 164 | lora_module_names = set() |
| 165 | for name, module in model.named_modules(): |
| 166 | if isinstance(module, cls): |
| 167 | names = name.split('.') |
| 168 | lora_module_names.add(names[0] if len(names) == 1 else names[-1]) |
| 169 | |
| 170 | |
| 171 | if 'lm_head' in lora_module_names: # needed for 16-bit |
| 172 | lora_module_names.remove('lm_head') |
| 173 | return list(lora_module_names) |
| 174 | |
| 175 | |
| 176 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, |