MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / _prepare_inputs

Function _prepare_inputs

examples/summarize.py:242–292  ·  view source on GitHub ↗
(batch_input_texts,
                        eval_task='summarize',
                        add_special_tokens=True,
                        min_input_length=0)

Source from the content-addressed store, hash-verified

240 ppls_hf = [[] for _ in range(num_sequences)]
241
242 def _prepare_inputs(batch_input_texts,
243 eval_task='summarize',
244 add_special_tokens=True,
245 min_input_length=0):
246 batch_size = len(batch_input_texts)
247 append_str = ' TL;DR: ' if eval_task == 'summarize' else ''
248 batch_input_ids = []
249 for i in range(batch_size):
250 curr_text = batch_input_texts[i] + append_str
251 curr_text = curr_text.strip().replace(" n't", "n't")
252
253 # TODO: The below lines are used to be compatible with the original code; may need fix
254 if 'GLM' in model_name and model_version in ('chatglm2',
255 'chatglm3'):
256 input_ids = tokenizer.encode(curr_text,
257 return_tensors='pt').squeeze(0)
258 input_ids = input_ids[:test_token_num]
259 elif 'qwen' in model_name.lower() and model_version == 'qwen':
260 # use make_content to generate prompt
261 system_prompt = "You are a useful assistant, please directly output the corresponding summary according to the article entered by the user."
262 _, input_id_list = make_context(
263 tokenizer=tokenizer,
264 query=curr_text,
265 history=[],
266 system=system_prompt,
267 max_input_length=test_token_num,
268 )
269 input_ids = torch.tensor(input_id_list)
270 else:
271 if 'qwen' in model_name.lower() and 'qwen2' in model_version:
272 messages = [{
273 "role":
274 "system",
275 "content":
276 "You are a helpful assistant, please summarize the article entered by the user with one or two sentences."
277 }, {
278 "role": "user",
279 "content": curr_text
280 }]
281 curr_text = tokenizer.apply_chat_template(
282 messages, tokenize=False, add_generation_prompt=True)
283 input_ids = tokenizer.encode(
284 curr_text,
285 return_tensors='pt',
286 add_special_tokens=add_special_tokens,
287 truncation=True,
288 max_length=test_token_num).squeeze(0)
289
290 if input_ids.numel() > min_input_length:
291 batch_input_ids.append(input_ids)
292 return batch_input_ids
293
294 def eval_trt_llm(datapoint,
295 eval_task='summarize',

Callers 2

eval_trt_llmFunction · 0.70
eval_hfFunction · 0.70

Calls 7

make_contextFunction · 0.90
replaceMethod · 0.80
squeezeMethod · 0.45
encodeMethod · 0.45
apply_chat_templateMethod · 0.45
numelMethod · 0.45
appendMethod · 0.45

Tested by

no test coverage detected