MCPcopy
hub / github.com/InternLM/InternLM / CheckpointManager

Class CheckpointManager

internlm/utils/model_checkpoint.py:270–633  ·  view source on GitHub ↗

StorageManagerContext

Source from the content-addressed store, hash-verified

268
269
270class CheckpointManager:
271 """StorageManagerContext"""
272
273 def __init__(self, ckpt_config, model, model_config=None, model_config_file=None, feishu_address=None) -> None:
274 """
275 CheckpointManager is used to decide when to store ckpt. If it is an asynchronous
276 upload mode, you must call wait_async_upload_finish at the end of the program to wait
277 for the asynchronous ckpt upload to complete.
278
279 Args:
280 ckpt_config (dict): model checkpoint config.
281 model (nn.module): model obj
282 optimizer (object): optimzier obj.
283 lr_scheduler (object): lr_scheduler obj.
284 model_config (dict): model config.
285 """
286 self.enable_save_ckpt = ckpt_config.enable_save_ckpt
287 self.checkpoint_every = ckpt_config.checkpoint_every
288 self.save_ckpt_folder = ckpt_config.save_ckpt_folder
289 self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder
290 self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq
291 self.stop_file_path = ckpt_config.stop_file_path
292 self.load_model_only_folder = ckpt_config.load_model_only_folder
293 self.feishu_address = feishu_address
294 self.storage_manager = get_storage_manager()
295 self.snapshot_counter = 0
296 self.load_optimizer = gpc.config.ckpt.load_optimizer
297
298 self.model = model
299 self.model_config = model_config
300 self.model_config_file = model_config_file
301
302 if self.stop_file_path and gpc.get_global_rank() == 0:
303 dir_path = os.path.dirname(self.stop_file_path)
304 if dir_path != "" and not os.path.exists(dir_path):
305 os.makedirs(dir_path)
306 with open(self.stop_file_path, "w", encoding="utf-8") as f:
307 f.write("0")
308
309 if ckpt_config.load_given_ckpt is False:
310 # Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
311 latest_ckpt_path = self.query_lastest_ckpt()
312 if latest_ckpt_path:
313 self.load_ckpt_folder = latest_ckpt_path
314 else:
315 # At this time, we have to load model init weights and train from step 0.
316 self.load_ckpt_folder = self.load_model_only_folder
317 else:
318 self.load_ckpt_folder = ckpt_config.load_ckpt_folder
319
320 if gpc.is_rank_for_log():
321 logger.info(f"load_ckpt_folder will set to :'{self.load_ckpt_folder}'")
322 if self.stop_file_path is None:
323 logger.warning("no set stop_file_path, quit_signal_handler is disable")
324
325 def quit_signal_handler(self, train_state) -> bool:
326 """
327 Exit signal detection function, if we write the exit step in the 'QUIT_FILE_PATH' file,

Callers 1

mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected