(batch_input_texts,
eval_task='summarize',
add_special_tokens=True,
min_input_length=0)
| 240 | ppls_hf = [[] for _ in range(num_sequences)] |
| 241 | |
| 242 | def _prepare_inputs(batch_input_texts, |
| 243 | eval_task='summarize', |
| 244 | add_special_tokens=True, |
| 245 | min_input_length=0): |
| 246 | batch_size = len(batch_input_texts) |
| 247 | append_str = ' TL;DR: ' if eval_task == 'summarize' else '' |
| 248 | batch_input_ids = [] |
| 249 | for i in range(batch_size): |
| 250 | curr_text = batch_input_texts[i] + append_str |
| 251 | curr_text = curr_text.strip().replace(" n't", "n't") |
| 252 | |
| 253 | # TODO: The below lines are used to be compatible with the original code; may need fix |
| 254 | if 'GLM' in model_name and model_version in ('chatglm2', |
| 255 | 'chatglm3'): |
| 256 | input_ids = tokenizer.encode(curr_text, |
| 257 | return_tensors='pt').squeeze(0) |
| 258 | input_ids = input_ids[:test_token_num] |
| 259 | elif 'qwen' in model_name.lower() and model_version == 'qwen': |
| 260 | # use make_content to generate prompt |
| 261 | system_prompt = "You are a useful assistant, please directly output the corresponding summary according to the article entered by the user." |
| 262 | _, input_id_list = make_context( |
| 263 | tokenizer=tokenizer, |
| 264 | query=curr_text, |
| 265 | history=[], |
| 266 | system=system_prompt, |
| 267 | max_input_length=test_token_num, |
| 268 | ) |
| 269 | input_ids = torch.tensor(input_id_list) |
| 270 | else: |
| 271 | if 'qwen' in model_name.lower() and 'qwen2' in model_version: |
| 272 | messages = [{ |
| 273 | "role": |
| 274 | "system", |
| 275 | "content": |
| 276 | "You are a helpful assistant, please summarize the article entered by the user with one or two sentences." |
| 277 | }, { |
| 278 | "role": "user", |
| 279 | "content": curr_text |
| 280 | }] |
| 281 | curr_text = tokenizer.apply_chat_template( |
| 282 | messages, tokenize=False, add_generation_prompt=True) |
| 283 | input_ids = tokenizer.encode( |
| 284 | curr_text, |
| 285 | return_tensors='pt', |
| 286 | add_special_tokens=add_special_tokens, |
| 287 | truncation=True, |
| 288 | max_length=test_token_num).squeeze(0) |
| 289 | |
| 290 | if input_ids.numel() > min_input_length: |
| 291 | batch_input_ids.append(input_ids) |
| 292 | return batch_input_ids |
| 293 | |
| 294 | def eval_trt_llm(datapoint, |
| 295 | eval_task='summarize', |
no test coverage detected