()
| 288 | epoch = 1 |
| 289 | |
| 290 | def refill_buffer(): |
| 291 | nonlocal epoch |
| 292 | doc_batch, epoch = next(batches) |
| 293 | token_lists = tokenizer.encode(doc_batch, prepend=bos_token) |
| 294 | doc_buffer.extend(token_lists) |
| 295 | |
| 296 | # Pre-allocate buffers: [inputs (B*T) | targets (B*T)] |
| 297 | row_buffer = torch.empty((B, row_capacity), dtype=torch.long) |
no test coverage detected