| 344 | |
| 345 | |
| 346 | def load( |
| 347 | booster: Booster, |
| 348 | load_dir: str, |
| 349 | model: nn.Module = None, |
| 350 | ema: nn.Module = None, |
| 351 | optimizer: Optimizer = None, |
| 352 | lr_scheduler: _LRScheduler = None, |
| 353 | sampler=None, |
| 354 | is_lora_train: bool = False, |
| 355 | is_load_both_lora_and_main: bool = False, |
| 356 | ) -> Tuple[int, int, int]: |
| 357 | assert os.path.exists(load_dir), f"Checkpoint directory {load_dir} does not exist" |
| 358 | assert os.path.exists(os.path.join(load_dir, "running_states.json")), "running_states.json does not exist" |
| 359 | running_states = load_json(os.path.join(load_dir, "running_states.json")) |
| 360 | if model is not None: |
| 361 | if is_lora_train: |
| 362 | if is_load_both_lora_and_main: |
| 363 | booster.load_model(model, os.path.join(load_dir, "lora", "adapter_model.bin"), strict=False) |
| 364 | booster.load_model(model, os.path.join(load_dir, "model"), strict=False) |
| 365 | else: |
| 366 | booster.load_model(model, os.path.join(load_dir, "lora", "adapter_model.bin"), strict=False) |
| 367 | else: |
| 368 | booster.load_model(model, os.path.join(load_dir, "model"), strict=False) |
| 369 | if ema is not None: |
| 370 | # ema is not boosted, so we don't use booster.load_model |
| 371 | ema.load_state_dict( |
| 372 | torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")), |
| 373 | strict=False, |
| 374 | ) |
| 375 | if optimizer is not None: |
| 376 | booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) |
| 377 | if lr_scheduler is not None: |
| 378 | booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) |
| 379 | if sampler is not None: |
| 380 | sampler.load_state_dict(torch.load(os.path.join(load_dir, "sampler"))) |
| 381 | dist.barrier() |
| 382 | |
| 383 | return ( |
| 384 | running_states["epoch"], |
| 385 | running_states["step"], |
| 386 | ) |
| 387 | |
| 388 | |
| 389 | def rm_checkpoints( |