MCPcopy
hub / github.com/showlab/Show-o / generate_images

Function generate_images

show-o2/train_mixed_modality_simple.py:562–657  ·  view source on GitHub ↗
(
        model,
        vae_model,
        text_tokenizer,
        config,
        global_step,
        device,
        weight_type,
        sampler,
        showo_token_ids,
)

Source from the content-addressed store, hash-verified

560
561@torch.no_grad()
562def generate_images(
563 model,
564 vae_model,
565 text_tokenizer,
566 config,
567 global_step,
568 device,
569 weight_type,
570 sampler,
571 showo_token_ids,
572):
573 logger.info("Generating images...")
574 model.eval()
575
576 # read validation prompts from file
577 with open(config.dataset.params.validation_prompts_file, "r") as f:
578 prompts = f.read().splitlines()[:config.training.batch_size_t2i]
579
580 num_t2i_image_tokens, num_mmu_image_tokens, num_video_tokens, max_seq_len, max_text_len, image_latent_dim, patch_size, latent_width, \
581 latent_height, pad_id, bos_id, eos_id, boi_id, eoi_id, bov_id, eov_id, image_pad_id, video_pad_id, guidance_scale \
582 = get_hyper_params(config, text_tokenizer, showo_token_ids)
583
584 batch_text_tokens, batch_text_tokens_null, batch_modality_positions, batch_modality_positions_null = \
585 prepare_gen_input(
586 prompts, text_tokenizer, num_t2i_image_tokens, bos_id, eos_id, boi_id, eoi_id, pad_id, image_pad_id,
587 max_text_len, device
588 )
589
590 z = torch.randn((len(prompts),
591 image_latent_dim, latent_height * patch_size,
592 latent_width * patch_size)).to(weight_type).to(device)
593
594 if guidance_scale > 0:
595 z = torch.cat([z, z], dim=0)
596 text_tokens = torch.cat([batch_text_tokens, batch_text_tokens_null], dim=0)
597 modality_positions = torch.cat([batch_modality_positions, batch_modality_positions_null], dim=0)
598 # B=None would potentially induce loss spike when there are a lot of ignored labels (-100) in the batch
599 # we must set B=text_tokens.shape[0] (loss spike may still happen sometimes)
600 # omni_mask_fn = omni_attn_mask(modality_positions)
601 # block_mask = create_block_mask(omni_mask_fn, B=z.size(0), H=None, Q_LEN=max_seq_len,
602 # KV_LEN=max_seq_len, device=device)
603 # or use naive omni attention mask, which is more stable
604 block_mask = omni_attn_mask_naive(text_tokens.size(0),
605 max_seq_len,
606 modality_positions,
607 device).to(weight_type)
608 else:
609 text_tokens = batch_text_tokens
610 modality_positions = batch_modality_positions
611 # B=None would potentially induce loss spike when there are a lot of ignored labels (-100) in the batch
612 # we must set B=text_tokens.shape[0] (loss spike may still happen sometimes)
613 # omni_mask_fn = omni_attn_mask(modality_positions)
614 # block_mask = create_block_mask(omni_mask_fn, B=z.size(0), H=None, Q_LEN=max_seq_len,
615 # KV_LEN=max_seq_len, device=device)
616 block_mask = omni_attn_mask_naive(text_tokens.size(0),
617 max_seq_len,
618 modality_positions,
619 device).to(weight_type)

Callers

nothing calls this directly

Calls 7

get_hyper_paramsFunction · 0.90
prepare_gen_inputFunction · 0.90
omni_attn_mask_naiveFunction · 0.90
denormFunction · 0.90
toMethod · 0.80
sample_odeMethod · 0.80
batch_decodeMethod · 0.80

Tested by

no test coverage detected