pd_handle_loop 主要负责与 pd master 进行注册连接,然后接收pd master发来的请求,然后 将推理结果转发给 pd master进行处理。
(manager: HttpServerManager, pd_master_obj: PD_Master_Obj)
| 60 | |
| 61 | |
| 62 | async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_Obj): |
| 63 | """ |
| 64 | pd_handle_loop 主要负责与 pd master 进行注册连接,然后接收pd master发来的请求,然后 |
| 65 | 将推理结果转发给 pd master进行处理。 |
| 66 | """ |
| 67 | # 创建转发队列 |
| 68 | forwarding_queue = AsyncQueue() |
| 69 | |
| 70 | while True: |
| 71 | forwarding_tokens_task = None |
| 72 | try: |
| 73 | uri = f"ws://{pd_master_obj.host_ip_port}/pd_register" |
| 74 | async with websockets.connect( |
| 75 | uri, max_size=get_lightllm_websocket_max_message_size(), max_queue=(2048 * 1024, 2048 * 1023) # 关键修改 |
| 76 | ) as websocket: |
| 77 | |
| 78 | sock = websocket.transport.get_extra_info("socket") |
| 79 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
| 80 | |
| 81 | args_dict = vars(manager.args) |
| 82 | args_dict["host"] = manager.host_ip |
| 83 | # 发送注册信息 |
| 84 | regist_json = { |
| 85 | "node_id": manager.args.pd_node_id, |
| 86 | "client_ip_port": f"{manager.host_ip}:{manager.args.port}", |
| 87 | "mode": manager.pd_mode.value, |
| 88 | "start_args": args_dict, |
| 89 | } |
| 90 | |
| 91 | await websocket.send(json.dumps(regist_json)) |
| 92 | logger.info(f"Sent registration JSON: {regist_json}") |
| 93 | |
| 94 | # 转发任务 |
| 95 | forwarding_tokens_task = asyncio.create_task(_up_tokens_to_pd_master(forwarding_queue, websocket)) |
| 96 | |
| 97 | # 接收 pd master 发来的请求,并推理后,将生成的token转发回pd master。 |
| 98 | while True: |
| 99 | recv_bytes = await websocket.recv() |
| 100 | obj = pickle.loads(recv_bytes) |
| 101 | if obj[0] == ObjType.REQ: |
| 102 | prompt, sampling_params, multimodal_params = obj[1] |
| 103 | asyncio.create_task( |
| 104 | _pd_process_generate(manager, prompt, sampling_params, multimodal_params, forwarding_queue) |
| 105 | ) |
| 106 | elif obj[0] == ObjType.ABORT: |
| 107 | group_req_id = obj[1] |
| 108 | await manager.abort(group_req_id) |
| 109 | else: |
| 110 | logger.error(f"recevie error obj {str(obj)}") |
| 111 | |
| 112 | except asyncio.CancelledError: |
| 113 | # 如果任务被取消,则退出循环 |
| 114 | logger.warning(f"forwarding_tokens_task {pd_master_obj} cancelled") |
| 115 | if forwarding_tokens_task is not None: |
| 116 | forwarding_tokens_task.cancel() |
| 117 | return |
| 118 | |
| 119 | except Exception as e: |
no test coverage detected