Generates sample images Args: accelerator (Accelerator): Accelerator object args (argparse.Namespace): Arguments object nextdit (lumina_models.NextDiT): NextDiT model gemma2_model (list[Gemma2Model]): Gemma2 model vae (AutoEncoder): VAE model
(
accelerator: Accelerator,
args: argparse.Namespace,
nextdit: lumina_models.NextDiT,
gemma2_model: list[Gemma2Model],
vae: AutoEncoder,
save_dir: str,
prompt_dicts: list[Dict[str, str]],
epoch: int,
global_step: int,
sample_prompts_gemma2_outputs: dict[str, Tuple[Tensor, Tensor, Tensor]],
prompt_replacement: Optional[Tuple[str, str]] = None,
controlnet=None,
)
| 248 | |
| 249 | @torch.no_grad() |
| 250 | def sample_image_inference( |
| 251 | accelerator: Accelerator, |
| 252 | args: argparse.Namespace, |
| 253 | nextdit: lumina_models.NextDiT, |
| 254 | gemma2_model: list[Gemma2Model], |
| 255 | vae: AutoEncoder, |
| 256 | save_dir: str, |
| 257 | prompt_dicts: list[Dict[str, str]], |
| 258 | epoch: int, |
| 259 | global_step: int, |
| 260 | sample_prompts_gemma2_outputs: dict[str, Tuple[Tensor, Tensor, Tensor]], |
| 261 | prompt_replacement: Optional[Tuple[str, str]] = None, |
| 262 | controlnet=None, |
| 263 | ): |
| 264 | """ |
| 265 | Generates sample images |
| 266 | |
| 267 | Args: |
| 268 | accelerator (Accelerator): Accelerator object |
| 269 | args (argparse.Namespace): Arguments object |
| 270 | nextdit (lumina_models.NextDiT): NextDiT model |
| 271 | gemma2_model (list[Gemma2Model]): Gemma2 model |
| 272 | vae (AutoEncoder): VAE model |
| 273 | save_dir (str): Directory to save images |
| 274 | prompt_dict (Dict[str, str]): Prompt dictionary |
| 275 | epoch (int): Epoch number |
| 276 | steps (int): Number of steps to run |
| 277 | sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing Gemma 2 outputs |
| 278 | prompt_replacement (Optional[Tuple[str, str]], optional): Replacement for positive and negative prompt. Defaults to None. |
| 279 | |
| 280 | Returns: |
| 281 | None |
| 282 | """ |
| 283 | |
| 284 | # encode prompts |
| 285 | tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() |
| 286 | encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() |
| 287 | |
| 288 | assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy) |
| 289 | assert isinstance(encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy) |
| 290 | |
| 291 | text_conds = [] |
| 292 | |
| 293 | # assuming seed, width, height, sample steps, guidance are the same |
| 294 | width = int(prompt_dicts[0].get("width", 1024)) |
| 295 | height = int(prompt_dicts[0].get("height", 1024)) |
| 296 | height = max(64, height - height % 8) # round to divisible by 8 |
| 297 | width = max(64, width - width % 8) # round to divisible by 8 |
| 298 | |
| 299 | guidance_scale = float(prompt_dicts[0].get("scale", 3.5)) |
| 300 | cfg_trunc_ratio = float(prompt_dicts[0].get("cfg_trunc_ratio", 0.25)) |
| 301 | renorm_cfg = float(prompt_dicts[0].get("renorm_cfg", 1.0)) |
| 302 | sample_steps = int(prompt_dicts[0].get("sample_steps", 36)) |
| 303 | seed = prompt_dicts[0].get("seed", None) |
| 304 | seed = int(seed) if seed is not None else None |
| 305 | assert seed is None or seed > 0, f"Invalid seed {seed}" |
| 306 | generator = torch.Generator(device=accelerator.device) |
| 307 | if seed is not None: |
no test coverage detected