MCPcopy Index your code
hub / github.com/InternLM/InternLM / StaticBatchSampler

Class StaticBatchSampler

internlm/data/batch_sampler.py:178–354  ·  view source on GitHub ↗

A static batch sampler that generates batches with a fixed micro-batch size. Args: num_samples (int): The total number of samples in the dataset. batch_size (int): The batch size for the current rank. Defaults to 192. rampup_batch_size (str): A string with three spa

Source from the content-addressed store, hash-verified

176
177
178class StaticBatchSampler:
179 """
180 A static batch sampler that generates batches with a fixed micro-batch size.
181
182 Args:
183 num_samples (int): The total number of samples in the dataset.
184 batch_size (int): The batch size for the current rank. Defaults to 192.
185 rampup_batch_size (str): A string with three space-separated integers representing the
186 starting batch size, the increment, and the number of steps between
187 each increment. For example, "192 24 8" means that the batch size
188 starts at 192 and increases by 24 every 8 steps. Defaults to
189 "6 2 8", which corresponds to a batch size of 2 for the first 6 steps.
190 micro_bsz (int): The micro-batch size. Defaults to 2.
191 seed (int): The random seed for shuffling the indices. Defaults to 0.
192 drop_last (bool): If True, drop the last incomplete batch. Currently only supports True. Defaults to True.
193 data_rank (int): The rank of the current process in the data parallel group. Defaults to 0.
194 data_world_size (int): The number of processes in the data parallel group. Defaults to 1.
195 """
196
197 def __init__(
198 self,
199 datasets,
200 batch_size=192,
201 rampup_batch_size="6 2 8",
202 micro_bsz=2,
203 seed=0,
204 drop_last=True,
205 data_rank=0,
206 data_world_size=1,
207 ):
208 assert drop_last is True, "Currently only support drop last"
209 if rampup_batch_size:
210 # In the process increase to batch_size
211 start_bsz, bsz_incre, incre_every = map(int, rampup_batch_size.split())
212 else:
213 start_bsz, bsz_incre, incre_every = batch_size, batch_size, 1
214 self.raw_rampup_batch_size = rampup_batch_size
215 self.start_bsz = start_bsz
216 self.bsz_incre = bsz_incre
217 self.incre_every = incre_every
218 if gpc.is_initialized(ParallelMode.PIPELINE):
219 assert (
220 batch_size - self.start_bsz
221 ) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}"
222 assert batch_size % micro_bsz == 0, f"batch_size({batch_size}) should be multiple of micro_bsz({micro_bsz})"
223 assert (
224 self.start_bsz % micro_bsz == 0
225 ), f"start_bsz({self.start_bsz}) should be multiple of micro_bsz({micro_bsz})"
226 assert (
227 self.bsz_incre % micro_bsz == 0
228 ), f"bsz_incre({self.bsz_incre}) should be multiple of micro_bsz({micro_bsz})"
229
230 self.batch_size = batch_size
231 self.epoch = 0
232 self.seed = seed
233 self.rng = np.random.RandomState(seed)
234 self.batch_count = 0
235 self.micro_bsz = micro_bsz

Callers 2

get_train_data_loaderFunction · 0.90
copyMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected