(
model,
vae_model,
text_tokenizer,
config,
global_step,
device,
weight_type,
sampler,
showo_token_ids,
)
| 560 | |
| 561 | @torch.no_grad() |
| 562 | def generate_images( |
| 563 | model, |
| 564 | vae_model, |
| 565 | text_tokenizer, |
| 566 | config, |
| 567 | global_step, |
| 568 | device, |
| 569 | weight_type, |
| 570 | sampler, |
| 571 | showo_token_ids, |
| 572 | ): |
| 573 | logger.info("Generating images...") |
| 574 | model.eval() |
| 575 | |
| 576 | # read validation prompts from file |
| 577 | with open(config.dataset.params.validation_prompts_file, "r") as f: |
| 578 | prompts = f.read().splitlines()[:config.training.batch_size_t2i] |
| 579 | |
| 580 | num_t2i_image_tokens, num_mmu_image_tokens, num_video_tokens, max_seq_len, max_text_len, image_latent_dim, patch_size, latent_width, \ |
| 581 | latent_height, pad_id, bos_id, eos_id, boi_id, eoi_id, bov_id, eov_id, image_pad_id, video_pad_id, guidance_scale \ |
| 582 | = get_hyper_params(config, text_tokenizer, showo_token_ids) |
| 583 | |
| 584 | batch_text_tokens, batch_text_tokens_null, batch_modality_positions, batch_modality_positions_null = \ |
| 585 | prepare_gen_input( |
| 586 | prompts, text_tokenizer, num_t2i_image_tokens, bos_id, eos_id, boi_id, eoi_id, pad_id, image_pad_id, |
| 587 | max_text_len, device |
| 588 | ) |
| 589 | |
| 590 | z = torch.randn((len(prompts), |
| 591 | image_latent_dim, latent_height * patch_size, |
| 592 | latent_width * patch_size)).to(weight_type).to(device) |
| 593 | |
| 594 | if guidance_scale > 0: |
| 595 | z = torch.cat([z, z], dim=0) |
| 596 | text_tokens = torch.cat([batch_text_tokens, batch_text_tokens_null], dim=0) |
| 597 | modality_positions = torch.cat([batch_modality_positions, batch_modality_positions_null], dim=0) |
| 598 | # B=None would potentially induce loss spike when there are a lot of ignored labels (-100) in the batch |
| 599 | # we must set B=text_tokens.shape[0] (loss spike may still happen sometimes) |
| 600 | # omni_mask_fn = omni_attn_mask(modality_positions) |
| 601 | # block_mask = create_block_mask(omni_mask_fn, B=z.size(0), H=None, Q_LEN=max_seq_len, |
| 602 | # KV_LEN=max_seq_len, device=device) |
| 603 | # or use naive omni attention mask, which is more stable |
| 604 | block_mask = omni_attn_mask_naive(text_tokens.size(0), |
| 605 | max_seq_len, |
| 606 | modality_positions, |
| 607 | device).to(weight_type) |
| 608 | else: |
| 609 | text_tokens = batch_text_tokens |
| 610 | modality_positions = batch_modality_positions |
| 611 | # B=None would potentially induce loss spike when there are a lot of ignored labels (-100) in the batch |
| 612 | # we must set B=text_tokens.shape[0] (loss spike may still happen sometimes) |
| 613 | # omni_mask_fn = omni_attn_mask(modality_positions) |
| 614 | # block_mask = create_block_mask(omni_mask_fn, B=z.size(0), H=None, Q_LEN=max_seq_len, |
| 615 | # KV_LEN=max_seq_len, device=device) |
| 616 | block_mask = omni_attn_mask_naive(text_tokens.size(0), |
| 617 | max_seq_len, |
| 618 | modality_positions, |
| 619 | device).to(weight_type) |
nothing calls this directly
no test coverage detected