MCPcopy Index your code
hub / github.com/zai-org/CodeGeeX2 / process

Function process

evaluation/generation.py:195–243  ·  view source on GitHub ↗
(obj)

Source from the content-addressed store, hash-verified

193 )
194
195 def process(obj):
196 results = []
197 prompt = obj["prompt"]
198 if args.generation_mode == "instruction":
199 inputs = tokenizer([prompt] * args.micro_batch_size, return_tensors="pt")
200 inputs = inputs.to(model.device)
201 outputs = model.generate(**inputs,
202 max_length=args.max_length,
203 do_sample=True if not args.greedy else False,
204 use_cache=True,
205 top_p=args.top_p,
206 top_k=args.top_k,
207 temperature=args.temperature,
208 pad_token_id=tokenizer.eos_token_id)
209 for i, output in enumerate(outputs):
210 response = tokenizer.decode(output)
211 res = obj.copy()
212 res["generation"] = response[len(prompt):].strip()
213 results.append(res)
214 elif args.generation_mode == "completion":
215 inputs = tokenizer([prompt for _ in range(args.micro_batch_size)], return_tensors="pt")
216 inputs = inputs.to(model.device)
217 stop_criteria = CodeStoppingCriteria(
218 max_length=args.max_length,
219 micro_batch_size=args.micro_batch_size,
220 tokenizer=tokenizer,
221 dataset_type=args.dataset_type,
222 language_type=args.language_type,
223 prompt=prompt)
224 outputs = model.generate(**inputs,
225 max_length=args.max_length,
226 do_sample=True if not args.greedy else False,
227 use_cache=True,
228 stopping_criteria=[stop_criteria],
229 top_p=args.top_p,
230 top_k=args.top_k,
231 temperature=args.temperature,
232 pad_token_id=tokenizer.eos_token_id)
233 for i, output in enumerate(outputs):
234 response = tokenizer.decode(output)
235 res = obj.copy()
236 res["generation_raw"] = response
237 res["generation"] = cleanup_code(
238 response[len(prompt):],
239 dataset_type=args.dataset_type,
240 language_type=args.language_type)
241 results.append(res)
242
243 return results
244
245 fout = open(output_path, "w", encoding="utf-8")
246 while True:

Callers 1

Calls 2

cleanup_codeFunction · 0.90

Tested by

no test coverage detected