| 181 | |
| 182 | |
| 183 | def run_generation_distributed(args, model, tokenizer): |
| 184 | logger.info(f"Connecting to tcp://{args.channel_ip}:{args.channel_port}") |
| 185 | context = zmq.Context() |
| 186 | socket = context.socket(zmq.REQ) |
| 187 | socket.connect(f"tcp://{args.channel_ip}:{args.channel_port}") |
| 188 | |
| 189 | os.makedirs(args.output_path, exist_ok=True) |
| 190 | output_path = os.path.join( |
| 191 | args.output_path, |
| 192 | f"{args.task_name}-t{args.temperature}-topp{args.top_p}-ns{args.samples_per_problem}-rank{args.rank}.jsonl", |
| 193 | ) |
| 194 | |
| 195 | def process(obj): |
| 196 | results = [] |
| 197 | prompt = obj["prompt"] |
| 198 | if args.generation_mode == "instruction": |
| 199 | inputs = tokenizer([prompt] * args.micro_batch_size, return_tensors="pt") |
| 200 | inputs = inputs.to(model.device) |
| 201 | outputs = model.generate(**inputs, |
| 202 | max_length=args.max_length, |
| 203 | do_sample=True if not args.greedy else False, |
| 204 | use_cache=True, |
| 205 | top_p=args.top_p, |
| 206 | top_k=args.top_k, |
| 207 | temperature=args.temperature, |
| 208 | pad_token_id=tokenizer.eos_token_id) |
| 209 | for i, output in enumerate(outputs): |
| 210 | response = tokenizer.decode(output) |
| 211 | res = obj.copy() |
| 212 | res["generation"] = response[len(prompt):].strip() |
| 213 | results.append(res) |
| 214 | elif args.generation_mode == "completion": |
| 215 | inputs = tokenizer([prompt for _ in range(args.micro_batch_size)], return_tensors="pt") |
| 216 | inputs = inputs.to(model.device) |
| 217 | stop_criteria = CodeStoppingCriteria( |
| 218 | max_length=args.max_length, |
| 219 | micro_batch_size=args.micro_batch_size, |
| 220 | tokenizer=tokenizer, |
| 221 | dataset_type=args.dataset_type, |
| 222 | language_type=args.language_type, |
| 223 | prompt=prompt) |
| 224 | outputs = model.generate(**inputs, |
| 225 | max_length=args.max_length, |
| 226 | do_sample=True if not args.greedy else False, |
| 227 | use_cache=True, |
| 228 | stopping_criteria=[stop_criteria], |
| 229 | top_p=args.top_p, |
| 230 | top_k=args.top_k, |
| 231 | temperature=args.temperature, |
| 232 | pad_token_id=tokenizer.eos_token_id) |
| 233 | for i, output in enumerate(outputs): |
| 234 | response = tokenizer.decode(output) |
| 235 | res = obj.copy() |
| 236 | res["generation_raw"] = response |
| 237 | res["generation"] = cleanup_code( |
| 238 | response[len(prompt):], |
| 239 | dataset_type=args.dataset_type, |
| 240 | language_type=args.language_type) |