(self, file, updated_embs, save_dtype, metadata)
| 150 | ) |
| 151 | |
| 152 | def save_weights(self, file, updated_embs, save_dtype, metadata): |
| 153 | state_dict = {"emb_params": updated_embs[0]} |
| 154 | |
| 155 | if save_dtype is not None: |
| 156 | for key in list(state_dict.keys()): |
| 157 | v = state_dict[key] |
| 158 | v = v.detach().clone().to("cpu").to(save_dtype) |
| 159 | state_dict[key] = v |
| 160 | |
| 161 | if os.path.splitext(file)[1] == ".safetensors": |
| 162 | from safetensors.torch import save_file |
| 163 | |
| 164 | save_file(state_dict, file, metadata) |
| 165 | else: |
| 166 | torch.save(state_dict, file) # can be loaded in Web UI |
| 167 | |
| 168 | def load_weights(self, file): |
| 169 | if os.path.splitext(file)[1] == ".safetensors": |
no test coverage detected