Save a checkpoint. Should be called only by the rank 0 process.
(self, non_blocking=True)
| 389 | self.check_ddp() |
| 390 | |
| 391 | def save(self, non_blocking=True): |
| 392 | """ |
| 393 | Save a checkpoint. |
| 394 | Should be called only by the rank 0 process. |
| 395 | """ |
| 396 | assert self.is_master, 'save() should be called only by the rank 0 process.' |
| 397 | print(f'\nSaving checkpoint at step {self.step}...', end='') |
| 398 | |
| 399 | model_ckpts = self._master_params_to_state_dicts(self.master_params) |
| 400 | for name, model_ckpt in model_ckpts.items(): |
| 401 | model_ckpt = {k: v.cpu() for k, v in model_ckpt.items()} # Move to CPU for saving |
| 402 | if non_blocking: |
| 403 | threading.Thread( |
| 404 | target=torch.save, |
| 405 | args=(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt')), |
| 406 | ).start() |
| 407 | else: |
| 408 | torch.save(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt')) |
| 409 | |
| 410 | for i, ema_rate in enumerate(self.ema_rate): |
| 411 | ema_ckpts = self._master_params_to_state_dicts(self.ema_params[i]) |
| 412 | for name, ema_ckpt in ema_ckpts.items(): |
| 413 | ema_ckpt = {k: v.cpu() for k, v in ema_ckpt.items()} # Move to CPU for saving |
| 414 | if non_blocking: |
| 415 | threading.Thread( |
| 416 | target=torch.save, |
| 417 | args=(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt')), |
| 418 | ).start() |
| 419 | else: |
| 420 | torch.save(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt')) |
| 421 | |
| 422 | misc_ckpt = { |
| 423 | 'optimizer': self.optimizer.state_dict(), |
| 424 | 'step': self.step, |
| 425 | 'data_sampler': self.data_sampler.state_dict(), |
| 426 | } |
| 427 | if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: |
| 428 | misc_ckpt['scaler'] = self.scaler.state_dict() |
| 429 | elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16: |
| 430 | misc_ckpt['log_scale'] = self.log_scale |
| 431 | if self.lr_scheduler_config is not None: |
| 432 | misc_ckpt['lr_scheduler'] = self.lr_scheduler.state_dict() |
| 433 | if self.elastic_controller_config is not None: |
| 434 | misc_ckpt['elastic_controller'] = self.elastic_controller.state_dict() |
| 435 | if self.grad_clip is not None and not isinstance(self.grad_clip, float): |
| 436 | misc_ckpt['grad_clip'] = self.grad_clip.state_dict() |
| 437 | if non_blocking: |
| 438 | threading.Thread( |
| 439 | target=torch.save, |
| 440 | args=(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt')), |
| 441 | ).start() |
| 442 | else: |
| 443 | torch.save(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt')) |
| 444 | print(' Done.') |
| 445 | |
| 446 | def _remap_checkpoint_keys(self, model_ckpt, model_state_dict): |
| 447 | """ |
no test coverage detected