MCPcopy
hub / github.com/kohya-ss/sd-scripts / main

Function main

gen_img.py:1612–3115  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

1610
1611
1612def main(args):
1613 if args.fp16:
1614 dtype = torch.float16
1615 elif args.bf16:
1616 dtype = torch.bfloat16
1617 else:
1618 dtype = torch.float32
1619
1620 highres_fix = args.highres_fix_scale is not None
1621 # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
1622
1623 if args.v2 and args.clip_skip is not None:
1624 logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
1625
1626 # モデルを読み込む
1627 if not os.path.exists(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う
1628 files = glob.glob(args.ckpt)
1629 if len(files) == 1:
1630 args.ckpt = files[0]
1631
1632 name_or_path = os.readlink(args.ckpt) if os.path.islink(args.ckpt) else args.ckpt
1633 use_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
1634
1635 # SDXLかどうかを判定する
1636 is_sdxl = args.sdxl
1637 if not is_sdxl and not args.v1 and not args.v2: # どれも指定されていない場合は自動で判定する
1638 if use_stable_diffusion_format:
1639 # if file size > 5.5GB, sdxl
1640 is_sdxl = os.path.getsize(name_or_path) > 5.5 * 1024**3
1641 else:
1642 # if `text_encoder_2` subdirectory exists, sdxl
1643 is_sdxl = os.path.isdir(os.path.join(name_or_path, "text_encoder_2"))
1644 logger.info(f"SDXL: {is_sdxl}")
1645
1646 if is_sdxl:
1647 if args.clip_skip is None:
1648 args.clip_skip = 2
1649
1650 (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
1651 args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype
1652 )
1653 unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet)
1654 text_encoders = [text_encoder1, text_encoder2]
1655 else:
1656 if args.clip_skip is None:
1657 args.clip_skip = 2 if args.v2 else 1
1658
1659 if use_stable_diffusion_format:
1660 logger.info("load StableDiffusion checkpoint")
1661 text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
1662 else:
1663 logger.info("load Diffusers pretrained models")
1664 loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
1665 text_encoder = loading_pipe.text_encoder
1666 vae = loading_pipe.vae
1667 unet = loading_pipe.unet
1668 tokenizer = loading_pipe.tokenizer
1669 del loading_pipe

Callers 1

gen_img.pyFile · 0.70

Calls 15

load_state_dictMethod · 0.95
apply_toMethod · 0.95
set_batch_cond_onlyMethod · 0.95
set_control_netsMethod · 0.95
set_deep_shrinkMethod · 0.95
set_gradual_latentMethod · 0.95
add_token_replacementMethod · 0.95
shuffleMethod · 0.95

Tested by

no test coverage detected