Split the input tensor across tensor parallel dimension. Args: x (paddle.Tensor): Input tensor to be split. Returns: paddle.Tensor: Splitted tensor.
(self, x)
| 160 | self.weight.set_value(weight_tensor.astype(self._norm_weight_dtype)) |
| 161 | |
| 162 | def split(self, x): |
| 163 | """ |
| 164 | Split the input tensor across tensor parallel dimension. |
| 165 | |
| 166 | Args: |
| 167 | x (paddle.Tensor): Input tensor to be split. |
| 168 | |
| 169 | Returns: |
| 170 | paddle.Tensor: Splitted tensor. |
| 171 | """ |
| 172 | token_num = x.shape[0] |
| 173 | token_num_per_rank = (token_num + self.tp_size - 1) // self.tp_size |
| 174 | # AllGather will hang when the data shapes on multi-ranks are different! |
| 175 | start_offset = self.tp_rank * token_num_per_rank |
| 176 | end_offset = (self.tp_rank + 1) * token_num_per_rank |
| 177 | if start_offset >= token_num: |
| 178 | start_offset = token_num |
| 179 | if end_offset > token_num: |
| 180 | end_offset = token_num |
| 181 | part_x = paddle.zeros(shape=[token_num_per_rank, x.shape[1]], dtype=x.dtype) |
| 182 | part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :] |
| 183 | return part_x |
| 184 | |
| 185 | def allgather(self, out, token_num): |
| 186 | """ |
no outgoing calls