A static batch sampler that generates batches with a fixed micro-batch size. Args: num_samples (int): The total number of samples in the dataset. batch_size (int): The batch size for the current rank. Defaults to 192. rampup_batch_size (str): A string with three spa
| 176 | |
| 177 | |
| 178 | class StaticBatchSampler: |
| 179 | """ |
| 180 | A static batch sampler that generates batches with a fixed micro-batch size. |
| 181 | |
| 182 | Args: |
| 183 | num_samples (int): The total number of samples in the dataset. |
| 184 | batch_size (int): The batch size for the current rank. Defaults to 192. |
| 185 | rampup_batch_size (str): A string with three space-separated integers representing the |
| 186 | starting batch size, the increment, and the number of steps between |
| 187 | each increment. For example, "192 24 8" means that the batch size |
| 188 | starts at 192 and increases by 24 every 8 steps. Defaults to |
| 189 | "6 2 8", which corresponds to a batch size of 2 for the first 6 steps. |
| 190 | micro_bsz (int): The micro-batch size. Defaults to 2. |
| 191 | seed (int): The random seed for shuffling the indices. Defaults to 0. |
| 192 | drop_last (bool): If True, drop the last incomplete batch. Currently only supports True. Defaults to True. |
| 193 | data_rank (int): The rank of the current process in the data parallel group. Defaults to 0. |
| 194 | data_world_size (int): The number of processes in the data parallel group. Defaults to 1. |
| 195 | """ |
| 196 | |
| 197 | def __init__( |
| 198 | self, |
| 199 | datasets, |
| 200 | batch_size=192, |
| 201 | rampup_batch_size="6 2 8", |
| 202 | micro_bsz=2, |
| 203 | seed=0, |
| 204 | drop_last=True, |
| 205 | data_rank=0, |
| 206 | data_world_size=1, |
| 207 | ): |
| 208 | assert drop_last is True, "Currently only support drop last" |
| 209 | if rampup_batch_size: |
| 210 | # In the process increase to batch_size |
| 211 | start_bsz, bsz_incre, incre_every = map(int, rampup_batch_size.split()) |
| 212 | else: |
| 213 | start_bsz, bsz_incre, incre_every = batch_size, batch_size, 1 |
| 214 | self.raw_rampup_batch_size = rampup_batch_size |
| 215 | self.start_bsz = start_bsz |
| 216 | self.bsz_incre = bsz_incre |
| 217 | self.incre_every = incre_every |
| 218 | if gpc.is_initialized(ParallelMode.PIPELINE): |
| 219 | assert ( |
| 220 | batch_size - self.start_bsz |
| 221 | ) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}" |
| 222 | assert batch_size % micro_bsz == 0, f"batch_size({batch_size}) should be multiple of micro_bsz({micro_bsz})" |
| 223 | assert ( |
| 224 | self.start_bsz % micro_bsz == 0 |
| 225 | ), f"start_bsz({self.start_bsz}) should be multiple of micro_bsz({micro_bsz})" |
| 226 | assert ( |
| 227 | self.bsz_incre % micro_bsz == 0 |
| 228 | ), f"bsz_incre({self.bsz_incre}) should be multiple of micro_bsz({micro_bsz})" |
| 229 | |
| 230 | self.batch_size = batch_size |
| 231 | self.epoch = 0 |
| 232 | self.seed = seed |
| 233 | self.rng = np.random.RandomState(seed) |
| 234 | self.batch_count = 0 |
| 235 | self.micro_bsz = micro_bsz |
no outgoing calls
no test coverage detected