| 126 | |
| 127 | |
| 128 | class LayerNorm(nn.Module): |
| 129 | def __init__(self, channels, eps=1e-4): |
| 130 | super().__init__() |
| 131 | self.channels = channels |
| 132 | self.eps = eps |
| 133 | |
| 134 | self.gamma = nn.Parameter(torch.ones(channels)) |
| 135 | self.beta = nn.Parameter(torch.zeros(channels)) |
| 136 | |
| 137 | def forward(self, x): |
| 138 | n_dims = len(x.shape) |
| 139 | mean = torch.mean(x, 1, keepdim=True) |
| 140 | variance = torch.mean((x - mean) ** 2, 1, keepdim=True) |
| 141 | |
| 142 | x = (x - mean) * torch.rsqrt(variance + self.eps) |
| 143 | |
| 144 | shape = [1, -1] + [1] * (n_dims - 2) |
| 145 | x = x * self.gamma.view(*shape) + self.beta.view(*shape) |
| 146 | return x |
| 147 | |
| 148 | |
| 149 | class ConvReluNorm(nn.Module): |