MCPcopy
hub / github.com/karpathy/nanoGPT / get_batch

Function get_batch

bench.py:37–43  ·  view source on GitHub ↗
(split)

Source from the content-addressed store, hash-verified

35 data_dir = os.path.join('data', dataset)
36 train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
37 def get_batch(split):
38 data = train_data # note ignore split in benchmarking script
39 ix = torch.randint(len(data) - block_size, (batch_size,))
40 x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
41 y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
42 x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
43 return x, y
44else:
45 # alternatively, if fixed data is desired to not care about data loading
46 x = torch.randint(50304, (batch_size, block_size), device=device)

Callers 1

bench.pyFile · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected