(tarball: str, *, max_images: Optional[int])
| 157 | #---------------------------------------------------------------------------- |
| 158 | |
| 159 | def open_cifar10(tarball: str, *, max_images: Optional[int]): |
| 160 | images = [] |
| 161 | labels = [] |
| 162 | |
| 163 | with tarfile.open(tarball, 'r:gz') as tar: |
| 164 | for batch in range(1, 6): |
| 165 | member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}') |
| 166 | with tar.extractfile(member) as file: |
| 167 | data = pickle.load(file, encoding='latin1') |
| 168 | images.append(data['data'].reshape(-1, 3, 32, 32)) |
| 169 | labels.append(data['labels']) |
| 170 | |
| 171 | images = np.concatenate(images) |
| 172 | labels = np.concatenate(labels) |
| 173 | images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC |
| 174 | assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8 |
| 175 | assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64] |
| 176 | assert np.min(images) == 0 and np.max(images) == 255 |
| 177 | assert np.min(labels) == 0 and np.max(labels) == 9 |
| 178 | |
| 179 | max_idx = maybe_min(len(images), max_images) |
| 180 | |
| 181 | def iterate_images(): |
| 182 | for idx, img in enumerate(images): |
| 183 | yield dict(img=img, label=int(labels[idx])) |
| 184 | if idx >= max_idx-1: |
| 185 | break |
| 186 | |
| 187 | return max_idx, iterate_images() |
| 188 | |
| 189 | #---------------------------------------------------------------------------- |
| 190 |
no test coverage detected