MCPcopy
hub / github.com/ddbourgin/numpy-ml / forward

Method forward

numpy_ml/tests/nn_torch_models.py:813–847  ·  view source on GitHub ↗
(self, X)

Source from the content-addressed store, hash-verified

811 self.batchnorm_skip.bias = nn.Parameter(torch.FloatTensor(intercept))
812
813 def forward(self, X):
814 if not isinstance(X, torch.Tensor):
815 # (N, H, W, C) -> (N, C, H, W)
816 X = np.moveaxis(X, [0, 1, 2, 3], [0, -2, -1, -3])
817 X = torchify(X)
818
819 self.X = X
820 self.X.retain_grad()
821
822 self.conv1_out = self.conv1(self.X)
823 self.conv1_out.retain_grad()
824
825 self.act_fn1_out = self.act_fn(self.conv1_out)
826 self.act_fn1_out.retain_grad()
827
828 self.batchnorm1_out = self.batchnorm1(self.act_fn1_out)
829 self.batchnorm1_out.retain_grad()
830
831 self.conv2_out = self.conv2(self.batchnorm1_out)
832 self.conv2_out.retain_grad()
833
834 self.batchnorm2_out = self.batchnorm2(self.conv2_out)
835 self.batchnorm2_out.retain_grad()
836
837 self.c_skip_out = self.conv_skip(self.X)
838 self.c_skip_out.retain_grad()
839
840 self.bn_skip_out = self.batchnorm_skip(self.c_skip_out)
841 self.bn_skip_out.retain_grad()
842
843 self.layer3_in = self.batchnorm2_out + self.bn_skip_out
844 self.layer3_in.retain_grad()
845
846 self.Y = self.act_fn(self.layer3_in)
847 self.Y.retain_grad()
848
849 def extract_grads(self, X):
850 self.forward(X)

Callers 1

extract_gradsMethod · 0.95

Calls 2

torchifyFunction · 0.85
act_fnMethod · 0.80

Tested by

no test coverage detected