MCPcopy
hub / github.com/SizheAn/PanoHead / open_cifar10

Function open_cifar10

dataset_tool.py:159–187  ·  view source on GitHub ↗
(tarball: str, *, max_images: Optional[int])

Source from the content-addressed store, hash-verified

157#----------------------------------------------------------------------------
158
159def open_cifar10(tarball: str, *, max_images: Optional[int]):
160 images = []
161 labels = []
162
163 with tarfile.open(tarball, 'r:gz') as tar:
164 for batch in range(1, 6):
165 member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}')
166 with tar.extractfile(member) as file:
167 data = pickle.load(file, encoding='latin1')
168 images.append(data['data'].reshape(-1, 3, 32, 32))
169 labels.append(data['labels'])
170
171 images = np.concatenate(images)
172 labels = np.concatenate(labels)
173 images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC
174 assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8
175 assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64]
176 assert np.min(images) == 0 and np.max(images) == 255
177 assert np.min(labels) == 0 and np.max(labels) == 9
178
179 max_idx = maybe_min(len(images), max_images)
180
181 def iterate_images():
182 for idx, img in enumerate(images):
183 yield dict(img=img, label=int(labels[idx]))
184 if idx >= max_idx-1:
185 break
186
187 return max_idx, iterate_images()
188
189#----------------------------------------------------------------------------
190

Callers 1

open_datasetFunction · 0.85

Calls 4

loadMethod · 0.80
appendMethod · 0.80
maybe_minFunction · 0.70
iterate_imagesFunction · 0.70

Tested by

no test coverage detected