MCPcopy
hub / github.com/kohya-ss/sd-scripts / load_weights

Method load_weights

train_textual_inversion.py:168–191  ·  view source on GitHub ↗
(self, file)

Source from the content-addressed store, hash-verified

166 torch.save(state_dict, file) # can be loaded in Web UI
167
168 def load_weights(self, file):
169 if os.path.splitext(file)[1] == ".safetensors":
170 from safetensors.torch import load_file
171
172 data = load_file(file)
173 else:
174 # compatible to Web UI's file format
175 data = torch.load(file, map_location="cpu")
176 if type(data) != dict:
177 raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}")
178
179 if "string_to_param" in data: # textual inversion embeddings
180 data = data["string_to_param"]
181 if hasattr(data, "_parameters"): # support old PyTorch?
182 data = getattr(data, "_parameters")
183
184 emb = next(iter(data.values()))
185 if type(emb) != torch.Tensor:
186 raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}")
187
188 if len(emb.size()) == 1:
189 emb = emb.unsqueeze(0)
190
191 return [emb]
192
193 def train(self, args):
194 if args.output_name is None:

Callers 1

trainMethod · 0.95

Calls

no outgoing calls

Tested by

no test coverage detected