MCPcopy
hub / github.com/ddbourgin/numpy-ml / forward

Method forward

numpy_ml/tests/nn_torch_models.py:460–488  ·  view source on GitHub ↗
(self, X)

Source from the content-addressed store, hash-verified

458 self.batchnorm2.bias = nn.Parameter(torch.FloatTensor(intercept))
459
460 def forward(self, X):
461 if not isinstance(X, torch.Tensor):
462 # (N, H, W, C) -> (N, C, H, W)
463 X = np.moveaxis(X, [0, 1, 2, 3], [0, -2, -1, -3])
464 X = torchify(X)
465
466 self.X = X
467 self.X.retain_grad()
468
469 self.conv1_out = self.conv1(self.X)
470 self.conv1_out.retain_grad()
471
472 self.act_fn1_out = self.act_fn(self.conv1_out)
473 self.act_fn1_out.retain_grad()
474
475 self.batchnorm1_out = self.batchnorm1(self.act_fn1_out)
476 self.batchnorm1_out.retain_grad()
477
478 self.conv2_out = self.conv2(self.batchnorm1_out)
479 self.conv2_out.retain_grad()
480
481 self.batchnorm2_out = self.batchnorm2(self.conv2_out)
482 self.batchnorm2_out.retain_grad()
483
484 self.layer3_in = self.batchnorm2_out + self.X
485 self.layer3_in.retain_grad()
486
487 self.Y = self.act_fn(self.layer3_in)
488 self.Y.retain_grad()
489
490 def extract_grads(self, X):
491 self.forward(X)

Callers 1

extract_gradsMethod · 0.95

Calls 2

torchifyFunction · 0.85
act_fnMethod · 0.80

Tested by

no test coverage detected