Add an operation that performs a recv to a rank from another. The recv operation receives a tensor from on a rank from another. If a rank 'i' receives a tensor from a rank 'j', the rank 'j' must have a corresponding 'send' operation to rank 'j'. See 'send'. That operation is i
(tensor: Tensor, src: int)
| 4296 | |
| 4297 | |
| 4298 | def recv(tensor: Tensor, src: int) -> Tensor: |
| 4299 | ''' |
| 4300 | Add an operation that performs a recv to a rank from another. |
| 4301 | |
| 4302 | The recv operation receives a tensor from on a rank from another. If a rank 'i' |
| 4303 | receives a tensor from a rank 'j', the rank 'j' must have a corresponding 'send' |
| 4304 | operation to rank 'j'. See 'send'. |
| 4305 | |
| 4306 | That operation is implemented using a plugin that wraps the NCCL recv |
| 4307 | point-to-point operation. See |
| 4308 | https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclrecv |
| 4309 | for details. |
| 4310 | |
| 4311 | Parameters: |
| 4312 | tensor : Tensor |
| 4313 | The input tensor. |
| 4314 | |
| 4315 | src : int |
| 4316 | The rank that sends the tensor to. |
| 4317 | |
| 4318 | Returns: |
| 4319 | The tensor produced by that layer. |
| 4320 | ''' |
| 4321 | recv_plg_creator = trt.get_plugin_registry().get_plugin_creator( |
| 4322 | 'Recv', '1', TRT_LLM_PLUGIN_NAMESPACE) |
| 4323 | assert recv_plg_creator is not None |
| 4324 | |
| 4325 | src = trt.PluginField("src_rank", np.array(src, dtype=np.int32), |
| 4326 | trt.PluginFieldType.INT32) |
| 4327 | p_dtype = default_net().plugin_config.nccl_plugin |
| 4328 | pf_type = trt.PluginField( |
| 4329 | "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32), |
| 4330 | trt.PluginFieldType.INT32) |
| 4331 | |
| 4332 | pfc = trt.PluginFieldCollection([src, pf_type]) |
| 4333 | recv_plug = recv_plg_creator.create_plugin("recv", pfc) |
| 4334 | plug_inputs = [tensor.cast(p_dtype).trt_tensor] |
| 4335 | |
| 4336 | layer = default_trtnet().add_plugin_v2(plug_inputs, recv_plug) |
| 4337 | _add_plugin_info(layer, recv_plg_creator, "recv", pfc) |
| 4338 | return _create_tensor(layer.get_output(0), layer).cast(tensor.dtype) |
| 4339 | |
| 4340 | |
| 4341 | def gemm_allreduce(a: Tensor, |