Load a checkpoint. Should be called by all processes.
(self, load_dir, step=0)
| 317 | master_params[i].data.copy_(param.data) |
| 318 | |
| 319 | def load(self, load_dir, step=0): |
| 320 | """ |
| 321 | Load a checkpoint. |
| 322 | Should be called by all processes. |
| 323 | """ |
| 324 | if self.is_master: |
| 325 | print(f'\nLoading checkpoint from step {step}...', end='') |
| 326 | |
| 327 | model_ckpts = {} |
| 328 | for name, model in self.models.items(): |
| 329 | model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True) |
| 330 | model_ckpts[name] = model_ckpt |
| 331 | model.load_state_dict(model_ckpt) |
| 332 | self._state_dicts_to_master_params(self.master_params, model_ckpts) |
| 333 | del model_ckpts |
| 334 | |
| 335 | if self.is_master: |
| 336 | for i, ema_rate in enumerate(self.ema_rate): |
| 337 | ema_ckpts = {} |
| 338 | for name, model in self.models.items(): |
| 339 | ema_ckpt = torch.load(os.path.join(load_dir, 'ckpts', f'{name}_ema{ema_rate}_step{step:07d}.pt'), map_location=self.device, weights_only=True) |
| 340 | ema_ckpts[name] = ema_ckpt |
| 341 | self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts) |
| 342 | del ema_ckpts |
| 343 | |
| 344 | misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False) |
| 345 | self.optimizer.load_state_dict(misc_ckpt['optimizer']) |
| 346 | self.step = misc_ckpt['step'] |
| 347 | self.data_sampler.load_state_dict(misc_ckpt['data_sampler']) |
| 348 | if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: |
| 349 | self.scaler.load_state_dict(misc_ckpt['scaler']) |
| 350 | elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16: |
| 351 | self.log_scale = misc_ckpt['log_scale'] |
| 352 | if self.lr_scheduler_config is not None: |
| 353 | self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler']) |
| 354 | if self.elastic_controller_config is not None: |
| 355 | self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller']) |
| 356 | if self.grad_clip is not None and not isinstance(self.grad_clip, float): |
| 357 | self.grad_clip.load_state_dict(misc_ckpt['grad_clip']) |
| 358 | del misc_ckpt |
| 359 | |
| 360 | if self.world_size > 1: |
| 361 | dist.barrier() |
| 362 | if self.is_master: |
| 363 | print(' Done.') |
| 364 | |
| 365 | if self.world_size > 1: |
| 366 | self.check_ddp() |
| 367 | |
| 368 | def save(self, non_blocking=True): |
| 369 | """ |
no test coverage detected