StorageManagerContext
| 268 | |
| 269 | |
| 270 | class CheckpointManager: |
| 271 | """StorageManagerContext""" |
| 272 | |
| 273 | def __init__(self, ckpt_config, model, model_config=None, model_config_file=None, feishu_address=None) -> None: |
| 274 | """ |
| 275 | CheckpointManager is used to decide when to store ckpt. If it is an asynchronous |
| 276 | upload mode, you must call wait_async_upload_finish at the end of the program to wait |
| 277 | for the asynchronous ckpt upload to complete. |
| 278 | |
| 279 | Args: |
| 280 | ckpt_config (dict): model checkpoint config. |
| 281 | model (nn.module): model obj |
| 282 | optimizer (object): optimzier obj. |
| 283 | lr_scheduler (object): lr_scheduler obj. |
| 284 | model_config (dict): model config. |
| 285 | """ |
| 286 | self.enable_save_ckpt = ckpt_config.enable_save_ckpt |
| 287 | self.checkpoint_every = ckpt_config.checkpoint_every |
| 288 | self.save_ckpt_folder = ckpt_config.save_ckpt_folder |
| 289 | self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder |
| 290 | self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq |
| 291 | self.stop_file_path = ckpt_config.stop_file_path |
| 292 | self.load_model_only_folder = ckpt_config.load_model_only_folder |
| 293 | self.feishu_address = feishu_address |
| 294 | self.storage_manager = get_storage_manager() |
| 295 | self.snapshot_counter = 0 |
| 296 | self.load_optimizer = gpc.config.ckpt.load_optimizer |
| 297 | |
| 298 | self.model = model |
| 299 | self.model_config = model_config |
| 300 | self.model_config_file = model_config_file |
| 301 | |
| 302 | if self.stop_file_path and gpc.get_global_rank() == 0: |
| 303 | dir_path = os.path.dirname(self.stop_file_path) |
| 304 | if dir_path != "" and not os.path.exists(dir_path): |
| 305 | os.makedirs(dir_path) |
| 306 | with open(self.stop_file_path, "w", encoding="utf-8") as f: |
| 307 | f.write("0") |
| 308 | |
| 309 | if ckpt_config.load_given_ckpt is False: |
| 310 | # Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder |
| 311 | latest_ckpt_path = self.query_lastest_ckpt() |
| 312 | if latest_ckpt_path: |
| 313 | self.load_ckpt_folder = latest_ckpt_path |
| 314 | else: |
| 315 | # At this time, we have to load model init weights and train from step 0. |
| 316 | self.load_ckpt_folder = self.load_model_only_folder |
| 317 | else: |
| 318 | self.load_ckpt_folder = ckpt_config.load_ckpt_folder |
| 319 | |
| 320 | if gpc.is_rank_for_log(): |
| 321 | logger.info(f"load_ckpt_folder will set to :'{self.load_ckpt_folder}'") |
| 322 | if self.stop_file_path is None: |
| 323 | logger.warning("no set stop_file_path, quit_signal_handler is disable") |
| 324 | |
| 325 | def quit_signal_handler(self, train_state) -> bool: |
| 326 | """ |
| 327 | Exit signal detection function, if we write the exit step in the 'QUIT_FILE_PATH' file, |