| 205 | |
| 206 | # Provide functionality for creating data loaders based on provided dataset configurations |
| 207 | class DataModuleFromConfig(pl.LightningDataModule): |
| 208 | def __init__( |
| 209 | self, |
| 210 | batch_size, |
| 211 | train=None, |
| 212 | validation=None, |
| 213 | test=None, |
| 214 | predict=None, |
| 215 | wrap=False, |
| 216 | num_workers=None, |
| 217 | shuffle_test_loader=False, |
| 218 | use_worker_init_fn=False, |
| 219 | shuffle_val_dataloader=False, |
| 220 | ): |
| 221 | super().__init__() |
| 222 | # Set data module attributes |
| 223 | self.batch_size = batch_size |
| 224 | self.dataset_configs = dict() |
| 225 | self.num_workers = num_workers if num_workers is not None else batch_size * 2 |
| 226 | self.use_worker_init_fn = use_worker_init_fn |
| 227 | # If a dataset is passed, add it to the dataset configs and create a corresponding dataloader method |
| 228 | if train is not None: |
| 229 | self.dataset_configs["train"] = train |
| 230 | self.train_dataloader = self._train_dataloader |
| 231 | if validation is not None: |
| 232 | self.dataset_configs["validation"] = validation |
| 233 | self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) |
| 234 | if test is not None: |
| 235 | self.dataset_configs["test"] = test |
| 236 | self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) |
| 237 | if predict is not None: |
| 238 | self.dataset_configs["predict"] = predict |
| 239 | self.predict_dataloader = self._predict_dataloader |
| 240 | self.wrap = wrap |
| 241 | |
| 242 | def prepare_data(self): |
| 243 | # Instantiate datasets |
| 244 | for data_cfg in self.dataset_configs.values(): |
| 245 | instantiate_from_config(data_cfg) |
| 246 | |
| 247 | def setup(self, stage=None): |
| 248 | # Instantiate datasets from the dataset configs |
| 249 | self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) |
| 250 | |
| 251 | # If wrap is true, create a WrappedDataset for each dataset |
| 252 | if self.wrap: |
| 253 | for k in self.datasets: |
| 254 | self.datasets[k] = WrappedDataset(self.datasets[k]) |
| 255 | |
| 256 | def _train_dataloader(self): |
| 257 | # Check if the train dataset is iterable |
| 258 | is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset) |
| 259 | # Set the worker initialization function of the dataset is iterable or use_worker_init_fn is True |
| 260 | if is_iterable_dataset or self.use_worker_init_fn: |
| 261 | init_fn = worker_init_fn |
| 262 | else: |
| 263 | init_fn = None |
| 264 | # Return a DataLoaderX object for the train dataset |
no outgoing calls
no test coverage detected
searching dependent graphs…