| 724 | |
| 725 | @classmethod |
| 726 | def from_checkpoint( |
| 727 | cls, |
| 728 | ckpt_dir: str, |
| 729 | rank: Optional[int] = None, |
| 730 | config: Optional[PretrainedConfig] = None, |
| 731 | *, |
| 732 | preprocess_weights_hook: Optional[Callable[[Dict[str, Tensor]], |
| 733 | Dict[str, Tensor]]] = None): |
| 734 | if config is None: |
| 735 | config = PretrainedConfig.from_json_file( |
| 736 | os.path.join(ckpt_dir, 'config.json')) |
| 737 | |
| 738 | if rank is not None: |
| 739 | config.set_rank(rank) |
| 740 | |
| 741 | rank = config.mapping.rank |
| 742 | if config.mapping.cp_size > 1: |
| 743 | # cp_tp_pp rank -> tp_pp rank: because different cp ranks share the same ckpt. |
| 744 | cp_size = config.mapping.cp_size |
| 745 | # rank = pp_rank × tp_size × cp_size + tp_rank × cp_size + cp_rank. |
| 746 | # rank // cp_size is equivalent to pp_rank × tp_size + tp_rank. |
| 747 | rank = rank // cp_size |
| 748 | weights_path = os.path.join(ckpt_dir, f'rank{rank}.safetensors') |
| 749 | |
| 750 | assert os.path.isfile(weights_path) |
| 751 | weights = safetensors.torch.load_file(weights_path) |
| 752 | is_checkpoint_pruned = getattr(config, 'is_pruned', False) |
| 753 | |
| 754 | if preprocess_weights_hook is not None: |
| 755 | weights = preprocess_weights_hook(weights) |
| 756 | |
| 757 | weights = preprocess_weights(weights, |
| 758 | config, |
| 759 | from_pruned=is_checkpoint_pruned) |
| 760 | model = cls(config) |
| 761 | model.load(weights, from_pruned=is_checkpoint_pruned) |
| 762 | return model |
| 763 | |
| 764 | def load(self, weights, from_pruned=False): |
| 765 | required_names = set() |