| 166 | torch.save(state_dict, file) # can be loaded in Web UI |
| 167 | |
| 168 | def load_weights(self, file): |
| 169 | if os.path.splitext(file)[1] == ".safetensors": |
| 170 | from safetensors.torch import load_file |
| 171 | |
| 172 | data = load_file(file) |
| 173 | else: |
| 174 | # compatible to Web UI's file format |
| 175 | data = torch.load(file, map_location="cpu") |
| 176 | if type(data) != dict: |
| 177 | raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}") |
| 178 | |
| 179 | if "string_to_param" in data: # textual inversion embeddings |
| 180 | data = data["string_to_param"] |
| 181 | if hasattr(data, "_parameters"): # support old PyTorch? |
| 182 | data = getattr(data, "_parameters") |
| 183 | |
| 184 | emb = next(iter(data.values())) |
| 185 | if type(emb) != torch.Tensor: |
| 186 | raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}") |
| 187 | |
| 188 | if len(emb.size()) == 1: |
| 189 | emb = emb.unsqueeze(0) |
| 190 | |
| 191 | return [emb] |
| 192 | |
| 193 | def train(self, args): |
| 194 | if args.output_name is None: |