| 648 | self.Y_main.retain_grad() |
| 649 | |
| 650 | def extract_grads(self, X_main, X_skip): |
| 651 | self.forward(X_main, X_skip) |
| 652 | self.loss = (self.Y_skip + self.Y_main).sum() |
| 653 | self.loss.backward() |
| 654 | |
| 655 | # W (theirs): (n_out, n_in, f[0]) -> W (mine): (f[0], n_in, n_out) |
| 656 | # X (theirs): (N, C, W) -> X (mine): (N, W, C) |
| 657 | # Y (theirs): (N, C, W) -> Y (mine): (N, W, C) |
| 658 | orig, X_swap, W_swap = [0, 1, 2], [0, -1, -2], [-1, -2, -3] |
| 659 | grads = { |
| 660 | "X_main": np.moveaxis(self.X_main.detach().numpy(), orig, X_swap), |
| 661 | "X_skip": np.moveaxis(self.X_skip.detach().numpy(), orig, X_swap), |
| 662 | "conv_dilation_W": np.moveaxis( |
| 663 | self.conv_dilation.weight.detach().numpy(), orig, W_swap |
| 664 | ), |
| 665 | "conv_dilation_b": self.conv_dilation.bias.detach() |
| 666 | .numpy() |
| 667 | .reshape(1, 1, -1), |
| 668 | "conv_1x1_W": np.moveaxis( |
| 669 | self.conv_1x1.weight.detach().numpy(), orig, W_swap |
| 670 | ), |
| 671 | "conv_1x1_b": self.conv_1x1.bias.detach().numpy().reshape(1, 1, -1), |
| 672 | "conv_dilation_out": np.moveaxis( |
| 673 | self.conv_dilation_out.detach().numpy(), orig, X_swap |
| 674 | ), |
| 675 | "tanh_out": np.moveaxis(self.tanh_out.detach().numpy(), orig, X_swap), |
| 676 | "sigm_out": np.moveaxis(self.sigm_out.detach().numpy(), orig, X_swap), |
| 677 | "multiply_gate_out": np.moveaxis( |
| 678 | self.multiply_gate_out.detach().numpy(), orig, X_swap |
| 679 | ), |
| 680 | "conv_1x1_out": np.moveaxis( |
| 681 | self.conv_1x1_out.detach().numpy(), orig, X_swap |
| 682 | ), |
| 683 | "Y_main": np.moveaxis(self.Y_main.detach().numpy(), orig, X_swap), |
| 684 | "Y_skip": np.moveaxis(self.Y_skip.detach().numpy(), orig, X_swap), |
| 685 | "dLdY_skip": np.moveaxis(self.Y_skip.grad.numpy(), orig, X_swap), |
| 686 | "dLdY_main": np.moveaxis(self.Y_main.grad.numpy(), orig, X_swap), |
| 687 | "dLdConv_1x1_out": np.moveaxis( |
| 688 | self.conv_1x1_out.grad.numpy(), orig, X_swap |
| 689 | ), |
| 690 | "dLdConv_1x1_W": np.moveaxis( |
| 691 | self.conv_1x1.weight.grad.numpy(), orig, W_swap |
| 692 | ), |
| 693 | "dLdConv_1x1_b": self.conv_1x1.bias.grad.numpy().reshape(1, 1, -1), |
| 694 | "dLdMultiply_out": np.moveaxis( |
| 695 | self.multiply_gate_out.grad.numpy(), orig, X_swap |
| 696 | ), |
| 697 | "dLdTanh_out": np.moveaxis(self.tanh_out.grad.numpy(), orig, X_swap), |
| 698 | "dLdSigm_out": np.moveaxis(self.sigm_out.grad.numpy(), orig, X_swap), |
| 699 | "dLdConv_dilation_out": np.moveaxis( |
| 700 | self.conv_dilation_out.grad.numpy(), orig, X_swap |
| 701 | ), |
| 702 | "dLdConv_dilation_W": np.moveaxis( |
| 703 | self.conv_dilation.weight.grad.numpy(), orig, W_swap |
| 704 | ), |
| 705 | "dLdConv_dilation_b": self.conv_dilation.bias.grad.numpy().reshape( |
| 706 | 1, 1, -1 |
| 707 | ), |