| 4229 | |
| 4230 | |
| 4231 | def reduce_scatter(tensor: Tensor, group: List[int]) -> Tensor: |
| 4232 | |
| 4233 | plg_creater = trt.get_plugin_registry().get_plugin_creator( |
| 4234 | 'ReduceScatter', '1', TRT_LLM_PLUGIN_NAMESPACE) |
| 4235 | assert plg_creater is not None |
| 4236 | |
| 4237 | p_dtype = default_net().plugin_config.nccl_plugin |
| 4238 | pf_type = trt.PluginField( |
| 4239 | "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32), |
| 4240 | trt.PluginFieldType.INT32) |
| 4241 | group = trt.PluginField("group", np.array(group, dtype=np.int32), |
| 4242 | trt.PluginFieldType.INT32) |
| 4243 | pfc = trt.PluginFieldCollection([group, pf_type]) |
| 4244 | |
| 4245 | reduce_scatter_plug = plg_creater.create_plugin("reduce_scatter", pfc) |
| 4246 | plug_inputs = [tensor.cast(p_dtype).trt_tensor] |
| 4247 | |
| 4248 | layer = default_trtnet().add_plugin_v2(plug_inputs, reduce_scatter_plug) |
| 4249 | _add_plugin_info(layer, plg_creater, "reduce_scatter", pfc) |
| 4250 | |
| 4251 | return _create_tensor(layer.get_output(0), layer).cast(tensor.dtype) |
| 4252 | |
| 4253 | |
| 4254 | def send(tensor: Tensor, tgt: int) -> Tensor: |