(obj)
| 193 | ) |
| 194 | |
| 195 | def process(obj): |
| 196 | results = [] |
| 197 | prompt = obj["prompt"] |
| 198 | if args.generation_mode == "instruction": |
| 199 | inputs = tokenizer([prompt] * args.micro_batch_size, return_tensors="pt") |
| 200 | inputs = inputs.to(model.device) |
| 201 | outputs = model.generate(**inputs, |
| 202 | max_length=args.max_length, |
| 203 | do_sample=True if not args.greedy else False, |
| 204 | use_cache=True, |
| 205 | top_p=args.top_p, |
| 206 | top_k=args.top_k, |
| 207 | temperature=args.temperature, |
| 208 | pad_token_id=tokenizer.eos_token_id) |
| 209 | for i, output in enumerate(outputs): |
| 210 | response = tokenizer.decode(output) |
| 211 | res = obj.copy() |
| 212 | res["generation"] = response[len(prompt):].strip() |
| 213 | results.append(res) |
| 214 | elif args.generation_mode == "completion": |
| 215 | inputs = tokenizer([prompt for _ in range(args.micro_batch_size)], return_tensors="pt") |
| 216 | inputs = inputs.to(model.device) |
| 217 | stop_criteria = CodeStoppingCriteria( |
| 218 | max_length=args.max_length, |
| 219 | micro_batch_size=args.micro_batch_size, |
| 220 | tokenizer=tokenizer, |
| 221 | dataset_type=args.dataset_type, |
| 222 | language_type=args.language_type, |
| 223 | prompt=prompt) |
| 224 | outputs = model.generate(**inputs, |
| 225 | max_length=args.max_length, |
| 226 | do_sample=True if not args.greedy else False, |
| 227 | use_cache=True, |
| 228 | stopping_criteria=[stop_criteria], |
| 229 | top_p=args.top_p, |
| 230 | top_k=args.top_k, |
| 231 | temperature=args.temperature, |
| 232 | pad_token_id=tokenizer.eos_token_id) |
| 233 | for i, output in enumerate(outputs): |
| 234 | response = tokenizer.decode(output) |
| 235 | res = obj.copy() |
| 236 | res["generation_raw"] = response |
| 237 | res["generation"] = cleanup_code( |
| 238 | response[len(prompt):], |
| 239 | dataset_type=args.dataset_type, |
| 240 | language_type=args.language_type) |
| 241 | results.append(res) |
| 242 | |
| 243 | return results |
| 244 | |
| 245 | fout = open(output_path, "w", encoding="utf-8") |
| 246 | while True: |
no test coverage detected