MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / reduce_scatter

Function reduce_scatter

tensorrt_llm/functional.py:4231–4251  ·  view source on GitHub ↗
(tensor: Tensor, group: List[int])

Source from the content-addressed store, hash-verified

4229
4230
4231def reduce_scatter(tensor: Tensor, group: List[int]) -> Tensor:
4232
4233 plg_creater = trt.get_plugin_registry().get_plugin_creator(
4234 'ReduceScatter', '1', TRT_LLM_PLUGIN_NAMESPACE)
4235 assert plg_creater is not None
4236
4237 p_dtype = default_net().plugin_config.nccl_plugin
4238 pf_type = trt.PluginField(
4239 "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
4240 trt.PluginFieldType.INT32)
4241 group = trt.PluginField("group", np.array(group, dtype=np.int32),
4242 trt.PluginFieldType.INT32)
4243 pfc = trt.PluginFieldCollection([group, pf_type])
4244
4245 reduce_scatter_plug = plg_creater.create_plugin("reduce_scatter", pfc)
4246 plug_inputs = [tensor.cast(p_dtype).trt_tensor]
4247
4248 layer = default_trtnet().add_plugin_v2(plug_inputs, reduce_scatter_plug)
4249 _add_plugin_info(layer, plg_creater, "reduce_scatter", pfc)
4250
4251 return _create_tensor(layer.get_output(0), layer).cast(tensor.dtype)
4252
4253
4254def send(tensor: Tensor, tgt: int) -> Tensor:

Callers 2

forward_reduce_scatterFunction · 0.90
forward_allreduceMethod · 0.85

Calls 8

default_netFunction · 0.85
str_dtype_to_trtFunction · 0.85
default_trtnetFunction · 0.85
_add_plugin_infoFunction · 0.85
_create_tensorFunction · 0.85
create_pluginMethod · 0.80
castMethod · 0.80
get_outputMethod · 0.45

Tested by 1

forward_reduce_scatterFunction · 0.72