Load a checkpoint. Should be called by all processes.
(self, load_dir, step=0)
| 340 | master_params[i].data.copy_(param.data) |
| 341 | |
| 342 | def load(self, load_dir, step=0): |
| 343 | """ |
| 344 | Load a checkpoint. |
| 345 | Should be called by all processes. |
| 346 | """ |
| 347 | if self.is_master: |
| 348 | print(f'\nLoading checkpoint from step {step}...', end='') |
| 349 | |
| 350 | model_ckpts = {} |
| 351 | for name, model in self.models.items(): |
| 352 | model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True) |
| 353 | model_ckpts[name] = model_ckpt |
| 354 | model.load_state_dict(model_ckpt) |
| 355 | self._state_dicts_to_master_params(self.master_params, model_ckpts) |
| 356 | del model_ckpts |
| 357 | |
| 358 | if self.is_master: |
| 359 | for i, ema_rate in enumerate(self.ema_rate): |
| 360 | ema_ckpts = {} |
| 361 | for name, model in self.models.items(): |
| 362 | ema_ckpt = torch.load(os.path.join(load_dir, 'ckpts', f'{name}_ema{ema_rate}_step{step:07d}.pt'), map_location=self.device, weights_only=True) |
| 363 | ema_ckpts[name] = ema_ckpt |
| 364 | self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts) |
| 365 | del ema_ckpts |
| 366 | |
| 367 | misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False) |
| 368 | self.optimizer.load_state_dict(misc_ckpt['optimizer']) |
| 369 | self.step = misc_ckpt['step'] |
| 370 | self.data_sampler.load_state_dict(misc_ckpt['data_sampler']) |
| 371 | if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: |
| 372 | self.scaler.load_state_dict(misc_ckpt['scaler']) |
| 373 | elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16: |
| 374 | self.log_scale = misc_ckpt['log_scale'] |
| 375 | if self.lr_scheduler_config is not None: |
| 376 | self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler']) |
| 377 | if self.elastic_controller_config is not None: |
| 378 | self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller']) |
| 379 | if self.grad_clip is not None and not isinstance(self.grad_clip, float): |
| 380 | self.grad_clip.load_state_dict(misc_ckpt['grad_clip']) |
| 381 | del misc_ckpt |
| 382 | |
| 383 | if self.world_size > 1: |
| 384 | dist.barrier() |
| 385 | if self.is_master: |
| 386 | print(' Done.') |
| 387 | |
| 388 | if self.world_size > 1: |
| 389 | self.check_ddp() |
| 390 | |
| 391 | def save(self, non_blocking=True): |
| 392 | """ |
no test coverage detected