| 215 | |
| 216 | class TorchLayerNormLayer(nn.Module): |
| 217 | def __init__(self, feat_dims, params, mode, epsilon=1e-5): |
| 218 | super(TorchLayerNormLayer, self).__init__() |
| 219 | |
| 220 | self.layer1 = nn.LayerNorm( |
| 221 | normalized_shape=feat_dims, eps=epsilon, elementwise_affine=True |
| 222 | ) |
| 223 | |
| 224 | scaler = params["scaler"] |
| 225 | intercept = params["intercept"] |
| 226 | |
| 227 | if mode == "2D": |
| 228 | scaler = np.moveaxis(scaler, [0, 1, 2], [-2, -1, -3]) |
| 229 | intercept = np.moveaxis(intercept, [0, 1, 2], [-2, -1, -3]) |
| 230 | |
| 231 | assert scaler.shape == self.layer1.weight.shape |
| 232 | assert intercept.shape == self.layer1.bias.shape |
| 233 | self.layer1.weight = nn.Parameter(torch.FloatTensor(scaler)) |
| 234 | self.layer1.bias = nn.Parameter(torch.FloatTensor(intercept)) |
| 235 | |
| 236 | def forward(self, X): |
| 237 | # (N, H, W, C) -> (N, C, H, W) |