| 847 | self.Y.retain_grad() |
| 848 | |
| 849 | def extract_grads(self, X): |
| 850 | self.forward(X) |
| 851 | self.loss = self.Y.sum() |
| 852 | self.loss.backward() |
| 853 | |
| 854 | orig, X_swap, W_swap = [0, 1, 2, 3], [0, -1, -3, -2], [-1, -2, -4, -3] |
| 855 | grads = { |
| 856 | # layer parameters |
| 857 | "conv1_W": np.moveaxis(self.conv1.weight.detach().numpy(), orig, W_swap), |
| 858 | "conv1_b": self.conv1.bias.detach().numpy().reshape(1, 1, 1, -1), |
| 859 | "bn1_intercept": self.batchnorm1.bias.detach().numpy(), |
| 860 | "bn1_scaler": self.batchnorm1.weight.detach().numpy(), |
| 861 | "bn1_running_mean": self.batchnorm1.running_mean.detach().numpy(), |
| 862 | "bn1_running_var": self.batchnorm1.running_var.detach().numpy(), |
| 863 | "conv2_W": np.moveaxis(self.conv2.weight.detach().numpy(), orig, W_swap), |
| 864 | "conv2_b": self.conv2.bias.detach().numpy().reshape(1, 1, 1, -1), |
| 865 | "bn2_intercept": self.batchnorm2.bias.detach().numpy(), |
| 866 | "bn2_scaler": self.batchnorm2.weight.detach().numpy(), |
| 867 | "bn2_running_mean": self.batchnorm2.running_mean.detach().numpy(), |
| 868 | "bn2_running_var": self.batchnorm2.running_var.detach().numpy(), |
| 869 | "conv_skip_W": np.moveaxis( |
| 870 | self.conv_skip.weight.detach().numpy(), orig, W_swap |
| 871 | ), |
| 872 | "conv_skip_b": self.conv_skip.bias.detach().numpy().reshape(1, 1, 1, -1), |
| 873 | "bn_skip_intercept": self.batchnorm_skip.bias.detach().numpy(), |
| 874 | "bn_skip_scaler": self.batchnorm_skip.weight.detach().numpy(), |
| 875 | "bn_skip_running_mean": self.batchnorm_skip.running_mean.detach().numpy(), |
| 876 | "bn_skip_running_var": self.batchnorm_skip.running_var.detach().numpy(), |
| 877 | # layer inputs/outputs (forward step) |
| 878 | "X": np.moveaxis(self.X.detach().numpy(), orig, X_swap), |
| 879 | "conv1_out": np.moveaxis(self.conv1_out.detach().numpy(), orig, X_swap), |
| 880 | "act1_out": np.moveaxis(self.act_fn1_out.detach().numpy(), orig, X_swap), |
| 881 | "bn1_out": np.moveaxis(self.batchnorm1_out.detach().numpy(), orig, X_swap), |
| 882 | "conv2_out": np.moveaxis(self.conv2_out.detach().numpy(), orig, X_swap), |
| 883 | "bn2_out": np.moveaxis(self.batchnorm2_out.detach().numpy(), orig, X_swap), |
| 884 | "conv_skip_out": np.moveaxis( |
| 885 | self.c_skip_out.detach().numpy(), orig, X_swap |
| 886 | ), |
| 887 | "bn_skip_out": np.moveaxis(self.bn_skip_out.detach().numpy(), orig, X_swap), |
| 888 | "add_out": np.moveaxis(self.layer3_in.detach().numpy(), orig, X_swap), |
| 889 | "Y": np.moveaxis(self.Y.detach().numpy(), orig, X_swap), |
| 890 | # layer gradients (backward step) |
| 891 | "dLdY": np.moveaxis(self.Y.grad.numpy(), orig, X_swap), |
| 892 | "dLdAdd": np.moveaxis(self.layer3_in.grad.numpy(), orig, X_swap), |
| 893 | "dLdBnSkip_out": np.moveaxis(self.bn_skip_out.grad.numpy(), orig, X_swap), |
| 894 | "dLdConvSkip_out": np.moveaxis(self.c_skip_out.grad.numpy(), orig, X_swap), |
| 895 | "dLdBn2_out": np.moveaxis(self.batchnorm2_out.grad.numpy(), orig, X_swap), |
| 896 | "dLdConv2_out": np.moveaxis(self.conv2_out.grad.numpy(), orig, X_swap), |
| 897 | "dLdBn1_out": np.moveaxis(self.batchnorm1_out.grad.numpy(), orig, X_swap), |
| 898 | "dLdActFn1_out": np.moveaxis(self.act_fn1_out.grad.numpy(), orig, X_swap), |
| 899 | "dLdConv1_out": np.moveaxis(self.act_fn1_out.grad.numpy(), orig, X_swap), |
| 900 | "dLdX": np.moveaxis(self.X.grad.numpy(), orig, X_swap), |
| 901 | # layer parameter gradients (backward step) |
| 902 | "dLdBnSkip_intercept": self.batchnorm_skip.bias.grad.numpy(), |
| 903 | "dLdBnSkip_scaler": self.batchnorm_skip.weight.grad.numpy(), |
| 904 | "dLdConvSkip_W": np.moveaxis( |
| 905 | self.conv_skip.weight.grad.numpy(), orig, W_swap |
| 906 | ), |