(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger,
zero_redundancy=False)
| 173 | |
| 174 | |
| 175 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger, |
| 176 | zero_redundancy=False): |
| 177 | global_rank = dist.get_rank() |
| 178 | |
| 179 | if zero_redundancy: |
| 180 | if config.TRAIN.MOE.SAVE_MASTER: |
| 181 | save_state = {'model': model.state_dict()} |
| 182 | if global_rank == 0: |
| 183 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.global') |
| 184 | logger.info(f"{save_path} saving......") |
| 185 | torch.save(save_state, save_path) |
| 186 | logger.info(f"{save_path} saved !!!") |
| 187 | else: |
| 188 | moe_model_state_dict, non_moe_model_state_dict = \ |
| 189 | split_moe_model_state_dict(model._ddp_params_and_buffers_to_ignore, model.state_dict()) |
| 190 | save_state = {'model': moe_model_state_dict} |
| 191 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.rank{global_rank}') |
| 192 | logger.info(f"{save_path} saving......") |
| 193 | torch.save(save_state, save_path) |
| 194 | logger.info(f"{save_path} saved !!!") |
| 195 | if global_rank == 0: |
| 196 | save_state_master = {'model': non_moe_model_state_dict} |
| 197 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.master') |
| 198 | logger.info(f"{save_path} saving......") |
| 199 | torch.save(save_state_master, save_path) |
| 200 | logger.info(f"{save_path} saved !!!") |
| 201 | else: |
| 202 | save_state = {'model': model.state_dict(), |
| 203 | 'optimizer': optimizer.state_dict(), |
| 204 | 'lr_scheduler': lr_scheduler.state_dict(), |
| 205 | 'max_accuracy': max_accuracy, |
| 206 | 'scaler': loss_scaler.state_dict(), |
| 207 | 'epoch': epoch, |
| 208 | 'config': config} |
| 209 | if config.TRAIN.MOE.SAVE_MASTER: |
| 210 | if global_rank == 0: |
| 211 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.global') |
| 212 | logger.info(f"{save_path} saving......") |
| 213 | torch.save(save_state, save_path) |
| 214 | logger.info(f"{save_path} saved !!!") |
| 215 | else: |
| 216 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.rank{global_rank}') |
| 217 | logger.info(f"{save_path} saving......") |
| 218 | torch.save(save_state, save_path) |
| 219 | logger.info(f"{save_path} saved !!!") |
| 220 | |
| 221 | |
| 222 | def auto_resume_helper(output_dir, save_master=False): |
no test coverage detected