MCPcopy
hub / github.com/ddbourgin/numpy-ml / forward

Method forward

numpy_ml/tests/nn_torch_models.py:618–648  ·  view source on GitHub ↗
(self, X_main, X_skip)

Source from the content-addressed store, hash-verified

616 assert self.conv_1x1.bias.shape == b.flatten().shape
617
618 def forward(self, X_main, X_skip):
619 # (N, W, C) -> (N, C, W)
620 self.X_main = np.moveaxis(X_main, [0, 1, 2], [0, -1, -2])
621 self.X_main = torchify(self.X_main)
622 self.X_main.retain_grad()
623
624 self.conv_dilation_out = self.conv_dilation(self.X_main)
625 self.conv_dilation_out.retain_grad()
626
627 self.tanh_out = torch.tanh(self.conv_dilation_out)
628 self.sigm_out = torch.sigmoid(self.conv_dilation_out)
629
630 self.tanh_out.retain_grad()
631 self.sigm_out.retain_grad()
632
633 self.multiply_gate_out = self.tanh_out * self.sigm_out
634 self.multiply_gate_out.retain_grad()
635
636 self.conv_1x1_out = self.conv_1x1(self.multiply_gate_out)
637 self.conv_1x1_out.retain_grad()
638
639 self.X_skip = torch.zeros_like(self.conv_1x1_out)
640 if X_skip is not None:
641 self.X_skip = torchify(np.moveaxis(X_skip, [0, 1, 2], [0, -1, -2]))
642 self.X_skip.retain_grad()
643
644 self.Y_skip = self.X_skip + self.conv_1x1_out
645 self.Y_main = self.X_main + self.conv_1x1_out
646
647 self.Y_skip.retain_grad()
648 self.Y_main.retain_grad()
649
650 def extract_grads(self, X_main, X_skip):
651 self.forward(X_main, X_skip)

Callers 1

extract_gradsMethod · 0.95

Calls 1

torchifyFunction · 0.85

Tested by

no test coverage detected