(self,
path, # Path to directory or zip.
resolution = None, # Ensure specific resolution, None = highest available.
**super_kwargs, # Additional arguments for the Dataset base class.
)
| 173 | |
| 174 | class ImageFolderDataset(Dataset): |
| 175 | def __init__(self, |
| 176 | path, # Path to directory or zip. |
| 177 | resolution = None, # Ensure specific resolution, None = highest available. |
| 178 | **super_kwargs, # Additional arguments for the Dataset base class. |
| 179 | ): |
| 180 | self._path = path |
| 181 | self._zipfile = None |
| 182 | |
| 183 | if os.path.isdir(self._path): |
| 184 | self._type = 'dir' |
| 185 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} |
| 186 | elif self._file_ext(self._path) == '.zip': |
| 187 | self._type = 'zip' |
| 188 | self._all_fnames = set(self._get_zipfile().namelist()) |
| 189 | else: |
| 190 | raise IOError('Path must point to a directory or zip') |
| 191 | |
| 192 | PIL.Image.init() |
| 193 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) |
| 194 | if len(self._image_fnames) == 0: |
| 195 | raise IOError('No image files found in the specified path') |
| 196 | |
| 197 | name = os.path.splitext(os.path.basename(self._path))[0] |
| 198 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) |
| 199 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): |
| 200 | raise IOError(f'Image files do not match the specified resolution. Resolution is {resolution}, shape is {raw_shape}') |
| 201 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) |
| 202 | |
| 203 | def _get_zipfile(self): |
| 204 | assert self._type == 'zip' |
nothing calls this directly
no test coverage detected