MCPcopy
hub / github.com/PaddlePaddle/ERNIE / forward

Method forward

ernie/longcontext_ops.py:145–303  ·  view source on GitHub ↗

Balance the input tensor between groups based on token_type_ids

(
        ctx,
        tensor: paddle.Tensor,
        token_type_ids: paddle.Tensor,
        group=None,
        is_tensor_sharded=True,
        axis=0,
        is_token_type_ids_sharded=False,
        unique_tokens_type=None,
    )

Source from the content-addressed store, hash-verified

143
144 @staticmethod
145 def forward(
146 ctx,
147 tensor: paddle.Tensor,
148 token_type_ids: paddle.Tensor,
149 group=None,
150 is_tensor_sharded=True,
151 axis=0,
152 is_token_type_ids_sharded=False,
153 unique_tokens_type=None,
154 ):
155 """Balance the input tensor between groups based on token_type_ids"""
156 ctx.is_tensor_sharded = is_tensor_sharded
157 ctx.axis = axis
158 ctx.tensor_shape = tensor.shape
159 ctx.tensor_dtype = tensor.dtype
160 ctx.token_type_ids_shape = token_type_ids.shape
161 ctx.token_type_ids_dtype = token_type_ids.dtype
162 ctx.group = fleet.get_hybrid_communicate_group().get_model_parallel_group() if group is None else group
163 ctx.rank = ctx.group.rank
164 ctx.world_size = ctx.group.nranks
165
166 if len(ctx.tensor_shape) == 1:
167 if not ctx.is_tensor_sharded:
168 tensor = tensor.split(num_or_sections=ctx.world_size, axis=0)[ctx.rank]
169 tensor = tensor.reshape([-1, 1])
170 else:
171 if (
172 len(ctx.tensor_shape) == 2
173 and not ctx.is_tensor_sharded
174 and (axis == -1 or axis == len(ctx.tensor_shape) - 1)
175 ):
176 raise ValueError(
177 "Do not support len(ctx.tensor_shape) == 2 and not ctx.is_tensor_sharded"
178 + " and (axis == -1 or axis == len(ctx.tensor_shape) -1)"
179 )
180 assert len(ctx.tensor_shape) <= 3, f"len(tensor.shape) must <= 3, but got {len(tensor.shape)}"
181 if len(ctx.tensor_shape) == 3:
182 assert ctx.tensor_shape[0] == 1, "only support tensor.shape[0] == 1"
183
184 if not ctx.is_tensor_sharded:
185 tensor = tensor.split(num_or_sections=ctx.world_size, axis=ctx.axis)[ctx.rank]
186 tensor = tensor.reshape([-1, tensor.shape[-1]])
187
188 ctx.tensor_grad_shape = tensor.shape
189
190 if is_token_type_ids_sharded:
191 assert (
192 unique_tokens_type is not None and len(unique_tokens_type) > 0
193 ), "require len(unique_tokens_type) > 0 when is_token_type_ids_sharded=True"
194 ctx.unique_tokens_type = unique_tokens_type
195 token_type_ids_per_rank = token_type_ids.flatten()
196 else:
197 ctx.unique_tokens_type = token_type_ids.unique().tolist()
198 token_type_ids_per_rank = token_type_ids.flatten().split(ctx.world_size)
199 assert tensor.shape[0] == token_type_ids_per_rank[ctx.rank].shape[0], (
200 f"tensor.shape[0]:{tensor.shape[0]} != "
201 + f"token_type_ids_per_rank[ctx.rank].shape[0]:{token_type_ids_per_rank[ctx.rank].shape[0]}"
202 )

Callers

nothing calls this directly

Calls 2

redistribute_tokensFunction · 0.85
flattenMethod · 0.80

Tested by

no test coverage detected