MCPcopy
hub / github.com/karpathy/nanochat / generate_batch

Method generate_batch

nanochat/engine.py:282–304  ·  view source on GitHub ↗

Non-streaming batch generation that just returns the final token sequences. Returns a list of token sequences (list of lists of ints). Terminal tokens (assistant_end, bos) are not included in the results.

(self, tokens, num_samples=1, **kwargs)

Source from the content-addressed store, hash-verified

280 logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size)
281
282 def generate_batch(self, tokens, num_samples=1, **kwargs):
283 """
284 Non-streaming batch generation that just returns the final token sequences.
285 Returns a list of token sequences (list of lists of ints).
286 Terminal tokens (assistant_end, bos) are not included in the results.
287 """
288 assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
289 bos = self.tokenizer.get_bos_token_id()
290 results = [tokens.copy() for _ in range(num_samples)]
291 masks = [[0] * len(tokens) for _ in range(num_samples)]
292 completed = [False] * num_samples
293 for token_column, token_masks in self.generate(tokens, num_samples, **kwargs):
294 for i, (token, mask) in enumerate(zip(token_column, token_masks)):
295 if not completed[i]:
296 if token == assistant_end or token == bos:
297 completed[i] = True
298 else:
299 results[i].append(token)
300 masks[i].append(mask)
301 # Stop if all rows are completed
302 if all(completed):
303 break
304 return results, masks
305
306
307if __name__ == "__main__":

Callers 10

mainFunction · 0.95
test_num_samples_countFunction · 0.95
run_generative_evalFunction · 0.80
get_batchFunction · 0.80
run_gsm8k_evalFunction · 0.80
base_train.pyFile · 0.80

Calls 3

generateMethod · 0.95
encode_specialMethod · 0.45
get_bos_token_idMethod · 0.45