MCPcopy
hub / github.com/ddbourgin/numpy-ml / forward

Method forward

numpy_ml/tests/nn_torch_models.py:352–374  ·  view source on GitHub ↗
(self, Xs)

Source from the content-addressed store, hash-verified

350 self.act_fn = act_fn
351
352 def forward(self, Xs):
353 self.Xs = []
354 x = Xs[0].copy()
355 if not isinstance(x, torch.Tensor):
356 x = torchify(x)
357
358 self.prod = x.clone()
359 x.retain_grad()
360 self.Xs.append(x)
361
362 for i in range(1, len(Xs)):
363 x = Xs[i]
364 if not isinstance(x, torch.Tensor):
365 x = torchify(x)
366
367 x.retain_grad()
368 self.Xs.append(x)
369 self.prod *= x
370
371 self.prod.retain_grad()
372 self.Y = self.act_fn(self.prod)
373 self.Y.retain_grad()
374 return self.Y
375
376 def extract_grads(self, X):
377 self.forward(X)

Callers 1

extract_gradsMethod · 0.95

Calls 3

torchifyFunction · 0.85
act_fnMethod · 0.80
copyMethod · 0.45

Tested by

no test coverage detected