MCPcopy
hub / github.com/FoundationVision/LlamaGen / generate

Function generate

autoregressive/models/generate.py:127–176  ·  view source on GitHub ↗
(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, **sampling_kwargs)

Source from the content-addressed store, hash-verified

125
126@torch.no_grad()
127def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, **sampling_kwargs):
128 if model.model_type == 'c2i':
129 if cfg_scale > 1.0:
130 cond_null = torch.ones_like(cond) * model.num_classes
131 cond_combined = torch.cat([cond, cond_null])
132 else:
133 cond_combined = cond
134 T = 1
135 elif model.model_type == 't2i':
136 if cfg_scale > 1.0:
137 cond_null = torch.zeros_like(cond) + model.cls_embedding.uncond_embedding
138 cond_combined = torch.cat([cond, cond_null])
139 else:
140 cond_combined = cond
141 T = cond.shape[1]
142 else:
143 raise Exception("please check model type")
144
145 T_new = T + max_new_tokens
146 max_seq_length = T_new
147 max_batch_size = cond.shape[0]
148
149 device = cond.device
150 with torch.device(device):
151 max_batch_size_cfg = max_batch_size * 2 if cfg_scale > 1.0 else max_batch_size
152 model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=model.tok_embeddings.weight.dtype)
153
154 if emb_masks is not None:
155 assert emb_masks.shape[0] == max_batch_size
156 assert emb_masks.shape[-1] == T
157 if cfg_scale > 1.0:
158 model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1)
159 else:
160 model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1)
161
162 eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device)
163 model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix
164
165 # create an empty tensor of the expected final shape and fill in the current tokens
166 seq = torch.empty((max_batch_size, T_new), dtype=torch.int, device=device)
167
168 input_pos = torch.arange(0, T, device=device)
169 next_token = prefill(model, cond_combined, input_pos, cfg_scale, **sampling_kwargs)
170 seq[:, T:T+1] = next_token
171
172 input_pos = torch.tensor([T], device=device, dtype=torch.int)
173 generated_tokens, _ = decode_n_tokens(model, next_token, input_pos, max_new_tokens-1, cfg_scale, cfg_interval, **sampling_kwargs)
174 seq[:, T+1:] = torch.cat(generated_tokens, dim=1)
175
176 return seq[:, T:]

Callers 4

mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90

Calls 4

prefillFunction · 0.85
decode_n_tokensFunction · 0.85
setup_cachesMethod · 0.80
emptyMethod · 0.45

Tested by

no test coverage detected