MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / from_hugging_face

Method from_hugging_face

tensorrt_llm/models/bert/model.py:150–201  ·  view source on GitHub ↗

Create a BertModel object from give parameters

(
            cls,
            hf_model_or_dir: Union[str, 'transformers.PreTrainedModel'],
            dtype: str = 'float16',
            mapping: Optional[Mapping] = None,
            quant_config: Optional[QuantConfig] = None,
            **kwargs)

Source from the content-addressed store, hash-verified

148
149 @classmethod
150 def from_hugging_face(
151 cls,
152 hf_model_or_dir: Union[str, 'transformers.PreTrainedModel'],
153 dtype: str = 'float16',
154 mapping: Optional[Mapping] = None,
155 quant_config: Optional[QuantConfig] = None,
156 **kwargs):
157 """
158 Create a BertModel object from give parameters
159 """
160 import transformers
161
162 assert hf_model_or_dir is not None
163 use_preloading = isinstance(hf_model_or_dir,
164 transformers.PreTrainedModel)
165 if use_preloading:
166 hf_model = hf_model_or_dir
167 hf_config_or_dir = hf_model.config
168 else:
169 hf_model_dir = hf_model_or_dir
170 hf_config_or_dir = hf_model_or_dir
171
172 load_model_on_cpu = kwargs.pop('load_model_on_cpu', False)
173 tllm_config = BERTConfig.from_hugging_face(
174 hf_config_or_dir=hf_config_or_dir,
175 dtype=dtype,
176 mapping=mapping,
177 quant_config=quant_config,
178 **kwargs)
179 #NOTE: override architecture info
180 RobertaCls_mapping = {
181 "BertModel": "RobertaModel",
182 "BertForQuestionAnswering": "RobertaForQuestionAnswering",
183 "BertForSequenceClassification": "RobertaForSequenceClassification",
184 }
185 if tllm_config.is_roberta:
186 setattr(tllm_config, 'architecture',
187 RobertaCls_mapping[cls.__name__])
188 else:
189 setattr(tllm_config, 'architecture', cls.__name__)
190
191 torch_dtype = torch.float16 if dtype == 'float16' else torch.float32
192 if not use_preloading:
193 hf_model = cls.load_hf_bert(model_dir=hf_model_dir,
194 load_model_on_cpu=load_model_on_cpu,
195 dtype=torch_dtype)
196 weights = load_weights_from_hf_model(hf_model=hf_model,
197 config=tllm_config)
198 model = cls(tllm_config)
199 model.load(weights)
200
201 return model
202
203 # Override the PretrainedModel's meothd, can unify in the future.
204 def prepare_inputs(self, max_batch_size, max_input_len, **kwargs):

Callers 15

build_from_hfFunction · 0.45
engine_from_checkpointFunction · 0.45
_from_hf_modelMethod · 0.45
test_weights_loaderMethod · 0.45
test_save_loadFunction · 0.45
test_async_ioFunction · 0.45
build_and_run_tp2Function · 0.45

Calls 4

popMethod · 0.80
load_hf_bertMethod · 0.80
loadMethod · 0.45

Tested by 11

engine_from_checkpointFunction · 0.36
_from_hf_modelMethod · 0.36
test_weights_loaderMethod · 0.36
test_save_loadFunction · 0.36
test_async_ioFunction · 0.36
build_and_run_tp2Function · 0.36