(args)
| 1610 | |
| 1611 | |
| 1612 | def main(args): |
| 1613 | if args.fp16: |
| 1614 | dtype = torch.float16 |
| 1615 | elif args.bf16: |
| 1616 | dtype = torch.bfloat16 |
| 1617 | else: |
| 1618 | dtype = torch.float32 |
| 1619 | |
| 1620 | highres_fix = args.highres_fix_scale is not None |
| 1621 | # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" |
| 1622 | |
| 1623 | if args.v2 and args.clip_skip is not None: |
| 1624 | logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") |
| 1625 | |
| 1626 | # モデルを読み込む |
| 1627 | if not os.path.exists(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う |
| 1628 | files = glob.glob(args.ckpt) |
| 1629 | if len(files) == 1: |
| 1630 | args.ckpt = files[0] |
| 1631 | |
| 1632 | name_or_path = os.readlink(args.ckpt) if os.path.islink(args.ckpt) else args.ckpt |
| 1633 | use_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers |
| 1634 | |
| 1635 | # SDXLかどうかを判定する |
| 1636 | is_sdxl = args.sdxl |
| 1637 | if not is_sdxl and not args.v1 and not args.v2: # どれも指定されていない場合は自動で判定する |
| 1638 | if use_stable_diffusion_format: |
| 1639 | # if file size > 5.5GB, sdxl |
| 1640 | is_sdxl = os.path.getsize(name_or_path) > 5.5 * 1024**3 |
| 1641 | else: |
| 1642 | # if `text_encoder_2` subdirectory exists, sdxl |
| 1643 | is_sdxl = os.path.isdir(os.path.join(name_or_path, "text_encoder_2")) |
| 1644 | logger.info(f"SDXL: {is_sdxl}") |
| 1645 | |
| 1646 | if is_sdxl: |
| 1647 | if args.clip_skip is None: |
| 1648 | args.clip_skip = 2 |
| 1649 | |
| 1650 | (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( |
| 1651 | args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype |
| 1652 | ) |
| 1653 | unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) |
| 1654 | text_encoders = [text_encoder1, text_encoder2] |
| 1655 | else: |
| 1656 | if args.clip_skip is None: |
| 1657 | args.clip_skip = 2 if args.v2 else 1 |
| 1658 | |
| 1659 | if use_stable_diffusion_format: |
| 1660 | logger.info("load StableDiffusion checkpoint") |
| 1661 | text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) |
| 1662 | else: |
| 1663 | logger.info("load Diffusers pretrained models") |
| 1664 | loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) |
| 1665 | text_encoder = loading_pipe.text_encoder |
| 1666 | vae = loading_pipe.vae |
| 1667 | unet = loading_pipe.unet |
| 1668 | tokenizer = loading_pipe.tokenizer |
| 1669 | del loading_pipe |
no test coverage detected