(rate, params)
| 231 | |
| 232 | def save(self): |
| 233 | def save_checkpoint(rate, params): |
| 234 | state_dict = self.mp_trainer.master_params_to_state_dict(params) |
| 235 | if dist.get_rank() == 0: |
| 236 | logger.log(f"saving model {rate}...") |
| 237 | if not rate: |
| 238 | filename = f"model{(self.step+self.resume_step):06d}.pt" |
| 239 | else: |
| 240 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" |
| 241 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: |
| 242 | th.save(state_dict, f) |
| 243 | |
| 244 | save_checkpoint(0, self.mp_trainer.master_params) |
| 245 | for rate, params in zip(self.ema_rate, self.ema_params): |
nothing calls this directly
no test coverage detected