| 17 | cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628] |
| 18 | |
| 19 | def fetch_cifar(): |
| 20 | X_train = Tensor.empty(50000, 3*32*32, device=f'disk:/tmp/cifar_train_x', dtype=dtypes.uint8) |
| 21 | Y_train = Tensor.empty(50000, device=f'disk:/tmp/cifar_train_y', dtype=dtypes.int64) |
| 22 | X_test = Tensor.empty(10000, 3*32*32, device=f'disk:/tmp/cifar_test_x', dtype=dtypes.uint8) |
| 23 | Y_test = Tensor.empty(10000, device=f'disk:/tmp/cifar_test_y', dtype=dtypes.int64) |
| 24 | |
| 25 | if not os.path.isfile("/tmp/cifar_extracted"): |
| 26 | def _load_disk_tensor(X, Y, db_list): |
| 27 | idx = 0 |
| 28 | for db in db_list: |
| 29 | x, y = db[b'data'], np.array(db[b'labels']) |
| 30 | assert x.shape[0] == y.shape[0] |
| 31 | X[idx:idx+x.shape[0]].assign(x) |
| 32 | Y[idx:idx+x.shape[0]].assign(y) |
| 33 | idx += x.shape[0] |
| 34 | assert idx == X.shape[0] and X.shape[0] == Y.shape[0] |
| 35 | |
| 36 | print("downloading and extracting CIFAR...") |
| 37 | fn = fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz') |
| 38 | tt = tarfile.open(fn, mode='r:gz') |
| 39 | _load_disk_tensor(X_train, Y_train, [pickle.load(tt.extractfile(f'cifar-10-batches-py/data_batch_{i}'), encoding="bytes") for i in range(1,6)]) |
| 40 | _load_disk_tensor(X_test, Y_test, [pickle.load(tt.extractfile('cifar-10-batches-py/test_batch'), encoding="bytes")]) |
| 41 | open("/tmp/cifar_extracted", "wb").close() |
| 42 | |
| 43 | return X_train, Y_train, X_test, Y_test |