A data sampler for distributed data parallelism. Args: dataset (:class:`torch.utils.data.Dataset`): The Dataset for sampling. shuffle (bool, optional): Whether to shuffle data, defaults to False. seed (int, optional): The random seed used for sampling, defaults to 0.
| 19 | |
| 20 | |
| 21 | class DataParallelSampler(Sampler): |
| 22 | """A data sampler for distributed data parallelism. |
| 23 | |
| 24 | Args: |
| 25 | dataset (:class:`torch.utils.data.Dataset`): The Dataset for sampling. |
| 26 | shuffle (bool, optional): Whether to shuffle data, defaults to False. |
| 27 | seed (int, optional): The random seed used for sampling, defaults to 0. |
| 28 | drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size |
| 29 | is not divisible by the batch size. If False and the size of dataset is not divisible by |
| 30 | the batch size, then the last batch will be smaller, defaults to False. |
| 31 | """ |
| 32 | |
| 33 | def __init__( |
| 34 | self, |
| 35 | dataset: Dataset, |
| 36 | shuffle: bool = False, |
| 37 | seed: int = 0, |
| 38 | drop_last: bool = False, |
| 39 | ) -> None: |
| 40 | self.dataset = dataset |
| 41 | self.num_replicas = gpc.get_world_size(ParallelMode.DATA) |
| 42 | self.rank = gpc.get_local_rank(ParallelMode.DATA) |
| 43 | self.epoch = 0 |
| 44 | self.drop_last = drop_last |
| 45 | # If the dataset length is evenly divisible by # of replicas, then there |
| 46 | # is no need to drop any data, since the dataset will be split equally. |
| 47 | # type: ignore[arg-type] |
| 48 | if self.drop_last and len(self.dataset) % self.num_replicas != 0: |
| 49 | # Split to nearest available length that is evenly divisible. |
| 50 | # This is to ensure each rank receives the same amount of data when |
| 51 | # using this Sampler. |
| 52 | self.num_samples = math.ceil( |
| 53 | # `type:ignore` is required because Dataset cannot provide a default __len__ |
| 54 | # see NOTE in pytorch/torch/utils/data/sampler.py |
| 55 | (len(self.dataset) - self.num_replicas) |
| 56 | / self.num_replicas # type: ignore[arg-type] |
| 57 | ) |
| 58 | else: |
| 59 | self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] |
| 60 | self.total_size = self.num_samples * self.num_replicas |
| 61 | self.shuffle = shuffle |
| 62 | self.seed = seed |
| 63 | |
| 64 | def __iter__(self) -> Iterator[T_co]: |
| 65 | if self.shuffle: |
| 66 | # deterministically shuffle based on epoch and seed |
| 67 | g = torch.Generator() |
| 68 | g.manual_seed(self.seed + self.epoch) |
| 69 | # type: ignore[arg-type] |
| 70 | indices = torch.randperm(len(self.dataset), generator=g).tolist() |
| 71 | |
| 72 | # update for next epoch so that there is no need to call |
| 73 | # set_epoch manually |
| 74 | self.epoch += 1 |
| 75 | else: |
| 76 | indices = list(range(len(self.dataset))) # type: ignore[arg-type] |
| 77 | |
| 78 | if not self.drop_last: |
no outgoing calls
no test coverage detected