| 10 | |
| 11 | |
| 12 | def test_model_inference(world_size, weight_dir, quant_type=None, batch_size=1, image_size=448): |
| 13 | workers = [] |
| 14 | for rank_id in range(world_size): |
| 15 | kvargs = { |
| 16 | "vit_tp": world_size, |
| 17 | "tp_rank_id": rank_id, |
| 18 | "vit_rank_id": rank_id, |
| 19 | "visual_gpu_ids": list(range(world_size)), |
| 20 | "visual_nccl_port": 28766, |
| 21 | "weight_dir": weight_dir, |
| 22 | "data_type": "bf16", |
| 23 | "quant_type": quant_type, |
| 24 | "quant_cfg": None, |
| 25 | } |
| 26 | |
| 27 | proc = multiprocessing.Process(target=tppart_model_infer, args=(kvargs, batch_size, image_size)) |
| 28 | proc.start() |
| 29 | workers.append(proc) |
| 30 | |
| 31 | for proc in workers: |
| 32 | proc.join() |
| 33 | return |
| 34 | |
| 35 | |
| 36 | def tppart_model_infer(model_kvargs, batch_size, image_size): |