MCPcopy
hub / github.com/ddbourgin/numpy-ml / extract_grads

Method extract_grads

numpy_ml/tests/nn_torch_models.py:849–917  ·  view source on GitHub ↗
(self, X)

Source from the content-addressed store, hash-verified

847 self.Y.retain_grad()
848
849 def extract_grads(self, X):
850 self.forward(X)
851 self.loss = self.Y.sum()
852 self.loss.backward()
853
854 orig, X_swap, W_swap = [0, 1, 2, 3], [0, -1, -3, -2], [-1, -2, -4, -3]
855 grads = {
856 # layer parameters
857 "conv1_W": np.moveaxis(self.conv1.weight.detach().numpy(), orig, W_swap),
858 "conv1_b": self.conv1.bias.detach().numpy().reshape(1, 1, 1, -1),
859 "bn1_intercept": self.batchnorm1.bias.detach().numpy(),
860 "bn1_scaler": self.batchnorm1.weight.detach().numpy(),
861 "bn1_running_mean": self.batchnorm1.running_mean.detach().numpy(),
862 "bn1_running_var": self.batchnorm1.running_var.detach().numpy(),
863 "conv2_W": np.moveaxis(self.conv2.weight.detach().numpy(), orig, W_swap),
864 "conv2_b": self.conv2.bias.detach().numpy().reshape(1, 1, 1, -1),
865 "bn2_intercept": self.batchnorm2.bias.detach().numpy(),
866 "bn2_scaler": self.batchnorm2.weight.detach().numpy(),
867 "bn2_running_mean": self.batchnorm2.running_mean.detach().numpy(),
868 "bn2_running_var": self.batchnorm2.running_var.detach().numpy(),
869 "conv_skip_W": np.moveaxis(
870 self.conv_skip.weight.detach().numpy(), orig, W_swap
871 ),
872 "conv_skip_b": self.conv_skip.bias.detach().numpy().reshape(1, 1, 1, -1),
873 "bn_skip_intercept": self.batchnorm_skip.bias.detach().numpy(),
874 "bn_skip_scaler": self.batchnorm_skip.weight.detach().numpy(),
875 "bn_skip_running_mean": self.batchnorm_skip.running_mean.detach().numpy(),
876 "bn_skip_running_var": self.batchnorm_skip.running_var.detach().numpy(),
877 # layer inputs/outputs (forward step)
878 "X": np.moveaxis(self.X.detach().numpy(), orig, X_swap),
879 "conv1_out": np.moveaxis(self.conv1_out.detach().numpy(), orig, X_swap),
880 "act1_out": np.moveaxis(self.act_fn1_out.detach().numpy(), orig, X_swap),
881 "bn1_out": np.moveaxis(self.batchnorm1_out.detach().numpy(), orig, X_swap),
882 "conv2_out": np.moveaxis(self.conv2_out.detach().numpy(), orig, X_swap),
883 "bn2_out": np.moveaxis(self.batchnorm2_out.detach().numpy(), orig, X_swap),
884 "conv_skip_out": np.moveaxis(
885 self.c_skip_out.detach().numpy(), orig, X_swap
886 ),
887 "bn_skip_out": np.moveaxis(self.bn_skip_out.detach().numpy(), orig, X_swap),
888 "add_out": np.moveaxis(self.layer3_in.detach().numpy(), orig, X_swap),
889 "Y": np.moveaxis(self.Y.detach().numpy(), orig, X_swap),
890 # layer gradients (backward step)
891 "dLdY": np.moveaxis(self.Y.grad.numpy(), orig, X_swap),
892 "dLdAdd": np.moveaxis(self.layer3_in.grad.numpy(), orig, X_swap),
893 "dLdBnSkip_out": np.moveaxis(self.bn_skip_out.grad.numpy(), orig, X_swap),
894 "dLdConvSkip_out": np.moveaxis(self.c_skip_out.grad.numpy(), orig, X_swap),
895 "dLdBn2_out": np.moveaxis(self.batchnorm2_out.grad.numpy(), orig, X_swap),
896 "dLdConv2_out": np.moveaxis(self.conv2_out.grad.numpy(), orig, X_swap),
897 "dLdBn1_out": np.moveaxis(self.batchnorm1_out.grad.numpy(), orig, X_swap),
898 "dLdActFn1_out": np.moveaxis(self.act_fn1_out.grad.numpy(), orig, X_swap),
899 "dLdConv1_out": np.moveaxis(self.act_fn1_out.grad.numpy(), orig, X_swap),
900 "dLdX": np.moveaxis(self.X.grad.numpy(), orig, X_swap),
901 # layer parameter gradients (backward step)
902 "dLdBnSkip_intercept": self.batchnorm_skip.bias.grad.numpy(),
903 "dLdBnSkip_scaler": self.batchnorm_skip.weight.grad.numpy(),
904 "dLdConvSkip_W": np.moveaxis(
905 self.conv_skip.weight.grad.numpy(), orig, W_swap
906 ),

Callers 1

Calls 2

forwardMethod · 0.95
backwardMethod · 0.45

Tested by 1