Prepare dataloader.
(self, **kwargs)
| 287 | self.grad_clip = getattr(grad_clip_utils, self.grad_clip['name'])(**self.grad_clip['args']) |
| 288 | |
| 289 | def prepare_dataloader(self, **kwargs): |
| 290 | """ |
| 291 | Prepare dataloader. |
| 292 | """ |
| 293 | self.data_sampler = ResumableSampler( |
| 294 | self.dataset, |
| 295 | shuffle=True, |
| 296 | ) |
| 297 | if self.num_workers is None or self.num_workers == -1: |
| 298 | num_workers = max(1, int(np.ceil((os.cpu_count() - 16) / torch.cuda.device_count()))) |
| 299 | else: |
| 300 | num_workers = self.num_workers |
| 301 | |
| 302 | self.dataloader = DataLoader( |
| 303 | self.dataset, |
| 304 | batch_size=self.batch_size_per_gpu, |
| 305 | num_workers=num_workers, |
| 306 | pin_memory=True, |
| 307 | drop_last=True, |
| 308 | persistent_workers=True, |
| 309 | collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, |
| 310 | sampler=self.data_sampler, |
| 311 | ) |
| 312 | self.data_iterator = cycle(self.dataloader) |
| 313 | |
| 314 | def _master_params_to_state_dicts(self, master_params): |
| 315 | """ |
no test coverage detected