r"""Sampler that restricts data loading to a subset of the dataset. It is especially useful in conjunction with :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a :class:`~torch.util
| 15 | |
| 16 | |
| 17 | class MyDistributedSampler(Sampler[T_co]): |
| 18 | r"""Sampler that restricts data loading to a subset of the dataset. |
| 19 | |
| 20 | It is especially useful in conjunction with |
| 21 | :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each |
| 22 | process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a |
| 23 | :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the |
| 24 | original dataset that is exclusive to it. |
| 25 | |
| 26 | .. note:: |
| 27 | Dataset is assumed to be of constant size and that any instance of it always |
| 28 | returns the same elements in the same order. |
| 29 | |
| 30 | Args: |
| 31 | dataset: Dataset used for sampling. |
| 32 | num_replicas (int, optional): Number of processes participating in |
| 33 | distributed training. By default, :attr:`world_size` is retrieved from the |
| 34 | current distributed group. |
| 35 | rank (int, optional): Rank of the current process within :attr:`num_replicas`. |
| 36 | By default, :attr:`rank` is retrieved from the current distributed |
| 37 | group. |
| 38 | shuffle (bool, optional): If ``True`` (default), sampler will shuffle the |
| 39 | indices. |
| 40 | seed (int, optional): random seed used to shuffle the sampler if |
| 41 | :attr:`shuffle=True`. This number should be identical across all |
| 42 | processes in the distributed group. Default: ``0``. |
| 43 | drop_last (bool, optional): if ``True``, then the sampler will drop the |
| 44 | tail of the data to make it evenly divisible across the number of |
| 45 | replicas. If ``False``, the sampler will add extra indices to make |
| 46 | the data evenly divisible across the replicas. Default: ``False``. |
| 47 | padding: (bool, optional): Whether to pad the dataset. Default: ``True``. |
| 48 | pair: (bool, optional): Pair output for Mixup. Default: ``False``. |
| 49 | |
| 50 | .. warning:: |
| 51 | In distributed mode, calling the :meth:`set_epoch` method at |
| 52 | the beginning of each epoch **before** creating the :class:`DataLoader` iterator |
| 53 | is necessary to make shuffling work properly across multiple epochs. Otherwise, |
| 54 | the same ordering will be always used. |
| 55 | |
| 56 | Example:: |
| 57 | |
| 58 | >>> sampler = DistributedSampler(dataset) if is_distributed else None |
| 59 | >>> loader = DataLoader(dataset, shuffle=(sampler is None), |
| 60 | ... sampler=sampler) |
| 61 | >>> for epoch in range(start_epoch, n_epochs): |
| 62 | ... if is_distributed: |
| 63 | ... sampler.set_epoch(epoch) |
| 64 | ... train(loader) |
| 65 | """ |
| 66 | |
| 67 | def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, |
| 68 | rank: Optional[int] = None, shuffle: bool = True, |
| 69 | seed: int = 0, drop_last: bool = False, |
| 70 | padding: bool = True, |
| 71 | pair: bool = False) -> None: |
| 72 | if num_replicas is None: |
| 73 | if not dist.is_available(): |
| 74 | num_replicas = 1 |