| 202 | if MLX_DEBUG >= 4: print(f"mlx5 {self.devfmt}: HCA initialized with gen_caps={gen_cur} roce_caps={roce_cur}") |
| 203 | |
| 204 | class MLXQP: |
| 205 | def __init__(self, dev:MLXDev, log_sq_size=4, log_rq_size=4, log_eq_size=7, log_cq_size=7): |
| 206 | self.dev, self.cq_size, self.log_sq_size, self.log_rq_size, self.head = dev, 1 << log_cq_size, log_sq_size, log_rq_size, 0 |
| 207 | |
| 208 | self.cq_dbr, self.qp_dbr = dev.dbr_alloc.alloc(8, alignment=8), dev.dbr_alloc.alloc(8, alignment=8) |
| 209 | |
| 210 | # create EQ, CQ |
| 211 | self.eq_mem, self.eq_paddrs, self.eq_info = self.create_queue(mlx5.MLX5_CMD_OP_CREATE_EQ, log_eq_size, entry_sz=64, owner_off=31, |
| 212 | eq_context_entry=dict(log_eq_size=log_eq_size, uar_page=dev.uar, log_page_size=0)) |
| 213 | |
| 214 | self.cq_mem, self.cq_paddrs, self.cq_info = self.create_queue(mlx5.MLX5_CMD_OP_CREATE_CQ, log_cq_size, entry_sz=64, owner_off=63, |
| 215 | cq_context=dict(log_cq_size=log_cq_size, uar_page=dev.uar, c_eqn_or_apu_element=self.eq_info['eq_number'], |
| 216 | dbr_addr=dev.dbr_paddrs[0] + self.cq_dbr, log_page_size=0)) |
| 217 | |
| 218 | # create QP, buffer is RQ (16B stride) + SQ (64B stride) |
| 219 | self.sq_offset = (1 << log_rq_size) << 4 |
| 220 | self.qp_buf, self.qp_paddrs, self.qp_info = self.create_queue(mlx5.MLX5_CMD_OP_CREATE_QP, log_sq_size, entry_sz=64, |
| 221 | owner_off=0, extra_sz=self.sq_offset, |
| 222 | qpc=dict(st=0, pm_state=3, pd=dev.pd, cqn_snd=self.cq_info['cqn'], cqn_rcv=self.cq_info['cqn'], log_msg_max=30, log_rq_size=log_rq_size, |
| 223 | log_rq_stride=0, log_sq_size=log_sq_size, rlky=1, uar_page=dev.uar, log_page_size=0, dbr_addr=dev.dbr_paddrs[0] + self.qp_dbr)) |
| 224 | |
| 225 | # transition to INIT |
| 226 | self.qp_op(mlx5.MLX5_CMD_OP_RST2INIT_QP, qpc_args=dict(log_ack_req_freq=8), addr_args=dict(pkey_index=0, vhca_port_num=1)) |
| 227 | |
| 228 | for i in range(self.cq_size): self.cq_mem[i * 64 + 63] = 0x01 # init owner bits so poll_cq waits for real CQEs |
| 229 | if MLX_DEBUG >= 1: print(f"mlx5: QP 0x{self.qp_info['qpn']:x} (EQ={self.eq_info['eq_number']} CQ=0x{self.cq_info['cqn']:x})") |
| 230 | |
| 231 | def create_queue(self, opcode, log_size, entry_sz, owner_off, extra_sz=0, **ctx_kw): |
| 232 | mem, paddrs = self.dev.pci_dev.alloc_sysmem((n := ceildiv((1 << log_size) * entry_sz + extra_sz, 0x1000)) * 0x1000) |
| 233 | return mem, paddrs, self.dev.cmd.exec(opcode, payload=struct.pack(f'>{n}Q', *paddrs), **ctx_kw) |
| 234 | |
| 235 | def qp_op(self, opcode, qpc_args=None, addr_args=None, **kwargs): |
| 236 | qpc_args = dict(st=0, pm_state=3, pd=self.dev.pd, cqn_snd=self.cq_info['cqn'], cqn_rcv=self.cq_info['cqn'], **(qpc_args or {})) |
| 237 | self.dev.cmd.exec(opcode, qpn=self.qp_info['qpn'], qpc=(qpc_args or {}) | {'primary_address_path': addr_args or {}}, **kwargs) |
| 238 | |
| 239 | def connect(self, remote:MLXQP): |
| 240 | self.qp_op(mlx5.MLX5_CMD_OP_INIT2RTR_QP, opt_param_mask=0x1A, |
| 241 | qpc_args=dict(mtu=5, log_msg_max=self.dev.caps['log_max_msg'], remote_qpn=remote.qp_info['qpn'], log_ack_req_freq=8, |
| 242 | log_rra_max=3, rre=1, rwe=1, min_rnr_nak=1, next_rcv_psn=0), |
| 243 | addr_args=dict(pkey_index=0, src_addr_index=0, hop_limit=64, udp_sport=udp_sport(self.qp_info['qpn'], remote.qp_info['qpn']), vhca_port_num=1, |
| 244 | rmac_47_32=hi32(remote.dev.mac), rmac_31_0=lo32(remote.dev.mac), rgid_rip=int.from_bytes(remote.dev.local_gid, 'big'))) |
| 245 | self.qp_op(mlx5.MLX5_CMD_OP_RTR2RTS_QP, qpc_args=dict(log_ack_req_freq=8, next_send_psn=0, log_sra_max=3, retry_count=7, rnr_retry=7), |
| 246 | addr_args=dict(ack_timeout=14, vhca_port_num=1)) |
| 247 | |
| 248 | if MLX_DEBUG >= 1: print(f"mlx5: QP 0x{self.qp_info['qpn']:x} connected (remote=0x{remote.qp_info['qpn']:x})") |
no outgoing calls
no test coverage detected
searching dependent graphs…