MCPcopy
hub / github.com/ddbourgin/numpy-ml / extract_grads

Method extract_grads

numpy_ml/tests/nn_torch_models.py:650–712  ·  view source on GitHub ↗
(self, X_main, X_skip)

Source from the content-addressed store, hash-verified

648 self.Y_main.retain_grad()
649
650 def extract_grads(self, X_main, X_skip):
651 self.forward(X_main, X_skip)
652 self.loss = (self.Y_skip + self.Y_main).sum()
653 self.loss.backward()
654
655 # W (theirs): (n_out, n_in, f[0]) -> W (mine): (f[0], n_in, n_out)
656 # X (theirs): (N, C, W) -> X (mine): (N, W, C)
657 # Y (theirs): (N, C, W) -> Y (mine): (N, W, C)
658 orig, X_swap, W_swap = [0, 1, 2], [0, -1, -2], [-1, -2, -3]
659 grads = {
660 "X_main": np.moveaxis(self.X_main.detach().numpy(), orig, X_swap),
661 "X_skip": np.moveaxis(self.X_skip.detach().numpy(), orig, X_swap),
662 "conv_dilation_W": np.moveaxis(
663 self.conv_dilation.weight.detach().numpy(), orig, W_swap
664 ),
665 "conv_dilation_b": self.conv_dilation.bias.detach()
666 .numpy()
667 .reshape(1, 1, -1),
668 "conv_1x1_W": np.moveaxis(
669 self.conv_1x1.weight.detach().numpy(), orig, W_swap
670 ),
671 "conv_1x1_b": self.conv_1x1.bias.detach().numpy().reshape(1, 1, -1),
672 "conv_dilation_out": np.moveaxis(
673 self.conv_dilation_out.detach().numpy(), orig, X_swap
674 ),
675 "tanh_out": np.moveaxis(self.tanh_out.detach().numpy(), orig, X_swap),
676 "sigm_out": np.moveaxis(self.sigm_out.detach().numpy(), orig, X_swap),
677 "multiply_gate_out": np.moveaxis(
678 self.multiply_gate_out.detach().numpy(), orig, X_swap
679 ),
680 "conv_1x1_out": np.moveaxis(
681 self.conv_1x1_out.detach().numpy(), orig, X_swap
682 ),
683 "Y_main": np.moveaxis(self.Y_main.detach().numpy(), orig, X_swap),
684 "Y_skip": np.moveaxis(self.Y_skip.detach().numpy(), orig, X_swap),
685 "dLdY_skip": np.moveaxis(self.Y_skip.grad.numpy(), orig, X_swap),
686 "dLdY_main": np.moveaxis(self.Y_main.grad.numpy(), orig, X_swap),
687 "dLdConv_1x1_out": np.moveaxis(
688 self.conv_1x1_out.grad.numpy(), orig, X_swap
689 ),
690 "dLdConv_1x1_W": np.moveaxis(
691 self.conv_1x1.weight.grad.numpy(), orig, W_swap
692 ),
693 "dLdConv_1x1_b": self.conv_1x1.bias.grad.numpy().reshape(1, 1, -1),
694 "dLdMultiply_out": np.moveaxis(
695 self.multiply_gate_out.grad.numpy(), orig, X_swap
696 ),
697 "dLdTanh_out": np.moveaxis(self.tanh_out.grad.numpy(), orig, X_swap),
698 "dLdSigm_out": np.moveaxis(self.sigm_out.grad.numpy(), orig, X_swap),
699 "dLdConv_dilation_out": np.moveaxis(
700 self.conv_dilation_out.grad.numpy(), orig, X_swap
701 ),
702 "dLdConv_dilation_W": np.moveaxis(
703 self.conv_dilation.weight.grad.numpy(), orig, W_swap
704 ),
705 "dLdConv_dilation_b": self.conv_dilation.bias.grad.numpy().reshape(
706 1, 1, -1
707 ),

Callers 1

test_WaveNetModuleFunction · 0.95

Calls 2

forwardMethod · 0.95
backwardMethod · 0.45

Tested by 1

test_WaveNetModuleFunction · 0.76