| 136 | |
| 137 | |
| 138 | class TorchBatchNormLayer(nn.Module): |
| 139 | def __init__(self, n_in, params, mode, momentum=0.9, epsilon=1e-5): |
| 140 | super(TorchBatchNormLayer, self).__init__() |
| 141 | |
| 142 | scaler = params["scaler"] |
| 143 | intercept = params["intercept"] |
| 144 | |
| 145 | if mode == "1D": |
| 146 | self.layer1 = nn.BatchNorm1d( |
| 147 | num_features=n_in, momentum=1 - momentum, eps=epsilon, affine=True |
| 148 | ) |
| 149 | elif mode == "2D": |
| 150 | self.layer1 = nn.BatchNorm2d( |
| 151 | num_features=n_in, momentum=1 - momentum, eps=epsilon, affine=True |
| 152 | ) |
| 153 | |
| 154 | self.layer1.weight = nn.Parameter(torch.FloatTensor(scaler)) |
| 155 | self.layer1.bias = nn.Parameter(torch.FloatTensor(intercept)) |
| 156 | |
| 157 | def forward(self, X): |
| 158 | # (N, H, W, C) -> (N, C, H, W) |
| 159 | if X.ndim == 4: |
| 160 | X = np.moveaxis(X, [0, 1, 2, 3], [0, -2, -1, -3]) |
| 161 | |
| 162 | if not isinstance(X, torch.Tensor): |
| 163 | X = torchify(X) |
| 164 | |
| 165 | self.X = X |
| 166 | self.Y = self.layer1(self.X) |
| 167 | self.Y.retain_grad() |
| 168 | |
| 169 | def extract_grads(self, X, Y_true=None): |
| 170 | self.forward(X) |
| 171 | |
| 172 | if isinstance(Y_true, np.ndarray): |
| 173 | Y_true = np.moveaxis(Y_true, [0, 1, 2, 3], [0, -2, -1, -3]) |
| 174 | self.loss1 = ( |
| 175 | 0.5 * F.mse_loss(self.Y, torchify(Y_true), size_average=False).sum() |
| 176 | ) |
| 177 | else: |
| 178 | self.loss1 = self.Y.sum() |
| 179 | |
| 180 | self.loss1.backward() |
| 181 | |
| 182 | X_np = self.X.detach().numpy() |
| 183 | Y_np = self.Y.detach().numpy() |
| 184 | dX_np = self.X.grad.numpy() |
| 185 | dY_np = self.Y.grad.numpy() |
| 186 | |
| 187 | if self.X.dim() == 4: |
| 188 | orig, X_swap = [0, 1, 2, 3], [0, -1, -3, -2] |
| 189 | if isinstance(Y_true, np.ndarray): |
| 190 | Y_true = np.moveaxis(Y_true, orig, X_swap) |
| 191 | X_np = np.moveaxis(X_np, orig, X_swap) |
| 192 | Y_np = np.moveaxis(Y_np, orig, X_swap) |
| 193 | dX_np = np.moveaxis(dX_np, orig, X_swap) |
| 194 | dY_np = np.moveaxis(dY_np, orig, X_swap) |
| 195 |
no outgoing calls