MCPcopy
hub / github.com/ddbourgin/numpy-ml / forward

Method forward

numpy_ml/tests/nn_torch_models.py:306–328  ·  view source on GitHub ↗
(self, Xs)

Source from the content-addressed store, hash-verified

304 self.act_fn = act_fn
305
306 def forward(self, Xs):
307 self.Xs = []
308 x = Xs[0].copy()
309 if not isinstance(x, torch.Tensor):
310 x = torchify(x)
311
312 self.sum = x.clone()
313 x.retain_grad()
314 self.Xs.append(x)
315
316 for i in range(1, len(Xs)):
317 x = Xs[i]
318 if not isinstance(x, torch.Tensor):
319 x = torchify(x)
320
321 x.retain_grad()
322 self.Xs.append(x)
323 self.sum += x
324
325 self.sum.retain_grad()
326 self.Y = self.act_fn(self.sum)
327 self.Y.retain_grad()
328 return self.Y
329
330 def extract_grads(self, X):
331 self.forward(X)

Callers 1

extract_gradsMethod · 0.95

Calls 3

torchifyFunction · 0.85
act_fnMethod · 0.80
copyMethod · 0.45

Tested by

no test coverage detected