Add an operation that performs a collective all-gather. Let's define 'group_size' as the length of the 'group' list. That functions creates a layer to gather 'group_size' tensors distributed amongst the 'group_size' participating ranks (one GPU per rank). The list 'group' cont
(tensor: Tensor, group: List[int], gather_dim: int = 0)
| 4142 | |
| 4143 | |
| 4144 | def allgather(tensor: Tensor, group: List[int], gather_dim: int = 0) -> Tensor: |
| 4145 | ''' |
| 4146 | Add an operation that performs a collective all-gather. |
| 4147 | |
| 4148 | Let's define 'group_size' as the length of the 'group' list. That functions |
| 4149 | creates a layer to gather 'group_size' tensors distributed |
| 4150 | amongst the 'group_size' participating ranks (one GPU per rank). |
| 4151 | |
| 4152 | The list 'group' contains the identifiers of the ranks participating into |
| 4153 | the collective operation. |
| 4154 | |
| 4155 | Note that 'group' here can be either TP group or PP group, because allgather communication is not limited to a specific split pattern. Therefore 'group_size' does not need to equal MPI 'world_size'. |
| 4156 | |
| 4157 | The tensors in the different ranks must be 1D tensors (or views) and the |
| 4158 | output tensor will have that same shape. |
| 4159 | |
| 4160 | Given the 'section_size = input.shape[0] / group_size', each rank |
| 4161 | contributes a section of its input tensor that correspond to |
| 4162 | 'rank*section_size:(rank+1)*section_size'. |
| 4163 | |
| 4164 | That operation is implemented using a plugin that wraps the NCCL all-gather |
| 4165 | collective operation. See |
| 4166 | https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather |
| 4167 | for details. |
| 4168 | |
| 4169 | Parameters: |
| 4170 | tensor : Tensor |
| 4171 | The input tensor. |
| 4172 | |
| 4173 | group : List[int] |
| 4174 | The ranks participating into the all-gather operation. |
| 4175 | |
| 4176 | gather_dim: int = 0 |
| 4177 | Gather along given dimension. By default 0, i.e. treated as 1D tensor. |
| 4178 | |
| 4179 | Returns: |
| 4180 | The tensor produced by that layer. |
| 4181 | ''' |
| 4182 | allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator( |
| 4183 | 'AllGather', '1', TRT_LLM_PLUGIN_NAMESPACE) |
| 4184 | assert allgather_plg_creator is not None |
| 4185 | |
| 4186 | group_size = len(group) |
| 4187 | group = trt.PluginField("group", np.array(group, dtype=np.int32), |
| 4188 | trt.PluginFieldType.INT32) |
| 4189 | |
| 4190 | p_dtype = default_net().plugin_config.nccl_plugin |
| 4191 | pf_type = trt.PluginField( |
| 4192 | "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32), |
| 4193 | trt.PluginFieldType.INT32) |
| 4194 | |
| 4195 | pfc = trt.PluginFieldCollection([group, pf_type]) |
| 4196 | allgather = allgather_plg_creator.create_plugin("allgather", pfc) |
| 4197 | plug_inputs = [tensor.cast(p_dtype).trt_tensor] |
| 4198 | |
| 4199 | layer = default_trtnet().add_plugin_v2(plug_inputs, allgather) |
| 4200 | _add_plugin_info(layer, allgather_plg_creator, "allgather", pfc) |
| 4201 |