| 314 | self.get_indices() # get a new round |
| 315 | |
| 316 | def state_dict(self): |
| 317 | states = { |
| 318 | "batch_size": self.batch_size, |
| 319 | "raw_rampup_batch_size": self.raw_rampup_batch_size, |
| 320 | "rng_state": self.rng_state, |
| 321 | "epoch": self.epoch, |
| 322 | "seed": self.seed, |
| 323 | "data_world_size": self.data_world_size, |
| 324 | "num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch, |
| 325 | "batch_count": self.batch_count, # The batch_count here is due to the existence of multiple processes, |
| 326 | # the batch may be oversent, and it needs to be overwritten by the external batch_count |
| 327 | "indices": self.indices, # The sequence used to breakpoint retraining is the same as before |
| 328 | } |
| 329 | |
| 330 | return states |
| 331 | |
| 332 | def load_state_dict(self, states): |
| 333 | for name in ("data_world_size", "raw_rampup_batch_size", "seed"): # 'batch_size' |