BOS-aligned dataloader with best-fit packing. Every row starts with BOS. Documents packed using best-fit to minimize cropping. When no document fits remaining space, crops shortest doc to fill exactly. 100% utilization (no padding).
(tokenizer, B, T, split, buffer_size=1000)
| 274 | |
| 275 | |
| 276 | def make_dataloader(tokenizer, B, T, split, buffer_size=1000): |
| 277 | """ |
| 278 | BOS-aligned dataloader with best-fit packing. |
| 279 | Every row starts with BOS. Documents packed using best-fit to minimize cropping. |
| 280 | When no document fits remaining space, crops shortest doc to fill exactly. |
| 281 | 100% utilization (no padding). |
| 282 | """ |
| 283 | assert split in ["train", "val"] |
| 284 | row_capacity = T + 1 |
| 285 | batches = _document_batches(split) |
| 286 | bos_token = tokenizer.get_bos_token_id() |
| 287 | doc_buffer = [] |
| 288 | epoch = 1 |
| 289 | |
| 290 | def refill_buffer(): |
| 291 | nonlocal epoch |
| 292 | doc_batch, epoch = next(batches) |
| 293 | token_lists = tokenizer.encode(doc_batch, prepend=bos_token) |
| 294 | doc_buffer.extend(token_lists) |
| 295 | |
| 296 | # Pre-allocate buffers: [inputs (B*T) | targets (B*T)] |
| 297 | row_buffer = torch.empty((B, row_capacity), dtype=torch.long) |
| 298 | cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=True) |
| 299 | gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device="cuda") |
| 300 | cpu_inputs = cpu_buffer[:B * T].view(B, T) |
| 301 | cpu_targets = cpu_buffer[B * T:].view(B, T) |
| 302 | inputs = gpu_buffer[:B * T].view(B, T) |
| 303 | targets = gpu_buffer[B * T:].view(B, T) |
| 304 | |
| 305 | while True: |
| 306 | for row_idx in range(B): |
| 307 | pos = 0 |
| 308 | while pos < row_capacity: |
| 309 | while len(doc_buffer) < buffer_size: |
| 310 | refill_buffer() |
| 311 | |
| 312 | remaining = row_capacity - pos |
| 313 | |
| 314 | # Find largest doc that fits entirely |
| 315 | best_idx = -1 |
| 316 | best_len = 0 |
| 317 | for i, doc in enumerate(doc_buffer): |
| 318 | doc_len = len(doc) |
| 319 | if doc_len <= remaining and doc_len > best_len: |
| 320 | best_idx = i |
| 321 | best_len = doc_len |
| 322 | |
| 323 | if best_idx >= 0: |
| 324 | doc = doc_buffer.pop(best_idx) |
| 325 | row_buffer[row_idx, pos:pos + len(doc)] = torch.tensor(doc, dtype=torch.long) |
| 326 | pos += len(doc) |
| 327 | else: |
| 328 | # No doc fits — crop shortest to fill remaining |
| 329 | shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i])) |
| 330 | doc = doc_buffer.pop(shortest_idx) |
| 331 | row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long) |
| 332 | pos += remaining |
| 333 |
no test coverage detected