Forward pass through the RMSNorm layer. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor after applying RMSNorm.
(self, x)
| 62 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| 63 | |
| 64 | def forward(self, x): |
| 65 | """ |
| 66 | Forward pass through the RMSNorm layer. |
| 67 | |
| 68 | Args: |
| 69 | x (torch.Tensor): The input tensor. |
| 70 | |
| 71 | Returns: |
| 72 | torch.Tensor: The output tensor after applying RMSNorm. |
| 73 | |
| 74 | """ |
| 75 | output = self._norm(x.float()).type_as(x) |
| 76 | return output * self.weight |
| 77 | |
| 78 | |
| 79 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): |