MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / save

Method save

tensorrt_llm/builder.py:662–720  ·  view source on GitHub ↗
(self, engine_dir: str)

Source from the content-addressed store, hash-verified

660 self.managed_weights[name] = np.ascontiguousarray(value)
661
662 def save(self, engine_dir: str):
663 os.makedirs(engine_dir, exist_ok=True)
664 lora_config = self.config.build_config.lora_config
665 lora_dirs = lora_config.lora_dir
666 root_lora_dir = os.path.join(engine_dir, 'lora')
667 if len(lora_dirs) > 0:
668 os.makedirs(root_lora_dir, exist_ok=True)
669 for index, lora_dir in enumerate(lora_dirs):
670 if lora_config.lora_ckpt_source == 'hf':
671 target_lora_dir = f"{root_lora_dir}/{index}"
672 os.makedirs(target_lora_dir, exist_ok=True)
673 shutil.copy2(os.path.join(lora_dir, 'adapter_config.json'),
674 target_lora_dir)
675 weight_file = os.path.join(lora_dir, 'adapter_model.bin')
676 if os.path.exists(weight_file):
677 shutil.copy2(weight_file, target_lora_dir)
678 weight_file = os.path.join(lora_dir,
679 'adapter_model.safetensors')
680 if os.path.exists(weight_file):
681 shutil.copy2(weight_file, target_lora_dir)
682 lora_config.lora_dir[index] = f"lora/{index}"
683 elif lora_config.lora_ckpt_source == 'nemo':
684 target_lora_file = f"{root_lora_dir}/{index}.nemo"
685 shutil.copyfile(lora_dir, target_lora_file)
686 lora_config.lora_dir[index] = f"lora/{index}.nemo"
687 else:
688 if os.path.exists(root_lora_dir) and os.path.isdir(root_lora_dir):
689 shutil.rmtree(root_lora_dir)
690 if self.config.pretrained_config.mapping.rank == 0:
691 config_dict = self.config.to_dict()
692 if self.config.pretrained_config.quant_algo == QuantAlgo.MIXED_PRECISION:
693 quant_dict = {
694 'version': self.config.version,
695 }
696 quant_dict.update(
697 config_dict['pretrained_config']['quantization'])
698 config_dict['pretrained_config']['quantization'].pop(
699 'quantized_layers', None)
700 with open(os.path.join(engine_dir, 'quant_cfg.json'),
701 "w",
702 encoding="utf-8") as f:
703 json.dump(quant_dict, f, indent=4, cls=ConfigEncoder)
704
705 with open(os.path.join(engine_dir, 'config.json'),
706 "w",
707 encoding="utf-8") as f:
708 json.dump(config_dict, f, indent=4, cls=ConfigEncoder)
709 if self.engine is not None:
710 serialize_engine(
711 self.engine,
712 os.path.join(
713 engine_dir,
714 f'rank{self.config.pretrained_config.mapping.rank}.engine'))
715 if self.managed_weights is not None and len(self.managed_weights) > 0:
716 fn = os.path.join(
717 engine_dir,
718 f'rank{self.config.pretrained_config.mapping.rank}_managed_weights.safetensors'
719 )

Callers 15

mainFunction · 0.45
generate_outputFunction · 0.45
csv_to_npyFunction · 0.45
mainFunction · 0.45
build_gptFunction · 0.45
enc_dec_build_helperFunction · 0.45
test_llm_build_configFunction · 0.45
mainFunction · 0.45
llama_7b_pathFunction · 0.45
llama_7b_bs2_pathFunction · 0.45
llama_7b_tp2_pathFunction · 0.45

Calls 5

popMethod · 0.80
serialize_engineFunction · 0.70
to_dictMethod · 0.45
updateMethod · 0.45

Tested by 15

mainFunction · 0.36
test_llm_build_configFunction · 0.36
llama_7b_pathFunction · 0.36
llama_7b_bs2_pathFunction · 0.36
llama_7b_tp2_pathFunction · 0.36
engine_from_checkpointFunction · 0.36