(self)
| 339 | self.num_consumed_samples_in_epoch = states["num_consumed_samples_in_epoch"] |
| 340 | |
| 341 | def copy(self): |
| 342 | copy_sampler = StaticBatchSampler( |
| 343 | self.datasets, |
| 344 | self.batch_size, |
| 345 | self.raw_rampup_batch_size, |
| 346 | self.micro_bsz, |
| 347 | self.seed, |
| 348 | drop_last=True, |
| 349 | data_rank=self.data_rank, |
| 350 | data_world_size=self.data_world_size, |
| 351 | ) |
| 352 | |
| 353 | copy_sampler.load_state_dict(self.state_dict()) |
| 354 | return copy_sampler |
no test coverage detected