Balance the input tensor between groups based on token_type_ids
(
ctx,
tensor: paddle.Tensor,
token_type_ids: paddle.Tensor,
group=None,
is_tensor_sharded=True,
axis=0,
is_token_type_ids_sharded=False,
unique_tokens_type=None,
)
| 143 | |
| 144 | @staticmethod |
| 145 | def forward( |
| 146 | ctx, |
| 147 | tensor: paddle.Tensor, |
| 148 | token_type_ids: paddle.Tensor, |
| 149 | group=None, |
| 150 | is_tensor_sharded=True, |
| 151 | axis=0, |
| 152 | is_token_type_ids_sharded=False, |
| 153 | unique_tokens_type=None, |
| 154 | ): |
| 155 | """Balance the input tensor between groups based on token_type_ids""" |
| 156 | ctx.is_tensor_sharded = is_tensor_sharded |
| 157 | ctx.axis = axis |
| 158 | ctx.tensor_shape = tensor.shape |
| 159 | ctx.tensor_dtype = tensor.dtype |
| 160 | ctx.token_type_ids_shape = token_type_ids.shape |
| 161 | ctx.token_type_ids_dtype = token_type_ids.dtype |
| 162 | ctx.group = fleet.get_hybrid_communicate_group().get_model_parallel_group() if group is None else group |
| 163 | ctx.rank = ctx.group.rank |
| 164 | ctx.world_size = ctx.group.nranks |
| 165 | |
| 166 | if len(ctx.tensor_shape) == 1: |
| 167 | if not ctx.is_tensor_sharded: |
| 168 | tensor = tensor.split(num_or_sections=ctx.world_size, axis=0)[ctx.rank] |
| 169 | tensor = tensor.reshape([-1, 1]) |
| 170 | else: |
| 171 | if ( |
| 172 | len(ctx.tensor_shape) == 2 |
| 173 | and not ctx.is_tensor_sharded |
| 174 | and (axis == -1 or axis == len(ctx.tensor_shape) - 1) |
| 175 | ): |
| 176 | raise ValueError( |
| 177 | "Do not support len(ctx.tensor_shape) == 2 and not ctx.is_tensor_sharded" |
| 178 | + " and (axis == -1 or axis == len(ctx.tensor_shape) -1)" |
| 179 | ) |
| 180 | assert len(ctx.tensor_shape) <= 3, f"len(tensor.shape) must <= 3, but got {len(tensor.shape)}" |
| 181 | if len(ctx.tensor_shape) == 3: |
| 182 | assert ctx.tensor_shape[0] == 1, "only support tensor.shape[0] == 1" |
| 183 | |
| 184 | if not ctx.is_tensor_sharded: |
| 185 | tensor = tensor.split(num_or_sections=ctx.world_size, axis=ctx.axis)[ctx.rank] |
| 186 | tensor = tensor.reshape([-1, tensor.shape[-1]]) |
| 187 | |
| 188 | ctx.tensor_grad_shape = tensor.shape |
| 189 | |
| 190 | if is_token_type_ids_sharded: |
| 191 | assert ( |
| 192 | unique_tokens_type is not None and len(unique_tokens_type) > 0 |
| 193 | ), "require len(unique_tokens_type) > 0 when is_token_type_ids_sharded=True" |
| 194 | ctx.unique_tokens_type = unique_tokens_type |
| 195 | token_type_ids_per_rank = token_type_ids.flatten() |
| 196 | else: |
| 197 | ctx.unique_tokens_type = token_type_ids.unique().tolist() |
| 198 | token_type_ids_per_rank = token_type_ids.flatten().split(ctx.world_size) |
| 199 | assert tensor.shape[0] == token_type_ids_per_rank[ctx.rank].shape[0], ( |
| 200 | f"tensor.shape[0]:{tensor.shape[0]} != " |
| 201 | + f"token_type_ids_per_rank[ctx.rank].shape[0]:{token_type_ids_per_rank[ctx.rank].shape[0]}" |
| 202 | ) |
nothing calls this directly
no test coverage detected