MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / load_from_model_dir

Method load_from_model_dir

tensorrt_llm/lora_manager.py:1040–1173  ·  view source on GitHub ↗
(uid, model_dir, hf_config)

Source from the content-addressed store, hash-verified

1038 return t_out
1039
1040 def load_from_model_dir(uid, model_dir, hf_config):
1041 if uid not in self._cpp_lora_weights:
1042 self._cpp_lora_weights[uid] = [] # Will be converted to tensor later
1043 if uid not in self._cpp_lora_config:
1044 self._cpp_lora_config[uid] = [] # Will be converted to tensor later
1045
1046 lora_model = load_state_dict(get_model_path(model_dir, "adapter_model"))
1047 if lora_model is None:
1048 raise ValueError(f"Failed to load adapter_model from {model_dir}")
1049 lora_model = preprocess_lora_weights(lora_model, model_config)
1050 all_weights = get_all_hf_lora_weights(lora_model, hf_modules, component)
1051 rank = int(hf_config["r"])
1052 rs_lora = bool(hf_config.get("use_rslora", False))
1053
1054 self._lora_uid_to_low_ranks[uid] = {}
1055 self._lora_weights_pointers_list[uid] = {}
1056 for layer_idx in sorted(all_weights.keys()):
1057 layer_weights = all_weights[layer_idx]
1058 self._lora_uid_to_low_ranks[uid][layer_idx] = {}
1059 self._lora_weights_pointers_list[uid][layer_idx] = {}
1060
1061 for lora_module in self.missing_qkv_modules:
1062 hf_module = model_config.trtllm_modules_to_hf_modules[lora_module]
1063 if isinstance(hf_module, list):
1064 hf_module = hf_module[0]
1065 layer_weights[hf_module] = {
1066 "in": torch.zeros(rank, model_config.hidden_size),
1067 "out": torch.zeros(model_config.hidden_size, rank),
1068 }
1069
1070 for hf_module, module_weights in layer_weights.items():
1071 lora_module = hf_modules_to_trtllm_modules[hf_module]
1072 if lora_module not in self.lora_target_modules:
1073 warnings.warn(
1074 f"LoRA module '{lora_module}' not in target modules {self.lora_target_modules}, skipping."
1075 )
1076 self._lora_uid_to_low_ranks[uid][layer_idx][lora_module] = 0
1077 continue
1078
1079 has_expert_indices = _is_moe_module_weights(module_weights)
1080
1081 if has_expert_indices: # MoE
1082 # Validate and extract matrices in one pass
1083 expert_indices = sorted(module_weights.keys())
1084 t_in_list, t_out_list = [], []
1085 for expert_idx in expert_indices:
1086 expert_weights = module_weights[expert_idx]
1087 _check_lora_in_out(
1088 layer_idx=layer_idx,
1089 lora_module=f"{lora_module}_expert_{expert_idx}",
1090 available_matrices=expert_weights,
1091 source_identifier=f"directory {model_dir}",
1092 )
1093 t_in_list.append(expert_weights["in"])
1094 t_out_list.append(expert_weights["out"])
1095
1096 t_in = torch.stack(t_in_list)
1097 t_out = torch.stack(t_out_list)

Callers

nothing calls this directly

Calls 15

load_state_dictFunction · 0.85
preprocess_lora_weightsFunction · 0.85
get_all_hf_lora_weightsFunction · 0.85
_is_moe_module_weightsFunction · 0.85
_check_lora_in_outFunction · 0.85
str_dtype_to_torchFunction · 0.85
maxFunction · 0.85
sqrtMethod · 0.80
flattenMethod · 0.80
get_model_pathFunction · 0.50
getMethod · 0.45
keysMethod · 0.45

Tested by

no test coverage detected