MCPcopy
hub / github.com/tinygrad/tinygrad / MLXQP

Class MLXQP

tinygrad/runtime/support/mlx/mlxdev.py:204–248  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

202 if MLX_DEBUG >= 4: print(f"mlx5 {self.devfmt}: HCA initialized with gen_caps={gen_cur} roce_caps={roce_cur}")
203
204class MLXQP:
205 def __init__(self, dev:MLXDev, log_sq_size=4, log_rq_size=4, log_eq_size=7, log_cq_size=7):
206 self.dev, self.cq_size, self.log_sq_size, self.log_rq_size, self.head = dev, 1 << log_cq_size, log_sq_size, log_rq_size, 0
207
208 self.cq_dbr, self.qp_dbr = dev.dbr_alloc.alloc(8, alignment=8), dev.dbr_alloc.alloc(8, alignment=8)
209
210 # create EQ, CQ
211 self.eq_mem, self.eq_paddrs, self.eq_info = self.create_queue(mlx5.MLX5_CMD_OP_CREATE_EQ, log_eq_size, entry_sz=64, owner_off=31,
212 eq_context_entry=dict(log_eq_size=log_eq_size, uar_page=dev.uar, log_page_size=0))
213
214 self.cq_mem, self.cq_paddrs, self.cq_info = self.create_queue(mlx5.MLX5_CMD_OP_CREATE_CQ, log_cq_size, entry_sz=64, owner_off=63,
215 cq_context=dict(log_cq_size=log_cq_size, uar_page=dev.uar, c_eqn_or_apu_element=self.eq_info['eq_number'],
216 dbr_addr=dev.dbr_paddrs[0] + self.cq_dbr, log_page_size=0))
217
218 # create QP, buffer is RQ (16B stride) + SQ (64B stride)
219 self.sq_offset = (1 << log_rq_size) << 4
220 self.qp_buf, self.qp_paddrs, self.qp_info = self.create_queue(mlx5.MLX5_CMD_OP_CREATE_QP, log_sq_size, entry_sz=64,
221 owner_off=0, extra_sz=self.sq_offset,
222 qpc=dict(st=0, pm_state=3, pd=dev.pd, cqn_snd=self.cq_info['cqn'], cqn_rcv=self.cq_info['cqn'], log_msg_max=30, log_rq_size=log_rq_size,
223 log_rq_stride=0, log_sq_size=log_sq_size, rlky=1, uar_page=dev.uar, log_page_size=0, dbr_addr=dev.dbr_paddrs[0] + self.qp_dbr))
224
225 # transition to INIT
226 self.qp_op(mlx5.MLX5_CMD_OP_RST2INIT_QP, qpc_args=dict(log_ack_req_freq=8), addr_args=dict(pkey_index=0, vhca_port_num=1))
227
228 for i in range(self.cq_size): self.cq_mem[i * 64 + 63] = 0x01 # init owner bits so poll_cq waits for real CQEs
229 if MLX_DEBUG >= 1: print(f"mlx5: QP 0x{self.qp_info['qpn']:x} (EQ={self.eq_info['eq_number']} CQ=0x{self.cq_info['cqn']:x})")
230
231 def create_queue(self, opcode, log_size, entry_sz, owner_off, extra_sz=0, **ctx_kw):
232 mem, paddrs = self.dev.pci_dev.alloc_sysmem((n := ceildiv((1 << log_size) * entry_sz + extra_sz, 0x1000)) * 0x1000)
233 return mem, paddrs, self.dev.cmd.exec(opcode, payload=struct.pack(f'>{n}Q', *paddrs), **ctx_kw)
234
235 def qp_op(self, opcode, qpc_args=None, addr_args=None, **kwargs):
236 qpc_args = dict(st=0, pm_state=3, pd=self.dev.pd, cqn_snd=self.cq_info['cqn'], cqn_rcv=self.cq_info['cqn'], **(qpc_args or {}))
237 self.dev.cmd.exec(opcode, qpn=self.qp_info['qpn'], qpc=(qpc_args or {}) | {'primary_address_path': addr_args or {}}, **kwargs)
238
239 def connect(self, remote:MLXQP):
240 self.qp_op(mlx5.MLX5_CMD_OP_INIT2RTR_QP, opt_param_mask=0x1A,
241 qpc_args=dict(mtu=5, log_msg_max=self.dev.caps['log_max_msg'], remote_qpn=remote.qp_info['qpn'], log_ack_req_freq=8,
242 log_rra_max=3, rre=1, rwe=1, min_rnr_nak=1, next_rcv_psn=0),
243 addr_args=dict(pkey_index=0, src_addr_index=0, hop_limit=64, udp_sport=udp_sport(self.qp_info['qpn'], remote.qp_info['qpn']), vhca_port_num=1,
244 rmac_47_32=hi32(remote.dev.mac), rmac_31_0=lo32(remote.dev.mac), rgid_rip=int.from_bytes(remote.dev.local_gid, 'big')))
245 self.qp_op(mlx5.MLX5_CMD_OP_RTR2RTS_QP, qpc_args=dict(log_ack_req_freq=8, next_send_psn=0, log_sra_max=3, retry_count=7, rnr_retry=7),
246 addr_args=dict(ack_timeout=14, vhca_port_num=1))
247
248 if MLX_DEBUG >= 1: print(f"mlx5: QP 0x{self.qp_info['qpn']:x} connected (remote=0x{remote.qp_info['qpn']:x})")

Callers 3

connectMethod · 0.90
connect.pyFile · 0.90
loopback.pyFile · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…