| 801 | param.set_value_or_dummy(weights[name]) |
| 802 | |
| 803 | def save_checkpoint(self, output_dir, save_config=True): |
| 804 | # multiple ranks could share same config.json, so adding a save_config parameter to let user avoiding writing config.json in all ranks |
| 805 | rank = self.config.mapping.rank |
| 806 | weights = { |
| 807 | name: numpy_to_torch(param.raw_value) |
| 808 | for name, param in self.named_parameters() |
| 809 | } |
| 810 | # If there are some tensors share memory, this will lead to error when we call "save_file". So, for repeated tensors, we |
| 811 | # clone the tensors to prevent this issue. |
| 812 | data_ptrs = set() |
| 813 | for name, param in weights.items(): |
| 814 | if param.data_ptr() in data_ptrs: |
| 815 | weights[name] = param.clone() |
| 816 | data_ptrs.add(weights[name].data_ptr()) |
| 817 | safetensors.torch.save_file( |
| 818 | weights, os.path.join(output_dir, f'rank{rank}.safetensors')) |
| 819 | if save_config: |
| 820 | self.config.to_json_file(os.path.join(output_dir, 'config.json')) |
| 821 | |
| 822 | def prepare_inputs( |
| 823 | self, |