MCPcopy
hub / github.com/kohya-ss/sd-scripts / process_batch

Function process_batch

gen_img_diffusers.py:3043–3307  ·  view source on GitHub ↗
(batch: List[BatchData], highres_fix, highres_1st=False)

Source from the content-addressed store, hash-verified

3041
3042 # バッチ処理の関数
3043 def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
3044 batch_size = len(batch)
3045
3046 # highres_fixの処理
3047 if highres_fix and not highres_1st:
3048 # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
3049 is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling
3050
3051 logger.info("process 1st stage")
3052 batch_1st = []
3053 for _, base, ext in batch:
3054 width_1st = int(ext.width * args.highres_fix_scale + 0.5)
3055 height_1st = int(ext.height * args.highres_fix_scale + 0.5)
3056 width_1st = width_1st - width_1st % 32
3057 height_1st = height_1st - height_1st % 32
3058
3059 strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength
3060
3061 ext_1st = BatchDataExt(
3062 width_1st,
3063 height_1st,
3064 args.highres_fix_steps,
3065 ext.scale,
3066 ext.negative_scale,
3067 strength_1st,
3068 ext.network_muls,
3069 ext.num_sub_prompts,
3070 )
3071 batch_1st.append(BatchData(is_1st_latent, base, ext_1st))
3072
3073 pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする
3074 images_1st = process_batch(batch_1st, True, True)
3075
3076 # 2nd stageのバッチを作成して以下処理する
3077 logger.info("process 2nd stage")
3078 width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height
3079
3080 if upscaler:
3081 # upscalerを使って画像を拡大する
3082 lowreso_imgs = None if is_1st_latent else images_1st
3083 lowreso_latents = None if not is_1st_latent else images_1st
3084
3085 # 戻り値はPIL.Image.Imageかtorch.Tensorのlatents
3086 batch_size = len(images_1st)
3087 vae_batch_size = (
3088 batch_size
3089 if args.vae_batch_size is None
3090 else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size)
3091 )
3092 vae_batch_size = int(vae_batch_size)
3093 images_1st = upscaler.upscale(
3094 vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size
3095 )
3096
3097 elif args.highres_fix_latents_upscaling:
3098 # latentを拡大する
3099 org_dtype = images_1st.dtype
3100 if images_1st.dtype == torch.bfloat16:

Callers 1

mainFunction · 0.70

Calls 14

support_latentsMethod · 0.80
upscaleMethod · 0.80
toMethod · 0.80
interpolateMethod · 0.80
BatchDataExtClass · 0.70
BatchDataClass · 0.70
BatchDataBaseClass · 0.70
randnMethod · 0.45
reset_sampler_noisesMethod · 0.45
set_multiplierMethod · 0.45

Tested by

no test coverage detected