| 214 | |
| 215 | |
| 216 | class TorchLayerNormLayer(nn.Module): |
| 217 | def __init__(self, feat_dims, params, mode, epsilon=1e-5): |
| 218 | super(TorchLayerNormLayer, self).__init__() |
| 219 | |
| 220 | self.layer1 = nn.LayerNorm( |
| 221 | normalized_shape=feat_dims, eps=epsilon, elementwise_affine=True |
| 222 | ) |
| 223 | |
| 224 | scaler = params["scaler"] |
| 225 | intercept = params["intercept"] |
| 226 | |
| 227 | if mode == "2D": |
| 228 | scaler = np.moveaxis(scaler, [0, 1, 2], [-2, -1, -3]) |
| 229 | intercept = np.moveaxis(intercept, [0, 1, 2], [-2, -1, -3]) |
| 230 | |
| 231 | assert scaler.shape == self.layer1.weight.shape |
| 232 | assert intercept.shape == self.layer1.bias.shape |
| 233 | self.layer1.weight = nn.Parameter(torch.FloatTensor(scaler)) |
| 234 | self.layer1.bias = nn.Parameter(torch.FloatTensor(intercept)) |
| 235 | |
| 236 | def forward(self, X): |
| 237 | # (N, H, W, C) -> (N, C, H, W) |
| 238 | if X.ndim == 4: |
| 239 | X = np.moveaxis(X, [0, 1, 2, 3], [0, -2, -1, -3]) |
| 240 | |
| 241 | if not isinstance(X, torch.Tensor): |
| 242 | X = torchify(X) |
| 243 | |
| 244 | self.X = X |
| 245 | self.Y = self.layer1(self.X) |
| 246 | self.Y.retain_grad() |
| 247 | |
| 248 | def extract_grads(self, X, Y_true=None): |
| 249 | self.forward(X) |
| 250 | |
| 251 | if isinstance(Y_true, np.ndarray): |
| 252 | Y_true = np.moveaxis(Y_true, [0, 1, 2, 3], [0, -2, -1, -3]) |
| 253 | self.loss1 = ( |
| 254 | 0.5 * F.mse_loss(self.Y, torchify(Y_true), size_average=False).sum() |
| 255 | ) |
| 256 | else: |
| 257 | self.loss1 = self.Y.sum() |
| 258 | |
| 259 | self.loss1.backward() |
| 260 | |
| 261 | X_np = self.X.detach().numpy() |
| 262 | Y_np = self.Y.detach().numpy() |
| 263 | dX_np = self.X.grad.numpy() |
| 264 | dY_np = self.Y.grad.numpy() |
| 265 | intercept_np = self.layer1.bias.detach().numpy() |
| 266 | scaler_np = self.layer1.weight.detach().numpy() |
| 267 | dIntercept_np = self.layer1.bias.grad.numpy() |
| 268 | dScaler_np = self.layer1.weight.grad.numpy() |
| 269 | |
| 270 | if self.X.dim() == 4: |
| 271 | orig, X_swap = [0, 1, 2, 3], [0, -1, -3, -2] |
| 272 | orig_p, p_swap = [0, 1, 2], [-1, -3, -2] |
| 273 | if isinstance(Y_true, np.ndarray): |
no outgoing calls