MCPcopy Index your code
hub / github.com/hpcaitech/ColossalAI / DataModuleFromConfig

Class DataModuleFromConfig

examples/images/diffusion/main.py:207–315  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

205
206# Provide functionality for creating data loaders based on provided dataset configurations
207class DataModuleFromConfig(pl.LightningDataModule):
208 def __init__(
209 self,
210 batch_size,
211 train=None,
212 validation=None,
213 test=None,
214 predict=None,
215 wrap=False,
216 num_workers=None,
217 shuffle_test_loader=False,
218 use_worker_init_fn=False,
219 shuffle_val_dataloader=False,
220 ):
221 super().__init__()
222 # Set data module attributes
223 self.batch_size = batch_size
224 self.dataset_configs = dict()
225 self.num_workers = num_workers if num_workers is not None else batch_size * 2
226 self.use_worker_init_fn = use_worker_init_fn
227 # If a dataset is passed, add it to the dataset configs and create a corresponding dataloader method
228 if train is not None:
229 self.dataset_configs["train"] = train
230 self.train_dataloader = self._train_dataloader
231 if validation is not None:
232 self.dataset_configs["validation"] = validation
233 self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
234 if test is not None:
235 self.dataset_configs["test"] = test
236 self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
237 if predict is not None:
238 self.dataset_configs["predict"] = predict
239 self.predict_dataloader = self._predict_dataloader
240 self.wrap = wrap
241
242 def prepare_data(self):
243 # Instantiate datasets
244 for data_cfg in self.dataset_configs.values():
245 instantiate_from_config(data_cfg)
246
247 def setup(self, stage=None):
248 # Instantiate datasets from the dataset configs
249 self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
250
251 # If wrap is true, create a WrappedDataset for each dataset
252 if self.wrap:
253 for k in self.datasets:
254 self.datasets[k] = WrappedDataset(self.datasets[k])
255
256 def _train_dataloader(self):
257 # Check if the train dataset is iterable
258 is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset)
259 # Set the worker initialization function of the dataset is iterable or use_worker_init_fn is True
260 if is_iterable_dataset or self.use_worker_init_fn:
261 init_fn = worker_init_fn
262 else:
263 init_fn = None
264 # Return a DataLoaderX object for the train dataset

Callers 1

main.pyFile · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…