(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, **sampling_kwargs)
| 125 | |
| 126 | @torch.no_grad() |
| 127 | def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, **sampling_kwargs): |
| 128 | if model.model_type == 'c2i': |
| 129 | if cfg_scale > 1.0: |
| 130 | cond_null = torch.ones_like(cond) * model.num_classes |
| 131 | cond_combined = torch.cat([cond, cond_null]) |
| 132 | else: |
| 133 | cond_combined = cond |
| 134 | T = 1 |
| 135 | elif model.model_type == 't2i': |
| 136 | if cfg_scale > 1.0: |
| 137 | cond_null = torch.zeros_like(cond) + model.cls_embedding.uncond_embedding |
| 138 | cond_combined = torch.cat([cond, cond_null]) |
| 139 | else: |
| 140 | cond_combined = cond |
| 141 | T = cond.shape[1] |
| 142 | else: |
| 143 | raise Exception("please check model type") |
| 144 | |
| 145 | T_new = T + max_new_tokens |
| 146 | max_seq_length = T_new |
| 147 | max_batch_size = cond.shape[0] |
| 148 | |
| 149 | device = cond.device |
| 150 | with torch.device(device): |
| 151 | max_batch_size_cfg = max_batch_size * 2 if cfg_scale > 1.0 else max_batch_size |
| 152 | model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=model.tok_embeddings.weight.dtype) |
| 153 | |
| 154 | if emb_masks is not None: |
| 155 | assert emb_masks.shape[0] == max_batch_size |
| 156 | assert emb_masks.shape[-1] == T |
| 157 | if cfg_scale > 1.0: |
| 158 | model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1) |
| 159 | else: |
| 160 | model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1) |
| 161 | |
| 162 | eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device) |
| 163 | model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix |
| 164 | |
| 165 | # create an empty tensor of the expected final shape and fill in the current tokens |
| 166 | seq = torch.empty((max_batch_size, T_new), dtype=torch.int, device=device) |
| 167 | |
| 168 | input_pos = torch.arange(0, T, device=device) |
| 169 | next_token = prefill(model, cond_combined, input_pos, cfg_scale, **sampling_kwargs) |
| 170 | seq[:, T:T+1] = next_token |
| 171 | |
| 172 | input_pos = torch.tensor([T], device=device, dtype=torch.int) |
| 173 | generated_tokens, _ = decode_n_tokens(model, next_token, input_pos, max_new_tokens-1, cfg_scale, cfg_interval, **sampling_kwargs) |
| 174 | seq[:, T+1:] = torch.cat(generated_tokens, dim=1) |
| 175 | |
| 176 | return seq[:, T:] |
no test coverage detected