Forward pass for row parallel linear layer. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Transformed tensor with row-parallel computation.
(self, x: torch.Tensor)
| 247 | super().__init__(self.part_in_features, out_features, bias, dtype) |
| 248 | |
| 249 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 250 | """ |
| 251 | Forward pass for row parallel linear layer. |
| 252 | |
| 253 | Args: |
| 254 | x (torch.Tensor): Input tensor. |
| 255 | |
| 256 | Returns: |
| 257 | torch.Tensor: Transformed tensor with row-parallel computation. |
| 258 | """ |
| 259 | y = linear(x, self.weight) |
| 260 | if world_size > 1: |
| 261 | dist.all_reduce(y) |
| 262 | if self.bias is not None: |
| 263 | y += self.bias |
| 264 | return y |
| 265 | |
| 266 | |
| 267 | class RMSNorm(nn.Module): |