MCPcopy
hub / github.com/ddbourgin/numpy-ml / extract_grads

Method extract_grads

numpy_ml/tests/nn_torch_models.py:248–298  ·  view source on GitHub ↗
(self, X, Y_true=None)

Source from the content-addressed store, hash-verified

246 self.Y.retain_grad()
247
248 def extract_grads(self, X, Y_true=None):
249 self.forward(X)
250
251 if isinstance(Y_true, np.ndarray):
252 Y_true = np.moveaxis(Y_true, [0, 1, 2, 3], [0, -2, -1, -3])
253 self.loss1 = (
254 0.5 * F.mse_loss(self.Y, torchify(Y_true), size_average=False).sum()
255 )
256 else:
257 self.loss1 = self.Y.sum()
258
259 self.loss1.backward()
260
261 X_np = self.X.detach().numpy()
262 Y_np = self.Y.detach().numpy()
263 dX_np = self.X.grad.numpy()
264 dY_np = self.Y.grad.numpy()
265 intercept_np = self.layer1.bias.detach().numpy()
266 scaler_np = self.layer1.weight.detach().numpy()
267 dIntercept_np = self.layer1.bias.grad.numpy()
268 dScaler_np = self.layer1.weight.grad.numpy()
269
270 if self.X.dim() == 4:
271 orig, X_swap = [0, 1, 2, 3], [0, -1, -3, -2]
272 orig_p, p_swap = [0, 1, 2], [-1, -3, -2]
273 if isinstance(Y_true, np.ndarray):
274 Y_true = np.moveaxis(Y_true, orig, X_swap)
275 X_np = np.moveaxis(X_np, orig, X_swap)
276 Y_np = np.moveaxis(Y_np, orig, X_swap)
277 dX_np = np.moveaxis(dX_np, orig, X_swap)
278 dY_np = np.moveaxis(dY_np, orig, X_swap)
279 scaler_np = np.moveaxis(scaler_np, orig_p, p_swap)
280 intercept_np = np.moveaxis(intercept_np, orig_p, p_swap)
281 dScaler_np = np.moveaxis(dScaler_np, orig_p, p_swap)
282 dIntercept_np = np.moveaxis(dIntercept_np, orig_p, p_swap)
283
284 grads = {
285 "loss": self.loss1.detach().numpy(),
286 "X": X_np,
287 "epsilon": self.layer1.eps,
288 "intercept": intercept_np,
289 "scaler": scaler_np,
290 "y": Y_np,
291 "dLdy": dY_np,
292 "dLdIntercept": dIntercept_np,
293 "dLdScaler": dScaler_np,
294 "dLdX": dX_np,
295 }
296 if isinstance(Y_true, np.ndarray):
297 grads["Y_true"] = Y_true
298 return grads
299
300
301class TorchAddLayer(nn.Module):

Callers 2

test_LayerNorm1DFunction · 0.95
test_LayerNorm2DFunction · 0.95

Calls 3

forwardMethod · 0.95
torchifyFunction · 0.85
backwardMethod · 0.45

Tested by 2

test_LayerNorm1DFunction · 0.76
test_LayerNorm2DFunction · 0.76