| 660 | self.managed_weights[name] = np.ascontiguousarray(value) |
| 661 | |
| 662 | def save(self, engine_dir: str): |
| 663 | os.makedirs(engine_dir, exist_ok=True) |
| 664 | lora_config = self.config.build_config.lora_config |
| 665 | lora_dirs = lora_config.lora_dir |
| 666 | root_lora_dir = os.path.join(engine_dir, 'lora') |
| 667 | if len(lora_dirs) > 0: |
| 668 | os.makedirs(root_lora_dir, exist_ok=True) |
| 669 | for index, lora_dir in enumerate(lora_dirs): |
| 670 | if lora_config.lora_ckpt_source == 'hf': |
| 671 | target_lora_dir = f"{root_lora_dir}/{index}" |
| 672 | os.makedirs(target_lora_dir, exist_ok=True) |
| 673 | shutil.copy2(os.path.join(lora_dir, 'adapter_config.json'), |
| 674 | target_lora_dir) |
| 675 | weight_file = os.path.join(lora_dir, 'adapter_model.bin') |
| 676 | if os.path.exists(weight_file): |
| 677 | shutil.copy2(weight_file, target_lora_dir) |
| 678 | weight_file = os.path.join(lora_dir, |
| 679 | 'adapter_model.safetensors') |
| 680 | if os.path.exists(weight_file): |
| 681 | shutil.copy2(weight_file, target_lora_dir) |
| 682 | lora_config.lora_dir[index] = f"lora/{index}" |
| 683 | elif lora_config.lora_ckpt_source == 'nemo': |
| 684 | target_lora_file = f"{root_lora_dir}/{index}.nemo" |
| 685 | shutil.copyfile(lora_dir, target_lora_file) |
| 686 | lora_config.lora_dir[index] = f"lora/{index}.nemo" |
| 687 | else: |
| 688 | if os.path.exists(root_lora_dir) and os.path.isdir(root_lora_dir): |
| 689 | shutil.rmtree(root_lora_dir) |
| 690 | if self.config.pretrained_config.mapping.rank == 0: |
| 691 | config_dict = self.config.to_dict() |
| 692 | if self.config.pretrained_config.quant_algo == QuantAlgo.MIXED_PRECISION: |
| 693 | quant_dict = { |
| 694 | 'version': self.config.version, |
| 695 | } |
| 696 | quant_dict.update( |
| 697 | config_dict['pretrained_config']['quantization']) |
| 698 | config_dict['pretrained_config']['quantization'].pop( |
| 699 | 'quantized_layers', None) |
| 700 | with open(os.path.join(engine_dir, 'quant_cfg.json'), |
| 701 | "w", |
| 702 | encoding="utf-8") as f: |
| 703 | json.dump(quant_dict, f, indent=4, cls=ConfigEncoder) |
| 704 | |
| 705 | with open(os.path.join(engine_dir, 'config.json'), |
| 706 | "w", |
| 707 | encoding="utf-8") as f: |
| 708 | json.dump(config_dict, f, indent=4, cls=ConfigEncoder) |
| 709 | if self.engine is not None: |
| 710 | serialize_engine( |
| 711 | self.engine, |
| 712 | os.path.join( |
| 713 | engine_dir, |
| 714 | f'rank{self.config.pretrained_config.mapping.rank}.engine')) |
| 715 | if self.managed_weights is not None and len(self.managed_weights) > 0: |
| 716 | fn = os.path.join( |
| 717 | engine_dir, |
| 718 | f'rank{self.config.pretrained_config.mapping.rank}_managed_weights.safetensors' |
| 719 | ) |