Creates new ``DPDataLoader`` based on passed ``data_loader`` argument. Args: data_loader: Any DataLoader instance. Must not be over an ``IterableDataset`` distributed: set ``True`` if you'll be using DPDataLoader in a DDP environment generator: R
(
cls,
data_loader: DataLoader,
*,
distributed: bool = False,
generator=None,
batch_first: bool = True,
rand_on_empty: bool = False,
)
| 305 | |
| 306 | @classmethod |
| 307 | def from_data_loader( |
| 308 | cls, |
| 309 | data_loader: DataLoader, |
| 310 | *, |
| 311 | distributed: bool = False, |
| 312 | generator=None, |
| 313 | batch_first: bool = True, |
| 314 | rand_on_empty: bool = False, |
| 315 | ): |
| 316 | """ |
| 317 | Creates new ``DPDataLoader`` based on passed ``data_loader`` argument. |
| 318 | |
| 319 | Args: |
| 320 | data_loader: Any DataLoader instance. Must not be over an ``IterableDataset`` |
| 321 | distributed: set ``True`` if you'll be using DPDataLoader in a DDP environment |
| 322 | generator: Random number generator used to sample elements. Defaults to |
| 323 | generator from the original data loader. |
| 324 | batch_first: Flag to indicate if the input tensor to the corresponding module |
| 325 | has the first dimension representing the batch. If set to True, dimensions on |
| 326 | input tensor are expected be ``[batch_size, ...]``, otherwise |
| 327 | ``[K, batch_size, ...]`` |
| 328 | rand_on_empty: set ``True`` to return a batch containing random numbers when encountering |
| 329 | empty batches rather than tensors with zero-length batch dimensions |
| 330 | |
| 331 | |
| 332 | |
| 333 | Returns: |
| 334 | New DPDataLoader instance, with all attributes and parameters inherited |
| 335 | from the original data loader, except for sampling mechanism. |
| 336 | |
| 337 | Examples: |
| 338 | >>> x, y = torch.randn(64, 5), torch.randint(0, 2, (64,)) |
| 339 | >>> dataset = TensorDataset(x,y) |
| 340 | >>> data_loader = DataLoader(dataset, batch_size=4) |
| 341 | >>> dp_data_loader = DPDataLoader.from_data_loader(data_loader) |
| 342 | """ |
| 343 | |
| 344 | if isinstance(data_loader.dataset, IterableDataset): |
| 345 | raise ValueError("Uniform sampling is not supported for IterableDataset") |
| 346 | |
| 347 | return cls( |
| 348 | dataset=data_loader.dataset, |
| 349 | sample_rate=1 / len(data_loader), |
| 350 | num_workers=data_loader.num_workers, |
| 351 | collate_fn=data_loader.collate_fn, |
| 352 | pin_memory=data_loader.pin_memory, |
| 353 | drop_last=data_loader.drop_last, |
| 354 | timeout=data_loader.timeout, |
| 355 | worker_init_fn=data_loader.worker_init_fn, |
| 356 | multiprocessing_context=data_loader.multiprocessing_context, |
| 357 | generator=generator if generator else data_loader.generator, |
| 358 | prefetch_factor=data_loader.prefetch_factor, |
| 359 | persistent_workers=data_loader.persistent_workers, |
| 360 | distributed=distributed, |
| 361 | batch_first=batch_first, |
| 362 | rand_on_empty=rand_on_empty, |
| 363 | ) |
| 364 |
no outgoing calls