| 1129 | return self.Y |
| 1130 | |
| 1131 | def extract_grads(self, X): |
| 1132 | self.forward(X) |
| 1133 | self.loss = self.Y.sum() |
| 1134 | self.loss.backward() |
| 1135 | |
| 1136 | # W (theirs): (n_out, n_in, f[0], f[1]) -> W (mine): (f[0], f[1], n_in, n_out) |
| 1137 | # X (theirs): (N, C, H, W) -> X (mine): (N, H, W, C) |
| 1138 | # Y (theirs): (N, C, H, W) -> Y (mine): (N, H, W, C) |
| 1139 | orig, X_swap = [0, 1, 2, 3], [0, -1, -3, -2] |
| 1140 | grads = { |
| 1141 | "X": np.moveaxis(self.X.detach().numpy(), orig, X_swap), |
| 1142 | "y": np.moveaxis(self.Y.detach().numpy(), orig, X_swap), |
| 1143 | "dLdY": np.moveaxis(self.Y.grad.numpy(), orig, X_swap), |
| 1144 | "dLdX": np.moveaxis(self.X.grad.numpy(), orig, X_swap), |
| 1145 | } |
| 1146 | return grads |
| 1147 | |
| 1148 | |
| 1149 | class TorchConv2DLayer(nn.Module): |