MCPcopy
hub / github.com/kohya-ss/sd-scripts / sample_image_inference

Function sample_image_inference

library/lumina_train_util.py:250–478  ·  view source on GitHub ↗

Generates sample images Args: accelerator (Accelerator): Accelerator object args (argparse.Namespace): Arguments object nextdit (lumina_models.NextDiT): NextDiT model gemma2_model (list[Gemma2Model]): Gemma2 model vae (AutoEncoder): VAE model

(
    accelerator: Accelerator,
    args: argparse.Namespace,
    nextdit: lumina_models.NextDiT,
    gemma2_model: list[Gemma2Model],
    vae: AutoEncoder,
    save_dir: str,
    prompt_dicts: list[Dict[str, str]],
    epoch: int,
    global_step: int,
    sample_prompts_gemma2_outputs: dict[str, Tuple[Tensor, Tensor, Tensor]],
    prompt_replacement: Optional[Tuple[str, str]] = None,
    controlnet=None,
)

Source from the content-addressed store, hash-verified

248
249@torch.no_grad()
250def sample_image_inference(
251 accelerator: Accelerator,
252 args: argparse.Namespace,
253 nextdit: lumina_models.NextDiT,
254 gemma2_model: list[Gemma2Model],
255 vae: AutoEncoder,
256 save_dir: str,
257 prompt_dicts: list[Dict[str, str]],
258 epoch: int,
259 global_step: int,
260 sample_prompts_gemma2_outputs: dict[str, Tuple[Tensor, Tensor, Tensor]],
261 prompt_replacement: Optional[Tuple[str, str]] = None,
262 controlnet=None,
263):
264 """
265 Generates sample images
266
267 Args:
268 accelerator (Accelerator): Accelerator object
269 args (argparse.Namespace): Arguments object
270 nextdit (lumina_models.NextDiT): NextDiT model
271 gemma2_model (list[Gemma2Model]): Gemma2 model
272 vae (AutoEncoder): VAE model
273 save_dir (str): Directory to save images
274 prompt_dict (Dict[str, str]): Prompt dictionary
275 epoch (int): Epoch number
276 steps (int): Number of steps to run
277 sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing Gemma 2 outputs
278 prompt_replacement (Optional[Tuple[str, str]], optional): Replacement for positive and negative prompt. Defaults to None.
279
280 Returns:
281 None
282 """
283
284 # encode prompts
285 tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
286 encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
287
288 assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
289 assert isinstance(encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
290
291 text_conds = []
292
293 # assuming seed, width, height, sample steps, guidance are the same
294 width = int(prompt_dicts[0].get("width", 1024))
295 height = int(prompt_dicts[0].get("height", 1024))
296 height = max(64, height - height % 8) # round to divisible by 8
297 width = max(64, width - width % 8) # round to divisible by 8
298
299 guidance_scale = float(prompt_dicts[0].get("scale", 3.5))
300 cfg_trunc_ratio = float(prompt_dicts[0].get("cfg_trunc_ratio", 0.25))
301 renorm_cfg = float(prompt_dicts[0].get("renorm_cfg", 1.0))
302 sample_steps = int(prompt_dicts[0].get("sample_steps", 36))
303 seed = prompt_dicts[0].get("seed", None)
304 seed = int(seed) if seed is not None else None
305 assert seed is None or seed > 0, f"Invalid seed {seed}"
306 generator = torch.Generator(device=accelerator.device)
307 if seed is not None:

Callers 1

sample_imagesFunction · 0.70

Calls 11

clean_memory_on_deviceFunction · 0.90
retrieve_timestepsFunction · 0.85
getMethod · 0.80
toMethod · 0.80
denoiseFunction · 0.70
get_strategyMethod · 0.45
tokenizeMethod · 0.45
encode_tokensMethod · 0.45
randnMethod · 0.45
decodeMethod · 0.45

Tested by

no test coverage detected