(args)
| 2369 | |
| 2370 | |
| 2371 | def main(args): |
| 2372 | if args.fp16: |
| 2373 | dtype = torch.float16 |
| 2374 | elif args.bf16: |
| 2375 | dtype = torch.bfloat16 |
| 2376 | else: |
| 2377 | dtype = torch.float32 |
| 2378 | |
| 2379 | highres_fix = args.highres_fix_scale is not None |
| 2380 | # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" |
| 2381 | |
| 2382 | if args.v2 and args.clip_skip is not None: |
| 2383 | logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") |
| 2384 | |
| 2385 | # モデルを読み込む |
| 2386 | if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う |
| 2387 | files = glob.glob(args.ckpt) |
| 2388 | if len(files) == 1: |
| 2389 | args.ckpt = files[0] |
| 2390 | |
| 2391 | use_stable_diffusion_format = os.path.isfile(args.ckpt) |
| 2392 | if use_stable_diffusion_format: |
| 2393 | logger.info("load StableDiffusion checkpoint") |
| 2394 | text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) |
| 2395 | else: |
| 2396 | logger.info("load Diffusers pretrained models") |
| 2397 | loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) |
| 2398 | text_encoder = loading_pipe.text_encoder |
| 2399 | vae = loading_pipe.vae |
| 2400 | unet = loading_pipe.unet |
| 2401 | tokenizer = loading_pipe.tokenizer |
| 2402 | del loading_pipe |
| 2403 | |
| 2404 | # Diffusers U-Net to original U-Net |
| 2405 | original_unet = UNet2DConditionModel( |
| 2406 | unet.config.sample_size, |
| 2407 | unet.config.attention_head_dim, |
| 2408 | unet.config.cross_attention_dim, |
| 2409 | unet.config.use_linear_projection, |
| 2410 | unet.config.upcast_attention, |
| 2411 | ) |
| 2412 | original_unet.load_state_dict(unet.state_dict()) |
| 2413 | unet = original_unet |
| 2414 | unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet) |
| 2415 | |
| 2416 | # VAEを読み込む |
| 2417 | if args.vae is not None: |
| 2418 | vae = model_util.load_vae(args.vae, dtype) |
| 2419 | logger.info("additional VAE loaded") |
| 2420 | |
| 2421 | # # 置換するCLIPを読み込む |
| 2422 | # if args.replace_clip_l14_336: |
| 2423 | # text_encoder = load_clip_l14_336(dtype) |
| 2424 | # logger.info(f"large clip {CLIP_ID_L14_336} is loaded") |
| 2425 | |
| 2426 | if args.clip_guidance_scale > 0.0 or args.clip_image_guidance_scale: |
| 2427 | logger.info("prepare clip model") |
| 2428 | clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH, torch_dtype=dtype) |
no test coverage detected