MCPcopy Index your code
hub / github.com/zai-org/CodeGeeX / sample_sequence_batch

Function sample_sequence_batch

codegeex/torch/inference.py:217–326  ·  view source on GitHub ↗
(
        model,
        tokenizer,
        context_tokens,
        context_lengths,
        attention_mask,
        position_ids,
        seq_length,
        out_seq_length,
        maxlen=None,
        return_scores: bool = False,
        prompt_length: int = None,
        bad_ids: List = None,
        temperature: float = 1.0,
        topp: float = 1.0,
        topk: int = 0.0,
        recompute: bool = False,
        greedy: bool = False,
)

Source from the content-addressed store, hash-verified

215
216
217def sample_sequence_batch(
218 model,
219 tokenizer,
220 context_tokens,
221 context_lengths,
222 attention_mask,
223 position_ids,
224 seq_length,
225 out_seq_length,
226 maxlen=None,
227 return_scores: bool = False,
228 prompt_length: int = None,
229 bad_ids: List = None,
230 temperature: float = 1.0,
231 topp: float = 1.0,
232 topk: int = 0.0,
233 recompute: bool = False,
234 greedy: bool = False,
235):
236 model.eval()
237 with torch.no_grad():
238 context_length = context_lengths.min().item()
239 eos_id = tokenizer.eos_token_id
240
241 counter = 0
242 org_context_length = context_length
243
244 layer_past = None
245 batch_size = context_tokens.size(0)
246 is_done = torch.zeros([batch_size]).byte().cuda()
247 tokens = context_tokens
248 if maxlen is None:
249 maxlen = seq_length - 1
250 if maxlen > (org_context_length + out_seq_length):
251 maxlen = org_context_length + out_seq_length
252
253 lengths = torch.ones([batch_size]).long().cuda() * maxlen
254 if return_scores:
255 scores = torch.zeros([batch_size]).float().cuda()
256
257 while context_length <= (maxlen):
258
259 if recompute:
260 logits = model(tokens,
261 position_ids,
262 attention_mask,
263 prompt_length=prompt_length,
264 context_length=context_length,
265 )
266 logits = logits[:, context_length - 1, :]
267 else:
268 if counter == 0:
269 tokens2use = tokens[:, :context_length]
270 positions2use = position_ids[:, :context_length]
271 else:
272 tokens2use = tokens[:, context_length - 1].view(
273 batch_size, -1)
274 positions2use = position_ids[:, context_length - 1].view(

Callers 1

get_token_streamFunction · 0.70

Calls 4

sizeMethod · 0.80
top_k_logitsFunction · 0.70
switchFunction · 0.70
evalMethod · 0.45

Tested by

no test coverage detected