MCPcopy Index your code
hub / github.com/InternLM/InternLM / get_dpsampler_dataloader

Function get_dpsampler_dataloader

internlm/data/batch_sampler.py:110–175  ·  view source on GitHub ↗

r"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not) Note: When pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data on the 1st stage and label on the last stage. Args:

(
    dataset,
    shuffle=False,
    seed=1024,
    add_sampler=True,
    drop_last=False,
    pin_memory=False,
    num_workers=0,
    **kwargs,
)

Source from the content-addressed store, hash-verified

108
109
110def get_dpsampler_dataloader(
111 dataset,
112 shuffle=False,
113 seed=1024,
114 add_sampler=True,
115 drop_last=False,
116 pin_memory=False,
117 num_workers=0,
118 **kwargs,
119):
120 r"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
121
122 Note:
123 When pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data
124 on the 1st stage and label on the last stage.
125
126 Args:
127 dataset (:class:`torch.utils.data.Dataset`): The dataset to be loaded.
128 shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
129 seed (int, optional): Random worker seed for sampling, defaults to 1024.
130 add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
131 drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
132 is not divisible by the batch size. If False and the size of dataset is not divisible by
133 the batch size, then the last batch will be smaller, defaults to False.
134 pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
135 num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
136 kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
137 `DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
138
139 Returns:
140 :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
141 """
142 _kwargs = kwargs.copy()
143
144 if add_sampler and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
145 sampler = DataParallelSampler(dataset, shuffle=shuffle, drop_last=drop_last)
146 else:
147 sampler = None
148
149 # Deterministic dataloader
150 def seed_worker():
151 worker_seed = seed
152 np.random.seed(worker_seed)
153 torch.manual_seed(worker_seed)
154 random.seed(worker_seed)
155
156 if sampler is None:
157 return DataLoader(
158 dataset,
159 worker_init_fn=seed_worker,
160 shuffle=shuffle,
161 drop_last=drop_last,
162 pin_memory=pin_memory,
163 num_workers=num_workers,
164 **_kwargs,
165 )
166 else:
167 return DataLoader(

Callers 1

Calls 4

DataParallelSamplerClass · 0.85
copyMethod · 0.80
is_initializedMethod · 0.80
get_world_sizeMethod · 0.80

Tested by

no test coverage detected