MCPcopy
hub / github.com/kohya-ss/sd-scripts / main

Function main

gen_img_diffusers.py:2371–3649  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

2369
2370
2371def main(args):
2372 if args.fp16:
2373 dtype = torch.float16
2374 elif args.bf16:
2375 dtype = torch.bfloat16
2376 else:
2377 dtype = torch.float32
2378
2379 highres_fix = args.highres_fix_scale is not None
2380 # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
2381
2382 if args.v2 and args.clip_skip is not None:
2383 logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
2384
2385 # モデルを読み込む
2386 if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う
2387 files = glob.glob(args.ckpt)
2388 if len(files) == 1:
2389 args.ckpt = files[0]
2390
2391 use_stable_diffusion_format = os.path.isfile(args.ckpt)
2392 if use_stable_diffusion_format:
2393 logger.info("load StableDiffusion checkpoint")
2394 text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
2395 else:
2396 logger.info("load Diffusers pretrained models")
2397 loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
2398 text_encoder = loading_pipe.text_encoder
2399 vae = loading_pipe.vae
2400 unet = loading_pipe.unet
2401 tokenizer = loading_pipe.tokenizer
2402 del loading_pipe
2403
2404 # Diffusers U-Net to original U-Net
2405 original_unet = UNet2DConditionModel(
2406 unet.config.sample_size,
2407 unet.config.attention_head_dim,
2408 unet.config.cross_attention_dim,
2409 unet.config.use_linear_projection,
2410 unet.config.upcast_attention,
2411 )
2412 original_unet.load_state_dict(unet.state_dict())
2413 unet = original_unet
2414 unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet)
2415
2416 # VAEを読み込む
2417 if args.vae is not None:
2418 vae = model_util.load_vae(args.vae, dtype)
2419 logger.info("additional VAE loaded")
2420
2421 # # 置換するCLIPを読み込む
2422 # if args.replace_clip_l14_336:
2423 # text_encoder = load_clip_l14_336(dtype)
2424 # logger.info(f"large clip {CLIP_ID_L14_336} is loaded")
2425
2426 if args.clip_guidance_scale > 0.0 or args.clip_image_guidance_scale:
2427 logger.info("prepare clip model")
2428 clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH, torch_dtype=dtype)

Callers 1

Calls 15

set_control_netsMethod · 0.95
set_deep_shrinkMethod · 0.95
set_gradual_latentMethod · 0.95
add_token_replacementMethod · 0.95
get_preferred_deviceFunction · 0.90
ControlNetInfoClass · 0.90
GradualLatentClass · 0.90

Tested by

no test coverage detected