MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / recv

Function recv

tensorrt_llm/functional.py:4298–4338  ·  view source on GitHub ↗

Add an operation that performs a recv to a rank from another. The recv operation receives a tensor from on a rank from another. If a rank 'i' receives a tensor from a rank 'j', the rank 'j' must have a corresponding 'send' operation to rank 'j'. See 'send'. That operation is i

(tensor: Tensor, src: int)

Source from the content-addressed store, hash-verified

4296
4297
4298def recv(tensor: Tensor, src: int) -> Tensor:
4299 '''
4300 Add an operation that performs a recv to a rank from another.
4301
4302 The recv operation receives a tensor from on a rank from another. If a rank 'i'
4303 receives a tensor from a rank 'j', the rank 'j' must have a corresponding 'send'
4304 operation to rank 'j'. See 'send'.
4305
4306 That operation is implemented using a plugin that wraps the NCCL recv
4307 point-to-point operation. See
4308 https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclrecv
4309 for details.
4310
4311 Parameters:
4312 tensor : Tensor
4313 The input tensor.
4314
4315 src : int
4316 The rank that sends the tensor to.
4317
4318 Returns:
4319 The tensor produced by that layer.
4320 '''
4321 recv_plg_creator = trt.get_plugin_registry().get_plugin_creator(
4322 'Recv', '1', TRT_LLM_PLUGIN_NAMESPACE)
4323 assert recv_plg_creator is not None
4324
4325 src = trt.PluginField("src_rank", np.array(src, dtype=np.int32),
4326 trt.PluginFieldType.INT32)
4327 p_dtype = default_net().plugin_config.nccl_plugin
4328 pf_type = trt.PluginField(
4329 "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
4330 trt.PluginFieldType.INT32)
4331
4332 pfc = trt.PluginFieldCollection([src, pf_type])
4333 recv_plug = recv_plg_creator.create_plugin("recv", pfc)
4334 plug_inputs = [tensor.cast(p_dtype).trt_tensor]
4335
4336 layer = default_trtnet().add_plugin_v2(plug_inputs, recv_plug)
4337 _add_plugin_info(layer, recv_plg_creator, "recv", pfc)
4338 return _create_tensor(layer.get_output(0), layer).cast(tensor.dtype)
4339
4340
4341def gemm_allreduce(a: Tensor,

Callers 15

forwardMethod · 0.90
forwardMethod · 0.90
forwardMethod · 0.90
forwardMethod · 0.90
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85

Calls 8

default_netFunction · 0.85
str_dtype_to_trtFunction · 0.85
default_trtnetFunction · 0.85
_add_plugin_infoFunction · 0.85
_create_tensorFunction · 0.85
create_pluginMethod · 0.80
castMethod · 0.80
get_outputMethod · 0.45

Tested by 1