(self)
| 269 | logger.logkv("lg_loss_scale", self.lg_loss_scale) |
| 270 | |
| 271 | def save(self): |
| 272 | def save_checkpoint(rate, params): |
| 273 | state_dict = self._master_params_to_state_dict(params) |
| 274 | if dist.get_rank() == 0: |
| 275 | logger.log(f"saving model {rate}...") |
| 276 | if not rate: |
| 277 | filename = f"model{(self.step+self.resume_step):06d}.pt" |
| 278 | else: |
| 279 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" |
| 280 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: |
| 281 | th.save(state_dict, f) |
| 282 | |
| 283 | save_checkpoint(0, self.master_params) |
| 284 | for rate, params in zip(self.ema_rate, self.ema_params): |
| 285 | save_checkpoint(rate, params) |
| 286 | |
| 287 | if dist.get_rank() == 0: |
| 288 | with bf.BlobFile( |
| 289 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), |
| 290 | "wb", |
| 291 | ) as f: |
| 292 | th.save(self.opt.state_dict(), f) |
| 293 | |
| 294 | dist.barrier() |
| 295 | |
| 296 | def _master_params_to_state_dict(self, master_params): |
| 297 | if self.use_fp16: |
no test coverage detected