MCPcopy
hub / github.com/kohya-ss/sd-scripts / load_target_model

Method load_target_model

train_network.py:172–180  ·  view source on GitHub ↗
(self, args, weight_dtype, accelerator)

Source from the content-addressed store, hash-verified

170 val_dataset_group.verify_bucket_reso_steps(64)
171
172 def load_target_model(self, args, weight_dtype, accelerator) -> tuple[str, nn.Module, nn.Module, Optional[nn.Module]]:
173 text_encoder, vae, unet, _ = model_io.load_target_model(args, weight_dtype, accelerator)
174
175 # モデルに xformers とか memory efficient attention を組み込む
176 model_io.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
177 if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
178 vae.set_use_memory_efficient_attention_xformers(args.xformers)
179
180 return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
181
182 def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, List[nn.Module]]:
183 raise NotImplementedError()

Callers 12

trainMethod · 0.95
trainFunction · 0.45
mainFunction · 0.45
mainFunction · 0.45
trainFunction · 0.45
trainFunction · 0.45
trainFunction · 0.45
trainFunction · 0.45
trainFunction · 0.45
cache_to_diskFunction · 0.45
cache_to_diskFunction · 0.45
test_load_target_modelFunction · 0.45

Calls

no outgoing calls

Tested by 1

test_load_target_modelFunction · 0.36