| 35 | data_dir = os.path.join('data', dataset) |
| 36 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') |
| 37 | def get_batch(split): |
| 38 | data = train_data # note ignore split in benchmarking script |
| 39 | ix = torch.randint(len(data) - block_size, (batch_size,)) |
| 40 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) |
| 41 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) |
| 42 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) |
| 43 | return x, y |
| 44 | else: |
| 45 | # alternatively, if fixed data is desired to not care about data loading |
| 46 | x = torch.randint(50304, (batch_size, block_size), device=device) |