MCPcopy
hub / github.com/meta-pytorch/opacus / from_data_loader

Method from_data_loader

opacus/data_loader.py:307–363  ·  view source on GitHub ↗

Creates new ``DPDataLoader`` based on passed ``data_loader`` argument. Args: data_loader: Any DataLoader instance. Must not be over an ``IterableDataset`` distributed: set ``True`` if you'll be using DPDataLoader in a DDP environment generator: R

(
        cls,
        data_loader: DataLoader,
        *,
        distributed: bool = False,
        generator=None,
        batch_first: bool = True,
        rand_on_empty: bool = False,
    )

Source from the content-addressed store, hash-verified

305
306 @classmethod
307 def from_data_loader(
308 cls,
309 data_loader: DataLoader,
310 *,
311 distributed: bool = False,
312 generator=None,
313 batch_first: bool = True,
314 rand_on_empty: bool = False,
315 ):
316 """
317 Creates new ``DPDataLoader`` based on passed ``data_loader`` argument.
318
319 Args:
320 data_loader: Any DataLoader instance. Must not be over an ``IterableDataset``
321 distributed: set ``True`` if you'll be using DPDataLoader in a DDP environment
322 generator: Random number generator used to sample elements. Defaults to
323 generator from the original data loader.
324 batch_first: Flag to indicate if the input tensor to the corresponding module
325 has the first dimension representing the batch. If set to True, dimensions on
326 input tensor are expected be ``[batch_size, ...]``, otherwise
327 ``[K, batch_size, ...]``
328 rand_on_empty: set ``True`` to return a batch containing random numbers when encountering
329 empty batches rather than tensors with zero-length batch dimensions
330
331
332
333 Returns:
334 New DPDataLoader instance, with all attributes and parameters inherited
335 from the original data loader, except for sampling mechanism.
336
337 Examples:
338 >>> x, y = torch.randn(64, 5), torch.randint(0, 2, (64,))
339 >>> dataset = TensorDataset(x,y)
340 >>> data_loader = DataLoader(dataset, batch_size=4)
341 >>> dp_data_loader = DPDataLoader.from_data_loader(data_loader)
342 """
343
344 if isinstance(data_loader.dataset, IterableDataset):
345 raise ValueError("Uniform sampling is not supported for IterableDataset")
346
347 return cls(
348 dataset=data_loader.dataset,
349 sample_rate=1 / len(data_loader),
350 num_workers=data_loader.num_workers,
351 collate_fn=data_loader.collate_fn,
352 pin_memory=data_loader.pin_memory,
353 drop_last=data_loader.drop_last,
354 timeout=data_loader.timeout,
355 worker_init_fn=data_loader.worker_init_fn,
356 multiprocessing_context=data_loader.multiprocessing_context,
357 generator=generator if generator else data_loader.generator,
358 prefetch_factor=data_loader.prefetch_factor,
359 persistent_workers=data_loader.persistent_workers,
360 distributed=distributed,
361 batch_first=batch_first,
362 rand_on_empty=rand_on_empty,
363 )
364

Callers 4

train_dataloaderMethod · 0.80
_prepare_data_loaderMethod · 0.80
_read_all_dpMethod · 0.80
test_drop_last_trueMethod · 0.80

Calls

no outgoing calls

Tested by 2

_read_all_dpMethod · 0.64
test_drop_last_trueMethod · 0.64