(args, node_rank: int, local_rank: int, master_port: int, num_devices: int)
| 281 | |
| 282 | |
| 283 | def main(args, node_rank: int, local_rank: int, master_port: int, num_devices: int): |
| 284 | world_size = args.gen_node_world_size * num_devices |
| 285 | args.rank = num_devices * node_rank + local_rank |
| 286 | args.world_size = world_size |
| 287 | logger.info(f"Generating on rank {args.rank} of {args.world_size}") |
| 288 | |
| 289 | try: |
| 290 | if args.model_name in ["codegeex2-6b"]: |
| 291 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) |
| 292 | else: |
| 293 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, clean_up_tokenization_spaces=False, trust_remote_code=True) |
| 294 | if args.model_name in ["codegeex2-6b"]: |
| 295 | model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True).to("cuda:{}".format(local_rank % torch.cuda.device_count())) |
| 296 | elif args.model_name in ["starcoder", "replit-code-v1-3b", "codegen25-7b-multi", "codegen25-7b-mono", "codegen-16B-multi"]: |
| 297 | model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True).to("cuda:{}".format(local_rank % torch.cuda.device_count())) |
| 298 | else: |
| 299 | try: |
| 300 | model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True).to("cuda:{}".format(local_rank % torch.cuda.device_count())) |
| 301 | except: |
| 302 | logger.error(f"Model {args.model_name} not supported.") |
| 303 | raise NotImplementedError |
| 304 | except Exception as e: |
| 305 | logger.error(e) |
| 306 | |
| 307 | model = model.eval() |
| 308 | # Generate samples. |
| 309 | run_generation_distributed(args, model, tokenizer) |
| 310 | |
| 311 | logger.info(f"rank={args.rank} worker finished, waiting ...") |
| 312 | exit(0) |
| 313 | |
| 314 | |
| 315 | def server(args): |
nothing calls this directly
no test coverage detected