Defines the forward computation of the layer. Args: x (paddle.Tensor): Input tensor to be normalized. residual_input (paddle.Tensor, optional): Residual input tensor for residual connection. Defaults to None. If provided, the normalization la
(
self,
x,
residual_input: Optional[paddle.Tensor] = None,
forward_meta: Optional[ForwardMeta] = None,
proxy_rmsnorm: Optional[Callable] = None,
)
| 200 | return multi_outs[:token_num, :] |
| 201 | |
| 202 | def forward( |
| 203 | self, |
| 204 | x, |
| 205 | residual_input: Optional[paddle.Tensor] = None, |
| 206 | forward_meta: Optional[ForwardMeta] = None, |
| 207 | proxy_rmsnorm: Optional[Callable] = None, |
| 208 | ) -> paddle.Tensor: |
| 209 | """ |
| 210 | Defines the forward computation of the layer. |
| 211 | |
| 212 | Args: |
| 213 | x (paddle.Tensor): Input tensor to be normalized. |
| 214 | residual_input (paddle.Tensor, optional): Residual input tensor for residual connection. |
| 215 | Defaults to None. If provided, the normalization layer will also return the residual |
| 216 | output for further computation. |
| 217 | |
| 218 | Returns: |
| 219 | paddle.Tensor or tuple of paddle.Tensor: |
| 220 | - If `residual_input` is None, returns the normalized output tensor. |
| 221 | - If `residual_input` is provided, returns a tuple of (normalized_output, residual_output). |
| 222 | The `residual_output` is the result of applying the normalization and possibly other |
| 223 | operations (like linear transformation) on the `residual_input`. |
| 224 | """ |
| 225 | x_dtype = x.dtype |
| 226 | x = x.astype(self.weight.dtype) |
| 227 | if residual_input is not None: |
| 228 | residual_input_dtype = residual_input.dtype |
| 229 | residual_input = residual_input.astype(self.weight.dtype) |
| 230 | |
| 231 | if residual_input is None: |
| 232 | residual_out = x |
| 233 | if proxy_rmsnorm is None: |
| 234 | if current_platform.is_gcu(): |
| 235 | if residual_input is None: |
| 236 | norm_out = rms_norm(x, self.weight, self.eps) |
| 237 | return norm_out.astype(x_dtype), residual_out |
| 238 | norm_out = self.norm_func(x, residual_input, self.weight, self.eps) |
| 239 | else: |
| 240 | norm_out = self.norm_func( |
| 241 | x, |
| 242 | norm_weight=self.weight, |
| 243 | norm_bias=None, |
| 244 | epsilon=self.eps, |
| 245 | begin_norm_axis=self.begin_norm_axis, |
| 246 | bias=self.bias, |
| 247 | residual=residual_input, |
| 248 | quant_scale=(-1 if self.quant_scale is None else self.quant_scale), |
| 249 | quant_round_type=self.quant_round_type, |
| 250 | quant_max_bound=self.quant_max_bound, |
| 251 | quant_min_bound=self.quant_min_bound, |
| 252 | ) |
| 253 | else: |
| 254 | if residual_input is not None: |
| 255 | x = x + residual_input |
| 256 | norm_out = proxy_rmsnorm(x, self.weight, self.eps), x |
| 257 | |
| 258 | out = norm_out[0].astype(x_dtype) |
| 259 | if residual_input is not None: |