| 246 | self.Y.retain_grad() |
| 247 | |
| 248 | def extract_grads(self, X, Y_true=None): |
| 249 | self.forward(X) |
| 250 | |
| 251 | if isinstance(Y_true, np.ndarray): |
| 252 | Y_true = np.moveaxis(Y_true, [0, 1, 2, 3], [0, -2, -1, -3]) |
| 253 | self.loss1 = ( |
| 254 | 0.5 * F.mse_loss(self.Y, torchify(Y_true), size_average=False).sum() |
| 255 | ) |
| 256 | else: |
| 257 | self.loss1 = self.Y.sum() |
| 258 | |
| 259 | self.loss1.backward() |
| 260 | |
| 261 | X_np = self.X.detach().numpy() |
| 262 | Y_np = self.Y.detach().numpy() |
| 263 | dX_np = self.X.grad.numpy() |
| 264 | dY_np = self.Y.grad.numpy() |
| 265 | intercept_np = self.layer1.bias.detach().numpy() |
| 266 | scaler_np = self.layer1.weight.detach().numpy() |
| 267 | dIntercept_np = self.layer1.bias.grad.numpy() |
| 268 | dScaler_np = self.layer1.weight.grad.numpy() |
| 269 | |
| 270 | if self.X.dim() == 4: |
| 271 | orig, X_swap = [0, 1, 2, 3], [0, -1, -3, -2] |
| 272 | orig_p, p_swap = [0, 1, 2], [-1, -3, -2] |
| 273 | if isinstance(Y_true, np.ndarray): |
| 274 | Y_true = np.moveaxis(Y_true, orig, X_swap) |
| 275 | X_np = np.moveaxis(X_np, orig, X_swap) |
| 276 | Y_np = np.moveaxis(Y_np, orig, X_swap) |
| 277 | dX_np = np.moveaxis(dX_np, orig, X_swap) |
| 278 | dY_np = np.moveaxis(dY_np, orig, X_swap) |
| 279 | scaler_np = np.moveaxis(scaler_np, orig_p, p_swap) |
| 280 | intercept_np = np.moveaxis(intercept_np, orig_p, p_swap) |
| 281 | dScaler_np = np.moveaxis(dScaler_np, orig_p, p_swap) |
| 282 | dIntercept_np = np.moveaxis(dIntercept_np, orig_p, p_swap) |
| 283 | |
| 284 | grads = { |
| 285 | "loss": self.loss1.detach().numpy(), |
| 286 | "X": X_np, |
| 287 | "epsilon": self.layer1.eps, |
| 288 | "intercept": intercept_np, |
| 289 | "scaler": scaler_np, |
| 290 | "y": Y_np, |
| 291 | "dLdy": dY_np, |
| 292 | "dLdIntercept": dIntercept_np, |
| 293 | "dLdScaler": dScaler_np, |
| 294 | "dLdX": dX_np, |
| 295 | } |
| 296 | if isinstance(Y_true, np.ndarray): |
| 297 | grads["Y_true"] = Y_true |
| 298 | return grads |
| 299 | |
| 300 | |
| 301 | class TorchAddLayer(nn.Module): |