MCPcopy Index your code
hub / github.com/zai-org/CodeGeeX2 / main

Function main

evaluation/generation.py:283–312  ·  view source on GitHub ↗
(args, node_rank: int, local_rank: int, master_port: int, num_devices: int)

Source from the content-addressed store, hash-verified

281
282
283def main(args, node_rank: int, local_rank: int, master_port: int, num_devices: int):
284 world_size = args.gen_node_world_size * num_devices
285 args.rank = num_devices * node_rank + local_rank
286 args.world_size = world_size
287 logger.info(f"Generating on rank {args.rank} of {args.world_size}")
288
289 try:
290 if args.model_name in ["codegeex2-6b"]:
291 tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
292 else:
293 tokenizer = AutoTokenizer.from_pretrained(args.model_path, clean_up_tokenization_spaces=False, trust_remote_code=True)
294 if args.model_name in ["codegeex2-6b"]:
295 model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True).to("cuda:{}".format(local_rank % torch.cuda.device_count()))
296 elif args.model_name in ["starcoder", "replit-code-v1-3b", "codegen25-7b-multi", "codegen25-7b-mono", "codegen-16B-multi"]:
297 model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True).to("cuda:{}".format(local_rank % torch.cuda.device_count()))
298 else:
299 try:
300 model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True).to("cuda:{}".format(local_rank % torch.cuda.device_count()))
301 except:
302 logger.error(f"Model {args.model_name} not supported.")
303 raise NotImplementedError
304 except Exception as e:
305 logger.error(e)
306
307 model = model.eval()
308 # Generate samples.
309 run_generation_distributed(args, model, tokenizer)
310
311 logger.info(f"rank={args.rank} worker finished, waiting ...")
312 exit(0)
313
314
315def server(args):

Callers

nothing calls this directly

Calls 3

infoMethod · 0.80
errorMethod · 0.80

Tested by

no test coverage detected