MCPcopy Index your code
hub / github.com/zai-org/CodeGeeX2 / run_generation_distributed

Function run_generation_distributed

evaluation/generation.py:183–280  ·  view source on GitHub ↗
(args, model, tokenizer)

Source from the content-addressed store, hash-verified

181
182
183def run_generation_distributed(args, model, tokenizer):
184 logger.info(f"Connecting to tcp://{args.channel_ip}:{args.channel_port}")
185 context = zmq.Context()
186 socket = context.socket(zmq.REQ)
187 socket.connect(f"tcp://{args.channel_ip}:{args.channel_port}")
188
189 os.makedirs(args.output_path, exist_ok=True)
190 output_path = os.path.join(
191 args.output_path,
192 f"{args.task_name}-t{args.temperature}-topp{args.top_p}-ns{args.samples_per_problem}-rank{args.rank}.jsonl",
193 )
194
195 def process(obj):
196 results = []
197 prompt = obj["prompt"]
198 if args.generation_mode == "instruction":
199 inputs = tokenizer([prompt] * args.micro_batch_size, return_tensors="pt")
200 inputs = inputs.to(model.device)
201 outputs = model.generate(**inputs,
202 max_length=args.max_length,
203 do_sample=True if not args.greedy else False,
204 use_cache=True,
205 top_p=args.top_p,
206 top_k=args.top_k,
207 temperature=args.temperature,
208 pad_token_id=tokenizer.eos_token_id)
209 for i, output in enumerate(outputs):
210 response = tokenizer.decode(output)
211 res = obj.copy()
212 res["generation"] = response[len(prompt):].strip()
213 results.append(res)
214 elif args.generation_mode == "completion":
215 inputs = tokenizer([prompt for _ in range(args.micro_batch_size)], return_tensors="pt")
216 inputs = inputs.to(model.device)
217 stop_criteria = CodeStoppingCriteria(
218 max_length=args.max_length,
219 micro_batch_size=args.micro_batch_size,
220 tokenizer=tokenizer,
221 dataset_type=args.dataset_type,
222 language_type=args.language_type,
223 prompt=prompt)
224 outputs = model.generate(**inputs,
225 max_length=args.max_length,
226 do_sample=True if not args.greedy else False,
227 use_cache=True,
228 stopping_criteria=[stop_criteria],
229 top_p=args.top_p,
230 top_k=args.top_k,
231 temperature=args.temperature,
232 pad_token_id=tokenizer.eos_token_id)
233 for i, output in enumerate(outputs):
234 response = tokenizer.decode(output)
235 res = obj.copy()
236 res["generation_raw"] = response
237 res["generation"] = cleanup_code(
238 response[len(prompt):],
239 dataset_type=args.dataset_type,
240 language_type=args.language_type)

Callers 1

mainFunction · 0.85

Calls 3

processFunction · 0.85
infoMethod · 0.80
errorMethod · 0.80

Tested by

no test coverage detected