| 762 | return model |
| 763 | |
| 764 | def load(self, weights, from_pruned=False): |
| 765 | required_names = set() |
| 766 | for name, param in self.named_parameters(): |
| 767 | if param.is_inited(): |
| 768 | continue |
| 769 | if name not in weights: |
| 770 | # Exemption for embedding sharing |
| 771 | if name.endswith('lm_head.weight') and any( |
| 772 | k.endswith('vocab_embedding.weight') |
| 773 | for k in weights.keys()): |
| 774 | continue |
| 775 | if name.endswith('lm_head.per_channel_scale') and any( |
| 776 | k.endswith('vocab_embedding.per_channel_scale') |
| 777 | for k in weights.keys()): |
| 778 | continue |
| 779 | required_names.add(name) |
| 780 | |
| 781 | provided_names = set(weights.keys()) |
| 782 | |
| 783 | if not required_names.issubset(provided_names): |
| 784 | raise RuntimeError( |
| 785 | f"Required but not provided tensors:{required_names.difference(provided_names)}" |
| 786 | ) |
| 787 | if not provided_names.issubset(required_names): |
| 788 | logger.warning( |
| 789 | f"Provided but not required tensors: {provided_names.difference(required_names)}" |
| 790 | ) |
| 791 | |
| 792 | for name, param in self.named_parameters(): |
| 793 | if name in provided_names: |
| 794 | if not from_pruned: |
| 795 | try: |
| 796 | param.value = weights[name] |
| 797 | except Exception as e: |
| 798 | raise RuntimeError( |
| 799 | f"Encounter error '{e}' for parameter '{name}'") |
| 800 | else: |
| 801 | param.set_value_or_dummy(weights[name]) |
| 802 | |
| 803 | def save_checkpoint(self, output_dir, save_config=True): |
| 804 | # multiple ranks could share same config.json, so adding a save_config parameter to let user avoiding writing config.json in all ranks |