(self, args, weight_dtype, accelerator)
| 170 | val_dataset_group.verify_bucket_reso_steps(64) |
| 171 | |
| 172 | def load_target_model(self, args, weight_dtype, accelerator) -> tuple[str, nn.Module, nn.Module, Optional[nn.Module]]: |
| 173 | text_encoder, vae, unet, _ = model_io.load_target_model(args, weight_dtype, accelerator) |
| 174 | |
| 175 | # モデルに xformers とか memory efficient attention を組み込む |
| 176 | model_io.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) |
| 177 | if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える |
| 178 | vae.set_use_memory_efficient_attention_xformers(args.xformers) |
| 179 | |
| 180 | return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet |
| 181 | |
| 182 | def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, List[nn.Module]]: |
| 183 | raise NotImplementedError() |
no outgoing calls