(self, dim, channel_first=True, images=True, bias=False)
| 39 | class RMS_norm(nn.Module): |
| 40 | |
| 41 | def __init__(self, dim, channel_first=True, images=True, bias=False): |
| 42 | super().__init__() |
| 43 | broadcastable_dims = (1, 1, 1) if not images else (1, 1) |
| 44 | shape = (dim, *broadcastable_dims) if channel_first else (dim,) |
| 45 | |
| 46 | self.channel_first = channel_first |
| 47 | self.scale = dim**0.5 |
| 48 | self.gamma = nn.Parameter(torch.ones(shape)) |
| 49 | self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. |
| 50 | |
| 51 | def forward(self, x): |
| 52 | return F.normalize( |