r"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not) Note: When pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data on the 1st stage and label on the last stage. Args:
(
dataset,
shuffle=False,
seed=1024,
add_sampler=True,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs,
)
| 108 | |
| 109 | |
| 110 | def get_dpsampler_dataloader( |
| 111 | dataset, |
| 112 | shuffle=False, |
| 113 | seed=1024, |
| 114 | add_sampler=True, |
| 115 | drop_last=False, |
| 116 | pin_memory=False, |
| 117 | num_workers=0, |
| 118 | **kwargs, |
| 119 | ): |
| 120 | r"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not) |
| 121 | |
| 122 | Note: |
| 123 | When pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data |
| 124 | on the 1st stage and label on the last stage. |
| 125 | |
| 126 | Args: |
| 127 | dataset (:class:`torch.utils.data.Dataset`): The dataset to be loaded. |
| 128 | shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. |
| 129 | seed (int, optional): Random worker seed for sampling, defaults to 1024. |
| 130 | add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. |
| 131 | drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size |
| 132 | is not divisible by the batch size. If False and the size of dataset is not divisible by |
| 133 | the batch size, then the last batch will be smaller, defaults to False. |
| 134 | pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. |
| 135 | num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. |
| 136 | kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in |
| 137 | `DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_. |
| 138 | |
| 139 | Returns: |
| 140 | :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. |
| 141 | """ |
| 142 | _kwargs = kwargs.copy() |
| 143 | |
| 144 | if add_sampler and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1: |
| 145 | sampler = DataParallelSampler(dataset, shuffle=shuffle, drop_last=drop_last) |
| 146 | else: |
| 147 | sampler = None |
| 148 | |
| 149 | # Deterministic dataloader |
| 150 | def seed_worker(): |
| 151 | worker_seed = seed |
| 152 | np.random.seed(worker_seed) |
| 153 | torch.manual_seed(worker_seed) |
| 154 | random.seed(worker_seed) |
| 155 | |
| 156 | if sampler is None: |
| 157 | return DataLoader( |
| 158 | dataset, |
| 159 | worker_init_fn=seed_worker, |
| 160 | shuffle=shuffle, |
| 161 | drop_last=drop_last, |
| 162 | pin_memory=pin_memory, |
| 163 | num_workers=num_workers, |
| 164 | **_kwargs, |
| 165 | ) |
| 166 | else: |
| 167 | return DataLoader( |
no test coverage detected