(self, states)
| 330 | return states |
| 331 | |
| 332 | def load_state_dict(self, states): |
| 333 | for name in ("data_world_size", "raw_rampup_batch_size", "seed"): # 'batch_size' |
| 334 | assert states[name] == getattr(self, name), (name, states[name], getattr(self, name)) # should not change |
| 335 | self.rng.set_state(states["rng_state"]) |
| 336 | self.get_indices(old_indices=None) # Regenerate indices based on random state |
| 337 | self.epoch = states["epoch"] |
| 338 | self.batch_count = states["batch_count"] |
| 339 | self.num_consumed_samples_in_epoch = states["num_consumed_samples_in_epoch"] |
| 340 | |
| 341 | def copy(self): |
| 342 | copy_sampler = StaticBatchSampler( |
no test coverage detected