(
network: trt.INetworkDefinition,
tensor: trt.ITensor,
workspace: Optional[trt.ITensor],
group: np.array,
dtype: trt.DataType,
all_reduce_params: AllReduceParams,
)
| 3978 | |
| 3979 | |
| 3980 | def create_allreduce_plugin( |
| 3981 | network: trt.INetworkDefinition, |
| 3982 | tensor: trt.ITensor, |
| 3983 | workspace: Optional[trt.ITensor], |
| 3984 | group: np.array, |
| 3985 | dtype: trt.DataType, |
| 3986 | all_reduce_params: AllReduceParams, |
| 3987 | ): |
| 3988 | allreduce_plg_creator = trt.get_plugin_registry().get_plugin_creator( |
| 3989 | 'AllReduce', '1', TRT_LLM_PLUGIN_NAMESPACE) |
| 3990 | assert allreduce_plg_creator is not None |
| 3991 | |
| 3992 | pf_group = trt.PluginField("group", group, trt.PluginFieldType.INT32) |
| 3993 | pf_dtype = trt.PluginField("type_id", np.array([int(dtype)], np.int32), |
| 3994 | trt.PluginFieldType.INT32) |
| 3995 | pfc = [pf_group, pf_dtype] |
| 3996 | p_strategy = trt.PluginField( |
| 3997 | "strategy", np.array([int(all_reduce_params.strategy)], np.int8), |
| 3998 | trt.PluginFieldType.INT8) |
| 3999 | pfc.append(p_strategy) |
| 4000 | p_fusion_op = trt.PluginField( |
| 4001 | "fusion_op", np.array([int(all_reduce_params.fusion_op)], np.int8), |
| 4002 | trt.PluginFieldType.INT8) |
| 4003 | pfc.append(p_fusion_op) |
| 4004 | p_eps = trt.PluginField( |
| 4005 | "eps", np.array([float(all_reduce_params.eps)], np.float32), |
| 4006 | trt.PluginFieldType.FLOAT32) |
| 4007 | pfc.append(p_eps) |
| 4008 | p_affine = trt.PluginField( |
| 4009 | "affine", np.array([int(all_reduce_params.has_affine())], np.int8), |
| 4010 | trt.PluginFieldType.INT8) |
| 4011 | pfc.append(p_affine) |
| 4012 | p_bias = trt.PluginField( |
| 4013 | "bias", np.array([int(all_reduce_params.has_bias())], np.int8), |
| 4014 | trt.PluginFieldType.INT8) |
| 4015 | pfc.append(p_bias) |
| 4016 | p_scale = trt.PluginField( |
| 4017 | "scale", np.array([int(all_reduce_params.has_scale())], np.int8), |
| 4018 | trt.PluginFieldType.INT8) |
| 4019 | pfc.append(p_scale) |
| 4020 | |
| 4021 | pfc = trt.PluginFieldCollection(pfc) |
| 4022 | ar_plug = allreduce_plg_creator.create_plugin("allreduce", pfc) |
| 4023 | plug_inputs = [tensor] |
| 4024 | if all_reduce_params.strategy not in { |
| 4025 | AllReduceStrategy.NCCL, AllReduceStrategy.UB, |
| 4026 | AllReduceStrategy.NCCL_SYMMETRIC |
| 4027 | }: |
| 4028 | plug_inputs.append(workspace) |
| 4029 | if all_reduce_params.fusion_op != AllReduceFusionOp.NONE: |
| 4030 | if all_reduce_params.has_bias() == 1: |
| 4031 | plug_inputs.append(all_reduce_params.bias.trt_tensor) |
| 4032 | plug_inputs.append(all_reduce_params.residual.trt_tensor) |
| 4033 | if all_reduce_params.has_affine() == 1: |
| 4034 | plug_inputs.append(all_reduce_params.norm_weight.trt_tensor) |
| 4035 | if all_reduce_params.fusion_op == AllReduceFusionOp.RESIDUAL_RMS_PREPOST_NORM: |
| 4036 | plug_inputs.append( |
| 4037 | all_reduce_params.norm_pre_residual_weight.trt_tensor) |
no test coverage detected