MCPcopy
hub / github.com/tinygrad/tinygrad / fetch_cifar

Function fetch_cifar

extra/datasets/__init__.py:19–43  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

17cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628]
18
19def fetch_cifar():
20 X_train = Tensor.empty(50000, 3*32*32, device=f'disk:/tmp/cifar_train_x', dtype=dtypes.uint8)
21 Y_train = Tensor.empty(50000, device=f'disk:/tmp/cifar_train_y', dtype=dtypes.int64)
22 X_test = Tensor.empty(10000, 3*32*32, device=f'disk:/tmp/cifar_test_x', dtype=dtypes.uint8)
23 Y_test = Tensor.empty(10000, device=f'disk:/tmp/cifar_test_y', dtype=dtypes.int64)
24
25 if not os.path.isfile("/tmp/cifar_extracted"):
26 def _load_disk_tensor(X, Y, db_list):
27 idx = 0
28 for db in db_list:
29 x, y = db[b'data'], np.array(db[b'labels'])
30 assert x.shape[0] == y.shape[0]
31 X[idx:idx+x.shape[0]].assign(x)
32 Y[idx:idx+x.shape[0]].assign(y)
33 idx += x.shape[0]
34 assert idx == X.shape[0] and X.shape[0] == Y.shape[0]
35
36 print("downloading and extracting CIFAR...")
37 fn = fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')
38 tt = tarfile.open(fn, mode='r:gz')
39 _load_disk_tensor(X_train, Y_train, [pickle.load(tt.extractfile(f'cifar-10-batches-py/data_batch_{i}'), encoding="bytes") for i in range(1,6)])
40 _load_disk_tensor(X_test, Y_test, [pickle.load(tt.extractfile('cifar-10-batches-py/test_batch'), encoding="bytes")])
41 open("/tmp/cifar_extracted", "wb").close()
42
43 return X_train, Y_train, X_test, Y_test

Callers

nothing calls this directly

Calls 6

fetchFunction · 0.90
_load_disk_tensorFunction · 0.85
emptyMethod · 0.45
openMethod · 0.45
loadMethod · 0.45
closeMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…