(batch, pad_id, args)
| 538 | |
| 539 | |
| 540 | def pad_batch(batch, pad_id, args): |
| 541 | context_lengths = [] |
| 542 | for tokens in batch: |
| 543 | context_length = len(tokens) |
| 544 | if context_length < args.seq_length: |
| 545 | tokens.extend([pad_id] * (args.seq_length - context_length)) |
| 546 | context_lengths.append(context_length) |
| 547 | return batch, context_lengths |
| 548 | |
| 549 | |
| 550 | def topk_sampling(logits: torch.FloatTensor, num_samples: int): |
no outgoing calls
no test coverage detected