(batch: List[BatchData], highres_fix, highres_1st=False)
| 3041 | |
| 3042 | # バッチ処理の関数 |
| 3043 | def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): |
| 3044 | batch_size = len(batch) |
| 3045 | |
| 3046 | # highres_fixの処理 |
| 3047 | if highres_fix and not highres_1st: |
| 3048 | # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す |
| 3049 | is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling |
| 3050 | |
| 3051 | logger.info("process 1st stage") |
| 3052 | batch_1st = [] |
| 3053 | for _, base, ext in batch: |
| 3054 | width_1st = int(ext.width * args.highres_fix_scale + 0.5) |
| 3055 | height_1st = int(ext.height * args.highres_fix_scale + 0.5) |
| 3056 | width_1st = width_1st - width_1st % 32 |
| 3057 | height_1st = height_1st - height_1st % 32 |
| 3058 | |
| 3059 | strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength |
| 3060 | |
| 3061 | ext_1st = BatchDataExt( |
| 3062 | width_1st, |
| 3063 | height_1st, |
| 3064 | args.highres_fix_steps, |
| 3065 | ext.scale, |
| 3066 | ext.negative_scale, |
| 3067 | strength_1st, |
| 3068 | ext.network_muls, |
| 3069 | ext.num_sub_prompts, |
| 3070 | ) |
| 3071 | batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) |
| 3072 | |
| 3073 | pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする |
| 3074 | images_1st = process_batch(batch_1st, True, True) |
| 3075 | |
| 3076 | # 2nd stageのバッチを作成して以下処理する |
| 3077 | logger.info("process 2nd stage") |
| 3078 | width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height |
| 3079 | |
| 3080 | if upscaler: |
| 3081 | # upscalerを使って画像を拡大する |
| 3082 | lowreso_imgs = None if is_1st_latent else images_1st |
| 3083 | lowreso_latents = None if not is_1st_latent else images_1st |
| 3084 | |
| 3085 | # 戻り値はPIL.Image.Imageかtorch.Tensorのlatents |
| 3086 | batch_size = len(images_1st) |
| 3087 | vae_batch_size = ( |
| 3088 | batch_size |
| 3089 | if args.vae_batch_size is None |
| 3090 | else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size) |
| 3091 | ) |
| 3092 | vae_batch_size = int(vae_batch_size) |
| 3093 | images_1st = upscaler.upscale( |
| 3094 | vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size |
| 3095 | ) |
| 3096 | |
| 3097 | elif args.highres_fix_latents_upscaling: |
| 3098 | # latentを拡大する |
| 3099 | org_dtype = images_1st.dtype |
| 3100 | if images_1st.dtype == torch.bfloat16: |
no test coverage detected