(self, rate)
| 123 | dist_util.sync_params(self.model.parameters()) |
| 124 | |
| 125 | def _load_ema_parameters(self, rate): |
| 126 | ema_params = copy.deepcopy(self.mp_trainer.master_params) |
| 127 | |
| 128 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint |
| 129 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) |
| 130 | if ema_checkpoint: |
| 131 | if dist.get_rank() == 0: |
| 132 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") |
| 133 | state_dict = dist_util.load_state_dict( |
| 134 | ema_checkpoint, map_location=dist_util.dev() |
| 135 | ) |
| 136 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) |
| 137 | |
| 138 | dist_util.sync_params(ema_params) |
| 139 | return ema_params |
| 140 | |
| 141 | def _load_optimizer_state(self): |
| 142 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint |
no test coverage detected