MCPcopy
hub / github.com/microsoft/Swin-Transformer / save_checkpoint

Function save_checkpoint

utils_moe.py:175–219  ·  view source on GitHub ↗
(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger,
                    zero_redundancy=False)

Source from the content-addressed store, hash-verified

173
174
175def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger,
176 zero_redundancy=False):
177 global_rank = dist.get_rank()
178
179 if zero_redundancy:
180 if config.TRAIN.MOE.SAVE_MASTER:
181 save_state = {'model': model.state_dict()}
182 if global_rank == 0:
183 save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.global')
184 logger.info(f"{save_path} saving......")
185 torch.save(save_state, save_path)
186 logger.info(f"{save_path} saved !!!")
187 else:
188 moe_model_state_dict, non_moe_model_state_dict = \
189 split_moe_model_state_dict(model._ddp_params_and_buffers_to_ignore, model.state_dict())
190 save_state = {'model': moe_model_state_dict}
191 save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.rank{global_rank}')
192 logger.info(f"{save_path} saving......")
193 torch.save(save_state, save_path)
194 logger.info(f"{save_path} saved !!!")
195 if global_rank == 0:
196 save_state_master = {'model': non_moe_model_state_dict}
197 save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.master')
198 logger.info(f"{save_path} saving......")
199 torch.save(save_state_master, save_path)
200 logger.info(f"{save_path} saved !!!")
201 else:
202 save_state = {'model': model.state_dict(),
203 'optimizer': optimizer.state_dict(),
204 'lr_scheduler': lr_scheduler.state_dict(),
205 'max_accuracy': max_accuracy,
206 'scaler': loss_scaler.state_dict(),
207 'epoch': epoch,
208 'config': config}
209 if config.TRAIN.MOE.SAVE_MASTER:
210 if global_rank == 0:
211 save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.global')
212 logger.info(f"{save_path} saving......")
213 torch.save(save_state, save_path)
214 logger.info(f"{save_path} saved !!!")
215 else:
216 save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.rank{global_rank}')
217 logger.info(f"{save_path} saving......")
218 torch.save(save_state, save_path)
219 logger.info(f"{save_path} saved !!!")
220
221
222def auto_resume_helper(output_dir, save_master=False):

Callers 1

mainFunction · 0.90

Calls 2

state_dictMethod · 0.80

Tested by

no test coverage detected